//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert Vector dialect to SPIRV dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include using namespace mlir; /// Gets the first integer value from `attr`, assuming it is an integer array /// attribute. static uint64_t getFirstIntValue(ArrayAttr attr) { return (*attr.getAsValueRange().begin()).getZExtValue(); } /// Returns the number of bits for the given scalar/vector type. static int getNumBits(Type type) { // TODO: This does not take into account any memory layout or widening // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even // though in practice it will likely be stored as in a 4xi64 vector register. if (auto vectorType = dyn_cast(type)) return vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); return type.getIntOrFloatBitWidth(); } namespace { struct VectorBitcastConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstType = getTypeConverter()->convertType(bitcastOp.getType()); if (!dstType) return failure(); if (dstType == adaptor.getSource().getType()) { rewriter.replaceOp(bitcastOp, adaptor.getSource()); return success(); } // Check that the source and destination type have the same bitwidth. // Depending on the target environment, we may need to emulate certain // types, which can cause issue with bitcast. Type srcType = adaptor.getSource().getType(); if (getNumBits(dstType) != getNumBits(srcType)) { return rewriter.notifyMatchFailure( bitcastOp, llvm::formatv("different source ({0}) and target ({1}) bitwidth", srcType, dstType)); } rewriter.replaceOpWithNewOp(bitcastOp, dstType, adaptor.getSource()); return success(); } }; struct VectorBroadcastConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resultType = getTypeConverter()->convertType(castOp.getResultVectorType()); if (!resultType) return failure(); if (isa(resultType)) { rewriter.replaceOp(castOp, adaptor.getSource()); return success(); } SmallVector source(castOp.getResultVectorType().getNumElements(), adaptor.getSource()); rewriter.replaceOpWithNewOp( castOp, castOp.getResultVectorType(), source); return success(); } }; struct VectorExtractOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only support extracting a scalar value now. VectorType resultVectorType = dyn_cast(extractOp.getType()); if (resultVectorType && resultVectorType.getNumElements() > 1) return failure(); Type dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) return failure(); if (isa(adaptor.getVector().getType())) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } int32_t id = getFirstIntValue(extractOp.getPosition()); rewriter.replaceOpWithNewOp( extractOp, adaptor.getVector(), id); return success(); } }; struct VectorExtractStridedSliceOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) return failure(); uint64_t offset = getFirstIntValue(extractOp.getOffsets()); uint64_t size = getFirstIntValue(extractOp.getSizes()); uint64_t stride = getFirstIntValue(extractOp.getStrides()); if (stride != 1) return failure(); Value srcVector = adaptor.getOperands().front(); // Extract vector<1xT> case. if (isa(dstType)) { rewriter.replaceOpWithNewOp(extractOp, srcVector, offset); return success(); } SmallVector indices(size); std::iota(indices.begin(), indices.end(), offset); rewriter.replaceOpWithNewOp( extractOp, dstType, srcVector, srcVector, rewriter.getI32ArrayAttr(indices)); return success(); } }; template struct VectorFmaOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstType = getTypeConverter()->convertType(fmaOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); return success(); } }; struct VectorInsertOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (isa(insertOp.getSourceType())) return rewriter.notifyMatchFailure(insertOp, "unsupported vector source"); if (!getTypeConverter()->convertType(insertOp.getDestVectorType())) return rewriter.notifyMatchFailure(insertOp, "unsupported dest vector type"); // Special case for inserting scalar values into size-1 vectors. if (insertOp.getSourceType().isIntOrFloat() && insertOp.getDestVectorType().getNumElements() == 1) { rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } int32_t id = getFirstIntValue(insertOp.getPosition()); rewriter.replaceOpWithNewOp( insertOp, adaptor.getSource(), adaptor.getDest(), id); return success(); } }; struct VectorExtractElementOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resultType = getTypeConverter()->convertType(extractOp.getType()); if (!resultType) return failure(); if (isa(adaptor.getVector().getType())) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } APInt cstPos; if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) rewriter.replaceOpWithNewOp( extractOp, resultType, adaptor.getVector(), rewriter.getI32ArrayAttr({static_cast(cstPos.getSExtValue())})); else rewriter.replaceOpWithNewOp( extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); return success(); } }; struct VectorInsertElementOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type vectorType = getTypeConverter()->convertType(insertOp.getType()); if (!vectorType) return failure(); if (isa(vectorType)) { rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } APInt cstPos; if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) rewriter.replaceOpWithNewOp( insertOp, adaptor.getSource(), adaptor.getDest(), cstPos.getSExtValue()); else rewriter.replaceOpWithNewOp( insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), adaptor.getPosition()); return success(); } }; struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value srcVector = adaptor.getOperands().front(); Value dstVector = adaptor.getOperands().back(); uint64_t stride = getFirstIntValue(insertOp.getStrides()); if (stride != 1) return failure(); uint64_t offset = getFirstIntValue(insertOp.getOffsets()); if (isa(srcVector.getType())) { assert(!isa(dstVector.getType())); rewriter.replaceOpWithNewOp( insertOp, dstVector.getType(), srcVector, dstVector, rewriter.getI32ArrayAttr(offset)); return success(); } uint64_t totalSize = cast(dstVector.getType()).getNumElements(); uint64_t insertSize = cast(srcVector.getType()).getNumElements(); SmallVector indices(totalSize); std::iota(indices.begin(), indices.end(), 0); std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, totalSize); rewriter.replaceOpWithNewOp( insertOp, dstVector.getType(), dstVector, srcVector, rewriter.getI32ArrayAttr(indices)); return success(); } }; template struct VectorReductionPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resultType = typeConverter->convertType(reduceOp.getType()); if (!resultType) return failure(); auto srcVectorType = dyn_cast(adaptor.getVector().getType()); if (!srcVectorType || srcVectorType.getRank() != 1) return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source"); // Extract all elements. int numElements = srcVectorType.getDimSize(0); SmallVector values; values.reserve(numElements + (adaptor.getAcc() != nullptr)); Location loc = reduceOp.getLoc(); for (int i = 0; i < numElements; ++i) { values.push_back(rewriter.create( loc, srcVectorType.getElementType(), adaptor.getVector(), rewriter.getI32ArrayAttr({i}))); } if (Value acc = adaptor.getAcc()) values.push_back(acc); // Reduce them. Value result = values.front(); for (Value next : llvm::ArrayRef(values).drop_front()) { switch (reduceOp.getKind()) { #define INT_AND_FLOAT_CASE(kind, iop, fop) \ case vector::CombiningKind::kind: \ if (resultType.isa()) { \ result = rewriter.create(loc, resultType, result, next); \ } else { \ assert(resultType.isa()); \ result = rewriter.create(loc, resultType, result, next); \ } \ break #define INT_OR_FLOAT_CASE(kind, fop) \ case vector::CombiningKind::kind: \ result = rewriter.create(loc, resultType, result, next); \ break INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp); INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp); INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp); INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp); INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp); INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp); INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp); case vector::CombiningKind::AND: case vector::CombiningKind::OR: case vector::CombiningKind::XOR: return rewriter.notifyMatchFailure(reduceOp, "unimplemented"); } } rewriter.replaceOp(reduceOp, result); return success(); } }; class VectorSplatPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return failure(); if (isa(dstType)) { rewriter.replaceOp(op, adaptor.getInput()); } else { auto dstVecType = cast(dstType); SmallVector source(dstVecType.getNumElements(), adaptor.getInput()); rewriter.replaceOpWithNewOp(op, dstType, source); } return success(); } }; struct VectorShuffleOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto oldResultType = shuffleOp.getResultVectorType(); Type newResultType = getTypeConverter()->convertType(oldResultType); if (!newResultType) return rewriter.notifyMatchFailure(shuffleOp, "unsupported result vector type"); auto oldSourceType = shuffleOp.getV1VectorType(); if (oldSourceType.getNumElements() > 1) { SmallVector components = llvm::to_vector<4>( llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t { return cast(attr).getValue().getZExtValue(); })); rewriter.replaceOpWithNewOp( shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), rewriter.getI32ArrayAttr(components)); return success(); } SmallVector oldOperands = {adaptor.getV1(), adaptor.getV2()}; SmallVector newOperands; newOperands.reserve(oldResultType.getNumElements()); for (const APInt &i : shuffleOp.getMask().getAsValueRange()) { newOperands.push_back(oldOperands[i.getZExtValue()]); } rewriter.replaceOpWithNewOp( shuffleOp, newResultType, newOperands); return success(); } }; struct VectorReductionToDotProd final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { if (op.getKind() != vector::CombiningKind::ADD) return rewriter.notifyMatchFailure(op, "combining kind is not 'add'"); auto resultType = dyn_cast(op.getType()); if (!resultType) return rewriter.notifyMatchFailure(op, "result is not an integer"); int64_t resultBitwidth = resultType.getIntOrFloatBitWidth(); if (!llvm::is_contained({32, 64}, resultBitwidth)) return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth"); VectorType inVecTy = op.getSourceVectorType(); if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) || inVecTy.getShape().size() != 1 || inVecTy.isScalable()) return rewriter.notifyMatchFailure(op, "unsupported vector shape"); auto mul = op.getVector().getDefiningOp(); if (!mul) return rewriter.notifyMatchFailure( op, "reduction operand is not 'arith.muli'"); if (succeeded(handleCase(op, mul, rewriter))) return success(); if (succeeded(handleCase(op, mul, rewriter))) return success(); if (succeeded(handleCase(op, mul, rewriter))) return success(); if (succeeded(handleCase(op, mul, rewriter))) return success(); return failure(); } private: template static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul, PatternRewriter &rewriter) { auto lhs = mul.getLhs().getDefiningOp(); if (!lhs) return failure(); Value lhsIn = lhs.getIn(); auto lhsInType = cast(lhsIn.getType()); if (!lhsInType.getElementType().isInteger(8)) return failure(); auto rhs = mul.getRhs().getDefiningOp(); if (!rhs) return failure(); Value rhsIn = rhs.getIn(); auto rhsInType = cast(rhsIn.getType()); if (!rhsInType.getElementType().isInteger(8)) return failure(); if (op.getSourceVectorType().getNumElements() == 3) { IntegerType i8Type = rewriter.getI8Type(); auto v4i8Type = VectorType::get({4}, i8Type); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter); lhsIn = rewriter.create( loc, v4i8Type, ValueRange{lhsIn, zero}); rhsIn = rewriter.create( loc, v4i8Type, ValueRange{rhsIn, zero}); } // There's no variant of dot prod ops for unsigned LHS and signed RHS, so // we have to swap operands instead in that case. if (SwapOperands) std::swap(lhsIn, rhsIn); if (Value acc = op.getAcc()) { rewriter.replaceOpWithNewOp(op, op.getType(), lhsIn, rhsIn, acc, nullptr); } else { rewriter.replaceOpWithNewOp(op, op.getType(), lhsIn, rhsIn, nullptr); } return success(); } }; } // namespace #define CL_MAX_MIN_OPS \ spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \ spirv::CLSMaxOp, spirv::CLSMinOp #define GL_MAX_MIN_OPS \ spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \ spirv::GLSMaxOp, spirv::GLSMinOp void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< VectorBitcastConvert, VectorBroadcastConvert, VectorExtractElementOpConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, VectorFmaOpConvert, VectorInsertElementOpConvert, VectorInsertOpConvert, VectorReductionPattern, VectorReductionPattern, VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, VectorSplatPattern>(typeConverter, patterns.getContext()); } void mlir::populateVectorReductionToSPIRVDotProductPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); }