//===- TensorToLinalgPass.cpp - Tensor to Linalg Passes -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to convert Tensor dialect to Linalg dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" #include "mlir/Conversion/TensorToLinalg/TensorToLinalg.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" namespace mlir { #define GEN_PASS_DEF_CONVERTTENSORTOLINALG #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { /// A pass converting MLIR Tensor operations into the Linalg dialect. class ConvertTensorToLinalgPass : public impl::ConvertTensorToLinalgBase { void runOnOperation() override { auto &context = getContext(); ConversionTarget target(context); target .addLegalDialect(); target.addIllegalOp(); RewritePatternSet patterns(&context); populateTensorToLinalgPatterns(patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::createConvertTensorToLinalgPass() { return std::make_unique(); }