//===- PassManagerTest.cpp - PassManager unit tests -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Pass/PassManager.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "gtest/gtest.h" #include using namespace mlir; using namespace mlir::detail; namespace { /// Analysis that operates on any operation. struct GenericAnalysis { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis) GenericAnalysis(Operation *op) : isFunc(isa(op)) {} const bool isFunc; }; /// Analysis that operates on a specific operation. struct OpSpecificAnalysis { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis) OpSpecificAnalysis(func::FuncOp op) : isSecret(op.getName() == "secret") {} const bool isSecret; }; /// Simple pass to annotate a func::FuncOp with the results of analysis. struct AnnotateFunctionPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass) void runOnOperation() override { func::FuncOp op = getOperation(); Builder builder(op->getParentOfType()); auto &ga = getAnalysis(); auto &sa = getAnalysis(); op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc)); op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret)); } }; TEST(PassManagerTest, OpSpecificAnalysis) { MLIRContext context; context.loadDialect(); Builder builder(&context); // Create a module with 2 functions. OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); for (StringRef name : {"secret", "not_secret"}) { auto func = func::FuncOp::create( builder.getUnknownLoc(), name, builder.getFunctionType(std::nullopt, std::nullopt)); func.setPrivate(); module->push_back(func); } // Instantiate and run our pass. auto pm = PassManager::on(&context); pm.addNestedPass(std::make_unique()); LogicalResult result = pm.run(module.get()); EXPECT_TRUE(succeeded(result)); // Verify that each function got annotated with expected attributes. for (func::FuncOp func : module->getOps()) { ASSERT_TRUE(isa(func->getAttr("isFunc"))); EXPECT_TRUE(cast(func->getAttr("isFunc")).getValue()); bool isSecret = func.getName() == "secret"; ASSERT_TRUE(isa(func->getAttr("isSecret"))); EXPECT_EQ(cast(func->getAttr("isSecret")).getValue(), isSecret); } } namespace { struct InvalidPass : Pass { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass) InvalidPass() : Pass(TypeID::get(), StringRef("invalid_op")) {} StringRef getName() const override { return "Invalid Pass"; } void runOnOperation() override {} bool canScheduleOn(RegisteredOperationName opName) const override { return true; } /// A clone method to create a copy of this pass. std::unique_ptr clonePass() const override { return std::make_unique( *static_cast(this)); } }; } // namespace TEST(PassManagerTest, InvalidPass) { MLIRContext context; context.allowUnregisteredDialects(); // Create a module OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); // Add a single "invalid_op" operation OpBuilder builder(&module->getBodyRegion()); OperationState state(UnknownLoc::get(&context), "invalid_op"); builder.insert(Operation::create(state)); // Register a diagnostic handler to capture the diagnostic so that we can // check it later. std::unique_ptr diagnostic; context.getDiagEngine().registerHandler([&](Diagnostic &diag) { diagnostic = std::make_unique(std::move(diag)); }); // Instantiate and run our pass. auto pm = PassManager::on(&context); pm.nest("invalid_op").addPass(std::make_unique()); LogicalResult result = pm.run(module.get()); EXPECT_TRUE(failed(result)); ASSERT_TRUE(diagnostic.get() != nullptr); EXPECT_EQ( diagnostic->str(), "'invalid_op' op trying to schedule a pass on an unregistered operation"); // Check that clearing the pass manager effectively removed the pass. pm.clear(); result = pm.run(module.get()); EXPECT_TRUE(succeeded(result)); // Check that adding the pass at the top-level triggers a fatal error. ASSERT_DEATH(pm.addPass(std::make_unique()), "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a " "PassManager intended to run on 'builtin.module', did you " "intend to nest?"); } } // namespace