//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Interfaces/DestinationStyleOpInterface.h" using namespace mlir; namespace mlir { #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc" } // namespace mlir OpOperandVector::operator SmallVector() { SmallVector result; result.reserve(this->size()); llvm::transform(*this, std::back_inserter(result), [](OpOperand *opOperand) { return opOperand->get(); }); return result; } namespace { size_t getNumTensorResults(Operation *op) { size_t numTensorResults = 0; for (auto t : op->getResultTypes()) { if (isa(t)) { ++numTensorResults; } } return numTensorResults; } } // namespace LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { DestinationStyleOpInterface dstStyleOp = cast(op); SmallVector outputTensorOperands; for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) { Type type = operand->get().getType(); if (isa(type)) { outputTensorOperands.push_back(operand); } else if (!isa(type)) { return op->emitOpError("expected that operand #") << operand->getOperandNumber() << " is a ranked tensor or a ranked memref"; } } // Verify the number of tensor results matches the number of output tensors. if (getNumTensorResults(op) != outputTensorOperands.size()) return op->emitOpError("expected the number of tensor results (") << getNumTensorResults(op) << ") to be equal to the number of output tensors (" << outputTensorOperands.size() << ")"; for (OpOperand *opOperand : outputTensorOperands) { OpResult result = dstStyleOp.getTiedOpResult(opOperand); if (result.getType() != opOperand->get().getType()) return op->emitOpError("expected type of operand #") << opOperand->getOperandNumber() << " (" << opOperand->get().getType() << ")" << " to match type of corresponding result (" << result.getType() << ")"; } return success(); }