//===- StorageBase.cpp - TACO-flavored sparse tensor representation -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains method definitions for `SparseTensorStorageBase`. // In particular we want to ensure that the default implementations of // the "partial method specialization" trick aren't inline (since there's // no benefit). Though this also helps ensure that we avoid weak-vtables: // // // This file is part of the lightweight runtime support library for sparse // tensor manipulations. The functionality of the support library is meant // to simplify benchmarking, testing, and debugging MLIR code operating on // sparse tensors. However, the provided functionality is **not** part of // core MLIR itself. // //===----------------------------------------------------------------------===// #include "mlir/ExecutionEngine/SparseTensor/Storage.h" using namespace mlir::sparse_tensor; SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, const uint64_t *lvlSizes, const DimLevelType *lvlTypes, const uint64_t *lvl2dim) : dimSizes(dimSizes, dimSizes + dimRank), lvlSizes(lvlSizes, lvlSizes + lvlRank), lvlTypes(lvlTypes, lvlTypes + lvlRank), lvl2dim(lvl2dim, lvl2dim + lvlRank) { // TODO: If we do get any nullptrs, I'm pretty sure these assertions // will run too late (i.e., after copying things into vectors above). // But since those fields are const I'm not sure there's any clean way // to assert things before copying... assert(dimSizes && "Got nullptr for dimension sizes"); assert(lvlSizes && "Got nullptr for level sizes"); assert(lvlTypes && "Got nullptr for level types"); assert(lvl2dim && "Got nullptr for level-to-dimension mapping"); // Validate dim-indexed parameters. assert(dimRank > 0 && "Trivial shape is unsupported"); for (uint64_t d = 0; d < dimRank; ++d) assert(dimSizes[d] > 0 && "Dimension size zero has trivial storage"); // Validate level-indexed parameters. assert(lvlRank > 0 && "Trivial shape is unsupported"); for (uint64_t l = 0; l < lvlRank; ++l) { assert(lvlSizes[l] > 0 && "Level size zero has trivial storage"); const auto dlt = lvlTypes[l]; // Avoid redundant bounds checking. // We use `MLIR_SPARSETENSOR_FATAL` here instead of `assert` so that // when this ctor is successful then all the methods can rely on the // fact that each level-type satisfies one of these options (even // when `NDEBUG` is true), thereby reducing the need to re-assert things. if (!(isDenseDLT(dlt) || isCompressedDLT(dlt) || isSingletonDLT(dlt))) MLIR_SPARSETENSOR_FATAL("unsupported level type: %d\n", static_cast(dlt)); } } // Helper macro for generating error messages when some // `SparseTensorStorage` is cast to `SparseTensorStorageBase` // and then the wrong "partial method specialization" is called. #define FATAL_PIV(NAME) \ MLIR_SPARSETENSOR_FATAL(" type mismatch for: " #NAME); #define IMPL_NEWENUMERATOR(VNAME, V) \ void SparseTensorStorageBase::newEnumerator( \ SparseTensorEnumeratorBase **, uint64_t, const uint64_t *, uint64_t, \ const uint64_t *) const { \ FATAL_PIV("newEnumerator" #VNAME); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_NEWENUMERATOR) #undef IMPL_NEWENUMERATOR #define IMPL_GETPOSITIONS(PNAME, P) \ void SparseTensorStorageBase::getPositions(std::vector

**, uint64_t) { \ FATAL_PIV("getPositions" #PNAME); \ } MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETPOSITIONS) #undef IMPL_GETPOSITIONS #define IMPL_GETCOORDINATES(CNAME, C) \ void SparseTensorStorageBase::getCoordinates(std::vector **, uint64_t) { \ FATAL_PIV("getCoordinates" #CNAME); \ } MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES) #undef IMPL_GETCOORDINATES #define IMPL_GETVALUES(VNAME, V) \ void SparseTensorStorageBase::getValues(std::vector **) { \ FATAL_PIV("getValues" #VNAME); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETVALUES) #undef IMPL_GETVALUES #define IMPL_LEXINSERT(VNAME, V) \ void SparseTensorStorageBase::lexInsert(const uint64_t *, V) { \ FATAL_PIV("lexInsert" #VNAME); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT) #undef IMPL_LEXINSERT #define IMPL_EXPINSERT(VNAME, V) \ void SparseTensorStorageBase::expInsert(uint64_t *, V *, bool *, uint64_t *, \ uint64_t) { \ FATAL_PIV("expInsert" #VNAME); \ } MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT) #undef IMPL_EXPINSERT #undef FATAL_PIV