|
|
|
#ifndef _data_types_cuh |
|
#define _data_types_cuh |
|
#include "marlin.cuh" |
|
#include <cuda_fp16.h> |
|
#include <cuda_bf16.h> |
|
|
|
namespace marlin { |
|
|
|
template <typename scalar_t> |
|
class ScalarType {}; |
|
|
|
template <> |
|
class ScalarType<half> { |
|
public: |
|
using scalar_t = half; |
|
using scalar_t2 = half2; |
|
|
|
|
|
|
|
|
|
using FragA = Vec<half2, 4>; |
|
using FragB = Vec<half2, 2>; |
|
using FragC = Vec<float, 4>; |
|
using FragS = Vec<half2, 1>; |
|
using FragZP = Vec<half2, 4>; |
|
|
|
static __device__ float inline num2float(const half x) { |
|
return __half2float(x); |
|
} |
|
|
|
static __device__ half2 inline num2num2(const half x) { |
|
return __half2half2(x); |
|
} |
|
|
|
static __device__ half2 inline nums2num2(const half x1, const half x2) { |
|
return __halves2half2(x1, x2); |
|
} |
|
|
|
static __host__ __device__ half inline float2num(const float x) { |
|
return __float2half(x); |
|
} |
|
}; |
|
|
|
template <> |
|
class ScalarType<nv_bfloat16> { |
|
public: |
|
using scalar_t = nv_bfloat16; |
|
using scalar_t2 = nv_bfloat162; |
|
|
|
using FragA = Vec<nv_bfloat162, 4>; |
|
using FragB = Vec<nv_bfloat162, 2>; |
|
using FragC = Vec<float, 4>; |
|
using FragS = Vec<nv_bfloat162, 1>; |
|
using FragZP = Vec<nv_bfloat162, 4>; |
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
|
static __device__ float inline num2float(const nv_bfloat16 x) { |
|
return __bfloat162float(x); |
|
} |
|
|
|
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { |
|
return __bfloat162bfloat162(x); |
|
} |
|
|
|
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, |
|
const nv_bfloat16 x2) { |
|
return __halves2bfloat162(x1, x2); |
|
} |
|
|
|
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { |
|
return __float2bfloat16(x); |
|
} |
|
#endif |
|
}; |
|
|
|
} |
|
|
|
#endif |
|
|