summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStanislav Funiak <stano@cerebras.net>2022-01-04 08:03:29 +0530
committerUday Bondhugula <uday@polymagelabs.com>2022-01-04 08:03:45 +0530
commitb4130e9eadfe46b4d3380c40ce8c3e900a0fd21b (patch)
treef8358248b443205dc06c6eaf5ec904071233ce4c
parent138803e017739c81b43b73631c7096bfc4d097d8 (diff)
downloadllvm-b4130e9eadfe46b4d3380c40ce8c3e900a0fd21b.tar.gz
[MLIR][PDL] Integration test of multi-root matching and related fixes.
This diff adds an integration test to multi-root PDL matching. It consists of two subtests: 1) A 1-layer perceptron with split forward / backward operations. 2) A 2-layer perceptron with fused forward / backward operations. These tests use a collection of hand-written patterns and TensorFlow operations to be matched. The first test has a DAG / SSA dominant resulting match; the second does not and is therefore stored in a graph region. This diff also includes two bug fixes: 1) Mark the pdl_interp dialect as a dependent in the TestPDLByteCodePass. This is needed, because we create ops from that dialect as a part of the PDL-to-PDLInterp lowering. 2) Fix of the starting index in the liveness range for the ForEach operations (bug exposed by the integration test). Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D116082
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp24
-rw-r--r--mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir294
-rw-r--r--mlir/test/lib/Rewrite/TestPDLByteCode.cpp6
3 files changed, 318 insertions, 6 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 765c47b2ed0c..d6a07f9067fe 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -551,10 +551,22 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
// finding the minimal number of overlapping live ranges. This is essentially
// a simplified form of register allocation where we don't necessarily have a
// limited number of registers, but we still want to minimize the number used.
- DenseMap<Operation *, unsigned> opToIndex;
- matcherFunc.getBody().walk([&](Operation *op) {
- opToIndex.insert(std::make_pair(op, opToIndex.size()));
- });
+ DenseMap<Operation *, unsigned> opToFirstIndex;
+ DenseMap<Operation *, unsigned> opToLastIndex;
+
+ // A custom walk that marks the first and the last index of each operation.
+ // The entry marks the beginning of the liveness range for this operation,
+ // followed by nested operations, followed by the end of the liveness range.
+ unsigned index = 0;
+ llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
+ opToFirstIndex.try_emplace(op, index++);
+ for (Region &region : op->getRegions())
+ for (Block &block : region.getBlocks())
+ for (Operation &nested : block)
+ walk(&nested);
+ opToLastIndex.try_emplace(op, index++);
+ };
+ walk(matcherFunc);
// Liveness info for each of the defs within the matcher.
ByteCodeLiveRange::Allocator allocator;
@@ -578,8 +590,8 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
// Set indices for the range of this block that the value is used.
auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
defRangeIt->second.liveness->insert(
- opToIndex[firstUseOrDef],
- opToIndex[info->getEndOperation(value, firstUseOrDef)],
+ opToFirstIndex[firstUseOrDef],
+ opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
/*dummyValue*/ 0);
// Check to see if this value is a range type.
diff --git a/mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir b/mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir
new file mode 100644
index 000000000000..be496ed3a675
--- /dev/null
+++ b/mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir
@@ -0,0 +1,294 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -test-pdl-bytecode-pass -split-input-file | FileCheck %s
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// 1-layer perceptron with split fwd/bwd operations
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ // fc_fwd
+ pdl.pattern : benefit(1) {
+ %in_type = pdl.type
+ %out_type = pdl.type
+ %weight_type = pdl.type
+ %rxact = pdl.operand : %in_type
+ %weight = pdl.operand : %weight_type
+
+ %attr0 = pdl.attribute false
+ %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
+
+ pdl.rewrite %op0 {
+ %op1 = pdl.operation "kernel.FcFwd" (%rxact, %weight : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+ %val1 = pdl.result 0 of %op1 // txact
+ pdl.replace %op0 with (%val1 : !pdl.value) // tf.MatMul
+ }
+ }
+
+ // fc_bwd
+ pdl.pattern : benefit(4) {
+ %in_type = pdl.type
+ %out_type = pdl.type
+ %weight_type = pdl.type
+ %const_type = pdl.type
+ %rxact = pdl.operand : %in_type
+ %rxdelta = pdl.operand : %out_type
+ %weight = pdl.operand : %weight_type
+
+ %attr0 = pdl.attribute true
+ %attr1 = pdl.attribute false
+ %op0 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%weight_type : !pdl.type)
+ %val0 = pdl.result 0 of %op0
+ %op1 = pdl.operation "tf.Const" -> (%const_type : !pdl.type)
+ %val1 = pdl.result 0 of %op1
+ %op2 = pdl.operation "tf.Mul" (%val0, %val1 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
+ %val2 = pdl.result 0 of %op2
+ %op3 = pdl.operation "tf.Sub" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
+
+ pdl.rewrite %op3 {
+ %op4 = pdl.operation "kernel.FcBwd" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
+ %val4 = pdl.result 0 of %op4 // weight_out
+ pdl.replace %op3 with (%val4 : !pdl.value) // tf.Sub
+ pdl.erase %op2 // tf.Mul
+ pdl.erase %op1 // tf.Const
+ pdl.erase %op0 // tf.MatMul
+ }
+ }
+
+ // softmax_cross_entropy
+ pdl.pattern : benefit(6) {
+ %in_type = pdl.type
+ %label_type = pdl.type
+ %loss_type = pdl.type
+ %mean_loss_type = pdl.type
+ %mean_const_type = pdl.type
+ %mul_const_type = pdl.type
+ %rxact = pdl.operand : %in_type
+ %rxlabel = pdl.operand : %label_type
+
+ %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type)
+ %val0_0 = pdl.result 0 of %op0 // loss
+ %val0_1 = pdl.result 1 of %op0 // gradient
+ %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type)
+ %val1 = pdl.result 0 of %op1
+ %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type)
+ %val2 = pdl.result 0 of %op2
+ %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type)
+ %val3 = pdl.result 0 of %op3
+ %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type)
+ %val4 = pdl.result 0 of %op4
+ %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type)
+
+ pdl.rewrite { // roots: %op2, %op5
+ %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type)
+ %val6_0 = pdl.result 0 of %op6 // txloss
+ %val6_1 = pdl.result 1 of %op6 // txdelta
+ pdl.replace %op5 with (%val6_1 : !pdl.value) // tf.Mul
+ pdl.erase %op4 // tf.Const
+ pdl.erase %op3 // tf.PreventGradient
+ pdl.replace %op2 with (%val6_0 : !pdl.value) // tf.Mean
+ pdl.erase %op1 // tf.Const
+ pdl.erase %op0 // tf.SparseSoftmaxCrossEntropyWithLogits
+ }
+ }
+}
+
+// CHECK-LABEL: test.mlp_split
+// CHECK: %[[FWD:.*]] = "kernel.FcFwd"(%arg0, %arg2) : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32>
+// CHECK: %[[SM:.*]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FWD]], %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>)
+// CHECK: %[[BWD:.*]] = "kernel.FcBwd"(%arg0, %[[SM]]#1, %arg2) : (tensor<2x20xf32>, tensor<2x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32>
+// CHECK: return %[[SM:.*]]#0, %[[BWD]] : tensor<f32>, tensor<20x10xf32>
+module @ir attributes { test.mlp_split } {
+ func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<20x10xf32>) -> (tensor<f32>, tensor<20x10xf32>) {
+ %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
+ %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
+ %3 = "tf.MatMul"(%arg0, %arg2) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32>
+ %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%3, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>)
+ %4 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
+ %5 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32>
+ %6 = "tf.Mul"(%5, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32>
+ %7 = "tf.MatMul"(%arg0, %6) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x10xf32>) -> tensor<20x10xf32>
+ %8 = "tf.Mul"(%7, %1) : (tensor<20x10xf32>, tensor<f32>) -> tensor<20x10xf32>
+ %9 = "tf.Sub"(%arg2, %8) : (tensor<20x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32>
+ return %4, %9 : tensor<f32>, tensor<20x10xf32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// 2-layer perceptron with fused fwd/bwd operations
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+
+ // gradient descent
+ pdl.pattern : benefit(3) {
+ %const_type = pdl.type
+ %param_type = pdl.type
+ %param = pdl.operand : %param_type
+ %gradient = pdl.operand : %param_type
+
+ %attr0 = pdl.attribute
+ %op0 = pdl.operation "tf.Const" {"value" = %attr0} -> (%const_type : !pdl.type)
+ %val0 = pdl.result 0 of %op0
+ %op1 = pdl.operation "tf.Mul" (%gradient, %val0 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
+ %val1 = pdl.result 0 of %op1
+ %op2 = pdl.operation "tf.Sub" (%param, %val1 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
+
+ pdl.rewrite %op2 {
+ %op3 = pdl.operation "kernel.GD" (%param, %gradient : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
+ %val3 = pdl.result 0 of %op3
+ pdl.replace %op2 with (%val3 : !pdl.value) // tf.Sub
+ pdl.erase %op1 // tf.Mul
+ }
+ }
+
+ // first FC
+ pdl.pattern : benefit(8) {
+ %in_type = pdl.type
+ %out_type = pdl.type
+ %weight_type = pdl.type
+ %bias_type = pdl.type
+ %rxact = pdl.operand : %in_type
+ %rxdelta = pdl.operand : %out_type
+ %weight = pdl.operand : %weight_type
+ %bias = pdl.operand : %bias_type
+
+ %attr0 = pdl.attribute false
+ %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
+ %val0 = pdl.result 0 of %op0
+ %op1 = pdl.operation "tf.BiasAdd" (%val0, %bias : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+ %val1 = pdl.result 0 of %op1
+ %op2 = pdl.operation "tf.Relu" (%val1 : !pdl.value) -> (%out_type : !pdl.type)
+ %val2 = pdl.result 0 of %op2
+ %op3 = pdl.operation "tf.ReluGrad" (%rxdelta, %val2 : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+ %val3 = pdl.result 0 of %op3
+ %attr1 = pdl.attribute true
+ %op4 = pdl.operation "tf.MatMul" (%rxact, %val3 : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type)
+ %val4 = pdl.result 0 of %op4
+ %op5 = pdl.operation "kernel.GD" (%weight, %val4 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
+ %op6 = pdl.operation "tf.BiasAddGrad" (%val3 : !pdl.value) -> (%bias_type : !pdl.type)
+ %val6 = pdl.result 0 of %op6
+ %op7 = pdl.operation "kernel.GD" (%bias, %val6 : !pdl.value, !pdl.value) -> (%bias_type : !pdl.type)
+
+ pdl.rewrite { // roots: %op2, %op5, %op7
+ %op8 = pdl.operation "kernel.FcWithBias" (%rxact, %rxdelta, %weight, %bias : !pdl.value, !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %weight_type, %bias_type : !pdl.type, !pdl.type, !pdl.type)
+ %val8_0 = pdl.result 0 of %op8 // txact
+ %val8_1 = pdl.result 1 of %op8 // weight_out
+ %val8_2 = pdl.result 2 of %op8 // bias_out
+ pdl.replace %op7 with (%val8_2 : !pdl.value) // kernel.GD
+ pdl.erase %op6 // tf.BiasAddGrad
+ pdl.replace %op5 with (%val8_1 : !pdl.value) // kernel.GD
+ pdl.erase %op4 // tf.MatMul
+ pdl.erase %op3 // tf.ReluGrad
+ pdl.replace %op2 with (%val8_0 : !pdl.value) // tf.Relu
+ pdl.erase %op1 // tf.BiasAdd
+ pdl.erase %op0 // tf.MatMul
+ }
+ }
+
+ // second FC
+ pdl.pattern : benefit(4) {
+ %in_type = pdl.type
+ %out_type = pdl.type
+ %weight_type = pdl.type
+ %rxact = pdl.operand : %in_type
+ %rxdelta = pdl.operand : %out_type
+ %weight = pdl.operand : %weight_type
+
+ %attr0 = pdl.attribute false
+ %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
+ %attr1 = pdl.attribute true
+ %op1 = pdl.operation "tf.MatMul" (%rxdelta, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%in_type : !pdl.type)
+ %op2 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type)
+ %val2 = pdl.result 0 of %op2
+ %op3 = pdl.operation "kernel.GD" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
+
+ pdl.rewrite { // roots: %op0, %op1, %op3
+ %op4 = pdl.operation "kernel.Fc" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %in_type, %weight_type : !pdl.type, !pdl.type, !pdl.type)
+ %val4_0 = pdl.result 0 of %op4 // txact
+ %val4_1 = pdl.result 1 of %op4 // txdelta
+ %val4_2 = pdl.result 2 of %op4 // weight_out
+ pdl.replace %op3 with (%val4_2 : !pdl.value) // Sgd
+ pdl.erase %op2 // tf.MatMul
+ pdl.replace %op1 with (%val4_1 : !pdl.value) // tf.MatMul
+ pdl.replace %op0 with (%val4_0 : !pdl.value) // tf.MatMul
+ }
+ }
+
+ // softmax_cross_entropy
+ pdl.pattern : benefit(6) {
+ %in_type = pdl.type
+ %label_type = pdl.type
+ %loss_type = pdl.type
+ %mean_loss_type = pdl.type
+ %mean_const_type = pdl.type
+ %mul_const_type = pdl.type
+ %rxact = pdl.operand : %in_type
+ %rxlabel = pdl.operand : %label_type
+
+ %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type)
+ %val0_0 = pdl.result 0 of %op0 // loss
+ %val0_1 = pdl.result 1 of %op0 // gradient
+ %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type)
+ %val1 = pdl.result 0 of %op1
+ %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type)
+ %val2 = pdl.result 0 of %op2
+ %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type)
+ %val3 = pdl.result 0 of %op3
+ %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type)
+ %val4 = pdl.result 0 of %op4
+ %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type)
+
+ pdl.rewrite { // roots: %op2, %op5
+ %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type)
+ %val6_0 = pdl.result 0 of %op6 // txloss
+ %val6_1 = pdl.result 1 of %op6 // txdelta
+ pdl.replace %op5 with (%val6_1 : !pdl.value) // tf.Mul
+ pdl.erase %op4 // tf.Const
+ pdl.erase %op3 // tf.PreventGradient
+ pdl.replace %op2 with (%val6_0 : !pdl.value) // tf.Mean
+ pdl.erase %op1 // tf.Const
+ pdl.erase %op0 // tf.SparseSoftmaxCrossEntropyWithLogits
+ }
+ }
+}
+
+// CHECK-LABEL: test.mlp_fused
+// CHECK: %[[FC2:.*]]:3 = "kernel.Fc"(%[[FC1:.*]]#0, %[[SM:.*]]#1, %arg4) : (tensor<2x256xf32>, tensor<2x10xf32>, tensor<256x10xf32>) -> (tensor<2x10xf32>, tensor<2x256xf32>, tensor<256x10xf32>)
+// CHECK: %[[SM]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FC2]]#0, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>)
+// CHECK: %[[FC1]]:3 = "kernel.FcWithBias"(%arg0, %[[FC2]]#1, %arg3, %arg2) : (tensor<2x20xf32>, tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>) -> (tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>)
+module @ir attributes { test.mlp_fused } {
+ func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<256xf32>, %arg3: tensor<20x256xf32>, %arg4: tensor<256x10xf32>) -> () { // tensor<f32>, tensor<256xf32>, tensor<20x256xf32>, tensor<256x10xf32>) {
+ // The replacement operations fuse forward and backward pass; therefore, the
+ // resulting graph is not a DAG. To address this, we wrap the operations in
+ // a graph region.
+ "test.graph_region"() ({
+ %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
+ %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
+ %3 = "tf.MatMul"(%arg0, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x256xf32>) -> tensor<2x256xf32>
+ %4 = "tf.BiasAdd"(%3, %arg2) {data_format = "NHWC"} : (tensor<2x256xf32>, tensor<256xf32>) -> tensor<2x256xf32>
+ %5 = "tf.Relu"(%4) : (tensor<2x256xf32>) -> tensor<2x256xf32>
+ %6 = "tf.MatMul"(%5, %arg4) {transpose_a = false, transpose_b = false} : (tensor<2x256xf32>, tensor<256x10xf32>) -> tensor<2x10xf32>
+ %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%6, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>)
+ %7 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
+ %8 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32>
+ %9 = "tf.Mul"(%8, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32>
+ %10 = "tf.MatMul"(%9, %arg4) {transpose_a = false, transpose_b = true} : (tensor<2x10xf32>, tensor<256x10xf32>) -> tensor<2x256xf32>
+ %11 = "tf.MatMul"(%5, %9) {transpose_a = true, transpose_b = false} : (tensor<2x256xf32>, tensor<2x10xf32>) -> tensor<256x10xf32>
+ %12 = "tf.ReluGrad"(%10, %5) : (tensor<2x256xf32>, tensor<2x256xf32>) -> tensor<2x256xf32>
+ %13 = "tf.BiasAddGrad"(%12) {data_format = "NHWC"} : (tensor<2x256xf32>) -> tensor<256xf32>
+ %14 = "tf.MatMul"(%arg0, %12) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x256xf32>) -> tensor<20x256xf32>
+ %15 = "tf.Mul"(%14, %1) : (tensor<20x256xf32>, tensor<f32>) -> tensor<20x256xf32>
+ %16 = "tf.Sub"(%arg3, %15) : (tensor<20x256xf32>, tensor<20x256xf32>) -> tensor<20x256xf32>
+ %17 = "tf.Mul"(%13, %1) : (tensor<256xf32>, tensor<f32>) -> tensor<256xf32>
+ %18 = "tf.Sub"(%arg2, %17) : (tensor<256xf32>, tensor<256xf32>) -> tensor<256xf32>
+ %19 = "tf.Mul"(%11, %1) : (tensor<256x10xf32>, tensor<f32>) -> tensor<256x10xf32>
+ %20 = "tf.Sub"(%arg4, %19) : (tensor<256x10xf32>, tensor<256x10xf32>) -> tensor<256x10xf32>
+ }) : () -> ()
+ return
+ }
+}
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index ef62d73978d8..748e54822718 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -75,6 +76,11 @@ struct TestPDLByteCodePass
StringRef getDescription() const final {
return "Test PDL ByteCode functionality";
}
+ void getDependentDialects(DialectRegistry &registry) const override {
+ // Mark the pdl_interp dialect as a dependent. This is needed, because we
+ // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
+ registry.insert<pdl_interp::PDLInterpDialect>();
+ }
void runOnOperation() final {
ModuleOp module = getOperation();