|
#pragma once |
|
|
|
#include "vectorization.cuh" |
|
|
|
#include <cmath> |
|
#include <c10/core/ScalarType.h> |
|
|
|
#ifndef USE_ROCM |
|
#include <c10/util/Float8_e4m3fn.h> |
|
using FP8_TYPE = c10::Float8_e4m3fn; |
|
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = |
|
std::numeric_limits<FP8_TYPE>::max(); |
|
#else |
|
#include <c10/util/Float8_e4m3fnuz.h> |
|
#include "amd/hip_float8.h" |
|
using FP8_TYPE = c10::Float8_e4m3fnuz; |
|
|
|
|
|
constexpr auto FP8_E4M3_MAX = 224.0f; |
|
#endif |
|
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value; |
|
|
|
namespace vllm { |
|
|
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { |
|
float old; |
|
old = (value >= 0) |
|
? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) |
|
: __uint_as_float( |
|
atomicMin((unsigned int*)addr, __float_as_uint(value))); |
|
|
|
return old; |
|
} |
|
|
|
template <bool is_scale_inverted> |
|
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, |
|
float const scale) { |
|
float x = 0.0f; |
|
if constexpr (is_scale_inverted) { |
|
x = val * scale; |
|
} else { |
|
x = val / scale; |
|
} |
|
|
|
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); |
|
#ifndef USE_ROCM |
|
return static_cast<c10::Float8_e4m3fn>(r); |
|
#else |
|
|
|
return c10::Float8_e4m3fnuz(hip_fp8(r).data, |
|
c10::Float8_e4m3fnuz::from_bits()); |
|
#endif |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t> |
|
__global__ void segmented_max_reduction(float* __restrict__ scale, |
|
const scalar_t* __restrict__ input, |
|
int64_t num_elems) { |
|
__shared__ float cache[1024]; |
|
int64_t i = blockDim.x * blockIdx.x + threadIdx.x; |
|
|
|
|
|
|
|
scalar_t tmp = 0.0; |
|
while (i < num_elems) { |
|
float x = static_cast<float>(input[i]); |
|
tmp = max(tmp, fabs(x)); |
|
i += blockDim.x * gridDim.x; |
|
} |
|
cache[threadIdx.x] = tmp; |
|
|
|
__syncthreads(); |
|
|
|
|
|
int ib = blockDim.x / 2; |
|
while (ib != 0) { |
|
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { |
|
cache[threadIdx.x] = cache[threadIdx.x + ib]; |
|
} |
|
__syncthreads(); |
|
ib /= 2; |
|
} |
|
|
|
|
|
if (threadIdx.x == 0) { |
|
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX); |
|
} |
|
} |
|
|
|
template <typename scalar_t> |
|
__device__ float thread_max_vec(scalar_t const* __restrict__ input, |
|
int64_t const num_elems, int const tid, |
|
int const step) { |
|
|
|
vec4_t<scalar_t> const* vectorized_in = |
|
reinterpret_cast<vec4_t<scalar_t> const*>(input); |
|
|
|
int64_t const num_vec_elems = num_elems >> 2; |
|
float absmax_val = 0.0f; |
|
|
|
#pragma unroll 4 |
|
for (int64_t i = tid; i < num_vec_elems; i += step) { |
|
vec4_t<scalar_t> in_vec = vectorized_in[i]; |
|
absmax_val = max(absmax_val, fabs(in_vec.x)); |
|
absmax_val = max(absmax_val, fabs(in_vec.y)); |
|
absmax_val = max(absmax_val, fabs(in_vec.z)); |
|
absmax_val = max(absmax_val, fabs(in_vec.w)); |
|
} |
|
|
|
|
|
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { |
|
absmax_val = max(absmax_val, fabs(input[i])); |
|
} |
|
|
|
return absmax_val; |
|
} |
|
|
|
template <typename scalar_t, bool is_scale_inverted> |
|
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, |
|
scalar_t const* __restrict__ input, |
|
float const scale, |
|
int64_t const num_elems, |
|
int const tid, int const step) { |
|
using float8x4_t = q8x4_t<FP8_TYPE>; |
|
|
|
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input); |
|
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out); |
|
|
|
int64_t const num_vec_elems = num_elems >> 2; |
|
|
|
#pragma unroll 4 |
|
for (int64_t i = tid; i < num_vec_elems; i += step) { |
|
vec4_t<scalar_t> in_vec = vectorized_in[i]; |
|
float8x4_t out_vec; |
|
|
|
out_vec.x = scaled_fp8_conversion<is_scale_inverted>( |
|
static_cast<float>(in_vec.x), scale); |
|
out_vec.y = scaled_fp8_conversion<is_scale_inverted>( |
|
static_cast<float>(in_vec.y), scale); |
|
out_vec.z = scaled_fp8_conversion<is_scale_inverted>( |
|
static_cast<float>(in_vec.z), scale); |
|
out_vec.w = scaled_fp8_conversion<is_scale_inverted>( |
|
static_cast<float>(in_vec.w), scale); |
|
vectorized_out[i] = out_vec; |
|
} |
|
|
|
|
|
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { |
|
out[i] = scaled_fp8_conversion<is_scale_inverted>( |
|
static_cast<float>(input[i]), scale); |
|
} |
|
} |
|
|
|
} |
|
|