//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" using namespace mlir; #define DEBUG_TYPE "memref-transforms" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") //===----------------------------------------------------------------------===// // MemRefMultiBufferOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; IRRewriter rewriter(getContext()); for (Operation *op : state.getPayloadOps(getTarget())) { bool canApplyMultiBuffer = true; auto target = cast(op); LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";); // Skip allocations not used in a loop. for (Operation *user : target->getUsers()) { if (isa(user)) continue; auto loop = user->getParentOfType(); if (!loop) { LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n"; DBGS() << "----due to user: " << *user;); canApplyMultiBuffer = false; break; } } if (!canApplyMultiBuffer) { LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";); continue; } auto newBuffer = memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis()); if (failed(newBuffer)) { LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";); return emitSilenceableFailure(target->getLoc()) << "op failed to multibuffer"; } results.push_back(*newBuffer); } transformResults.set(cast(getResult()), results); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // MemRefExtractAddressComputationsOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefExtractAddressComputationsOp::applyToOne( Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->hasTrait()) { auto diag = this->emitOpError("requires isolated-from-above targets"); diag.attachNote(target->getLoc()) << "non-isolated target"; return DiagnosedSilenceableFailure::definiteFailure(); } MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); memref::populateExtractAddressComputationsPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) return emitDefaultDefiniteFailure(target); results.push_back(target); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // MemRefMakeLoopIndependentOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne( Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Gather IVs. SmallVector ivs; Operation *nextOp = target; for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) { nextOp = nextOp->getParentOfType(); if (!nextOp) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "could not find " << i << "-th enclosing loop"; diag.attachNote(target->getLoc()) << "target op"; return diag; } ivs.push_back(cast(nextOp).getInductionVar()); } // Rewrite IR. IRRewriter rewriter(target->getContext()); FailureOr replacement = failure(); if (auto allocaOp = dyn_cast(target)) { replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs); } else { DiagnosedSilenceableFailure diag = emitSilenceableError() << "unsupported target op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } if (failed(replacement)) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "could not make target op loop-independent"; diag.attachNote(target->getLoc()) << "target op"; return diag; } results.push_back(replacement->getDefiningOp()); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// namespace { class MemRefTransformDialectExtension : public transform::TransformDialectExtension< MemRefTransformDialectExtension> { public: using Base::Base; void init() { declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" >(); } }; } // namespace #define GET_OP_CLASSES #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" void mlir::memref::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); }