Spaces:
Sleeping
Sleeping
/*************************************************************************************************** | |
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright notice, this | |
* list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright notice, | |
* this list of conditions and the following disclaimer in the documentation | |
* and/or other materials provided with the distribution. | |
* | |
* 3. Neither the name of the copyright holder nor the names of its | |
* contributors may be used to endorse or promote products derived from | |
* this software without specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
* | |
**************************************************************************************************/ | |
#include <iosfwd> | |
#include <complex> | |
#include "cutlass/cutlass.h" | |
#include "cutlass/numeric_types.h" | |
#include "cutlass/complex.h" | |
#include "cutlass/blas3.h" | |
#include "cutlass/layout/matrix.h" | |
#include "cutlass/library/library.h" | |
#include "cutlass/library/util.h" | |
namespace cutlass { | |
namespace library { | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
Provider enumerant; | |
} | |
Provider_enumerants[] = { | |
{"none", "None", Provider::kNone}, | |
{"cutlass", "CUTLASS", Provider::kCUTLASS}, | |
{"host", "reference_host", Provider::kReferenceHost}, | |
{"device", "reference_device", Provider::kReferenceDevice}, | |
{"cublas", "cuBLAS", Provider::kCUBLAS}, | |
{"cudnn", "cuDNN", Provider::kCUDNN}, | |
}; | |
/// Converts a Provider enumerant to a string | |
char const *to_string(Provider provider, bool pretty) { | |
for (auto const & possible : Provider_enumerants) { | |
if (provider == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Parses a Provider enumerant from a string | |
template <> | |
Provider from_string<Provider>(std::string const &str) { | |
for (auto const & possible : Provider_enumerants) { | |
if ((str.compare(possible.text) == 0) || | |
(str.compare(possible.pretty) == 0)) { | |
return possible.enumerant; | |
} | |
} | |
return Provider::kInvalid; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
GemmKind enumerant; | |
} | |
GemmKind_enumerants[] = { | |
{"gemm", "<Gemm>", GemmKind::kGemm}, | |
{"spgemm", "<Sparse>", GemmKind::kSparse}, | |
{"universal", "<Universal>", GemmKind::kUniversal}, | |
{"planar_complex", "<PlanarComplex>", GemmKind::kPlanarComplex}, | |
{"planar_complex_array", "<PlanarComplexArray>", GemmKind::kPlanarComplexArray}, | |
{"grouped", "<Grouped>", GemmKind::kGrouped}, | |
}; | |
/// Converts a GemmKind enumerant to a string | |
char const *to_string(GemmKind type, bool pretty) { | |
for (auto const & possible : GemmKind_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
RankKKind enumerant; | |
} | |
RankKKind_enumerants[] = { | |
{"universal", "<Universal>", RankKKind::kUniversal}, | |
}; | |
/// Converts a SyrkKind enumerant to a string | |
char const *to_string(RankKKind type, bool pretty) { | |
for (auto const & possible :RankKKind_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
TrmmKind enumerant; | |
} | |
TrmmKind_enumerants[] = { | |
{"universal", "<Universal>", TrmmKind::kUniversal}, | |
}; | |
/// Converts a TrmmKind enumerant to a string | |
char const *to_string(TrmmKind type, bool pretty) { | |
for (auto const & possible :TrmmKind_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
SymmKind enumerant; | |
} | |
SymmKind_enumerants[] = { | |
{"universal", "<Universal>", SymmKind::kUniversal}, | |
}; | |
/// Converts a SymmKind enumerant to a string | |
char const *to_string(SymmKind type, bool pretty) { | |
for (auto const & possible :SymmKind_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
SideMode enumerant; | |
} | |
SideMode_enumerants[] = { | |
{"left", "Left", SideMode::kLeft}, | |
{"right", "Right", SideMode::kRight} | |
}; | |
/// Converts a SideMode enumerant to a string | |
char const *to_string(SideMode type, bool pretty) { | |
for (auto const & possible :SideMode_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
FillMode enumerant; | |
} | |
FillMode_enumerants[] = { | |
{"lower", "Lower", FillMode::kLower}, | |
{"upper", "Upper", FillMode::kUpper} | |
}; | |
/// Converts a FillMode enumerant to a string | |
char const *to_string(FillMode type, bool pretty) { | |
for (auto const & possible :FillMode_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
BlasMode enumerant; | |
} | |
BlasMode_enumerants[] = { | |
{"symmetric", "Symmetric", BlasMode::kSymmetric}, | |
{"hermitian", "Hermitian", BlasMode::kHermitian} | |
}; | |
/// Converts a BlasMode enumerant to a string | |
char const *to_string(BlasMode type, bool pretty) { | |
for (auto const & possible :BlasMode_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
DiagType enumerant; | |
} | |
DiagType_enumerants[] = { | |
{"nonunit", "NonUnit", DiagType::kNonUnit}, | |
{"unit", "Unit", DiagType::kUnit} | |
}; | |
/// Converts a DiagType enumerant to a string | |
char const *to_string(DiagType type, bool pretty) { | |
for (auto const & possible :DiagType_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
OperationKind enumerant; | |
} | |
OperationKind_enumerants[] = { | |
{"eq_gemm", "EqGemm", OperationKind::kEqGemm}, | |
{"gemm", "Gemm", OperationKind::kGemm}, | |
{"rank_k", "RankK", OperationKind::kRankK}, | |
{"rank_2k", "Rank2K", OperationKind::kRank2K}, | |
{"trmm", "Trmm", OperationKind::kTrmm}, | |
{"symm", "Symm", OperationKind::kSymm}, | |
{"conv2d", "Conv2d", OperationKind::kConv2d}, | |
{"conv3d", "Conv3d", OperationKind::kConv3d}, | |
{"spgemm", "SparseGemm", OperationKind::kSparseGemm}, | |
}; | |
/// Converts a Status enumerant to a string | |
char const *to_string(OperationKind enumerant, bool pretty) { | |
for (auto const & possible : OperationKind_enumerants) { | |
if (enumerant == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Converts a Status enumerant from a string | |
template <> | |
OperationKind from_string<OperationKind>(std::string const &str) { | |
for (auto const & possible : OperationKind_enumerants) { | |
if ((str.compare(possible.text) == 0) || | |
(str.compare(possible.pretty) == 0)) { | |
return possible.enumerant; | |
} | |
} | |
return OperationKind::kInvalid; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
Status enumerant; | |
} | |
Status_enumerants[] = { | |
{"success", "Success", Status::kSuccess}, | |
{"misaligned_operand", "Error: misaligned operand", Status::kErrorMisalignedOperand}, | |
{"invalid_problem", "Error: invalid problem", Status::kErrorInvalidProblem}, | |
{"not_supported", "Error: not supported", Status::kErrorNotSupported}, | |
{"internal", "Error: internal", Status::kErrorInternal} | |
}; | |
/// Converts a Status enumerant to a string | |
char const *to_string(Status status, bool pretty) { | |
for (auto const & possible : Status_enumerants) { | |
if (status == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Converts a Status enumerant from a string | |
template <> | |
Status from_string<Status>(std::string const &str) { | |
for (auto const & possible : Status_enumerants) { | |
if ((str.compare(possible.text) == 0) || | |
(str.compare(possible.pretty) == 0)) { | |
return possible.enumerant; | |
} | |
} | |
return Status::kInvalid; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
NumericTypeID enumerant; | |
} | |
NumericTypeID_enumerants[] = { | |
{"unknown", "<unknown>", NumericTypeID::kUnknown}, | |
{"void", "Void", NumericTypeID::kVoid}, | |
{"b1", "B1", NumericTypeID::kB1}, | |
{"u2", "U2", NumericTypeID::kU2}, | |
{"u4", "U4", NumericTypeID::kU4}, | |
{"u8", "U8", NumericTypeID::kU8}, | |
{"u16", "U16", NumericTypeID::kU16}, | |
{"u32", "U32", NumericTypeID::kU32}, | |
{"u64", "U64", NumericTypeID::kU64}, | |
{"s2", "S2", NumericTypeID::kS2}, | |
{"s4", "S4", NumericTypeID::kS4}, | |
{"s8", "S8", NumericTypeID::kS8}, | |
{"s16", "S16", NumericTypeID::kS16}, | |
{"s32", "S32", NumericTypeID::kS32}, | |
{"s64", "S64", NumericTypeID::kS64}, | |
{"fe4m3", "FE4M3", NumericTypeID::kFE4M3}, | |
{"fe5m2", "FE5M2", NumericTypeID::kFE5M2}, | |
{"f16", "F16", NumericTypeID::kF16}, | |
{"bf16", "BF16", NumericTypeID::kBF16}, | |
{"f32", "F32", NumericTypeID::kF32}, | |
{"tf32", "TF32", NumericTypeID::kTF32}, | |
{"f64", "F64", NumericTypeID::kF64}, | |
{"cf16", "CF16", NumericTypeID::kCF16}, | |
{"cbf16", "CBF16", NumericTypeID::kCBF16}, | |
{"cf32", "CF32", NumericTypeID::kCF32}, | |
{"ctf32", "CTF32", NumericTypeID::kCTF32}, | |
{"cf64", "CF64", NumericTypeID::kCF64}, | |
{"cu2", "CU2", NumericTypeID::kCU2}, | |
{"cu4", "CU4", NumericTypeID::kCU4}, | |
{"cu8", "CU8", NumericTypeID::kCU8}, | |
{"cu16", "CU16", NumericTypeID::kCU16}, | |
{"cu32", "CU32", NumericTypeID::kCU32}, | |
{"cu64", "CU64", NumericTypeID::kCU64}, | |
{"cs2", "CS2", NumericTypeID::kCS2}, | |
{"cs4", "CS4", NumericTypeID::kCS4}, | |
{"cs8", "CS8", NumericTypeID::kCS8}, | |
{"cs16", "CS16", NumericTypeID::kCS16}, | |
{"cs32", "CS32", NumericTypeID::kCS32}, | |
{"cs64", "CS64", NumericTypeID::kCS64}, | |
{"*", "<unknown/enumerate all>", NumericTypeID::kUnknown} | |
}; | |
/// Converts a NumericTypeID enumerant to a string | |
char const *to_string(NumericTypeID type, bool pretty) { | |
for (auto const & possible : NumericTypeID_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Parses a NumericTypeID enumerant from a string | |
template <> | |
NumericTypeID from_string<NumericTypeID>(std::string const &str) { | |
for (auto const & possible : NumericTypeID_enumerants) { | |
if ((str.compare(possible.text) == 0) || | |
(str.compare(possible.pretty) == 0)) { | |
return possible.enumerant; | |
} | |
} | |
return NumericTypeID::kInvalid; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Returns the size of a data type in bits | |
int sizeof_bits(NumericTypeID type) { | |
switch (type) { | |
case NumericTypeID::kFE4M3: return 8; | |
case NumericTypeID::kFE5M2: return 8; | |
case NumericTypeID::kF16: return 16; | |
case NumericTypeID::kBF16: return 16; | |
case NumericTypeID::kTF32: return 32; | |
case NumericTypeID::kF32: return 32; | |
case NumericTypeID::kF64: return 64; | |
case NumericTypeID::kCF16: return 32; | |
case NumericTypeID::kCBF16: return 32; | |
case NumericTypeID::kCF32: return 64; | |
case NumericTypeID::kCTF32: return 64; | |
case NumericTypeID::kCF64: return 128; | |
case NumericTypeID::kS2: return 2; | |
case NumericTypeID::kS4: return 4; | |
case NumericTypeID::kS8: return 8; | |
case NumericTypeID::kS16: return 16; | |
case NumericTypeID::kS32: return 32; | |
case NumericTypeID::kS64: return 64; | |
case NumericTypeID::kU2: return 2; | |
case NumericTypeID::kU4: return 4; | |
case NumericTypeID::kU8: return 8; | |
case NumericTypeID::kU16: return 16; | |
case NumericTypeID::kU32: return 32; | |
case NumericTypeID::kU64: return 64; | |
case NumericTypeID::kB1: return 1; | |
default: break; | |
} | |
return 0; | |
} | |
/// Returns true if the numeric type is a complex data type or false if real-valued. | |
bool is_complex_type(NumericTypeID type) { | |
switch (type) { | |
case NumericTypeID::kCF16: return true; | |
case NumericTypeID::kCF32: return true; | |
case NumericTypeID::kCF64: return true; | |
case NumericTypeID::kCBF16: return true; | |
case NumericTypeID::kCTF32: return true; | |
default: break; | |
} | |
return false; | |
} | |
/// Returns the field underlying a complex valued type | |
NumericTypeID get_real_type(NumericTypeID type) { | |
switch (type) { | |
case NumericTypeID::kCF16: return NumericTypeID::kF16; | |
case NumericTypeID::kCF32: return NumericTypeID::kF32; | |
case NumericTypeID::kCF64: return NumericTypeID::kF64; | |
case NumericTypeID::kCBF16: return NumericTypeID::kBF16; | |
case NumericTypeID::kCTF32: return NumericTypeID::kTF32; | |
default: break; | |
} | |
return type; | |
} | |
/// Returns true if numeric type is integer | |
bool is_integer_type(NumericTypeID type) { | |
switch (type) { | |
case NumericTypeID::kS2: return true; | |
case NumericTypeID::kS4: return true; | |
case NumericTypeID::kS8: return true; | |
case NumericTypeID::kS16: return true; | |
case NumericTypeID::kS32: return true; | |
case NumericTypeID::kS64: return true; | |
case NumericTypeID::kU2: return true; | |
case NumericTypeID::kU4: return true; | |
case NumericTypeID::kU8: return true; | |
case NumericTypeID::kU16: return true; | |
case NumericTypeID::kU32: return true; | |
case NumericTypeID::kU64: return true; | |
default: break; | |
} | |
return false; | |
} | |
/// Returns true if numeric type is signed | |
bool is_signed_type(NumericTypeID type) { | |
switch (type) { | |
case NumericTypeID::kFE4M3: return true; | |
case NumericTypeID::kFE5M2: return true; | |
case NumericTypeID::kF16: return true; | |
case NumericTypeID::kBF16: return true; | |
case NumericTypeID::kTF32: return true; | |
case NumericTypeID::kF32: return true; | |
case NumericTypeID::kF64: return true; | |
case NumericTypeID::kS2: return true; | |
case NumericTypeID::kS4: return true; | |
case NumericTypeID::kS8: return true; | |
case NumericTypeID::kS16: return true; | |
case NumericTypeID::kS32: return true; | |
case NumericTypeID::kS64: return true; | |
default: break; | |
} | |
return false; | |
} | |
/// Returns true if numeric type is a signed integer | |
bool is_signed_integer(NumericTypeID type) { | |
return is_integer_type(type) && is_signed_type(type); | |
} | |
/// returns true if numeric type is an unsigned integer | |
bool is_unsigned_integer(NumericTypeID type) { | |
return is_integer_type(type) && !is_signed_type(type); | |
} | |
/// Returns true if numeric type is floating-point type | |
bool is_float_type(NumericTypeID type) { | |
switch (type) { | |
case NumericTypeID::kFE4M3: return true; | |
case NumericTypeID::kFE5M2: return true; | |
case NumericTypeID::kF16: return true; | |
case NumericTypeID::kBF16: return true; | |
case NumericTypeID::kTF32: return true; | |
case NumericTypeID::kF32: return true; | |
case NumericTypeID::kF64: return true; | |
case NumericTypeID::kCF16: return true; | |
case NumericTypeID::kCBF16: return true; | |
case NumericTypeID::kCTF32: return true; | |
case NumericTypeID::kCF32: return true; | |
case NumericTypeID::kCF64: return true; | |
default: break; | |
} | |
return false; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
LayoutTypeID layout; | |
char const *alias; | |
} | |
layout_aliases[] = { | |
{LayoutTypeID::kUnknown, "unknown"}, | |
{LayoutTypeID::kRowMajor, "row"}, | |
{LayoutTypeID::kRowMajor, "t"}, | |
{LayoutTypeID::kColumnMajor, "column"}, | |
{LayoutTypeID::kColumnMajor, "col"}, | |
{LayoutTypeID::kColumnMajor, "n"}, | |
{LayoutTypeID::kColumnMajorInterleavedK2, "nk2"}, | |
{LayoutTypeID::kRowMajorInterleavedK2, "tk2"}, | |
{LayoutTypeID::kColumnMajorInterleavedK4, "nk4"}, | |
{LayoutTypeID::kRowMajorInterleavedK4, "tk4"}, | |
{LayoutTypeID::kColumnMajorInterleavedK16, "nk16"}, | |
{LayoutTypeID::kRowMajorInterleavedK16, "tk16"}, | |
{LayoutTypeID::kColumnMajorInterleavedK32, "nk32"}, | |
{LayoutTypeID::kRowMajorInterleavedK32, "tk32"}, | |
{LayoutTypeID::kColumnMajorInterleavedK64, "nk64"}, | |
{LayoutTypeID::kRowMajorInterleavedK64, "tk64"}, | |
{LayoutTypeID::kTensorNCHW, "nchw"}, | |
{LayoutTypeID::kTensorNCDHW, "ncdhw"}, | |
{LayoutTypeID::kTensorNHWC, "nhwc"}, | |
{LayoutTypeID::kTensorNDHWC, "ndhwc"}, | |
{LayoutTypeID::kTensorNC32HW32, "nc32hw32"}, | |
{LayoutTypeID::kTensorNC64HW64, "nc64hw64"}, | |
{LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, | |
{LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, | |
{LayoutTypeID::kUnknown, "*"}, | |
{LayoutTypeID::kInvalid, nullptr} | |
}; | |
/// Converts a LayoutTypeID enumerant to a string | |
char const *to_string(LayoutTypeID layout, bool pretty) { | |
for (auto const & alias : layout_aliases) { | |
if (alias.layout == layout) { | |
return alias.alias; | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Parses a LayoutTypeID enumerant from a string | |
template <> | |
LayoutTypeID from_string<LayoutTypeID>(std::string const &str) { | |
for (auto const & alias : layout_aliases) { | |
if (str.compare(alias.alias) == 0) { | |
return alias.layout; | |
} | |
} | |
return LayoutTypeID::kInvalid; | |
} | |
/// Gets stride rank for the layout_id (static function) | |
int get_layout_stride_rank(LayoutTypeID layout_id) { | |
switch (layout_id) { | |
case LayoutTypeID::kColumnMajor: | |
return cutlass::layout::ColumnMajor::kStrideRank; | |
case LayoutTypeID::kRowMajor: | |
return cutlass::layout::RowMajor::kStrideRank; | |
case LayoutTypeID::kColumnMajorInterleavedK2: | |
return cutlass::layout::ColumnMajorInterleaved<2>::kStrideRank; | |
case LayoutTypeID::kRowMajorInterleavedK2: | |
return cutlass::layout::RowMajorInterleaved<2>::kStrideRank; | |
case LayoutTypeID::kColumnMajorInterleavedK4: | |
return cutlass::layout::ColumnMajorInterleaved<4>::kStrideRank; | |
case LayoutTypeID::kRowMajorInterleavedK4: | |
return cutlass::layout::RowMajorInterleaved<4>::kStrideRank; | |
case LayoutTypeID::kColumnMajorInterleavedK16: | |
return cutlass::layout::ColumnMajorInterleaved<16>::kStrideRank; | |
case LayoutTypeID::kRowMajorInterleavedK16: | |
return cutlass::layout::RowMajorInterleaved<16>::kStrideRank; | |
case LayoutTypeID::kColumnMajorInterleavedK32: | |
return cutlass::layout::ColumnMajorInterleaved<32>::kStrideRank; | |
case LayoutTypeID::kRowMajorInterleavedK32: | |
return cutlass::layout::RowMajorInterleaved<32>::kStrideRank; | |
case LayoutTypeID::kColumnMajorInterleavedK64: | |
return cutlass::layout::ColumnMajorInterleaved<64>::kStrideRank; | |
case LayoutTypeID::kRowMajorInterleavedK64: | |
return cutlass::layout::RowMajorInterleaved<64>::kStrideRank; | |
case LayoutTypeID::kTensorNCHW: | |
return cutlass::layout::TensorNCHW::kStrideRank; | |
case LayoutTypeID::kTensorNHWC: | |
return cutlass::layout::TensorNHWC::kStrideRank; | |
case LayoutTypeID::kTensorNDHWC: | |
return cutlass::layout::TensorNDHWC::kStrideRank; | |
case LayoutTypeID::kTensorNC32HW32: | |
return cutlass::layout::TensorNCxHWx<32>::kStrideRank; | |
case LayoutTypeID::kTensorNC64HW64: | |
return cutlass::layout::TensorNCxHWx<64>::kStrideRank; | |
case LayoutTypeID::kTensorC32RSK32: | |
return cutlass::layout::TensorCxRSKx<32>::kStrideRank; | |
case LayoutTypeID::kTensorC64RSK64: | |
return cutlass::layout::TensorCxRSKx<64>::kStrideRank; | |
default: | |
throw std::runtime_error("Unsupported LayoutTypeID in LayoutType::get_stride_rank"); | |
} | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
OpcodeClassID enumerant; | |
} | |
OpcodeClassID_enumerants[] = { | |
{"simt", "<simt>", OpcodeClassID::kSimt}, | |
{"tensorop", "<tensorop>", OpcodeClassID::kTensorOp}, | |
{"wmmatensorop", "<wmmatensorop>", OpcodeClassID::kWmmaTensorOp}, | |
{"wmma", "<wmma>", OpcodeClassID::kWmmaTensorOp}, | |
}; | |
/// Converts a OpcodeClassID enumerant to a string | |
char const *to_string(OpcodeClassID type, bool pretty) { | |
for (auto const & possible : OpcodeClassID_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Converts a OpcodeClassID enumerant from a string | |
template <> | |
OpcodeClassID from_string<OpcodeClassID>(std::string const &str) { | |
for (auto const & possible : OpcodeClassID_enumerants) { | |
if ((str.compare(possible.text) == 0) || | |
(str.compare(possible.pretty) == 0)) { | |
return possible.enumerant; | |
} | |
} | |
return OpcodeClassID::kInvalid; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
ComplexTransform enumerant; | |
} | |
ComplexTransform_enumerants[] = { | |
{"n", "none", ComplexTransform::kNone}, | |
{"c", "conj", ComplexTransform::kConjugate} | |
}; | |
/// Converts a ComplexTransform enumerant to a string | |
char const *to_string(ComplexTransform type, bool pretty) { | |
for (auto const & possible : ComplexTransform_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Converts a ComplexTransform enumerant from a string | |
template <> | |
ComplexTransform from_string<ComplexTransform>(std::string const &str) { | |
for (auto const & possible : ComplexTransform_enumerants) { | |
if ((str.compare(possible.text) == 0) || | |
(str.compare(possible.pretty) == 0)) { | |
return possible.enumerant; | |
} | |
} | |
return ComplexTransform::kInvalid; | |
} | |
static struct { | |
char const *text; | |
char const *pretty; | |
SplitKMode enumerant; | |
} | |
SplitKMode_enumerants[] = { | |
{"serial", "<serial>", SplitKMode::kSerial}, | |
{"parallel", "<parallel>", SplitKMode::kParallel}, | |
}; | |
/// Converts a SplitKMode enumerant to a string | |
char const *to_string(SplitKMode type, bool pretty) { | |
for (auto const & possible : SplitKMode_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Converts a SplitKMode enumerant from a string | |
template <> | |
SplitKMode from_string<SplitKMode>(std::string const &str) { | |
for (auto const & possible : SplitKMode_enumerants) { | |
if ((str.compare(possible.text) == 0) || | |
(str.compare(possible.pretty) == 0)) { | |
return possible.enumerant; | |
} | |
} | |
return SplitKMode::kInvalid; | |
} | |
///////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
ConvModeID enumerant; | |
} | |
ConvModeID_enumerants[] = { | |
{"cross", "<cross>", ConvModeID::kCrossCorrelation}, | |
{"conv", "<conv>", ConvModeID::kConvolution}, | |
}; | |
/// Converts a ConvModeID enumerant to a string | |
char const *to_string(ConvModeID type, bool pretty) { | |
for (auto const & possible : ConvModeID_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Converts a ConvModeID enumerant from a string | |
template <> | |
ConvModeID from_string<ConvModeID>(std::string const &str) { | |
for (auto const & possible : ConvModeID_enumerants) { | |
if ((str.compare(possible.text) == 0) || | |
(str.compare(possible.pretty) == 0)) { | |
return possible.enumerant; | |
} | |
} | |
return ConvModeID::kInvalid; | |
} | |
static struct { | |
char const *text; | |
char const *pretty; | |
IteratorAlgorithmID enumerant; | |
} | |
IteratorAlgorithmID_enumerants[] = { | |
{"none", "<none>", IteratorAlgorithmID::kNone}, | |
{"analytic", "<analytic>", IteratorAlgorithmID::kAnalytic}, | |
{"optimized", "<optimized>", IteratorAlgorithmID::kOptimized}, | |
{"fixed_channels", "<fixed_channels>", IteratorAlgorithmID::kFixedChannels}, | |
{"few_channels", "<few_channels>", IteratorAlgorithmID::kFewChannels}, | |
}; | |
/// Converts a ConvModeID enumerant to a string | |
char const *to_string(IteratorAlgorithmID type, bool pretty) { | |
for (auto const & possible : IteratorAlgorithmID_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Converts a ConvModeID enumerant from a string | |
template <> | |
IteratorAlgorithmID from_string<IteratorAlgorithmID>(std::string const &str) { | |
for (auto const & possible : IteratorAlgorithmID_enumerants) { | |
if ((str.compare(possible.text) == 0) || | |
(str.compare(possible.pretty) == 0)) { | |
return possible.enumerant; | |
} | |
} | |
return IteratorAlgorithmID::kInvalid; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
ConvKind enumerant; | |
} | |
ConvKind_enumerants[] = { | |
{"unknown", "<unknown>", ConvKind::kUnknown}, | |
{"fprop", "<fprop>", ConvKind::kFprop}, | |
{"dgrad", "<dgrad>", ConvKind::kDgrad}, | |
{"wgrad", "<wgrad>", ConvKind::kWgrad}, | |
}; | |
/// Converts a ConvKind enumerant to a string | |
char const *to_string(ConvKind type, bool pretty) { | |
for (auto const & possible : ConvKind_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Converts a ConvKind enumerant from a string | |
template <> | |
ConvKind from_string<ConvKind>(std::string const &str) { | |
for (auto const & possible : ConvKind_enumerants) { | |
if ((str.compare(possible.text) == 0) || | |
(str.compare(possible.pretty) == 0)) { | |
return possible.enumerant; | |
} | |
} | |
return ConvKind::kInvalid; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
static struct { | |
char const *text; | |
char const *pretty; | |
RasterOrder enumerant; | |
} | |
RasterOrder_enumerants[] = { | |
{"along_n", "<along_n>", RasterOrder::kAlongN}, | |
{"along_m", "<along_m>", RasterOrder::kAlongM}, | |
{"heuristic", "<heuristic>", RasterOrder::kHeuristic}, | |
}; | |
/// Converts a RasterOrder enumerant to a string | |
char const *to_string(RasterOrder type, bool pretty) { | |
for (auto const & possible : RasterOrder_enumerants) { | |
if (type == possible.enumerant) { | |
if (pretty) { | |
return possible.pretty; | |
} | |
else { | |
return possible.text; | |
} | |
} | |
} | |
return pretty ? "Invalid" : "invalid"; | |
} | |
/// Converts a RasterOrder enumerant from a string | |
template <> | |
RasterOrder from_string<RasterOrder>(std::string const &str) { | |
for (auto const & possible : RasterOrder_enumerants) { | |
if ((str.compare(possible.text) == 0) || | |
(str.compare(possible.pretty) == 0)) { | |
return possible.enumerant; | |
} | |
} | |
return RasterOrder::kInvalid; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. | |
bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string const &str) { | |
int size_bytes = sizeof_bits(type) / 8; | |
if (!size_bytes) { | |
return false; | |
} | |
bytes.resize(size_bytes, 0); | |
std::stringstream ss; | |
ss << str; | |
switch (type) { | |
case NumericTypeID::kU8: | |
{ | |
ss >> *reinterpret_cast<uint8_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kU16: | |
{ | |
ss >> *reinterpret_cast<uint16_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kU32: | |
{ | |
ss >> *reinterpret_cast<uint32_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kU64: | |
{ | |
ss >> *reinterpret_cast<uint64_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kS8: | |
{ | |
ss >> *reinterpret_cast<int8_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kS16: | |
{ | |
ss >> *reinterpret_cast<int16_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kS32: | |
{ | |
ss >> *reinterpret_cast<int32_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kS64: | |
{ | |
ss >> *reinterpret_cast<int64_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kFE4M3: | |
{ | |
float tmp; | |
ss >> tmp; | |
*reinterpret_cast<float_e4m3_t *>(bytes.data()) = static_cast<float_e4m3_t>(tmp); | |
} | |
break; | |
case NumericTypeID::kFE5M2: | |
{ | |
float tmp; | |
ss >> tmp; | |
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(tmp); | |
} | |
break; | |
case NumericTypeID::kF16: | |
{ | |
float tmp; | |
ss >> tmp; | |
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(tmp); | |
} | |
break; | |
case NumericTypeID::kBF16: | |
{ | |
float tmp; | |
ss >> tmp; | |
*reinterpret_cast<bfloat16_t *>(bytes.data()) = static_cast<bfloat16_t>(tmp); | |
} | |
break; | |
case NumericTypeID::kTF32: | |
{ | |
float tmp; | |
ss >> tmp; | |
*reinterpret_cast<tfloat32_t *>(bytes.data()) = static_cast<tfloat32_t>(tmp); | |
} | |
break; | |
case NumericTypeID::kF32: | |
{ | |
ss >> *reinterpret_cast<float *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kF64: | |
{ | |
ss >> *reinterpret_cast<double *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kCF16: | |
{ | |
std::complex<float> tmp; | |
ss >> tmp; | |
cutlass::complex<cutlass::half_t> *x = reinterpret_cast<cutlass::complex<half_t> *>(bytes.data()); | |
x->real() = static_cast<half_t>(std::real(tmp)); | |
x->imag() = static_cast<half_t>(std::imag(tmp)); | |
} | |
break; | |
case NumericTypeID::kCBF16: | |
{ | |
std::complex<float> tmp; | |
ss >> tmp; | |
cutlass::complex<cutlass::bfloat16_t> *x = reinterpret_cast<cutlass::complex<bfloat16_t> *>(bytes.data()); | |
x->real() = static_cast<bfloat16_t>(std::real(tmp)); | |
x->imag() = static_cast<bfloat16_t>(std::imag(tmp)); | |
} | |
break; | |
case NumericTypeID::kCF32: | |
{ | |
ss >> *reinterpret_cast<std::complex<float>*>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kCTF32: | |
{ | |
std::complex<float> tmp; | |
ss >> tmp; | |
cutlass::complex<cutlass::tfloat32_t> *x = reinterpret_cast<cutlass::complex<tfloat32_t> *>(bytes.data()); | |
x->real() = static_cast<tfloat32_t>(std::real(tmp)); | |
x->imag() = static_cast<tfloat32_t>(std::imag(tmp)); | |
} | |
break; | |
case NumericTypeID::kCF64: | |
{ | |
ss >> *reinterpret_cast<std::complex<double>*>(bytes.data()); | |
} | |
break; | |
default: | |
return false; | |
} | |
return true; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
std::string lexical_cast(int64_t int_value) { | |
std::stringstream ss; | |
ss << int_value; | |
return ss.str(); | |
} | |
/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. | |
std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type) { | |
size_t size_bytes = sizeof_bits(type) / 8; | |
if (!size_bytes || size_bytes != bytes.size()) { | |
return "<invalid>"; | |
} | |
bytes.resize(size_bytes, 0); | |
std::stringstream ss; | |
switch (type) { | |
case NumericTypeID::kU8: | |
{ | |
ss << *reinterpret_cast<uint8_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kU16: | |
{ | |
ss << *reinterpret_cast<uint16_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kU32: | |
{ | |
ss << *reinterpret_cast<uint32_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kU64: | |
{ | |
ss << *reinterpret_cast<uint64_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kS8: | |
{ | |
ss << *reinterpret_cast<int8_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kS16: | |
{ | |
ss << *reinterpret_cast<int16_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kS32: | |
{ | |
ss << *reinterpret_cast<int32_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kS64: | |
{ | |
ss << *reinterpret_cast<int64_t *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kFE4M3: | |
{ | |
float tmp = *reinterpret_cast<float_e4m3_t *>(bytes.data()); | |
ss << tmp; | |
} | |
break; | |
case NumericTypeID::kFE5M2: | |
{ | |
float tmp = *reinterpret_cast<float_e5m2_t *>(bytes.data()); | |
ss << tmp; | |
} | |
break; | |
case NumericTypeID::kF16: | |
{ | |
float tmp = *reinterpret_cast<half_t *>(bytes.data()); | |
ss << tmp; | |
} | |
break; | |
case NumericTypeID::kBF16: | |
{ | |
float tmp = *reinterpret_cast<bfloat16_t *>(bytes.data()); | |
ss << tmp; | |
} | |
break; | |
case NumericTypeID::kTF32: | |
{ | |
float tmp = *reinterpret_cast<tfloat32_t *>(bytes.data()); | |
ss << tmp; | |
} | |
break; | |
case NumericTypeID::kF32: | |
{ | |
ss << *reinterpret_cast<float *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kF64: | |
{ | |
ss << *reinterpret_cast<double *>(bytes.data()); | |
} | |
break; | |
case NumericTypeID::kCF16: | |
{ | |
cutlass::complex<half_t> const *x = | |
reinterpret_cast<cutlass::complex<half_t> const *>(bytes.data()); | |
ss << float(x->real()); | |
if (x->imag() != cutlass::half_t()) { | |
ss << "+i" << float(x->imag()); | |
} | |
} | |
break; | |
case NumericTypeID::kCBF16: | |
{ | |
cutlass::complex<bfloat16_t> const *x = | |
reinterpret_cast<cutlass::complex<bfloat16_t> const *>(bytes.data()); | |
ss << float(x->real()); | |
if (x->imag() != cutlass::bfloat16_t()) { | |
ss << "+i" << float(x->imag()); | |
} | |
} | |
break; | |
case NumericTypeID::kCF32: | |
{ | |
cutlass::complex<float> const * x = reinterpret_cast<cutlass::complex<float> const *>(bytes.data()); | |
ss << x->real(); | |
if (x->imag() != float()) { | |
ss << "+i" << x->imag(); | |
} | |
} | |
break; | |
case NumericTypeID::kCTF32: | |
{ | |
cutlass::complex<tfloat32_t> const * x = reinterpret_cast<cutlass::complex<tfloat32_t> const *>(bytes.data()); | |
ss << float(x->real()); | |
if (x->imag() != tfloat32_t()) { | |
ss << "+i" << float(x->imag()); | |
} | |
} | |
break; | |
case NumericTypeID::kCF64: | |
{ | |
cutlass::complex<double> const * x = reinterpret_cast<cutlass::complex<double> const *>(bytes.data()); | |
ss << x->real(); | |
if (x->imag() != double()) { | |
ss << "+i" << x->imag(); | |
} | |
} | |
break; | |
default: | |
return "<unknown>"; | |
} | |
return ss.str(); | |
} | |
/// Casts from a signed int64 to the destination type. Returns true if successful. | |
bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t src) { | |
int size_bytes = sizeof_bits(type) / 8; | |
if (!size_bytes) { | |
return false; | |
} | |
bytes.resize(size_bytes, 0); | |
switch (type) { | |
case NumericTypeID::kU8: | |
{ | |
*reinterpret_cast<uint8_t *>(bytes.data()) = static_cast<uint8_t>(src); | |
} | |
break; | |
case NumericTypeID::kU16: | |
{ | |
*reinterpret_cast<uint16_t *>(bytes.data()) = static_cast<uint16_t>(src); | |
} | |
break; | |
case NumericTypeID::kU32: | |
{ | |
*reinterpret_cast<uint32_t *>(bytes.data()) = static_cast<uint32_t>(src); | |
} | |
break; | |
case NumericTypeID::kU64: | |
{ | |
*reinterpret_cast<uint64_t *>(bytes.data()) = static_cast<uint64_t>(src); | |
} | |
break; | |
case NumericTypeID::kS8: | |
{ | |
*reinterpret_cast<int8_t *>(bytes.data()) = static_cast<int8_t>(src); | |
} | |
break; | |
case NumericTypeID::kS16: | |
{ | |
*reinterpret_cast<int16_t *>(bytes.data()) = static_cast<int16_t>(src); | |
} | |
break; | |
case NumericTypeID::kS32: | |
{ | |
*reinterpret_cast<int32_t *>(bytes.data()) = static_cast<int32_t>(src); | |
} | |
break; | |
case NumericTypeID::kS64: | |
{ | |
*reinterpret_cast<int64_t *>(bytes.data()) = static_cast<int64_t>(src); | |
} | |
break; | |
case NumericTypeID::kFE4M3: | |
{ | |
*reinterpret_cast<float_e4m3_t *>(bytes.data()) = static_cast<float_e4m3_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kFE5M2: | |
{ | |
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kF16: | |
{ | |
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kBF16: | |
{ | |
*reinterpret_cast<bfloat16_t *>(bytes.data()) = static_cast<bfloat16_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kTF32: | |
{ | |
*reinterpret_cast<tfloat32_t *>(bytes.data()) = static_cast<tfloat32_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kF32: | |
{ | |
*reinterpret_cast<float *>(bytes.data()) = static_cast<float>(src); | |
} | |
break; | |
case NumericTypeID::kF64: | |
{ | |
*reinterpret_cast<double *>(bytes.data()) = double(src); | |
} | |
break; | |
case NumericTypeID::kCF16: | |
{ | |
cutlass::complex<cutlass::half_t> *x = reinterpret_cast<cutlass::complex<half_t> *>(bytes.data()); | |
x->real() = static_cast<half_t>(float(src)); | |
x->imag() = static_cast<half_t>(float(0)); | |
} | |
break; | |
case NumericTypeID::kCF32: | |
{ | |
*reinterpret_cast<cutlass::complex<float>*>(bytes.data()) = cutlass::complex<float>(float(src), float(0)); | |
} | |
break; | |
case NumericTypeID::kCF64: | |
{ | |
*reinterpret_cast<cutlass::complex<double>*>(bytes.data()) = cutlass::complex<double>(double(src), double(0)); | |
} | |
break; | |
default: | |
return false; | |
} | |
return true; | |
} | |
/// Casts from an unsigned int64 to the destination type. Returns true if successful. | |
bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t src) { | |
int size_bytes = sizeof_bits(type) / 8; | |
if (!size_bytes) { | |
return false; | |
} | |
bytes.resize(size_bytes, 0); | |
switch (type) { | |
case NumericTypeID::kU8: | |
{ | |
*reinterpret_cast<uint8_t *>(bytes.data()) = static_cast<uint8_t>(src); | |
} | |
break; | |
case NumericTypeID::kU16: | |
{ | |
*reinterpret_cast<uint16_t *>(bytes.data()) = static_cast<uint16_t>(src); | |
} | |
break; | |
case NumericTypeID::kU32: | |
{ | |
*reinterpret_cast<uint32_t *>(bytes.data()) = static_cast<uint32_t>(src); | |
} | |
break; | |
case NumericTypeID::kU64: | |
{ | |
*reinterpret_cast<uint64_t *>(bytes.data()) = static_cast<uint64_t>(src); | |
} | |
break; | |
case NumericTypeID::kS8: | |
{ | |
*reinterpret_cast<int8_t *>(bytes.data()) = static_cast<int8_t>(src); | |
} | |
break; | |
case NumericTypeID::kS16: | |
{ | |
*reinterpret_cast<int16_t *>(bytes.data()) = static_cast<int16_t>(src); | |
} | |
break; | |
case NumericTypeID::kS32: | |
{ | |
*reinterpret_cast<int32_t *>(bytes.data()) = static_cast<int32_t>(src); | |
} | |
break; | |
case NumericTypeID::kS64: | |
{ | |
*reinterpret_cast<int64_t *>(bytes.data()) = static_cast<int64_t>(src); | |
} | |
break; | |
case NumericTypeID::kFE4M3: | |
{ | |
*reinterpret_cast<float_e4m3_t *>(bytes.data()) = static_cast<float_e4m3_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kFE5M2: | |
{ | |
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kF16: | |
{ | |
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kBF16: | |
{ | |
*reinterpret_cast<bfloat16_t *>(bytes.data()) = static_cast<bfloat16_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kTF32: | |
{ | |
*reinterpret_cast<tfloat32_t *>(bytes.data()) = static_cast<tfloat32_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kF32: | |
{ | |
*reinterpret_cast<float *>(bytes.data()) = static_cast<float>(src); | |
} | |
break; | |
case NumericTypeID::kF64: | |
{ | |
*reinterpret_cast<double *>(bytes.data()) = double(src); | |
} | |
break; | |
case NumericTypeID::kCF16: | |
{ | |
cutlass::complex<cutlass::half_t> *x = reinterpret_cast<cutlass::complex<half_t> *>(bytes.data()); | |
x->real() = static_cast<half_t>(float(src)); | |
x->imag() = static_cast<half_t>(float(0)); | |
} | |
break; | |
case NumericTypeID::kCF32: | |
{ | |
*reinterpret_cast<std::complex<float>*>(bytes.data()) = std::complex<float>(float(src), float(0)); | |
} | |
break; | |
case NumericTypeID::kCF64: | |
{ | |
*reinterpret_cast<std::complex<double>*>(bytes.data()) = std::complex<double>(double(src), double(0)); | |
} | |
break; | |
default: | |
return false; | |
} | |
return true; | |
} | |
/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. | |
bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double src) { | |
int size_bytes = sizeof_bits(type) / 8; | |
if (!size_bytes) { | |
return false; | |
} | |
bytes.resize(size_bytes, 0); | |
switch (type) { | |
case NumericTypeID::kU8: | |
{ | |
*reinterpret_cast<uint8_t *>(bytes.data()) = static_cast<uint8_t>(src); | |
} | |
break; | |
case NumericTypeID::kU16: | |
{ | |
*reinterpret_cast<uint16_t *>(bytes.data()) = static_cast<uint16_t>(src); | |
} | |
break; | |
case NumericTypeID::kU32: | |
{ | |
*reinterpret_cast<uint32_t *>(bytes.data()) = static_cast<uint32_t>(src); | |
} | |
break; | |
case NumericTypeID::kU64: | |
{ | |
*reinterpret_cast<uint64_t *>(bytes.data()) = static_cast<uint64_t>(src); | |
} | |
break; | |
case NumericTypeID::kS8: | |
{ | |
*reinterpret_cast<int8_t *>(bytes.data()) = static_cast<int8_t>(src); | |
} | |
break; | |
case NumericTypeID::kS16: | |
{ | |
*reinterpret_cast<int16_t *>(bytes.data()) = static_cast<int16_t>(src); | |
} | |
break; | |
case NumericTypeID::kS32: | |
{ | |
*reinterpret_cast<int32_t *>(bytes.data()) = static_cast<int32_t>(src); | |
} | |
break; | |
case NumericTypeID::kS64: | |
{ | |
*reinterpret_cast<int64_t *>(bytes.data()) = static_cast<int64_t>(src); | |
} | |
break; | |
case NumericTypeID::kFE4M3: | |
{ | |
*reinterpret_cast<float_e4m3_t *>(bytes.data()) = static_cast<float_e4m3_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kFE5M2: | |
{ | |
*reinterpret_cast<float_e5m2_t *>(bytes.data()) = static_cast<float_e5m2_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kF16: | |
{ | |
*reinterpret_cast<half_t *>(bytes.data()) = static_cast<half_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kBF16: | |
{ | |
*reinterpret_cast<bfloat16_t *>(bytes.data()) = static_cast<bfloat16_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kTF32: | |
{ | |
*reinterpret_cast<tfloat32_t *>(bytes.data()) = static_cast<tfloat32_t>(float(src)); | |
} | |
break; | |
case NumericTypeID::kF32: | |
{ | |
*reinterpret_cast<float *>(bytes.data()) = static_cast<float>(src); | |
} | |
break; | |
case NumericTypeID::kF64: | |
{ | |
*reinterpret_cast<double *>(bytes.data()) = src; | |
} | |
break; | |
case NumericTypeID::kCF16: | |
{ | |
cutlass::complex<cutlass::half_t> *x = reinterpret_cast<cutlass::complex<half_t> *>(bytes.data()); | |
x->real() = static_cast<half_t>(float(src)); | |
x->imag() = static_cast<half_t>(float(0)); | |
} | |
break; | |
case NumericTypeID::kCBF16: | |
{ | |
cutlass::complex<cutlass::bfloat16_t> *x = reinterpret_cast<cutlass::complex<bfloat16_t> *>(bytes.data()); | |
x->real() = static_cast<bfloat16_t>(bfloat16_t(src)); | |
x->imag() = static_cast<bfloat16_t>(bfloat16_t(0)); | |
} | |
break; | |
case NumericTypeID::kCF32: | |
{ | |
*reinterpret_cast<cutlass::complex<float>*>(bytes.data()) = cutlass::complex<float>(float(src), float()); | |
} | |
break; | |
case NumericTypeID::kCTF32: | |
{ | |
*reinterpret_cast<cutlass::complex<tfloat32_t>*>(bytes.data()) = cutlass::complex<tfloat32_t>(tfloat32_t(src), tfloat32_t()); | |
} | |
break; | |
case NumericTypeID::kCF64: | |
{ | |
*reinterpret_cast<cutlass::complex<double>*>(bytes.data()) = cutlass::complex<double>(src, double()); | |
} | |
break; | |
default: | |
return false; | |
} | |
return true; | |
} | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |
} // namespace library | |
} // namespace cutlass | |
/////////////////////////////////////////////////////////////////////////////////////////////////// | |