summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichal Terepeta <michal.terepeta@gmail.com>2021-11-18 09:41:57 +0100
committerAlex Zinenko <zinenko@google.com>2021-11-18 09:42:57 +0100
commit54c99842079997b0fe208acdab01e540c0d81b51 (patch)
treeecb31cef4e67441899804062e4e2bcb0c7370203
parentb10562612f2e67bda7813ae9586f0afd37dc3c29 (diff)
downloadllvm-54c99842079997b0fe208acdab01e540c0d81b51.tar.gz
[mlir][Python] Fix generation of accessors for Optional
Previously, in case there was only one `Optional` operand/result within the list, we would always return `None` from the accessor, e.g., for a single optional result we would generate: ``` return self.operation.results[0] if len(self.operation.results) > 1 else None ``` But what we really want is to return `None` only if the length of `results` is smaller than the total number of element groups (i.e., the optional operand/result is in fact missing). This commit also renames a few local variables in the generator to make the distinction between `isVariadic()` and `isVariableLength()` a bit more clear. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D113855
-rw-r--r--mlir/python/mlir/dialects/_linalg_ops_ext.py9
-rw-r--r--mlir/test/mlir-tblgen/op-python-bindings.td2
-rw-r--r--mlir/test/python/dialects/linalg/ops.py5
-rw-r--r--mlir/test/python/dialects/python_test.py18
-rw-r--r--mlir/test/python/python_test_ops.td5
-rw-r--r--mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp31
6 files changed, 42 insertions, 28 deletions
diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index b7641c0a4b53..d6c57547ee16 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -36,15 +36,6 @@ class FillOp:
OpView.__init__(self, op)
linalgDialect = Context.current.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, self.operation)
- # TODO: self.result is None. When len(results) == 1 we expect it to be
- # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug
- # in the generator of _linalg_ops_gen.py where we have:
- # ```
- # def result(self):
- # return self.operation.results[0] \
- # if len(self.operation.results) > 1 else None
- # ```
-
class InitTensorOp:
"""Extends the linalg.init_tensor op."""
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index becce13050a1..aa9977e047f1 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -304,7 +304,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: @builtins.property
// CHECK: def optional(self):
- // CHECK: return self.operation.operands[1] if len(self.operation.operands) > 2 else None
+ // CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
}
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index d788292f3424..e5b96c260eaa 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -68,10 +68,7 @@ def testFill():
@builtin.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32))
def fill_tensor(out):
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
- # TODO: FillOp.result is None. When len(results) == 1 we expect it to
- # be results[0] as per _linalg_ops_gen.py. This seems like an
- # orthogonal bug in the generator of _linalg_ops_gen.py.
- return linalg.FillOp(output=out, value=zero).results[0]
+ return linalg.FillOp(output=out, value=zero).result
# CHECK-LABEL: func @fill_buffer
# CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 2267b59cd4d7..f9da91fba4cd 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -207,3 +207,21 @@ def resultTypesDefinedByTraits():
print(implied.flt.type)
# CHECK: index
print(implied.index.type)
+
+
+# CHECK-LABEL: TEST: testOptionalOperandOp
+@run
+def testOptionalOperandOp():
+ with Context() as ctx, Location.unknown():
+ test.register_python_test_dialect(ctx)
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ op1 = test.OptionalOperandOp(None)
+ # CHECK: op1.input is None: True
+ print(f"op1.input is None: {op1.input is None}")
+
+ op2 = test.OptionalOperandOp(op1)
+ # CHECK: op2.input is None: False
+ print(f"op2.input is None: {op2.input is None}")
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 0f947e7e536b..6ee71dbf8b12 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -76,4 +76,9 @@ def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op",
let results = (outs AnyType:$one, AnyType:$two, AnyType:$three);
}
+def OptionalOperandOp : TestOp<"optional_operand_op"> {
+ let arguments = (ins Optional<AnyType>:$input);
+ let results = (outs I32:$result);
+}
+
#endif // PYTHON_TEST_OPS
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 8babff25db07..fb634a1be395 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -109,10 +109,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
+/// This works if we have only one variable-length group (and it's the optional
+/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
+/// smaller than the total number of groups.
constexpr const char *opOneOptionalTemplate = R"Py(
@builtins.property
def {0}(self):
- return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None
+ return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
)Py";
/// Template for the variadic group accessor in the single variadic group case:
@@ -311,7 +314,7 @@ static std::string attrSizedTraitForKind(const char *kind) {
/// `operand` or `result` and is used verbatim in the emitted code.
static void emitElementAccessors(
const Operator &op, raw_ostream &os, const char *kind,
- llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
+ llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) {
@@ -326,12 +329,12 @@ static void emitElementAccessors(
llvm::StringRef(kind).drop_front());
std::string attrSizedTrait = attrSizedTraitForKind(kind);
- unsigned numVariadic = getNumVariadic(op);
+ unsigned numVariableLength = getNumVariableLength(op);
- // If there is only one variadic element group, its size can be inferred from
- // the total number of elements. If there are none, the generation is
- // straightforward.
- if (numVariadic <= 1) {
+ // If there is only one variable-length element group, its size can be
+ // inferred from the total number of elements. If there are none, the
+ // generation is straightforward.
+ if (numVariableLength <= 1) {
bool seenVariableLength = false;
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
@@ -364,7 +367,7 @@ static void emitElementAccessors(
const NamedTypeConstraint &element = getElement(op, i);
if (!element.name.empty()) {
os << llvm::formatv(opVariadicEqualPrefixTemplate,
- sanitizeName(element.name), kind, numVariadic,
+ sanitizeName(element.name), kind, numVariableLength,
numPrecedingSimple, numPrecedingVariadic);
os << llvm::formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
@@ -414,20 +417,20 @@ static const NamedTypeConstraint &getResult(const Operator &op, int i) {
/// Emits accessors to Op operands.
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
- auto getNumVariadic = [](const Operator &oper) {
+ auto getNumVariableLengthOperands = [](const Operator &oper) {
return oper.getNumVariableLengthOperands();
};
- emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands,
- getOperand);
+ emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
+ getNumOperands, getOperand);
}
/// Emits accessors Op results.
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
- auto getNumVariadic = [](const Operator &oper) {
+ auto getNumVariableLengthResults = [](const Operator &oper) {
return oper.getNumVariableLengthResults();
};
- emitElementAccessors(op, os, "result", getNumVariadic, getNumResults,
- getResult);
+ emitElementAccessors(op, os, "result", getNumVariableLengthResults,
+ getNumResults, getResult);
}
/// Emits accessors to Op attributes.