diff options
Diffstat (limited to 'flang/lib/Optimizer/Transforms/AbstractResult.cpp')
-rw-r--r-- | flang/lib/Optimizer/Transforms/AbstractResult.cpp | 116 |
1 files changed, 69 insertions, 47 deletions
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp index dcc6e902fd84..df00c17863d9 100644 --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -28,6 +28,8 @@ namespace fir { #define DEBUG_TYPE "flang-abstract-result-opt" +using namespace mlir; + namespace fir { namespace { @@ -40,7 +42,7 @@ static mlir::Type getResultArgumentType(mlir::Type resultType, return fir::BoxType::get(type); return fir::ReferenceType::get(type); }) - .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type { + .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { return fir::ReferenceType::get(type); }) .Default([](mlir::Type) -> mlir::Type { @@ -75,16 +77,18 @@ static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { shouldBoxResult; } -class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> { +template <typename Op> +class CallConversion : public mlir::OpRewritePattern<Op> { public: - using OpRewritePattern::OpRewritePattern; - CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) - : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} + using mlir::OpRewritePattern<Op>::OpRewritePattern; + + CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} + mlir::LogicalResult - matchAndRewrite(fir::CallOp callOp, - mlir::PatternRewriter &rewriter) const override { - auto loc = callOp.getLoc(); - auto result = callOp->getResult(0); + matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto result = op->getResult(0); if (!result.hasOneUse()) { mlir::emitError(loc, "calls with abstract result must have exactly one user"); @@ -109,50 +113,74 @@ public: // TODO: This should be generalized for derived types, and it is // architecture and OS dependent. bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); - fir::CallOp newCallOp; + Op newOp; if (isResultBuiltinCPtr) { - auto recTy = result.getType().dyn_cast<fir::RecordType>(); + auto recTy = result.getType().template dyn_cast<fir::RecordType>(); newResultTypes.emplace_back(recTy.getTypeList()[0].second); } - if (callOp.getCallee()) { + + // fir::CallOp specific handling. + if constexpr (std::is_same_v<Op, fir::CallOp>) { + if (op.getCallee()) { + llvm::SmallVector<mlir::Value> newOperands; + if (!isResultBuiltinCPtr) + newOperands.emplace_back(arg); + newOperands.append(op.getOperands().begin(), op.getOperands().end()); + newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), + newResultTypes, newOperands); + } else { + // Indirect calls. + llvm::SmallVector<mlir::Type> newInputTypes; + if (!isResultBuiltinCPtr) + newInputTypes.emplace_back(argType); + for (auto operand : op.getOperands().drop_front()) + newInputTypes.push_back(operand.getType()); + auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, + newResultTypes); + + llvm::SmallVector<mlir::Value> newOperands; + newOperands.push_back( + rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); + if (!isResultBuiltinCPtr) + newOperands.push_back(arg); + newOperands.append(op.getOperands().begin() + 1, + op.getOperands().end()); + newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, + newResultTypes, newOperands); + } + } + + // fir::DispatchOp specific handling. + if constexpr (std::is_same_v<Op, fir::DispatchOp>) { llvm::SmallVector<mlir::Value> newOperands; if (!isResultBuiltinCPtr) newOperands.emplace_back(arg); - newOperands.append(callOp.getOperands().begin(), - callOp.getOperands().end()); - newCallOp = rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), - newResultTypes, newOperands); - } else { - // Indirect calls. - llvm::SmallVector<mlir::Type> newInputTypes; - if (!isResultBuiltinCPtr) - newInputTypes.emplace_back(argType); - for (auto operand : callOp.getOperands().drop_front()) - newInputTypes.push_back(operand.getType()); - auto newFuncTy = mlir::FunctionType::get(callOp.getContext(), - newInputTypes, newResultTypes); + unsigned passArgShift = newOperands.size(); + newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); - llvm::SmallVector<mlir::Value> newOperands; - newOperands.push_back(rewriter.create<fir::ConvertOp>( - loc, newFuncTy, callOp.getOperand(0))); - if (!isResultBuiltinCPtr) - newOperands.push_back(arg); - newOperands.append(callOp.getOperands().begin() + 1, - callOp.getOperands().end()); - newCallOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, - newResultTypes, newOperands); + fir::DispatchOp newDispatchOp; + if (op.getPassArgPos()) + newOp = rewriter.create<fir::DispatchOp>( + loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), + op.getOperands()[0], newOperands, + rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift)); + else + newOp = rewriter.create<fir::DispatchOp>( + loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), + op.getOperands()[0], newOperands, nullptr); } + if (isResultBuiltinCPtr) { mlir::Value save = saveResult.getMemref(); - auto module = callOp->getParentOfType<mlir::ModuleOp>(); + auto module = op->template getParentOfType<mlir::ModuleOp>(); fir::KindMapping kindMap = fir::getKindMapping(module); FirOpBuilder builder(rewriter, kindMap); mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( builder, loc, save, result.getType()); - rewriter.create<fir::StoreOp>(loc, newCallOp->getResult(0), saveAddr); + rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr); } - callOp->dropAllReferences(); - rewriter.eraseOp(callOp); + op->dropAllReferences(); + rewriter.eraseOp(op); return mlir::success(); } @@ -289,17 +317,11 @@ public: return true; }); target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { - if (dispatch->getNumResults() != 1) - return true; - auto resultType = dispatch->getResult(0).getType(); - if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) { - TODO(dispatch.getLoc(), "dispatchOp with abstract results"); - return false; - } - return true; + return !hasAbstractResult(dispatch.getFunctionType()); }); - patterns.insert<CallOpConversion>(context, shouldBoxResult); + patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); + patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); patterns.insert<SaveResultOpConversion>(context); patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); if (mlir::failed( |