//===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This is a prototype GPU codegenerator for the sparse compiler. // The objective is to eventually use the right combination of // direct code generation and libary calls into vendor-specific // highly optimized sparse libraries (e.g. cuSparse for CUDA). // //===----------------------------------------------------------------------===// #include "CodegenUtils.h" #include "LoopEmitter.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" using namespace mlir; using namespace mlir::sparse_tensor; namespace { //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// /// Marks the given top module as a GPU container module. static void markAsGPUContainer(ModuleOp topModule) { topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), UnitAttr::get(topModule->getContext())); } /// Constructs a new GPU module (for GPU kernels) inside the given top module, /// or returns an existing GPU module if one was built previously. static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) { for (auto op : topModule.getBodyRegion().getOps()) return op; // existing markAsGPUContainer(topModule); builder.setInsertionPointToStart(&topModule.getBodyRegion().front()); return builder.create(topModule->getLoc(), "sparse_kernels"); } /// Constructs a new GPU kernel in the given GPU module. static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, SmallVectorImpl &args) { // Get a unique kernel name. Not very creative, // but we simply try kernel0, kernel1, etc. unsigned kernelNumber = 0; SmallString<16> kernelName; do { kernelName.clear(); ("kernel" + Twine(kernelNumber++)).toStringRef(kernelName); } while (gpuModule.lookupSymbol(kernelName)); // Then we insert a new kernel with given arguments into the module. builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front()); SmallVector argsTp; for (unsigned i = 0, e = args.size(); i < e; i++) argsTp.push_back(args[i].getType()); FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {}); auto gpuFunc = builder.create(gpuModule->getLoc(), kernelName, type); gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), builder.getUnitAttr()); return gpuFunc; } /// Constructs code to launch GPU kernel. static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, SmallVectorImpl &args, SmallVectorImpl &tokens, unsigned numThreads) { Location loc = gpuFunc->getLoc(); Value none = TypedValue<::mlir::IntegerType>{}; Value one = constantIndex(builder, loc, 1); Value numT = constantIndex(builder, loc, numThreads); gpu::KernelDim3 gridSize = {one, one, one}; gpu::KernelDim3 blckSize = {numT, one, one}; return builder .create(loc, gpuFunc, gridSize, blckSize, /*dynSharedMemSz*/ none, args, builder.getType(), tokens) .getAsyncToken(); } /// Maps the provided ranked host buffer into the device address space. /// Writes from the host are guaranteed to be visible to device kernels /// that are launched afterwards. Writes from the device are guaranteed /// to be visible on the host after synchronizing with the device kernel /// completion. Needs to cast the buffer to a unranked buffer. static Value genHostRegisterMemref(OpBuilder &builder, Location loc, Value mem) { MemRefType memTp = cast(mem.getType()); UnrankedMemRefType resTp = UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); Value cast = builder.create(loc, resTp, mem); builder.create(loc, cast); return cast; } /// Unmaps the provided buffer, expecting the casted buffer. static void genHostUnregisterMemref(OpBuilder &builder, Location loc, Value cast) { builder.create(loc, cast); } /// Generates first wait in an asynchronous chain. static Value genFirstWait(OpBuilder &builder, Location loc) { Type tokenType = builder.getType(); return builder.create(loc, tokenType, ValueRange()) .getAsyncToken(); } /// Generates last, blocking wait in an asynchronous chain. static void genBlockingWait(OpBuilder &builder, Location loc, ValueRange operands) { builder.create(loc, Type(), operands); } /// Allocates memory on the device. /// TODO: A `host_shared` attribute could be used to indicate that /// the buffer is visible by both host and device, but lowering /// that feature does not seem to be fully supported yet. static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, Value token) { auto tp = cast(mem.getType()); auto elemTp = tp.getElementType(); auto shape = tp.getShape(); auto memTp = MemRefType::get(shape, elemTp); SmallVector dynamicSizes; for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) { if (shape[r] == ShapedType::kDynamic) { Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r); dynamicSizes.push_back(dimOp); } } return builder.create(loc, TypeRange({memTp, token.getType()}), token, dynamicSizes, ValueRange()); } // Allocates a void buffer on the device with given size. static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, Value token) { const auto memTp = MemRefType::get({ShapedType::kDynamic}, builder.getI8Type()); return builder.create(loc, TypeRange({memTp, token.getType()}), token, size, ValueRange()); } /// Deallocates memory from the device. static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, Value token) { return builder.create(loc, token.getType(), token, mem) .getAsyncToken(); } /// Copies memory between host and device (direction is implicit). static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst, Value src, Value token) { return builder.create(loc, token.getType(), token, dst, src) .getAsyncToken(); } /// Generates an alloc/copy pair. static Value genAllocCopy(OpBuilder &builder, Location loc, Value b, SmallVectorImpl &tokens) { Value firstToken = genFirstWait(builder, loc); auto alloc = genAllocMemRef(builder, loc, b, firstToken); Value devMem = alloc.getResult(0); Value depToken = alloc.getAsyncToken(); // copy-after-alloc tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken)); return devMem; } /// Generates a memref from tensor operation. static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, Value tensor) { auto tensorType = tensor.getType().cast(); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); return rewriter.create(loc, memrefType, tensor); } /// Prepares the outlined arguments, passing scalars and buffers in. Here we /// assume that the first buffer is the one allocated for output. We create /// a set of properly chained asynchronous allocation/copy pairs to increase /// overlap before launching the kernel. /// TODO: the output assumption may be a bit too brittle static Value genParametersIn(OpBuilder &builder, Location loc, SmallVectorImpl &scalars, SmallVectorImpl &buffers, SmallVectorImpl &args, SmallVectorImpl &tokens, bool useHostRegistrationForOut) { Value out; // Scalars are passed by value. for (Value s : scalars) args.push_back(s); // Buffers are need to be made visible on device. for (Value b : buffers) { if (useHostRegistrationForOut) { out = genHostRegisterMemref(builder, loc, b); args.push_back(b); useHostRegistrationForOut = false; continue; } args.push_back(genAllocCopy(builder, loc, b, tokens)); } return out; } /// Finalizes the outlined arguments. The output buffer is copied depending /// on the kernel token and then deallocated. All other buffers are simply /// deallocated. Then we wait for all operations to complete. static void genParametersOut(OpBuilder &builder, Location loc, Value out, Value kernelToken, SmallVectorImpl &scalars, SmallVectorImpl &buffers, SmallVectorImpl &args, SmallVectorImpl &tokens) { unsigned base = scalars.size(); for (unsigned i = base, e = args.size(); i < e; i++) { Value firstToken; if (i == base) { // Assumed output parameter: unregister or copy-out. if (out) { genHostUnregisterMemref(builder, loc, out); out = Value(); continue; } firstToken = genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken); } else { firstToken = genFirstWait(builder, loc); } tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken)); } } /// Constructs code for new GPU kernel. static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, scf::ParallelOp forallOp, SmallVectorImpl &constants, SmallVectorImpl &scalars, SmallVectorImpl &buffers) { Location loc = gpuFunc->getLoc(); Block &block = gpuFunc.getBody().front(); rewriter.setInsertionPointToStart(&block); // Re-generate the constants, recapture all arguments. unsigned arg = 0; IRMapping irMap; for (Value c : constants) irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0)); for (Value s : scalars) irMap.map(s, block.getArgument(arg++)); for (Value b : buffers) irMap.map(b, block.getArgument(arg++)); // Assume 1-dimensional grid/block configuration (only x dimension), // so that: // row = blockIdx.x * blockDim.x + threadIdx.x // inc = blockDim.x * gridDim.x Value bid = rewriter.create(loc, gpu::Dimension::x); Value bsz = rewriter.create(loc, gpu::Dimension::x); Value tid = rewriter.create(loc, gpu::Dimension::x); Value gsz = rewriter.create(loc, gpu::Dimension::x); Value mul = rewriter.create(loc, bid, bsz); Value row = rewriter.create(loc, mul, tid); Value inc = rewriter.create(loc, bsz, gsz); // Construct the iteration over the computational space that // accounts for the fact that the total number of threads and // the amount of work to be done usually do not match precisely. // for (r = row; r < N; r += inc) { // // } Value upper = irMap.lookup(forallOp.getUpperBound()[0]); scf::ForOp forOp = rewriter.create(loc, row, upper, inc); rewriter.cloneRegionBefore(forallOp.getLoopBody(), forOp.getLoopBody(), forOp.getLoopBody().begin(), irMap); // Done. rewriter.setInsertionPointAfter(forOp); rewriter.create(gpuFunc->getLoc()); } //===----------------------------------------------------------------------===// // Library helper methods. //===----------------------------------------------------------------------===// /// Helper to detect a * b. static bool matchMulOfArgs(linalg::GenericOp op, Value val) { if (auto *def = val.getDefiningOp()) { if (isa(def) || isa(def)) { Value a = op.getBlock()->getArguments()[0]; Value b = op.getBlock()->getArguments()[1]; return (def->getOperand(0) == a && def->getOperand(1) == b) || (def->getOperand(0) == b && def->getOperand(1) == a); } } return false; } /// Helper to detect x = x + a * b static bool matchSumOfMultOfArgs(linalg::GenericOp op) { auto yieldOp = cast(op.getRegion().front().getTerminator()); if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { if (isa(def) || isa(def)) { Value x = op.getBlock()->getArguments()[2]; return (def->getOperand(0) == x && matchMulOfArgs(op, def->getOperand(1))) || (def->getOperand(1) == x && matchMulOfArgs(op, def->getOperand(0))); } } return false; } /// Test for sorted COO with suitable data and coordinates types. static bool isAdmissibleCOO(SparseTensorType &aTp) { return aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) && aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && (aTp.getElementType().isF64() || aTp.getElementType().isF32()) && (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 || aTp.getCrdWidth() == 64); } /// Test for CSR with suitable data and coordinates types. static bool isAdmissibleCSR(SparseTensorType &aTp) { return aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && (aTp.getElementType().isF64() || aTp.getElementType().isF32()) && (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 || aTp.getCrdWidth() == 64); } /// Generates the first positions/coordinates of a sparse matrix. static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, bool isCOO, bool enableRT) { if (isCOO) { // Library uses SoA COO, direct IR uses AoS COO. if (enableRT) return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0); return genToCoordinatesBuffer(builder, loc, a); } // CSR uses positions. return genToPositions(builder, loc, a, 1); } /// Generates the second coordinates of a sparse matrix. static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, bool isCOO, bool enableRT) { if (isCOO && !enableRT) return Value(); // nothing needed return genToCoordinates(builder, loc, a, 1, /*cooStart=*/0); } /// Generates the sparse matrix multiplication. static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp, Type tokenTp, Value token, Value szY, Value szX, Value nnzA, Value rowA, Value colA, Value valA, bool isCOO, bool enableRT) { if (isCOO) { // Library uses SoA COO, direct IR uses AoS COO. if (enableRT) return builder.create(loc, handleTp, tokenTp, token, szY, szX, nnzA, rowA, colA, valA); llvm_unreachable("gpu::CreateCooAoSOp is deprecated"); } return builder.create(loc, handleTp, tokenTp, token, szY, szX, nnzA, rowA, colA, valA); } /// Match and rewrite SpMV kernel. static LogicalResult rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT) { Location loc = op.getLoc(); Value a = op.getOperand(0); Value x = op.getOperand(1); Value y = op.getOperand(2); // we have y = Ax SmallVector tokens; // Only admissible sparse matrix format and dense vectors for now. bool isCOO = false; SparseTensorType aTp = getSparseTensorType(a); SparseTensorType xTp = getSparseTensorType(x); SparseTensorType yTp = getSparseTensorType(y); if (xTp.hasEncoding() || yTp.hasEncoding()) return failure(); if (isAdmissibleCOO(aTp)) { isCOO = true; // TODO: CreateCooAoSOp was deprecated, find another way if (!enableRT) return failure(); } else if (isAdmissibleCSR(aTp)) { isCOO = false; } else { return failure(); } // Start sparse kernel and copy data from host to device. // a : memR/memC/memV -> rowA,colA,valA // x : memX -> vecX // y : memY -> vecY Value nnzA = rewriter.create(loc, a); Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0); Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1); Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT); Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT); Value memV = genToValues(rewriter, loc, a); Value rowA = genAllocCopy(rewriter, loc, memR, tokens); Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value(); Value valA = genAllocCopy(rewriter, loc, memV, tokens); Value memX = genTensorToMemref(rewriter, loc, x); Value vecX = genAllocCopy(rewriter, loc, memX, tokens); Value memY = genTensorToMemref(rewriter, loc, y); Value vecY = genAllocCopy(rewriter, loc, memY, tokens); genBlockingWait(rewriter, loc, tokens); tokens.clear(); // Create sparse environment and sparse matrix/dense vector handles. Type indexTp = rewriter.getIndexType(); Type handleTp = rewriter.getType(); Type tokenTp = rewriter.getType(); Value token = genFirstWait(rewriter, loc); auto env = rewriter.create(loc, handleTp, tokenTp, token); Value handle = env.getResult(0); token = env.getAsyncToken(); Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szY, szX, nnzA, rowA, colA, valA, isCOO, enableRT); Value spMatA = spGenA->getResult(0); token = spGenA->getResult(1); auto dvecX = rewriter.create(loc, handleTp, tokenTp, token, vecX, szX); Value dnX = dvecX.getResult(0); token = dvecX.getAsyncToken(); auto dvecY = rewriter.create(loc, handleTp, tokenTp, token, vecY, szY); Value dnY = dvecY.getResult(0); token = dvecY.getAsyncToken(); // Precompute buffersize for SpMV. auto bufferComp = rewriter.create( loc, indexTp, tokenTp, token, handle, spMatA, dnX, dnY); Value bufferSz = bufferComp.getResult(0); token = bufferComp.getAsyncToken(); auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); Value buffer = buf.getResult(0); token = buf.getAsyncToken(); // Perform the SpMV. auto spmvComp = rewriter.create(loc, tokenTp, token, handle, spMatA, dnX, dnY, buffer); token = spmvComp.getAsyncToken(); // Copy data back to host and free all the resoures. token = rewriter.create(loc, tokenTp, token, spMatA) .getAsyncToken(); token = rewriter.create(loc, tokenTp, token, dnX) .getAsyncToken(); token = rewriter.create(loc, tokenTp, token, dnY) .getAsyncToken(); token = rewriter.create(loc, tokenTp, token, handle) .getAsyncToken(); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); tokens.clear(); token = genFirstWait(rewriter, loc); token = genCopyMemRef(rewriter, loc, memY, vecY, token); token = genDeallocMemRef(rewriter, loc, rowA, token); if (colA) token = genDeallocMemRef(rewriter, loc, colA, token); token = genDeallocMemRef(rewriter, loc, valA, token); token = genDeallocMemRef(rewriter, loc, buffer, token); token = genDeallocMemRef(rewriter, loc, vecX, token); token = genDeallocMemRef(rewriter, loc, vecY, token); tokens.push_back(token); genBlockingWait(rewriter, loc, tokens); tokens.clear(); // Done. rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); return success(); } /// Match and rewrite SpMM kernel. static LogicalResult rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT) { return failure(); // TODO: implement } //===----------------------------------------------------------------------===// // Rewriting rules for direct code generation. //===----------------------------------------------------------------------===// /// Proof-of-concept rewriter. This rule generates a GPU implementation /// for each outermost forall loop generated by the sparse compiler. /// TODO: right works with parallelization-strategy=dense-outer-loop /// but give this its own flags in the future struct ForallRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; ForallRewriter(MLIRContext *context, unsigned nT) : OpRewritePattern(context), numThreads(nT){}; LogicalResult matchAndRewrite(scf::ParallelOp forallOp, PatternRewriter &rewriter) const override { // Reject inadmissible loop form. // Essentially only accept a loop, generated by the sparse compiler, // of the form // forall (i = 0; i < N; i++) // so that cyclic scheduling over the threads is easy. if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) || forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 || !matchPattern(forallOp.getLowerBound()[0], m_Zero()) || !matchPattern(forallOp.getStep()[0], m_One())) return failure(); // Collect every value that is computed outside the parallel loop. SetVector invariants; // stable iteration! forallOp->walk([&](Operation *op) { // Collect all values of admissible ops. for (OpOperand &o : op->getOpOperands()) { Value val = o.get(); Block *block; if (auto arg = dyn_cast(val)) block = arg.getOwner(); else block = val.getDefiningOp()->getBlock(); if (!isNestedIn(block, forallOp)) invariants.insert(val); } }); // Outline the outside values as proper parameters. Fail when sharing // value between host and device is not straightforward. SmallVector constants; SmallVector scalars; SmallVector buffers; for (Value val : invariants) { Type tp = val.getType(); if (val.getDefiningOp()) constants.push_back(val); else if (isa(tp) || tp.isIntOrIndex()) scalars.push_back(val); else if (isa(tp)) buffers.push_back(val); else return failure(); // don't know how to share } // Pass outlined non-constant values. // TODO: Experiment with `useHostRegistrationForOut` to see if we want to // keep the feature at all (either through a heuristic or compiler // option for gpu codegen). Location loc = forallOp->getLoc(); SmallVector args; SmallVector tokens; Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens, /*useHostRegistrationForOut=*/false); // Set up GPU module and construct GPU function. auto saveIp = rewriter.saveInsertionPoint(); ModuleOp topModule = forallOp->getParentOfType(); auto gpuModule = genGPUModule(rewriter, topModule); auto gpuFunc = genGPUFunc(rewriter, gpuModule, args); genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers); // Generate code that launches the kernel asynchronously, blocking on all // opens tokens and yielding a new token for the output. // TODO: Passing in tokens to launch up does not seem to be properly lowered // by cubin yet, hence the current blocking wait. rewriter.restoreInsertionPoint(saveIp); genBlockingWait(rewriter, loc, tokens); tokens.clear(); Value kernelToken = genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads); // Finalize the outlined arguments. genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args, tokens); genBlockingWait(rewriter, loc, tokens); rewriter.eraseOp(forallOp); return success(); } private: // Helper method to see if block appears in given loop. static bool isNestedIn(Block *block, scf::ParallelOp forallOp) { for (Operation *o = block->getParentOp(); o; o = o->getParentOp()) { if (o == forallOp) return true; } return false; } unsigned numThreads; }; //===----------------------------------------------------------------------===// // Rewriting rules for library recognition and code generation. //===----------------------------------------------------------------------===// /// Proof-of-concept rewriter. This rule recognizes certain math kernels /// and replaces these with corresponding calls into the sparse library. struct LinalgOpRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LinalgOpRewriter(MLIRContext *context, bool rt) : OpRewritePattern(context), enableRT(rt) {} LogicalResult matchAndRewrite(linalg::GenericOp op, PatternRewriter &rewriter) const override { if (op.getNumDpsInits() != 1) return failure(); // reject multi-output const unsigned numLoops = op.getNumLoops(); const unsigned numTensors = op->getNumOperands(); const auto iteratorTypes = op.getIteratorTypesArray(); SmallVector maps = op.getIndexingMapsArray(); using MapList = ArrayRef>; auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr i, j, k; bindDims(getContext(), i, j, k); // TODO: more robust patterns, tranposed versions, more kernels... // Recognize a SpMV kernel. if (numLoops == 2 && numTensors == 3 && linalg::isParallelIterator(iteratorTypes[0]) && linalg::isReductionIterator(iteratorTypes[1]) && maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { return rewriteSpMV(rewriter, op, enableRT); } // Recognize a SpMM kernel. if (numLoops == 3 && numTensors == 3 && linalg::isParallelIterator(iteratorTypes[0]) && linalg::isParallelIterator(iteratorTypes[1]) && linalg::isReductionIterator(iteratorTypes[2]) && maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { return rewriteSpMM(rewriter, op, enableRT); } return failure(); } private: bool enableRT; }; } // namespace //===----------------------------------------------------------------------===// // Public method for populating GPU rewriting rules. // // Currently two set of rewriting rules are made available. The first set // implements direct code generation, currently by means of convering the // outermost paralell loop into GPU threads. The second set implements // libary recognition of a set of sparse operations. Eventually, the right // combination of these two approaches has to be found. //===----------------------------------------------------------------------===// void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, unsigned numThreads) { patterns.add(patterns.getContext(), numThreads); } void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, bool enableRT) { patterns.add(patterns.getContext(), enableRT); }