//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert Vector dialect to SPIRV dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" #include "../PassDetail.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; namespace { struct VectorBroadcastConvert final : public SPIRVOpLowering { using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (broadcastOp.source().getType().isa() || !spirv::CompositeType::isValid(broadcastOp.getVectorType())) return failure(); vector::BroadcastOp::Adaptor adaptor(operands); SmallVector source(broadcastOp.getVectorType().getNumElements(), adaptor.source()); Value construct = rewriter.create( broadcastOp.getLoc(), broadcastOp.getVectorType(), source); rewriter.replaceOp(broadcastOp, construct); return success(); } }; struct VectorExtractOpConvert final : public SPIRVOpLowering { using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (extractOp.getType().isa() || !spirv::CompositeType::isValid(extractOp.getVectorType())) return failure(); vector::ExtractOp::Adaptor adaptor(operands); int32_t id = extractOp.position().begin()->cast().getInt(); Value newExtract = rewriter.create( extractOp.getLoc(), adaptor.vector(), id); rewriter.replaceOp(extractOp, newExtract); return success(); } }; struct VectorInsertOpConvert final : public SPIRVOpLowering { using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(vector::InsertOp insertOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (insertOp.getSourceType().isa() || !spirv::CompositeType::isValid(insertOp.getDestVectorType())) return failure(); vector::InsertOp::Adaptor adaptor(operands); int32_t id = insertOp.position().begin()->cast().getInt(); Value newInsert = rewriter.create( insertOp.getLoc(), adaptor.source(), adaptor.dest(), id); rewriter.replaceOp(insertOp, newInsert); return success(); } }; struct VectorExtractElementOpConvert final : public SPIRVOpLowering { using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(vector::ExtractElementOp extractElementOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) return failure(); vector::ExtractElementOp::Adaptor adaptor(operands); Value newExtractElement = rewriter.create( extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(), extractElementOp.position()); rewriter.replaceOp(extractElementOp, newExtractElement); return success(); } }; struct VectorInsertElementOpConvert final : public SPIRVOpLowering { using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(vector::InsertElementOp insertElementOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) return failure(); vector::InsertElementOp::Adaptor adaptor(operands); Value newInsertElement = rewriter.create( insertElementOp.getLoc(), insertElementOp.getType(), insertElementOp.dest(), adaptor.source(), insertElementOp.position()); rewriter.replaceOp(insertElementOp, newInsertElement); return success(); } }; } // namespace void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert(context, typeConverter); }