summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorValentin Clement <clementval@gmail.com>2022-11-28 16:49:25 +0100
committerValentin Clement <clementval@gmail.com>2022-11-28 16:49:51 +0100
commitafb34cf3077a38007fcebe17dc384532207283fa (patch)
tree15e052536d66414eaade0e6ec9f12dddf83fd719
parent50caf6936ba91b4cc45ffa4e3591f0dcf0c4e387 (diff)
downloadllvm-afb34cf3077a38007fcebe17dc384532207283fa.tar.gz
[flang] Hanlde disptach op in abstract result pass
Update the call conversion pattern to support fir.dispatch operation as well. The first operand of fir.dispatch op is always the polymoprhic object. The pass_arg_pos attribute needs to be shifted when the result is added as argument. Reviewed By: jeanPerier Differential Revision: https://reviews.llvm.org/D138799
-rw-r--r--flang/lib/Optimizer/Dialect/FIRType.cpp2
-rw-r--r--flang/lib/Optimizer/Transforms/AbstractResult.cpp116
-rw-r--r--flang/test/Fir/abstract-results.fir19
3 files changed, 89 insertions, 48 deletions
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index eb9b9afae90e..89a806c0474a 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -960,7 +960,7 @@ bool fir::hasAbstractResult(mlir::FunctionType ty) {
if (ty.getNumResults() == 0)
return false;
auto resultType = ty.getResult(0);
- return resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>();
+ return resultType.isa<fir::SequenceType, fir::BaseBoxType, fir::RecordType>();
}
/// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index dcc6e902fd84..df00c17863d9 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -28,6 +28,8 @@ namespace fir {
#define DEBUG_TYPE "flang-abstract-result-opt"
+using namespace mlir;
+
namespace fir {
namespace {
@@ -40,7 +42,7 @@ static mlir::Type getResultArgumentType(mlir::Type resultType,
return fir::BoxType::get(type);
return fir::ReferenceType::get(type);
})
- .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
+ .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
return fir::ReferenceType::get(type);
})
.Default([](mlir::Type) -> mlir::Type {
@@ -75,16 +77,18 @@ static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
shouldBoxResult;
}
-class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
+template <typename Op>
+class CallConversion : public mlir::OpRewritePattern<Op> {
public:
- using OpRewritePattern::OpRewritePattern;
- CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
- : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
+ : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
+
mlir::LogicalResult
- matchAndRewrite(fir::CallOp callOp,
- mlir::PatternRewriter &rewriter) const override {
- auto loc = callOp.getLoc();
- auto result = callOp->getResult(0);
+ matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto result = op->getResult(0);
if (!result.hasOneUse()) {
mlir::emitError(loc,
"calls with abstract result must have exactly one user");
@@ -109,50 +113,74 @@ public:
// TODO: This should be generalized for derived types, and it is
// architecture and OS dependent.
bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
- fir::CallOp newCallOp;
+ Op newOp;
if (isResultBuiltinCPtr) {
- auto recTy = result.getType().dyn_cast<fir::RecordType>();
+ auto recTy = result.getType().template dyn_cast<fir::RecordType>();
newResultTypes.emplace_back(recTy.getTypeList()[0].second);
}
- if (callOp.getCallee()) {
+
+ // fir::CallOp specific handling.
+ if constexpr (std::is_same_v<Op, fir::CallOp>) {
+ if (op.getCallee()) {
+ llvm::SmallVector<mlir::Value> newOperands;
+ if (!isResultBuiltinCPtr)
+ newOperands.emplace_back(arg);
+ newOperands.append(op.getOperands().begin(), op.getOperands().end());
+ newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
+ newResultTypes, newOperands);
+ } else {
+ // Indirect calls.
+ llvm::SmallVector<mlir::Type> newInputTypes;
+ if (!isResultBuiltinCPtr)
+ newInputTypes.emplace_back(argType);
+ for (auto operand : op.getOperands().drop_front())
+ newInputTypes.push_back(operand.getType());
+ auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
+ newResultTypes);
+
+ llvm::SmallVector<mlir::Value> newOperands;
+ newOperands.push_back(
+ rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
+ if (!isResultBuiltinCPtr)
+ newOperands.push_back(arg);
+ newOperands.append(op.getOperands().begin() + 1,
+ op.getOperands().end());
+ newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
+ newResultTypes, newOperands);
+ }
+ }
+
+ // fir::DispatchOp specific handling.
+ if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
llvm::SmallVector<mlir::Value> newOperands;
if (!isResultBuiltinCPtr)
newOperands.emplace_back(arg);
- newOperands.append(callOp.getOperands().begin(),
- callOp.getOperands().end());
- newCallOp = rewriter.create<fir::CallOp>(loc, *callOp.getCallee(),
- newResultTypes, newOperands);
- } else {
- // Indirect calls.
- llvm::SmallVector<mlir::Type> newInputTypes;
- if (!isResultBuiltinCPtr)
- newInputTypes.emplace_back(argType);
- for (auto operand : callOp.getOperands().drop_front())
- newInputTypes.push_back(operand.getType());
- auto newFuncTy = mlir::FunctionType::get(callOp.getContext(),
- newInputTypes, newResultTypes);
+ unsigned passArgShift = newOperands.size();
+ newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
- llvm::SmallVector<mlir::Value> newOperands;
- newOperands.push_back(rewriter.create<fir::ConvertOp>(
- loc, newFuncTy, callOp.getOperand(0)));
- if (!isResultBuiltinCPtr)
- newOperands.push_back(arg);
- newOperands.append(callOp.getOperands().begin() + 1,
- callOp.getOperands().end());
- newCallOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
- newResultTypes, newOperands);
+ fir::DispatchOp newDispatchOp;
+ if (op.getPassArgPos())
+ newOp = rewriter.create<fir::DispatchOp>(
+ loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
+ op.getOperands()[0], newOperands,
+ rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift));
+ else
+ newOp = rewriter.create<fir::DispatchOp>(
+ loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
+ op.getOperands()[0], newOperands, nullptr);
}
+
if (isResultBuiltinCPtr) {
mlir::Value save = saveResult.getMemref();
- auto module = callOp->getParentOfType<mlir::ModuleOp>();
+ auto module = op->template getParentOfType<mlir::ModuleOp>();
fir::KindMapping kindMap = fir::getKindMapping(module);
FirOpBuilder builder(rewriter, kindMap);
mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
builder, loc, save, result.getType());
- rewriter.create<fir::StoreOp>(loc, newCallOp->getResult(0), saveAddr);
+ rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
}
- callOp->dropAllReferences();
- rewriter.eraseOp(callOp);
+ op->dropAllReferences();
+ rewriter.eraseOp(op);
return mlir::success();
}
@@ -289,17 +317,11 @@ public:
return true;
});
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
- if (dispatch->getNumResults() != 1)
- return true;
- auto resultType = dispatch->getResult(0).getType();
- if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
- TODO(dispatch.getLoc(), "dispatchOp with abstract results");
- return false;
- }
- return true;
+ return !hasAbstractResult(dispatch.getFunctionType());
});
- patterns.insert<CallOpConversion>(context, shouldBoxResult);
+ patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
+ patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
patterns.insert<SaveResultOpConversion>(context);
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
if (mlir::failed(
diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
index 14c59a656974..374c0d18753b 100644
--- a/flang/test/Fir/abstract-results.fir
+++ b/flang/test/Fir/abstract-results.fir
@@ -244,6 +244,25 @@ func.func @_QPtest_return_cptr() {
// FUNC-BOX: fir.store %[[VAL]] to %[[ADDR]] : !fir.ref<i64>
}
+// FUNC-REF-LABEL: func @dispatch(
+// FUNC-REF-SAME: %[[ARG0:.*]]: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}
+// FUNC-BOX-LABEL: func @dispatch(
+// FUNC-BOX-SAME: %[[ARG0:.*]]: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}
+func.func @dispatch(%arg0: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}) {
+ %buffer = fir.alloca !fir.type<t{x:f32}>
+ %res = fir.dispatch "ret_array"(%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) -> !fir.type<t{x:f32}> {pass_arg_pos = 0 : i32}
+ fir.save_result %res to %buffer : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>
+ return
+ // FUNC-REF: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}>
+ // FUNC-REF: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%[[buffer]], %[[ARG0]] : !fir.ref<!fir.type<t{x:f32}>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+ // FUNC-REF-NOT: fir.save_result
+
+ // FUNC-BOX: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}>
+ // FUNC-BOX: %[[box:.*]] = fir.embox %[[buffer]] : (!fir.ref<!fir.type<t{x:f32}>>) -> !fir.box<!fir.type<t{x:f32}>>
+ // FUNC-BOX: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%[[box]], %[[ARG0]] : !fir.box<!fir.type<t{x:f32}>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+ // FUNC-BOX-NOT: fir.save_result
+}
+
// ------------------------ Test fir.address_of rewrite ------------------------
func.func private @takesfuncarray((i32) -> !fir.array<?xf32>)