//===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/Operation.h" #include using namespace mlir; //===----------------------------------------------------------------------===// // AttrTypeWalker //===----------------------------------------------------------------------===// WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) { return walkImpl(attr, attrWalkFns, order); } WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) { return walkImpl(type, typeWalkFns, order); } template WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns, WalkOrder order) { // Check if we've already walk this element before. auto key = std::make_pair(element.getAsOpaquePointer(), (int)order); auto it = visitedAttrTypes.find(key); if (it != visitedAttrTypes.end()) return it->second; visitedAttrTypes.try_emplace(key, WalkResult::advance()); // If we are walking in post order, walk the sub elements first. if (order == WalkOrder::PostOrder) { if (walkSubElements(element, order).wasInterrupted()) return visitedAttrTypes[key] = WalkResult::interrupt(); } // Walk this element, bailing if skipped or interrupted. for (auto &walkFn : llvm::reverse(walkFns)) { WalkResult walkResult = walkFn(element); if (walkResult.wasInterrupted()) return visitedAttrTypes[key] = WalkResult::interrupt(); if (walkResult.wasSkipped()) return WalkResult::advance(); } // If we are walking in pre-order, walk the sub elements last. if (order == WalkOrder::PreOrder) { if (walkSubElements(element, order).wasInterrupted()) return WalkResult::interrupt(); } return WalkResult::advance(); } template WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) { WalkResult result = WalkResult::advance(); auto walkFn = [&](auto element) { if (element && !result.wasInterrupted()) result = walkImpl(element, order); }; interface.walkImmediateSubElements(walkFn, walkFn); return result.wasInterrupted() ? result : WalkResult::advance(); } //===----------------------------------------------------------------------===// /// AttrTypeReplacer //===----------------------------------------------------------------------===// void AttrTypeReplacer::addReplacement(ReplaceFn fn) { attrReplacementFns.emplace_back(std::move(fn)); } void AttrTypeReplacer::addReplacement(ReplaceFn fn) { typeReplacementFns.push_back(std::move(fn)); } void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { // Functor that replaces the given element if the new value is different, // otherwise returns nullptr. auto replaceIfDifferent = [&](auto element) { auto replacement = replace(element); return (replacement && replacement != element) ? replacement : nullptr; }; // Update the attribute dictionary. if (replaceAttrs) { if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary())) op->setAttrs(cast(newAttrs)); } // If we aren't updating locations or types, we're done. if (!replaceTypes && !replaceLocs) return; // Update the location. if (replaceLocs) { if (Attribute newLoc = replaceIfDifferent(op->getLoc())) op->setLoc(cast(newLoc)); } // Update the result types. if (replaceTypes) { for (OpResult result : op->getResults()) if (Type newType = replaceIfDifferent(result.getType())) result.setType(newType); } // Update any nested block arguments. for (Region ®ion : op->getRegions()) { for (Block &block : region) { for (BlockArgument &arg : block.getArguments()) { if (replaceLocs) { if (Attribute newLoc = replaceIfDifferent(arg.getLoc())) arg.setLoc(cast(newLoc)); } if (replaceTypes) { if (Type newType = replaceIfDifferent(arg.getType())) arg.setType(newType); } } } } } void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { op->walk([&](Operation *nestedOp) { replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes); }); } template static void updateSubElementImpl(T element, AttrTypeReplacer &replacer, SmallVectorImpl &newElements, FailureOr &changed) { // Bail early if we failed at any point. if (failed(changed)) return; // Guard against potentially null inputs. We always map null to null. if (!element) { newElements.push_back(nullptr); return; } // Replace the element. if (T result = replacer.replace(element)) { newElements.push_back(result); if (result != element) changed = true; } else { changed = failure(); } } template T AttrTypeReplacer::replaceSubElements(T interface) { // Walk the current sub-elements, replacing them as necessary. SmallVector newAttrs; SmallVector newTypes; FailureOr changed = false; interface.walkImmediateSubElements( [&](Attribute element) { updateSubElementImpl(element, *this, newAttrs, changed); }, [&](Type element) { updateSubElementImpl(element, *this, newTypes, changed); }); if (failed(changed)) return nullptr; // If any sub-elements changed, use the new elements during the replacement. T result = interface; if (*changed) result = interface.replaceImmediateSubElements(newAttrs, newTypes); return result; } /// Shared implementation of replacing a given attribute or type element. template T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) { const void *opaqueElement = element.getAsOpaquePointer(); auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement); if (!inserted) return T::getFromOpaquePointer(it->second); T result = element; WalkResult walkResult = WalkResult::advance(); for (auto &replaceFn : llvm::reverse(replaceFns)) { if (std::optional> newRes = replaceFn(element)) { std::tie(result, walkResult) = *newRes; break; } } // If an error occurred, return nullptr to indicate failure. if (walkResult.wasInterrupted() || !result) { attrTypeMap[opaqueElement] = nullptr; return nullptr; } // Handle replacing sub-elements if this element is also a container. if (!walkResult.wasSkipped()) { // Replace the sub elements of this element, bailing if we fail. if (!(result = replaceSubElements(result))) { attrTypeMap[opaqueElement] = nullptr; return nullptr; } } attrTypeMap[opaqueElement] = result.getAsOpaquePointer(); return result; } Attribute AttrTypeReplacer::replace(Attribute attr) { return replaceImpl(attr, attrReplacementFns); } Type AttrTypeReplacer::replace(Type type) { return replaceImpl(type, typeReplacementFns); } //===----------------------------------------------------------------------===// // AttrTypeImmediateSubElementWalker //===----------------------------------------------------------------------===// void AttrTypeImmediateSubElementWalker::walk(Attribute element) { if (element) walkAttrsFn(element); } void AttrTypeImmediateSubElementWalker::walk(Type element) { if (element) walkTypesFn(element); }