summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStanislav Funiak <stano@cerebras.net>2022-01-04 08:03:26 +0530
committerUday Bondhugula <uday@polymagelabs.com>2022-01-04 08:03:44 +0530
commit138803e017739c81b43b73631c7096bfc4d097d8 (patch)
tree7c05e121316b669f9f24b8109234aed83580f329
parent2692eae57428e1136ab58ac4004883245d0623ca (diff)
downloadllvm-138803e017739c81b43b73631c7096bfc4d097d8.tar.gz
[MLIR][PDL] Make predicate order deterministic.
The tree merging of pattern predicates places the predicates in an unordered set. When the predicates are sorted, they are taken in the set order, not the insertion order. This results in nondeterministic behavior. One solution to this problem would be to use `SetVector`. However, the value `SetVector` does not provide a `find` function for fast O(1) lookups and stores the predicates twice -- once in the set and once in the vector, which is undesirable, because we store patternToAnswer in each predicate. A simpler solution is to store the tie breaking ID (which follows the insertion order), and use this ID to break any ties when comparing predicates. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D116081
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp19
1 files changed, 14 insertions, 5 deletions
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 24b2f19e58c2..9fd5de11a83d 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -721,6 +721,11 @@ struct OrderedPredicate {
/// opposed to those shared across patterns.
unsigned secondary = 0;
+ /// The tie breaking ID, used to preserve a deterministic (insertion) order
+ /// among all the predicates with the same priority, depth, and position /
+ /// predicate dependency.
+ unsigned id = 0;
+
/// A map between a pattern operation and the answer to the predicate question
/// within that pattern.
DenseMap<Operation *, Qualifier *> patternToAnswer;
@@ -733,12 +738,13 @@ struct OrderedPredicate {
// * lower depth
// * lower position dependency
// * lower predicate dependency
+ // * lower tie breaking ID
auto *rhsPos = rhs.position;
return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
- rhsPos->getKind(), rhs.question->getKind()) >
+ rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
std::make_tuple(rhs.primary, rhs.secondary,
position->getOperationDepth(), position->getKind(),
- question->getKind());
+ question->getKind(), id);
}
};
@@ -903,6 +909,9 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
auto it = uniqued.insert(predicate);
it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
predicate.answer);
+ // Mark the insertion order (0-based indexing).
+ if (it.second)
+ it.first->id = uniqued.size() - 1;
}
}
@@ -939,9 +948,9 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
ordered.reserve(uniqued.size());
for (auto &ip : uniqued)
ordered.push_back(&ip);
- std::stable_sort(
- ordered.begin(), ordered.end(),
- [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
+ llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
+ return *lhs < *rhs;
+ });
// Build the matchers for each of the pattern predicate lists.
std::unique_ptr<MatcherNode> root;