summaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/Transforms/AbstractResult.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer/Transforms/AbstractResult.cpp')
-rw-r--r--flang/lib/Optimizer/Transforms/AbstractResult.cpp116
1 files changed, 69 insertions, 47 deletions
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index dcc6e902fd84..df00c17863d9 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -28,6 +28,8 @@ namespace fir {
#define DEBUG_TYPE "flang-abstract-result-opt"
+using namespace mlir;
+
namespace fir {
namespace {
@@ -40,7 +42,7 @@ static mlir::Type getResultArgumentType(mlir::Type resultType,
return fir::BoxType::get(type);
return fir::ReferenceType::get(type);
})
- .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
+ .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
return fir::ReferenceType::get(type);
})
.Default([](mlir::Type) -> mlir::Type {
@@ -75,16 +77,18 @@ static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
shouldBoxResult;
}
-class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
+template <typename Op>
+class CallConversion : public mlir::OpRewritePattern<Op> {
public:
- using OpRewritePattern::OpRewritePattern;
- CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
- : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
+ : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
+
mlir::LogicalResult
- matchAndRewrite(fir::CallOp callOp,
- mlir::PatternRewriter &rewriter) const override {
- auto loc = callOp.getLoc();
- auto result = callOp->getResult(0);
+ matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto result = op->getResult(0);
if (!result.hasOneUse()) {
mlir::emitError(loc,
"calls with abstract result must have exactly one user");
@@ -109,50 +113,74 @@ public:
// TODO: This should be generalized for derived types, and it is
// architecture and OS dependent.
bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
- fir::CallOp newCallOp;
+ Op newOp;
if (isResultBuiltinCPtr) {
- auto recTy = result.getType().dyn_cast<fir::RecordType>();
+ auto recTy = result.getType().template dyn_cast<fir::RecordType>();
newResultTypes.emplace_back(recTy.getTypeList()[0].second);
}
- if (callOp.getCallee()) {
+
+ // fir::CallOp specific handling.
+ if constexpr (std::is_same_v<Op, fir::CallOp>) {
+ if (op.getCallee()) {
+ llvm::SmallVector<mlir::Value> newOperands;
+ if (!isResultBuiltinCPtr)
+ newOperands.emplace_back(arg);
+ newOperands.append(op.getOperands().begin(), op.getOperands().end());
+ newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
+ newResultTypes, newOperands);
+ } else {
+ // Indirect calls.
+ llvm::SmallVector<mlir::Type> newInputTypes;
+ if (!isResultBuiltinCPtr)
+ newInputTypes.emplace_back(argType);
+ for (auto operand : op.getOperands().drop_front())
+ newInputTypes.push_back(operand.getType());
+ auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
+ newResultTypes);
+
+ llvm::SmallVector<mlir::Value> newOperands;
+ newOperands.push_back(
+ rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
+ if (!isResultBuiltinCPtr)
+ newOperands.push_back(arg);
+ newOperands.append(op.getOperands().begin() + 1,
+ op.getOperands().end());
+ newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
+ newResultTypes, newOperands);
+ }
+ }
+
+ // fir::DispatchOp specific handling.
+ if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
llvm::SmallVector<mlir::Value> newOperands;
if (!isResultBuiltinCPtr)
newOperands.emplace_back(arg);
- newOperands.append(callOp.getOperands().begin(),
- callOp.getOperands().end());
- newCallOp = rewriter.create<fir::CallOp>(loc, *callOp.getCallee(),
- newResultTypes, newOperands);
- } else {
- // Indirect calls.
- llvm::SmallVector<mlir::Type> newInputTypes;
- if (!isResultBuiltinCPtr)
- newInputTypes.emplace_back(argType);
- for (auto operand : callOp.getOperands().drop_front())
- newInputTypes.push_back(operand.getType());
- auto newFuncTy = mlir::FunctionType::get(callOp.getContext(),
- newInputTypes, newResultTypes);
+ unsigned passArgShift = newOperands.size();
+ newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
- llvm::SmallVector<mlir::Value> newOperands;
- newOperands.push_back(rewriter.create<fir::ConvertOp>(
- loc, newFuncTy, callOp.getOperand(0)));
- if (!isResultBuiltinCPtr)
- newOperands.push_back(arg);
- newOperands.append(callOp.getOperands().begin() + 1,
- callOp.getOperands().end());
- newCallOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
- newResultTypes, newOperands);
+ fir::DispatchOp newDispatchOp;
+ if (op.getPassArgPos())
+ newOp = rewriter.create<fir::DispatchOp>(
+ loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
+ op.getOperands()[0], newOperands,
+ rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift));
+ else
+ newOp = rewriter.create<fir::DispatchOp>(
+ loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
+ op.getOperands()[0], newOperands, nullptr);
}
+
if (isResultBuiltinCPtr) {
mlir::Value save = saveResult.getMemref();
- auto module = callOp->getParentOfType<mlir::ModuleOp>();
+ auto module = op->template getParentOfType<mlir::ModuleOp>();
fir::KindMapping kindMap = fir::getKindMapping(module);
FirOpBuilder builder(rewriter, kindMap);
mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
builder, loc, save, result.getType());
- rewriter.create<fir::StoreOp>(loc, newCallOp->getResult(0), saveAddr);
+ rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
}
- callOp->dropAllReferences();
- rewriter.eraseOp(callOp);
+ op->dropAllReferences();
+ rewriter.eraseOp(op);
return mlir::success();
}
@@ -289,17 +317,11 @@ public:
return true;
});
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
- if (dispatch->getNumResults() != 1)
- return true;
- auto resultType = dispatch->getResult(0).getType();
- if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
- TODO(dispatch.getLoc(), "dispatchOp with abstract results");
- return false;
- }
- return true;
+ return !hasAbstractResult(dispatch.getFunctionType());
});
- patterns.insert<CallOpConversion>(context, shouldBoxResult);
+ patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
+ patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
patterns.insert<SaveResultOpConversion>(context);
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
if (mlir::failed(