/*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #include #include #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/complex.h" #include "cutlass/blas3.h" #include "cutlass/layout/matrix.h" #include "cutlass/library/library.h" #include "cutlass/library/util.h" namespace cutlass { namespace library { ///////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; Provider enumerant; } Provider_enumerants[] = { {"none", "None", Provider::kNone}, {"cutlass", "CUTLASS", Provider::kCUTLASS}, {"host", "reference_host", Provider::kReferenceHost}, {"device", "reference_device", Provider::kReferenceDevice}, {"cublas", "cuBLAS", Provider::kCUBLAS}, {"cudnn", "cuDNN", Provider::kCUDNN}, }; /// Converts a Provider enumerant to a string char const *to_string(Provider provider, bool pretty) { for (auto const & possible : Provider_enumerants) { if (provider == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /// Parses a Provider enumerant from a string template <> Provider from_string(std::string const &str) { for (auto const & possible : Provider_enumerants) { if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } return Provider::kInvalid; } /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; GemmKind enumerant; } GemmKind_enumerants[] = { {"gemm", "", GemmKind::kGemm}, {"spgemm", "", GemmKind::kSparse}, {"universal", "", GemmKind::kUniversal}, {"planar_complex", "", GemmKind::kPlanarComplex}, {"planar_complex_array", "", GemmKind::kPlanarComplexArray}, {"grouped", "", GemmKind::kGrouped}, }; /// Converts a GemmKind enumerant to a string char const *to_string(GemmKind type, bool pretty) { for (auto const & possible : GemmKind_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; RankKKind enumerant; } RankKKind_enumerants[] = { {"universal", "", RankKKind::kUniversal}, }; /// Converts a SyrkKind enumerant to a string char const *to_string(RankKKind type, bool pretty) { for (auto const & possible :RankKKind_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; TrmmKind enumerant; } TrmmKind_enumerants[] = { {"universal", "", TrmmKind::kUniversal}, }; /// Converts a TrmmKind enumerant to a string char const *to_string(TrmmKind type, bool pretty) { for (auto const & possible :TrmmKind_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; SymmKind enumerant; } SymmKind_enumerants[] = { {"universal", "", SymmKind::kUniversal}, }; /// Converts a SymmKind enumerant to a string char const *to_string(SymmKind type, bool pretty) { for (auto const & possible :SymmKind_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; SideMode enumerant; } SideMode_enumerants[] = { {"left", "Left", SideMode::kLeft}, {"right", "Right", SideMode::kRight} }; /// Converts a SideMode enumerant to a string char const *to_string(SideMode type, bool pretty) { for (auto const & possible :SideMode_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; FillMode enumerant; } FillMode_enumerants[] = { {"lower", "Lower", FillMode::kLower}, {"upper", "Upper", FillMode::kUpper} }; /// Converts a FillMode enumerant to a string char const *to_string(FillMode type, bool pretty) { for (auto const & possible :FillMode_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; BlasMode enumerant; } BlasMode_enumerants[] = { {"symmetric", "Symmetric", BlasMode::kSymmetric}, {"hermitian", "Hermitian", BlasMode::kHermitian} }; /// Converts a BlasMode enumerant to a string char const *to_string(BlasMode type, bool pretty) { for (auto const & possible :BlasMode_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; DiagType enumerant; } DiagType_enumerants[] = { {"nonunit", "NonUnit", DiagType::kNonUnit}, {"unit", "Unit", DiagType::kUnit} }; /// Converts a DiagType enumerant to a string char const *to_string(DiagType type, bool pretty) { for (auto const & possible :DiagType_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } ///////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; OperationKind enumerant; } OperationKind_enumerants[] = { {"eq_gemm", "EqGemm", OperationKind::kEqGemm}, {"gemm", "Gemm", OperationKind::kGemm}, {"rank_k", "RankK", OperationKind::kRankK}, {"rank_2k", "Rank2K", OperationKind::kRank2K}, {"trmm", "Trmm", OperationKind::kTrmm}, {"symm", "Symm", OperationKind::kSymm}, {"conv2d", "Conv2d", OperationKind::kConv2d}, {"conv3d", "Conv3d", OperationKind::kConv3d}, {"spgemm", "SparseGemm", OperationKind::kSparseGemm}, }; /// Converts a Status enumerant to a string char const *to_string(OperationKind enumerant, bool pretty) { for (auto const & possible : OperationKind_enumerants) { if (enumerant == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /// Converts a Status enumerant from a string template <> OperationKind from_string(std::string const &str) { for (auto const & possible : OperationKind_enumerants) { if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } return OperationKind::kInvalid; } ///////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; Status enumerant; } Status_enumerants[] = { {"success", "Success", Status::kSuccess}, {"misaligned_operand", "Error: misaligned operand", Status::kErrorMisalignedOperand}, {"invalid_problem", "Error: invalid problem", Status::kErrorInvalidProblem}, {"not_supported", "Error: not supported", Status::kErrorNotSupported}, {"internal", "Error: internal", Status::kErrorInternal} }; /// Converts a Status enumerant to a string char const *to_string(Status status, bool pretty) { for (auto const & possible : Status_enumerants) { if (status == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /// Converts a Status enumerant from a string template <> Status from_string(std::string const &str) { for (auto const & possible : Status_enumerants) { if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } return Status::kInvalid; } /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; NumericTypeID enumerant; } NumericTypeID_enumerants[] = { {"unknown", "", NumericTypeID::kUnknown}, {"void", "Void", NumericTypeID::kVoid}, {"b1", "B1", NumericTypeID::kB1}, {"u2", "U2", NumericTypeID::kU2}, {"u4", "U4", NumericTypeID::kU4}, {"u8", "U8", NumericTypeID::kU8}, {"u16", "U16", NumericTypeID::kU16}, {"u32", "U32", NumericTypeID::kU32}, {"u64", "U64", NumericTypeID::kU64}, {"s2", "S2", NumericTypeID::kS2}, {"s4", "S4", NumericTypeID::kS4}, {"s8", "S8", NumericTypeID::kS8}, {"s16", "S16", NumericTypeID::kS16}, {"s32", "S32", NumericTypeID::kS32}, {"s64", "S64", NumericTypeID::kS64}, {"fe4m3", "FE4M3", NumericTypeID::kFE4M3}, {"fe5m2", "FE5M2", NumericTypeID::kFE5M2}, {"f16", "F16", NumericTypeID::kF16}, {"bf16", "BF16", NumericTypeID::kBF16}, {"f32", "F32", NumericTypeID::kF32}, {"tf32", "TF32", NumericTypeID::kTF32}, {"f64", "F64", NumericTypeID::kF64}, {"cf16", "CF16", NumericTypeID::kCF16}, {"cbf16", "CBF16", NumericTypeID::kCBF16}, {"cf32", "CF32", NumericTypeID::kCF32}, {"ctf32", "CTF32", NumericTypeID::kCTF32}, {"cf64", "CF64", NumericTypeID::kCF64}, {"cu2", "CU2", NumericTypeID::kCU2}, {"cu4", "CU4", NumericTypeID::kCU4}, {"cu8", "CU8", NumericTypeID::kCU8}, {"cu16", "CU16", NumericTypeID::kCU16}, {"cu32", "CU32", NumericTypeID::kCU32}, {"cu64", "CU64", NumericTypeID::kCU64}, {"cs2", "CS2", NumericTypeID::kCS2}, {"cs4", "CS4", NumericTypeID::kCS4}, {"cs8", "CS8", NumericTypeID::kCS8}, {"cs16", "CS16", NumericTypeID::kCS16}, {"cs32", "CS32", NumericTypeID::kCS32}, {"cs64", "CS64", NumericTypeID::kCS64}, {"*", "", NumericTypeID::kUnknown} }; /// Converts a NumericTypeID enumerant to a string char const *to_string(NumericTypeID type, bool pretty) { for (auto const & possible : NumericTypeID_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /// Parses a NumericTypeID enumerant from a string template <> NumericTypeID from_string(std::string const &str) { for (auto const & possible : NumericTypeID_enumerants) { if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } return NumericTypeID::kInvalid; } ///////////////////////////////////////////////////////////////////////////////////////////////// /// Returns the size of a data type in bits int sizeof_bits(NumericTypeID type) { switch (type) { case NumericTypeID::kFE4M3: return 8; case NumericTypeID::kFE5M2: return 8; case NumericTypeID::kF16: return 16; case NumericTypeID::kBF16: return 16; case NumericTypeID::kTF32: return 32; case NumericTypeID::kF32: return 32; case NumericTypeID::kF64: return 64; case NumericTypeID::kCF16: return 32; case NumericTypeID::kCBF16: return 32; case NumericTypeID::kCF32: return 64; case NumericTypeID::kCTF32: return 64; case NumericTypeID::kCF64: return 128; case NumericTypeID::kS2: return 2; case NumericTypeID::kS4: return 4; case NumericTypeID::kS8: return 8; case NumericTypeID::kS16: return 16; case NumericTypeID::kS32: return 32; case NumericTypeID::kS64: return 64; case NumericTypeID::kU2: return 2; case NumericTypeID::kU4: return 4; case NumericTypeID::kU8: return 8; case NumericTypeID::kU16: return 16; case NumericTypeID::kU32: return 32; case NumericTypeID::kU64: return 64; case NumericTypeID::kB1: return 1; default: break; } return 0; } /// Returns true if the numeric type is a complex data type or false if real-valued. bool is_complex_type(NumericTypeID type) { switch (type) { case NumericTypeID::kCF16: return true; case NumericTypeID::kCF32: return true; case NumericTypeID::kCF64: return true; case NumericTypeID::kCBF16: return true; case NumericTypeID::kCTF32: return true; default: break; } return false; } /// Returns the field underlying a complex valued type NumericTypeID get_real_type(NumericTypeID type) { switch (type) { case NumericTypeID::kCF16: return NumericTypeID::kF16; case NumericTypeID::kCF32: return NumericTypeID::kF32; case NumericTypeID::kCF64: return NumericTypeID::kF64; case NumericTypeID::kCBF16: return NumericTypeID::kBF16; case NumericTypeID::kCTF32: return NumericTypeID::kTF32; default: break; } return type; } /// Returns true if numeric type is integer bool is_integer_type(NumericTypeID type) { switch (type) { case NumericTypeID::kS2: return true; case NumericTypeID::kS4: return true; case NumericTypeID::kS8: return true; case NumericTypeID::kS16: return true; case NumericTypeID::kS32: return true; case NumericTypeID::kS64: return true; case NumericTypeID::kU2: return true; case NumericTypeID::kU4: return true; case NumericTypeID::kU8: return true; case NumericTypeID::kU16: return true; case NumericTypeID::kU32: return true; case NumericTypeID::kU64: return true; default: break; } return false; } /// Returns true if numeric type is signed bool is_signed_type(NumericTypeID type) { switch (type) { case NumericTypeID::kFE4M3: return true; case NumericTypeID::kFE5M2: return true; case NumericTypeID::kF16: return true; case NumericTypeID::kBF16: return true; case NumericTypeID::kTF32: return true; case NumericTypeID::kF32: return true; case NumericTypeID::kF64: return true; case NumericTypeID::kS2: return true; case NumericTypeID::kS4: return true; case NumericTypeID::kS8: return true; case NumericTypeID::kS16: return true; case NumericTypeID::kS32: return true; case NumericTypeID::kS64: return true; default: break; } return false; } /// Returns true if numeric type is a signed integer bool is_signed_integer(NumericTypeID type) { return is_integer_type(type) && is_signed_type(type); } /// returns true if numeric type is an unsigned integer bool is_unsigned_integer(NumericTypeID type) { return is_integer_type(type) && !is_signed_type(type); } /// Returns true if numeric type is floating-point type bool is_float_type(NumericTypeID type) { switch (type) { case NumericTypeID::kFE4M3: return true; case NumericTypeID::kFE5M2: return true; case NumericTypeID::kF16: return true; case NumericTypeID::kBF16: return true; case NumericTypeID::kTF32: return true; case NumericTypeID::kF32: return true; case NumericTypeID::kF64: return true; case NumericTypeID::kCF16: return true; case NumericTypeID::kCBF16: return true; case NumericTypeID::kCTF32: return true; case NumericTypeID::kCF32: return true; case NumericTypeID::kCF64: return true; default: break; } return false; } ///////////////////////////////////////////////////////////////////////////////////////////////// static struct { LayoutTypeID layout; char const *alias; } layout_aliases[] = { {LayoutTypeID::kUnknown, "unknown"}, {LayoutTypeID::kRowMajor, "row"}, {LayoutTypeID::kRowMajor, "t"}, {LayoutTypeID::kColumnMajor, "column"}, {LayoutTypeID::kColumnMajor, "col"}, {LayoutTypeID::kColumnMajor, "n"}, {LayoutTypeID::kColumnMajorInterleavedK2, "nk2"}, {LayoutTypeID::kRowMajorInterleavedK2, "tk2"}, {LayoutTypeID::kColumnMajorInterleavedK4, "nk4"}, {LayoutTypeID::kRowMajorInterleavedK4, "tk4"}, {LayoutTypeID::kColumnMajorInterleavedK16, "nk16"}, {LayoutTypeID::kRowMajorInterleavedK16, "tk16"}, {LayoutTypeID::kColumnMajorInterleavedK32, "nk32"}, {LayoutTypeID::kRowMajorInterleavedK32, "tk32"}, {LayoutTypeID::kColumnMajorInterleavedK64, "nk64"}, {LayoutTypeID::kRowMajorInterleavedK64, "tk64"}, {LayoutTypeID::kTensorNCHW, "nchw"}, {LayoutTypeID::kTensorNCDHW, "ncdhw"}, {LayoutTypeID::kTensorNHWC, "nhwc"}, {LayoutTypeID::kTensorNDHWC, "ndhwc"}, {LayoutTypeID::kTensorNC32HW32, "nc32hw32"}, {LayoutTypeID::kTensorNC64HW64, "nc64hw64"}, {LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, {LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, {LayoutTypeID::kUnknown, "*"}, {LayoutTypeID::kInvalid, nullptr} }; /// Converts a LayoutTypeID enumerant to a string char const *to_string(LayoutTypeID layout, bool pretty) { for (auto const & alias : layout_aliases) { if (alias.layout == layout) { return alias.alias; } } return pretty ? "Invalid" : "invalid"; } /// Parses a LayoutTypeID enumerant from a string template <> LayoutTypeID from_string(std::string const &str) { for (auto const & alias : layout_aliases) { if (str.compare(alias.alias) == 0) { return alias.layout; } } return LayoutTypeID::kInvalid; } /// Gets stride rank for the layout_id (static function) int get_layout_stride_rank(LayoutTypeID layout_id) { switch (layout_id) { case LayoutTypeID::kColumnMajor: return cutlass::layout::ColumnMajor::kStrideRank; case LayoutTypeID::kRowMajor: return cutlass::layout::RowMajor::kStrideRank; case LayoutTypeID::kColumnMajorInterleavedK2: return cutlass::layout::ColumnMajorInterleaved<2>::kStrideRank; case LayoutTypeID::kRowMajorInterleavedK2: return cutlass::layout::RowMajorInterleaved<2>::kStrideRank; case LayoutTypeID::kColumnMajorInterleavedK4: return cutlass::layout::ColumnMajorInterleaved<4>::kStrideRank; case LayoutTypeID::kRowMajorInterleavedK4: return cutlass::layout::RowMajorInterleaved<4>::kStrideRank; case LayoutTypeID::kColumnMajorInterleavedK16: return cutlass::layout::ColumnMajorInterleaved<16>::kStrideRank; case LayoutTypeID::kRowMajorInterleavedK16: return cutlass::layout::RowMajorInterleaved<16>::kStrideRank; case LayoutTypeID::kColumnMajorInterleavedK32: return cutlass::layout::ColumnMajorInterleaved<32>::kStrideRank; case LayoutTypeID::kRowMajorInterleavedK32: return cutlass::layout::RowMajorInterleaved<32>::kStrideRank; case LayoutTypeID::kColumnMajorInterleavedK64: return cutlass::layout::ColumnMajorInterleaved<64>::kStrideRank; case LayoutTypeID::kRowMajorInterleavedK64: return cutlass::layout::RowMajorInterleaved<64>::kStrideRank; case LayoutTypeID::kTensorNCHW: return cutlass::layout::TensorNCHW::kStrideRank; case LayoutTypeID::kTensorNHWC: return cutlass::layout::TensorNHWC::kStrideRank; case LayoutTypeID::kTensorNDHWC: return cutlass::layout::TensorNDHWC::kStrideRank; case LayoutTypeID::kTensorNC32HW32: return cutlass::layout::TensorNCxHWx<32>::kStrideRank; case LayoutTypeID::kTensorNC64HW64: return cutlass::layout::TensorNCxHWx<64>::kStrideRank; case LayoutTypeID::kTensorC32RSK32: return cutlass::layout::TensorCxRSKx<32>::kStrideRank; case LayoutTypeID::kTensorC64RSK64: return cutlass::layout::TensorCxRSKx<64>::kStrideRank; default: throw std::runtime_error("Unsupported LayoutTypeID in LayoutType::get_stride_rank"); } } ///////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; OpcodeClassID enumerant; } OpcodeClassID_enumerants[] = { {"simt", "", OpcodeClassID::kSimt}, {"tensorop", "", OpcodeClassID::kTensorOp}, {"wmmatensorop", "", OpcodeClassID::kWmmaTensorOp}, {"wmma", "", OpcodeClassID::kWmmaTensorOp}, }; /// Converts a OpcodeClassID enumerant to a string char const *to_string(OpcodeClassID type, bool pretty) { for (auto const & possible : OpcodeClassID_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /// Converts a OpcodeClassID enumerant from a string template <> OpcodeClassID from_string(std::string const &str) { for (auto const & possible : OpcodeClassID_enumerants) { if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } return OpcodeClassID::kInvalid; } ///////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; ComplexTransform enumerant; } ComplexTransform_enumerants[] = { {"n", "none", ComplexTransform::kNone}, {"c", "conj", ComplexTransform::kConjugate} }; /// Converts a ComplexTransform enumerant to a string char const *to_string(ComplexTransform type, bool pretty) { for (auto const & possible : ComplexTransform_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /// Converts a ComplexTransform enumerant from a string template <> ComplexTransform from_string(std::string const &str) { for (auto const & possible : ComplexTransform_enumerants) { if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } return ComplexTransform::kInvalid; } static struct { char const *text; char const *pretty; SplitKMode enumerant; } SplitKMode_enumerants[] = { {"serial", "", SplitKMode::kSerial}, {"parallel", "", SplitKMode::kParallel}, }; /// Converts a SplitKMode enumerant to a string char const *to_string(SplitKMode type, bool pretty) { for (auto const & possible : SplitKMode_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /// Converts a SplitKMode enumerant from a string template <> SplitKMode from_string(std::string const &str) { for (auto const & possible : SplitKMode_enumerants) { if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } return SplitKMode::kInvalid; } ///////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; ConvModeID enumerant; } ConvModeID_enumerants[] = { {"cross", "", ConvModeID::kCrossCorrelation}, {"conv", "", ConvModeID::kConvolution}, }; /// Converts a ConvModeID enumerant to a string char const *to_string(ConvModeID type, bool pretty) { for (auto const & possible : ConvModeID_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /// Converts a ConvModeID enumerant from a string template <> ConvModeID from_string(std::string const &str) { for (auto const & possible : ConvModeID_enumerants) { if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } return ConvModeID::kInvalid; } static struct { char const *text; char const *pretty; IteratorAlgorithmID enumerant; } IteratorAlgorithmID_enumerants[] = { {"none", "", IteratorAlgorithmID::kNone}, {"analytic", "", IteratorAlgorithmID::kAnalytic}, {"optimized", "", IteratorAlgorithmID::kOptimized}, {"fixed_channels", "", IteratorAlgorithmID::kFixedChannels}, {"few_channels", "", IteratorAlgorithmID::kFewChannels}, }; /// Converts a ConvModeID enumerant to a string char const *to_string(IteratorAlgorithmID type, bool pretty) { for (auto const & possible : IteratorAlgorithmID_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /// Converts a ConvModeID enumerant from a string template <> IteratorAlgorithmID from_string(std::string const &str) { for (auto const & possible : IteratorAlgorithmID_enumerants) { if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } return IteratorAlgorithmID::kInvalid; } /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; ConvKind enumerant; } ConvKind_enumerants[] = { {"unknown", "", ConvKind::kUnknown}, {"fprop", "", ConvKind::kFprop}, {"dgrad", "", ConvKind::kDgrad}, {"wgrad", "", ConvKind::kWgrad}, }; /// Converts a ConvKind enumerant to a string char const *to_string(ConvKind type, bool pretty) { for (auto const & possible : ConvKind_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /// Converts a ConvKind enumerant from a string template <> ConvKind from_string(std::string const &str) { for (auto const & possible : ConvKind_enumerants) { if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } return ConvKind::kInvalid; } /////////////////////////////////////////////////////////////////////////////////////////////////// static struct { char const *text; char const *pretty; RasterOrder enumerant; } RasterOrder_enumerants[] = { {"along_n", "", RasterOrder::kAlongN}, {"along_m", "", RasterOrder::kAlongM}, {"heuristic", "", RasterOrder::kHeuristic}, }; /// Converts a RasterOrder enumerant to a string char const *to_string(RasterOrder type, bool pretty) { for (auto const & possible : RasterOrder_enumerants) { if (type == possible.enumerant) { if (pretty) { return possible.pretty; } else { return possible.text; } } } return pretty ? "Invalid" : "invalid"; } /// Converts a RasterOrder enumerant from a string template <> RasterOrder from_string(std::string const &str) { for (auto const & possible : RasterOrder_enumerants) { if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } return RasterOrder::kInvalid; } /////////////////////////////////////////////////////////////////////////////////////////////////// /// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str) { int size_bytes = sizeof_bits(type) / 8; if (!size_bytes) { return false; } bytes.resize(size_bytes, 0); std::stringstream ss; ss << str; switch (type) { case NumericTypeID::kU8: { ss >> *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kU16: { ss >> *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kU32: { ss >> *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kU64: { ss >> *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kS8: { ss >> *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kS16: { ss >> *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kS32: { ss >> *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kS64: { ss >> *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kFE4M3: { float tmp; ss >> tmp; *reinterpret_cast(bytes.data()) = static_cast(tmp); } break; case NumericTypeID::kFE5M2: { float tmp; ss >> tmp; *reinterpret_cast(bytes.data()) = static_cast(tmp); } break; case NumericTypeID::kF16: { float tmp; ss >> tmp; *reinterpret_cast(bytes.data()) = static_cast(tmp); } break; case NumericTypeID::kBF16: { float tmp; ss >> tmp; *reinterpret_cast(bytes.data()) = static_cast(tmp); } break; case NumericTypeID::kTF32: { float tmp; ss >> tmp; *reinterpret_cast(bytes.data()) = static_cast(tmp); } break; case NumericTypeID::kF32: { ss >> *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kF64: { ss >> *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kCF16: { std::complex tmp; ss >> tmp; cutlass::complex *x = reinterpret_cast *>(bytes.data()); x->real() = static_cast(std::real(tmp)); x->imag() = static_cast(std::imag(tmp)); } break; case NumericTypeID::kCBF16: { std::complex tmp; ss >> tmp; cutlass::complex *x = reinterpret_cast *>(bytes.data()); x->real() = static_cast(std::real(tmp)); x->imag() = static_cast(std::imag(tmp)); } break; case NumericTypeID::kCF32: { ss >> *reinterpret_cast*>(bytes.data()); } break; case NumericTypeID::kCTF32: { std::complex tmp; ss >> tmp; cutlass::complex *x = reinterpret_cast *>(bytes.data()); x->real() = static_cast(std::real(tmp)); x->imag() = static_cast(std::imag(tmp)); } break; case NumericTypeID::kCF64: { ss >> *reinterpret_cast*>(bytes.data()); } break; default: return false; } return true; } /////////////////////////////////////////////////////////////////////////////////////////////////// std::string lexical_cast(int64_t int_value) { std::stringstream ss; ss << int_value; return ss.str(); } /// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. std::string lexical_cast(std::vector &bytes, NumericTypeID type) { size_t size_bytes = sizeof_bits(type) / 8; if (!size_bytes || size_bytes != bytes.size()) { return ""; } bytes.resize(size_bytes, 0); std::stringstream ss; switch (type) { case NumericTypeID::kU8: { ss << *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kU16: { ss << *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kU32: { ss << *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kU64: { ss << *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kS8: { ss << *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kS16: { ss << *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kS32: { ss << *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kS64: { ss << *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kFE4M3: { float tmp = *reinterpret_cast(bytes.data()); ss << tmp; } break; case NumericTypeID::kFE5M2: { float tmp = *reinterpret_cast(bytes.data()); ss << tmp; } break; case NumericTypeID::kF16: { float tmp = *reinterpret_cast(bytes.data()); ss << tmp; } break; case NumericTypeID::kBF16: { float tmp = *reinterpret_cast(bytes.data()); ss << tmp; } break; case NumericTypeID::kTF32: { float tmp = *reinterpret_cast(bytes.data()); ss << tmp; } break; case NumericTypeID::kF32: { ss << *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kF64: { ss << *reinterpret_cast(bytes.data()); } break; case NumericTypeID::kCF16: { cutlass::complex const *x = reinterpret_cast const *>(bytes.data()); ss << float(x->real()); if (x->imag() != cutlass::half_t()) { ss << "+i" << float(x->imag()); } } break; case NumericTypeID::kCBF16: { cutlass::complex const *x = reinterpret_cast const *>(bytes.data()); ss << float(x->real()); if (x->imag() != cutlass::bfloat16_t()) { ss << "+i" << float(x->imag()); } } break; case NumericTypeID::kCF32: { cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); ss << x->real(); if (x->imag() != float()) { ss << "+i" << x->imag(); } } break; case NumericTypeID::kCTF32: { cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); ss << float(x->real()); if (x->imag() != tfloat32_t()) { ss << "+i" << float(x->imag()); } } break; case NumericTypeID::kCF64: { cutlass::complex const * x = reinterpret_cast const *>(bytes.data()); ss << x->real(); if (x->imag() != double()) { ss << "+i" << x->imag(); } } break; default: return ""; } return ss.str(); } /// Casts from a signed int64 to the destination type. Returns true if successful. bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src) { int size_bytes = sizeof_bits(type) / 8; if (!size_bytes) { return false; } bytes.resize(size_bytes, 0); switch (type) { case NumericTypeID::kU8: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU16: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU32: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU64: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS8: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS16: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS32: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS64: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kFE4M3: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kFE5M2: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kBF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kTF32: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kF32: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kF64: { *reinterpret_cast(bytes.data()) = double(src); } break; case NumericTypeID::kCF16: { cutlass::complex *x = reinterpret_cast *>(bytes.data()); x->real() = static_cast(float(src)); x->imag() = static_cast(float(0)); } break; case NumericTypeID::kCF32: { *reinterpret_cast*>(bytes.data()) = cutlass::complex(float(src), float(0)); } break; case NumericTypeID::kCF64: { *reinterpret_cast*>(bytes.data()) = cutlass::complex(double(src), double(0)); } break; default: return false; } return true; } /// Casts from an unsigned int64 to the destination type. Returns true if successful. bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src) { int size_bytes = sizeof_bits(type) / 8; if (!size_bytes) { return false; } bytes.resize(size_bytes, 0); switch (type) { case NumericTypeID::kU8: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU16: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU32: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU64: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS8: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS16: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS32: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS64: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kFE4M3: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kFE5M2: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kBF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kTF32: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kF32: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kF64: { *reinterpret_cast(bytes.data()) = double(src); } break; case NumericTypeID::kCF16: { cutlass::complex *x = reinterpret_cast *>(bytes.data()); x->real() = static_cast(float(src)); x->imag() = static_cast(float(0)); } break; case NumericTypeID::kCF32: { *reinterpret_cast*>(bytes.data()) = std::complex(float(src), float(0)); } break; case NumericTypeID::kCF64: { *reinterpret_cast*>(bytes.data()) = std::complex(double(src), double(0)); } break; default: return false; } return true; } /// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. bool cast_from_double(std::vector &bytes, NumericTypeID type, double src) { int size_bytes = sizeof_bits(type) / 8; if (!size_bytes) { return false; } bytes.resize(size_bytes, 0); switch (type) { case NumericTypeID::kU8: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU16: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU32: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU64: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS8: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS16: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS32: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS64: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kFE4M3: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kFE5M2: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kBF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kTF32: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kF32: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kF64: { *reinterpret_cast(bytes.data()) = src; } break; case NumericTypeID::kCF16: { cutlass::complex *x = reinterpret_cast *>(bytes.data()); x->real() = static_cast(float(src)); x->imag() = static_cast(float(0)); } break; case NumericTypeID::kCBF16: { cutlass::complex *x = reinterpret_cast *>(bytes.data()); x->real() = static_cast(bfloat16_t(src)); x->imag() = static_cast(bfloat16_t(0)); } break; case NumericTypeID::kCF32: { *reinterpret_cast*>(bytes.data()) = cutlass::complex(float(src), float()); } break; case NumericTypeID::kCTF32: { *reinterpret_cast*>(bytes.data()) = cutlass::complex(tfloat32_t(src), tfloat32_t()); } break; case NumericTypeID::kCF64: { *reinterpret_cast*>(bytes.data()) = cutlass::complex(src, double()); } break; default: return false; } return true; } /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////////