//===- Bufferize.cpp - Bufferization utilities ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include namespace mlir { namespace bufferization { #define GEN_PASS_DEF_FINALIZINGBUFFERIZE #define GEN_PASS_DEF_BUFFERIZATIONBUFFERIZE #define GEN_PASS_DEF_ONESHOTBUFFERIZE #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" } // namespace bufferization } // namespace mlir #define DEBUG_TYPE "bufferize" using namespace mlir; using namespace mlir::bufferization; //===----------------------------------------------------------------------===// // BufferizeTypeConverter //===----------------------------------------------------------------------===// static Value materializeToTensor(OpBuilder &builder, TensorType type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); assert(isa(inputs[0].getType())); return builder.create(loc, type, inputs[0]); } /// Registers conversions into BufferizeTypeConverter BufferizeTypeConverter::BufferizeTypeConverter() { // Keep all types unchanged. addConversion([](Type type) { return type; }); // Convert RankedTensorType to MemRefType. addConversion([](RankedTensorType type) -> Type { return MemRefType::get(type.getShape(), type.getElementType()); }); // Convert UnrankedTensorType to UnrankedMemRefType. addConversion([](UnrankedTensorType type) -> Type { return UnrankedMemRefType::get(type.getElementType(), 0); }); addArgumentMaterialization(materializeToTensor); addSourceMaterialization(materializeToTensor); addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1 && "expected exactly one input"); if (auto inputType = dyn_cast(inputs[0].getType())) { // MemRef to MemRef cast. assert(inputType != type && "expected different types"); // Unranked to ranked and ranked to unranked casts must be explicit. auto rankedDestType = dyn_cast(type); if (!rankedDestType) return nullptr; FailureOr replacement = castOrReallocMemRefValue(builder, inputs[0], rankedDestType); if (failed(replacement)) return nullptr; return *replacement; } if (isa(inputs[0].getType())) { // Tensor to MemRef cast. return builder.create(loc, type, inputs[0]); } llvm_unreachable("only tensor/memref input types supported"); }); } void mlir::bufferization::populateBufferizeMaterializationLegality( ConversionTarget &target) { target.addLegalOp(); } namespace { // In a finalizing bufferize conversion, we know that all tensors have been // converted to memrefs, thus, this op becomes an identity. class BufferizeToTensorOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOp(op, adaptor.getMemref()); return success(); } }; } // namespace namespace { // In a finalizing bufferize conversion, we know that all tensors have been // converted to memrefs, thus, this op becomes an identity. class BufferizeToMemrefOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOp(op, adaptor.getTensor()); return success(); } }; } // namespace void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); } namespace { struct FinalizingBufferizePass : public bufferization::impl::FinalizingBufferizeBase< FinalizingBufferizePass> { using FinalizingBufferizeBase< FinalizingBufferizePass>::FinalizingBufferizeBase; void runOnOperation() override { auto func = getOperation(); auto *context = &getContext(); BufferizeTypeConverter typeConverter; RewritePatternSet patterns(context); ConversionTarget target(*context); populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); // If all result types are legal, and all block arguments are legal (ensured // by func conversion above), then all types in the program are legal. // // We also check that the operand types are legal to avoid creating invalid // IR. For example, this prevents // populateEliminateBufferizeMaterializationsPatterns from updating the // types of the operands to a return op without updating the enclosing // function. target.markUnknownOpDynamicallyLegal( [&](Operation *op) { return typeConverter.isLegal(op); }); if (failed(applyFullConversion(func, target, std::move(patterns)))) signalPassFailure(); } }; static LayoutMapOption parseLayoutMapOption(const std::string &s) { if (s == "fully-dynamic-layout-map") return LayoutMapOption::FullyDynamicLayoutMap; if (s == "identity-layout-map") return LayoutMapOption::IdentityLayoutMap; if (s == "infer-layout-map") return LayoutMapOption::InferLayoutMap; llvm_unreachable("invalid layout map option"); } static OneShotBufferizationOptions::AnalysisHeuristic parseHeuristicOption(const std::string &s) { if (s == "bottom-up") return OneShotBufferizationOptions::AnalysisHeuristic::BottomUp; if (s == "top-down") return OneShotBufferizationOptions::AnalysisHeuristic::TopDown; llvm_unreachable("invalid analysisheuristic option"); } struct OneShotBufferizePass : public bufferization::impl::OneShotBufferizeBase { OneShotBufferizePass() = default; explicit OneShotBufferizePass(const OneShotBufferizationOptions &options) : options(options) {} void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); registerAllocationOpInterfaceExternalModels(registry); } void runOnOperation() override { OneShotBufferizationOptions opt; if (!options) { // Make new bufferization options if none were provided when creating the // pass. opt.allowReturnAllocs = allowReturnAllocs; opt.allowUnknownOps = allowUnknownOps; opt.analysisFuzzerSeed = analysisFuzzerSeed; opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic); opt.copyBeforeWrite = copyBeforeWrite; opt.createDeallocs = createDeallocs; opt.dumpAliasSets = dumpAliasSets; opt.setFunctionBoundaryTypeConversion( parseLayoutMapOption(functionBoundaryTypeConversion)); if (mustInferMemorySpace) opt.defaultMemorySpace = std::nullopt; opt.printConflicts = printConflicts; opt.testAnalysisOnly = testAnalysisOnly; opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; opt.noAnalysisFuncFilter = noAnalysisFuncFilter; // Configure type converter. LayoutMapOption unknownTypeConversionOption = parseLayoutMapOption(unknownTypeConversion); opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace, const BufferizationOptions &options) { auto tensorType = cast(value.getType()); if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap) return bufferization::getMemRefTypeWithStaticIdentityLayout( tensorType, memorySpace); assert(unknownTypeConversionOption == LayoutMapOption::FullyDynamicLayoutMap && "invalid layout map option"); return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); }; // Configure op filter. OpFilter::Entry::FilterFn filterFn = [&](Operation *op) { // Filter may be specified via options. if (this->dialectFilter.hasValue()) return llvm::is_contained(this->dialectFilter, op->getDialect()->getNamespace()); // No filter specified: All other ops are allowed. return true; }; opt.opFilter.allowOperation(filterFn); } else { opt = *options; } BufferizationStatistics statistics; ModuleOp moduleOp = getOperation(); if (opt.bufferizeFunctionBoundaries) { if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) { signalPassFailure(); return; } } else { assert(opt.noAnalysisFuncFilter.empty() && "invalid combination of bufferization flags"); if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) { signalPassFailure(); return; } } // Set pass statistics. this->numBufferAlloc = statistics.numBufferAlloc; this->numBufferDealloc = statistics.numBufferDealloc; this->numTensorInPlace = statistics.numTensorInPlace; this->numTensorOutOfPlace = statistics.numTensorOutOfPlace; if (opt.testAnalysisOnly) return; OpPassManager cleanupPipeline("builtin.module"); cleanupPipeline.addPass(createCanonicalizerPass()); cleanupPipeline.addPass(createCSEPass()); cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); (void)runPipeline(cleanupPipeline, moduleOp); } private: std::optional options; }; } // namespace namespace { struct BufferizationBufferizePass : public bufferization::impl::BufferizationBufferizeBase< BufferizationBufferizePass> { void runOnOperation() override { BufferizationOptions options = getPartialBufferizationOptions(); options.opFilter.allowDialect(); if (failed(bufferizeOp(getOperation(), options))) signalPassFailure(); } void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); } }; } // namespace std::unique_ptr mlir::bufferization::createBufferizationBufferizePass() { return std::make_unique(); } std::unique_ptr mlir::bufferization::createOneShotBufferizePass() { return std::make_unique(); } std::unique_ptr mlir::bufferization::createOneShotBufferizePass( const OneShotBufferizationOptions &options) { return std::make_unique(options); } std::unique_ptr> mlir::bufferization::createFinalizingBufferizePass() { return std::make_unique(); } //===----------------------------------------------------------------------===// // BufferizableOpInterface-based Bufferization //===----------------------------------------------------------------------===// static bool isaTensor(Type t) { return isa(t); } /// Return true if the given op has a tensor result or a tensor operand. static bool hasTensorSemantics(Operation *op) { if (auto funcOp = dyn_cast(op)) { bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); return hasTensorArg || hasTensorResult; } bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); return hasTensorResult || hasTensorOperand; } namespace { /// A rewriter that keeps track of extra information during bufferization. class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { public: BufferizationRewriter(MLIRContext *ctx, DenseSet &erasedOps, DenseSet &toMemrefOps, SmallVector &worklist, const BufferizationOptions &options, const OpFilter *opFilter, BufferizationStatistics *statistics) : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), worklist(worklist), analysisState(options), opFilter(opFilter), statistics(statistics) { setListener(this); } protected: void notifyOperationRemoved(Operation *op) override { // TODO: Walk can be removed when D144193 has landed. op->walk([&](Operation *op) { erasedOps.insert(op); // Erase if present. toMemrefOps.erase(op); }); } void notifyOperationInserted(Operation *op) override { erasedOps.erase(op); // Gather statistics about allocs and deallocs. if (statistics) { if (auto sideEffectingOp = dyn_cast(op)) { statistics->numBufferAlloc += static_cast( sideEffectingOp.hasEffect()); statistics->numBufferDealloc += static_cast( sideEffectingOp.hasEffect()); } } // Keep track of to_memref ops. if (isa(op)) { toMemrefOps.insert(op); return; } // Skip to_tensor ops. if (isa(op)) return; // Skip non-tensor ops. if (!hasTensorSemantics(op)) return; // Skip ops that are not allowed to be bufferized. auto const &options = analysisState.getOptions(); if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op))) return; // Add op to worklist. worklist.push_back(op); } private: /// A set of all erased ops. DenseSet &erasedOps; /// A set of all to_memref ops. DenseSet &toMemrefOps; /// The worklist of ops to be bufferized. SmallVector &worklist; /// The analysis state. Used for debug assertions and access to the /// bufferization options. const AnalysisState analysisState; /// An extra op filter for bufferization. const OpFilter *opFilter; /// Bufferization statistics for debugging. BufferizationStatistics *statistics; }; } // namespace LogicalResult bufferization::bufferizeOp(Operation *op, const BufferizationOptions &options, bool copyBeforeWrite, const OpFilter *opFilter, BufferizationStatistics *statistics) { if (copyBeforeWrite) { AnalysisState state(options); if (failed(insertTensorCopies(op, state))) return failure(); } // Keep track of to_memref ops. DenseSet toMemrefOps; op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); }); // Gather all bufferizable ops in top-to-bottom order. // // We should ideally know the exact memref type of all operands when // bufferizing an op. (This is the case when bufferizing top-to-bottom.) // Otherwise, we have to use a memref type with a fully dynamic layout map to // avoid copies. We are currently missing patterns for layout maps to // canonicalize away (or canonicalize to more precise layouts). // // FuncOps must be bufferized before their bodies, so add them to the worklist // first. SmallVector worklist; op->walk([&](func::FuncOp funcOp) { if (hasTensorSemantics(funcOp)) worklist.push_back(funcOp); }); op->walk([&](Operation *op) { if (hasTensorSemantics(op) && !isa(op)) worklist.push_back(op); }); // Keep track of all erased ops. DenseSet erasedOps; // Bufferize all ops. BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, worklist, options, opFilter, statistics); for (unsigned i = 0; i < worklist.size(); ++i) { Operation *nextOp = worklist[i]; // Skip ops that were erased. if (erasedOps.contains(nextOp)) continue; // Skip ops that are not bufferizable or not allowed. auto bufferizableOp = options.dynCastBufferizableOp(nextOp); if (!bufferizableOp) continue; if (opFilter && !opFilter->isOpAllowed(nextOp)) continue; // Skip ops that no longer have tensor semantics. if (!hasTensorSemantics(nextOp)) continue; // Bufferize the op. LLVM_DEBUG(llvm::dbgs() << "//===-------------------------------------------===//\n" << "IR after bufferizing: " << nextOp->getName() << "\n"); rewriter.setInsertionPoint(nextOp); if (failed(bufferizableOp.bufferize(rewriter, options))) { LLVM_DEBUG(llvm::dbgs() << "failed to bufferize\n" << "//===-------------------------------------------===//\n"); return nextOp->emitError("failed to bufferize op"); } LLVM_DEBUG(llvm::dbgs() << *op << "\n//===-------------------------------------------===//\n"); } // Fold all to_memref(to_tensor(x)) pairs. for (Operation *op : toMemrefOps) { rewriter.setInsertionPoint(op); (void)bufferization::foldToMemrefToTensorPair(rewriter, cast(op)); } // Remove all dead to_tensor ops. op->walk([&](ToTensorOp toTensorOp) { if (toTensorOp->getUses().empty()) { rewriter.eraseOp(toTensorOp); return WalkResult::skip(); } return WalkResult::advance(); }); /// Check the result of bufferization. Return an error if an op was not /// bufferized, unless partial bufferization is allowed. if (options.allowUnknownOps) return success(); for (Operation *op : worklist) { // Skip ops that are entirely gone. if (erasedOps.contains(op)) continue; // Ops that no longer have tensor semantics (because they were updated // in-place) are allowed. if (!hasTensorSemantics(op)) continue; // Continue ops that are not allowed. if (!options.isOpAllowed(op)) continue; if (opFilter && !opFilter->isOpAllowed(op)) continue; // Ops without any uses and no side effects will fold away. if (op->getUses().empty() && isMemoryEffectFree(op)) continue; // ToTensorOps/ToMemrefOps are allowed in the output. if (isa(op)) continue; return op->emitError("op was not bufferized"); } return success(); } BufferizationOptions bufferization::getPartialBufferizationOptions() { BufferizationOptions options; options.allowUnknownOps = true; options.createDeallocs = false; options.enforceAliasingInvariants = false; options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, const BufferizationOptions &options) { return getMemRefTypeWithStaticIdentityLayout( cast(value.getType()), memorySpace); }; options.opFilter.allowDialect(); return options; }