//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains definitions of patterns to lower GPU Subgroup MMA ops to // SPIRV Dialect ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/TypeUtilities.h" using namespace mlir; /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op /// when the elementwise op directly supports with cooperative matrix type. /// Returns false if cannot. /// /// See SPV_NV_cooperative_matrix for supported elementwise ops. static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, spirv::CooperativeMatrixNVType coopType, ValueRange operands) { switch (op.getOpType()) { case gpu::MMAElementwiseOp::ADDF: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::ADDI: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::SUBF: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::SUBI: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::DIVF: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::DIVS: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::DIVU: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::NEGATEF: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::NEGATES: builder.replaceOpWithNewOp(op, coopType, operands); return true; default: break; } return false; } namespace { /// This class implements the conversion of GPU MMA loadOp to /// CooperativeMatrixLoad op in the SPIRV dialect. struct WmmaLoadOpToSPIRVLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = subgroupMmaLoadMatrixOp->getLoc(); gpu::MMAMatrixType retType = cast(subgroupMmaLoadMatrixOp.getRes().getType()); auto memrefType = cast(subgroupMmaLoadMatrixOp.getSrcMemref().getType()); Value bufferPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter); auto coopType = convertMMAToSPIRVType(retType); int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue(); auto i32Type = rewriter.getI32Type(); auto strideValue = rewriter.create( loc, i32Type, IntegerAttr::get(i32Type, stride)); bool isColMajor = static_cast(subgroupMmaLoadMatrixOp.getTranspose()); auto columnMajor = rewriter.create( loc, rewriter.getI1Type(), rewriter.getBoolAttr(isColMajor)); rewriter.replaceOpWithNewOp( subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor, spirv::MemoryAccessAttr()); return success(); } }; /// This class implements the conversion of GPU MMA StoreOp to /// CooperativeMatrixStore op in the SPIRV dialect. struct WmmaStoreOpToSPIRVLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = subgroupMmaStoreMatrixOp->getLoc(); auto memrefType = cast(subgroupMmaStoreMatrixOp.getDstMemref().getType()); Value bufferPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter); int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue(); auto i32Type = rewriter.getI32Type(); auto strideValue = rewriter.create( loc, i32Type, IntegerAttr::get(i32Type, stride)); bool useColMajor = static_cast(subgroupMmaStoreMatrixOp.getTranspose()); auto columnMajor = rewriter.create( loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor)); rewriter.replaceOpWithNewOp( subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue, columnMajor, spirv::MemoryAccessAttr()); return success(); } }; /// This class implements the conversion of GPU MMA Compute to /// CooperativeMatrixMulAdd op in the SPIRV dialect. struct WmmaMmaOpToSPIRVLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(), adaptor.getOpB(), adaptor.getOpC()); return success(); } }; /// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops. struct WmmaConstantOpToSPIRVLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value cst = adaptor.getOperands()[0]; auto coopType = convertMMAToSPIRVType( cast(subgroupMmaConstantMatrixOp.getType())); rewriter.replaceOpWithNewOp( subgroupMmaConstantMatrixOp, coopType, cst); return success(); } }; /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for /// the default case. struct WmmaElementwiseOpToSPIRVDefaultLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // All operands should be of cooperative matrix types. for (Value operand : adaptor.getOperands()) { if (!isa(operand.getType())) return failure(); } auto coopType = convertMMAToSPIRVType( cast(elementwiseOp.getType())); return success(createElementwiseOp(rewriter, elementwiseOp, coopType, adaptor.getOperands())); } }; /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for /// matrix times scalar case. struct WmmaElementwiseOpToSPIRVScalarMulLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (adaptor.getOperands().size() != 2) return failure(); // All operands should be of cooperative matrix types. for (Value operand : adaptor.getOperands()) { if (!isa(operand.getType())) return failure(); } if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF) return failure(); // Use the original operands to check whether one of the operands is a splat // scalar value. Value lhs = elementwiseOp.getOperands().front(); Value rhs = elementwiseOp.getOperands().back(); Value splat = nullptr; Value matrix = nullptr; if (lhs.getDefiningOp()) { splat = adaptor.getOperands().front(); matrix = adaptor.getOperands().back(); } else if (rhs.getDefiningOp()) { matrix = adaptor.getOperands().front(); splat = adaptor.getOperands().back(); } if (!splat || !matrix) return failure(); // Constant MMA matrix ops are converted to spirv.CompositeConstruct ops. Value scalar = nullptr; auto cc = splat.getDefiningOp(); if (!cc) return failure(); assert(cc.getConstituents().size() == 1); scalar = cc.getConstituents().front(); auto coopType = convertMMAToSPIRVType( cast(elementwiseOp.getType())); rewriter.replaceOpWithNewOp( elementwiseOp, coopType, ValueRange{matrix, scalar}); return success(); } }; } // namespace /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. mlir::spirv::CooperativeMatrixNVType mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) { ArrayRef retTypeShape = type.getShape(); Type elementType = type.getElementType(); return spirv::CooperativeMatrixNVType::get( elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]); } void mlir::populateGpuWMMAToSPIRVConversionPatterns( SPIRVTypeConverter &converter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add(converter, context); // Give the following patterns higher benefit to prevail over the default one. patterns.add(converter, context, /*benefit=*/2); }