diff options
author | Valentin Clement <clementval@gmail.com> | 2022-11-28 16:49:25 +0100 |
---|---|---|
committer | Valentin Clement <clementval@gmail.com> | 2022-11-28 16:49:51 +0100 |
commit | afb34cf3077a38007fcebe17dc384532207283fa (patch) | |
tree | 15e052536d66414eaade0e6ec9f12dddf83fd719 | |
parent | 50caf6936ba91b4cc45ffa4e3591f0dcf0c4e387 (diff) | |
download | llvm-afb34cf3077a38007fcebe17dc384532207283fa.tar.gz |
[flang] Hanlde disptach op in abstract result pass
Update the call conversion pattern to support fir.dispatch
operation as well. The first operand of fir.dispatch op is always the
polymoprhic object. The pass_arg_pos attribute needs to be shifted when
the result is added as argument.
Reviewed By: jeanPerier
Differential Revision: https://reviews.llvm.org/D138799
-rw-r--r-- | flang/lib/Optimizer/Dialect/FIRType.cpp | 2 | ||||
-rw-r--r-- | flang/lib/Optimizer/Transforms/AbstractResult.cpp | 116 | ||||
-rw-r--r-- | flang/test/Fir/abstract-results.fir | 19 |
3 files changed, 89 insertions, 48 deletions
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index eb9b9afae90e..89a806c0474a 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -960,7 +960,7 @@ bool fir::hasAbstractResult(mlir::FunctionType ty) { if (ty.getNumResults() == 0) return false; auto resultType = ty.getResult(0); - return resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>(); + return resultType.isa<fir::SequenceType, fir::BaseBoxType, fir::RecordType>(); } /// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp index dcc6e902fd84..df00c17863d9 100644 --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -28,6 +28,8 @@ namespace fir { #define DEBUG_TYPE "flang-abstract-result-opt" +using namespace mlir; + namespace fir { namespace { @@ -40,7 +42,7 @@ static mlir::Type getResultArgumentType(mlir::Type resultType, return fir::BoxType::get(type); return fir::ReferenceType::get(type); }) - .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type { + .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type { return fir::ReferenceType::get(type); }) .Default([](mlir::Type) -> mlir::Type { @@ -75,16 +77,18 @@ static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { shouldBoxResult; } -class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> { +template <typename Op> +class CallConversion : public mlir::OpRewritePattern<Op> { public: - using OpRewritePattern::OpRewritePattern; - CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) - : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} + using mlir::OpRewritePattern<Op>::OpRewritePattern; + + CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {} + mlir::LogicalResult - matchAndRewrite(fir::CallOp callOp, - mlir::PatternRewriter &rewriter) const override { - auto loc = callOp.getLoc(); - auto result = callOp->getResult(0); + matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto result = op->getResult(0); if (!result.hasOneUse()) { mlir::emitError(loc, "calls with abstract result must have exactly one user"); @@ -109,50 +113,74 @@ public: // TODO: This should be generalized for derived types, and it is // architecture and OS dependent. bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); - fir::CallOp newCallOp; + Op newOp; if (isResultBuiltinCPtr) { - auto recTy = result.getType().dyn_cast<fir::RecordType>(); + auto recTy = result.getType().template dyn_cast<fir::RecordType>(); newResultTypes.emplace_back(recTy.getTypeList()[0].second); } - if (callOp.getCallee()) { + + // fir::CallOp specific handling. + if constexpr (std::is_same_v<Op, fir::CallOp>) { + if (op.getCallee()) { + llvm::SmallVector<mlir::Value> newOperands; + if (!isResultBuiltinCPtr) + newOperands.emplace_back(arg); + newOperands.append(op.getOperands().begin(), op.getOperands().end()); + newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(), + newResultTypes, newOperands); + } else { + // Indirect calls. + llvm::SmallVector<mlir::Type> newInputTypes; + if (!isResultBuiltinCPtr) + newInputTypes.emplace_back(argType); + for (auto operand : op.getOperands().drop_front()) + newInputTypes.push_back(operand.getType()); + auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, + newResultTypes); + + llvm::SmallVector<mlir::Value> newOperands; + newOperands.push_back( + rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0))); + if (!isResultBuiltinCPtr) + newOperands.push_back(arg); + newOperands.append(op.getOperands().begin() + 1, + op.getOperands().end()); + newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, + newResultTypes, newOperands); + } + } + + // fir::DispatchOp specific handling. + if constexpr (std::is_same_v<Op, fir::DispatchOp>) { llvm::SmallVector<mlir::Value> newOperands; if (!isResultBuiltinCPtr) newOperands.emplace_back(arg); - newOperands.append(callOp.getOperands().begin(), - callOp.getOperands().end()); - newCallOp = rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), - newResultTypes, newOperands); - } else { - // Indirect calls. - llvm::SmallVector<mlir::Type> newInputTypes; - if (!isResultBuiltinCPtr) - newInputTypes.emplace_back(argType); - for (auto operand : callOp.getOperands().drop_front()) - newInputTypes.push_back(operand.getType()); - auto newFuncTy = mlir::FunctionType::get(callOp.getContext(), - newInputTypes, newResultTypes); + unsigned passArgShift = newOperands.size(); + newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); - llvm::SmallVector<mlir::Value> newOperands; - newOperands.push_back(rewriter.create<fir::ConvertOp>( - loc, newFuncTy, callOp.getOperand(0))); - if (!isResultBuiltinCPtr) - newOperands.push_back(arg); - newOperands.append(callOp.getOperands().begin() + 1, - callOp.getOperands().end()); - newCallOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, - newResultTypes, newOperands); + fir::DispatchOp newDispatchOp; + if (op.getPassArgPos()) + newOp = rewriter.create<fir::DispatchOp>( + loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), + op.getOperands()[0], newOperands, + rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift)); + else + newOp = rewriter.create<fir::DispatchOp>( + loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), + op.getOperands()[0], newOperands, nullptr); } + if (isResultBuiltinCPtr) { mlir::Value save = saveResult.getMemref(); - auto module = callOp->getParentOfType<mlir::ModuleOp>(); + auto module = op->template getParentOfType<mlir::ModuleOp>(); fir::KindMapping kindMap = fir::getKindMapping(module); FirOpBuilder builder(rewriter, kindMap); mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( builder, loc, save, result.getType()); - rewriter.create<fir::StoreOp>(loc, newCallOp->getResult(0), saveAddr); + rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr); } - callOp->dropAllReferences(); - rewriter.eraseOp(callOp); + op->dropAllReferences(); + rewriter.eraseOp(op); return mlir::success(); } @@ -289,17 +317,11 @@ public: return true; }); target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { - if (dispatch->getNumResults() != 1) - return true; - auto resultType = dispatch->getResult(0).getType(); - if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) { - TODO(dispatch.getLoc(), "dispatchOp with abstract results"); - return false; - } - return true; + return !hasAbstractResult(dispatch.getFunctionType()); }); - patterns.insert<CallOpConversion>(context, shouldBoxResult); + patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult); + patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult); patterns.insert<SaveResultOpConversion>(context); patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); if (mlir::failed( diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir index 14c59a656974..374c0d18753b 100644 --- a/flang/test/Fir/abstract-results.fir +++ b/flang/test/Fir/abstract-results.fir @@ -244,6 +244,25 @@ func.func @_QPtest_return_cptr() { // FUNC-BOX: fir.store %[[VAL]] to %[[ADDR]] : !fir.ref<i64> } +// FUNC-REF-LABEL: func @dispatch( +// FUNC-REF-SAME: %[[ARG0:.*]]: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"} +// FUNC-BOX-LABEL: func @dispatch( +// FUNC-BOX-SAME: %[[ARG0:.*]]: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"} +func.func @dispatch(%arg0: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}) { + %buffer = fir.alloca !fir.type<t{x:f32}> + %res = fir.dispatch "ret_array"(%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) -> !fir.type<t{x:f32}> {pass_arg_pos = 0 : i32} + fir.save_result %res to %buffer : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>> + return + // FUNC-REF: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}> + // FUNC-REF: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%[[buffer]], %[[ARG0]] : !fir.ref<!fir.type<t{x:f32}>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32} + // FUNC-REF-NOT: fir.save_result + + // FUNC-BOX: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}> + // FUNC-BOX: %[[box:.*]] = fir.embox %[[buffer]] : (!fir.ref<!fir.type<t{x:f32}>>) -> !fir.box<!fir.type<t{x:f32}>> + // FUNC-BOX: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%[[box]], %[[ARG0]] : !fir.box<!fir.type<t{x:f32}>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32} + // FUNC-BOX-NOT: fir.save_result +} + // ------------------------ Test fir.address_of rewrite ------------------------ func.func private @takesfuncarray((i32) -> !fir.array<?xf32>) |