Add missing `scalar_type.hpp`
Browse files- core/scalar_type.hpp +347 -0
core/scalar_type.hpp
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
// For TORCH_CHECK
|
4 |
+
#include <torch/library.h>
|
5 |
+
|
6 |
+
namespace vllm {
|
7 |
+
|
8 |
+
//
|
9 |
+
// ScalarType can represent a wide range of floating point and integer types,
|
10 |
+
// in particular it can be used to represent sub-byte data types (something
|
11 |
+
// that torch.dtype currently does not support).
|
12 |
+
//
|
13 |
+
// The type definitions on the Python side can be found in: vllm/scalar_type.py
|
14 |
+
// these type definitions should be kept up to date with any Python API changes
|
15 |
+
// here.
|
16 |
+
//
|
17 |
+
class ScalarType {
|
18 |
+
public:
|
19 |
+
enum NanRepr : uint8_t {
|
20 |
+
NAN_NONE = 0, // nans are not supported
|
21 |
+
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
22 |
+
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
23 |
+
|
24 |
+
NAN_REPR_ID_MAX
|
25 |
+
};
|
26 |
+
|
27 |
+
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
|
28 |
+
int32_t bias, bool finite_values_only = false,
|
29 |
+
NanRepr nan_repr = NAN_IEEE_754)
|
30 |
+
: exponent(exponent),
|
31 |
+
mantissa(mantissa),
|
32 |
+
signed_(signed_),
|
33 |
+
bias(bias),
|
34 |
+
finite_values_only(finite_values_only),
|
35 |
+
nan_repr(nan_repr){};
|
36 |
+
|
37 |
+
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
38 |
+
return ScalarType(0, size_bits - 1, true, bias);
|
39 |
+
}
|
40 |
+
|
41 |
+
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
|
42 |
+
return ScalarType(0, size_bits, false, bias);
|
43 |
+
}
|
44 |
+
|
45 |
+
// IEEE 754 compliant floating point type
|
46 |
+
static constexpr ScalarType float_IEEE754(uint8_t exponent,
|
47 |
+
uint8_t mantissa) {
|
48 |
+
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
49 |
+
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
|
50 |
+
}
|
51 |
+
|
52 |
+
// IEEE 754 non-compliant floating point type
|
53 |
+
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
|
54 |
+
bool finite_values_only,
|
55 |
+
NanRepr nan_repr) {
|
56 |
+
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
57 |
+
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
58 |
+
TORCH_CHECK(nan_repr != NAN_IEEE_754,
|
59 |
+
"use `float_IEEE754` constructor for floating point types that "
|
60 |
+
"follow IEEE 754 conventions");
|
61 |
+
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
|
62 |
+
nan_repr);
|
63 |
+
}
|
64 |
+
|
65 |
+
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
66 |
+
uint8_t const mantissa; // size of the mantissa field (size of the integer
|
67 |
+
// excluding the sign bit for integer types)
|
68 |
+
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
69 |
+
// sign bit)
|
70 |
+
int32_t const bias; // stored values equal value + bias,
|
71 |
+
// used for quantized type
|
72 |
+
|
73 |
+
// Extra Floating point info
|
74 |
+
bool const finite_values_only; // i.e. no +/-inf if true
|
75 |
+
NanRepr const nan_repr; // how NaNs are represented
|
76 |
+
// (not applicable for integer types)
|
77 |
+
|
78 |
+
using Id = int64_t;
|
79 |
+
|
80 |
+
private:
|
81 |
+
// Field size in id
|
82 |
+
template <typename T_>
|
83 |
+
static constexpr size_t member_id_field_width() {
|
84 |
+
using T = std::decay_t<T_>;
|
85 |
+
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
|
86 |
+
}
|
87 |
+
|
88 |
+
template <typename Fn, typename Init, typename Member, typename... Rest>
|
89 |
+
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
|
90 |
+
Rest... rest) {
|
91 |
+
auto new_val = f(val, member);
|
92 |
+
if constexpr (sizeof...(rest) > 0) {
|
93 |
+
return reduce_members_helper(f, new_val, rest...);
|
94 |
+
} else {
|
95 |
+
return new_val;
|
96 |
+
};
|
97 |
+
}
|
98 |
+
|
99 |
+
template <typename Fn, typename Init>
|
100 |
+
constexpr auto reduce_members(Fn f, Init init) const {
|
101 |
+
// Should be in constructor order for `from_id`
|
102 |
+
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
|
103 |
+
finite_values_only, nan_repr);
|
104 |
+
};
|
105 |
+
|
106 |
+
template <typename Fn, typename Init>
|
107 |
+
static constexpr auto reduce_member_types(Fn f, Init init) {
|
108 |
+
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
|
109 |
+
return dummy_type.reduce_members(f, init);
|
110 |
+
};
|
111 |
+
|
112 |
+
static constexpr auto id_size_bits() {
|
113 |
+
return reduce_member_types(
|
114 |
+
[](int acc, auto member) -> int {
|
115 |
+
return acc + member_id_field_width<decltype(member)>();
|
116 |
+
},
|
117 |
+
0);
|
118 |
+
}
|
119 |
+
|
120 |
+
public:
|
121 |
+
// unique id for this scalar type that can be computed at compile time for
|
122 |
+
// c++17 template specialization this is not needed once we migrate to
|
123 |
+
// c++20 and can pass literal classes as template parameters
|
124 |
+
constexpr Id id() const {
|
125 |
+
static_assert(id_size_bits() <= sizeof(Id) * 8,
|
126 |
+
"ScalarType id is too large to be stored");
|
127 |
+
|
128 |
+
auto or_and_advance = [](std::pair<Id, uint32_t> result,
|
129 |
+
auto member) -> std::pair<Id, uint32_t> {
|
130 |
+
auto [id, bit_offset] = result;
|
131 |
+
auto constexpr bits = member_id_field_width<decltype(member)>();
|
132 |
+
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
|
133 |
+
<< bit_offset,
|
134 |
+
bit_offset + bits};
|
135 |
+
};
|
136 |
+
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
|
137 |
+
}
|
138 |
+
|
139 |
+
// create a ScalarType from an id, for c++17 template specialization,
|
140 |
+
// this is not needed once we migrate to c++20 and can pass literal
|
141 |
+
// classes as template parameters
|
142 |
+
static constexpr ScalarType from_id(Id id) {
|
143 |
+
auto extract_and_advance = [id](auto result, auto member) {
|
144 |
+
using T = decltype(member);
|
145 |
+
auto [tuple, bit_offset] = result;
|
146 |
+
auto constexpr bits = member_id_field_width<T>();
|
147 |
+
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
|
148 |
+
((uint64_t(1) << bits) - 1));
|
149 |
+
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
|
150 |
+
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
|
151 |
+
};
|
152 |
+
|
153 |
+
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
|
154 |
+
std::pair<std::tuple<>, int>{});
|
155 |
+
return std::apply([](auto... args) { return ScalarType(args...); },
|
156 |
+
tuple_args);
|
157 |
+
}
|
158 |
+
|
159 |
+
constexpr int64_t size_bits() const {
|
160 |
+
return mantissa + exponent + is_signed();
|
161 |
+
}
|
162 |
+
constexpr bool is_signed() const { return signed_; }
|
163 |
+
constexpr bool is_integer() const { return exponent == 0; }
|
164 |
+
constexpr bool is_floating_point() const { return exponent > 0; }
|
165 |
+
constexpr bool is_ieee_754() const {
|
166 |
+
return is_floating_point() && finite_values_only == false &&
|
167 |
+
nan_repr == NAN_IEEE_754;
|
168 |
+
}
|
169 |
+
constexpr bool has_nans() const {
|
170 |
+
return is_floating_point() && nan_repr != NAN_NONE;
|
171 |
+
}
|
172 |
+
constexpr bool has_infs() const {
|
173 |
+
return is_floating_point() && finite_values_only == false;
|
174 |
+
}
|
175 |
+
constexpr bool has_bias() const { return bias != 0; }
|
176 |
+
|
177 |
+
private:
|
178 |
+
double _floating_point_max() const {
|
179 |
+
TORCH_CHECK(mantissa <= 52 && exponent <= 11,
|
180 |
+
"Cannot represent max/min as a double for type ", str());
|
181 |
+
|
182 |
+
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
183 |
+
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
184 |
+
max_mantissa -= 1;
|
185 |
+
}
|
186 |
+
|
187 |
+
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
188 |
+
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
189 |
+
TORCH_CHECK(exponent < 11,
|
190 |
+
"Cannot represent max/min as a double for type ", str());
|
191 |
+
max_exponent += 1;
|
192 |
+
}
|
193 |
+
|
194 |
+
// adjust the exponent to match that of a double
|
195 |
+
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
|
196 |
+
// is the exponent bits), there is some precedent for non-standard biases,
|
197 |
+
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
|
198 |
+
// but to avoid premature over complication we are just assuming the
|
199 |
+
// standard exponent bias until there is a need to support non-standard
|
200 |
+
// biases
|
201 |
+
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
|
202 |
+
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
|
203 |
+
|
204 |
+
uint64_t max_exponent_double =
|
205 |
+
max_exponent - exponent_bias + exponent_bias_double;
|
206 |
+
|
207 |
+
// shift the mantissa into the position for a double and
|
208 |
+
// the exponent
|
209 |
+
uint64_t double_raw =
|
210 |
+
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
|
211 |
+
|
212 |
+
return *reinterpret_cast<double*>(&double_raw);
|
213 |
+
}
|
214 |
+
|
215 |
+
constexpr std::variant<int64_t, double> _raw_max() const {
|
216 |
+
if (is_floating_point()) {
|
217 |
+
return {_floating_point_max()};
|
218 |
+
} else {
|
219 |
+
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
|
220 |
+
"Cannot represent max as a int64_t");
|
221 |
+
return {(int64_t(1) << mantissa) - 1};
|
222 |
+
}
|
223 |
+
}
|
224 |
+
|
225 |
+
constexpr std::variant<int64_t, double> _raw_min() const {
|
226 |
+
if (is_floating_point()) {
|
227 |
+
TORCH_CHECK(is_signed(),
|
228 |
+
"We currently assume all floating point types are signed");
|
229 |
+
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
|
230 |
+
|
231 |
+
double max = _floating_point_max();
|
232 |
+
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
|
233 |
+
uint64_t min_raw = max_raw | sign_bit_double;
|
234 |
+
return {*reinterpret_cast<double*>(&min_raw)};
|
235 |
+
} else {
|
236 |
+
TORCH_CHECK(!is_signed() || size_bits() <= 64,
|
237 |
+
"Cannot represent min as a int64_t");
|
238 |
+
if (is_signed()) {
|
239 |
+
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
|
240 |
+
// then perform an arithmetic shift right to set all the bits above
|
241 |
+
// (size_bits() - 1) to 1
|
242 |
+
return {INT64_MIN >> (64 - size_bits())};
|
243 |
+
} else {
|
244 |
+
return {int64_t(0)};
|
245 |
+
}
|
246 |
+
}
|
247 |
+
}
|
248 |
+
|
249 |
+
public:
|
250 |
+
// Max representable value for this scalar type.
|
251 |
+
// (accounting for bias if there is one)
|
252 |
+
constexpr std::variant<int64_t, double> max() const {
|
253 |
+
return std::visit(
|
254 |
+
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
255 |
+
_raw_max());
|
256 |
+
}
|
257 |
+
|
258 |
+
// Min representable value for this scalar type.
|
259 |
+
// (accounting for bias if there is one)
|
260 |
+
constexpr std::variant<int64_t, double> min() const {
|
261 |
+
return std::visit(
|
262 |
+
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
263 |
+
_raw_min());
|
264 |
+
}
|
265 |
+
|
266 |
+
std::string str() const {
|
267 |
+
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
|
268 |
+
* for floating point types (leading f) the scheme is:
|
269 |
+
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
270 |
+
* flags:
|
271 |
+
* - no-flags: means it follows IEEE 754 conventions
|
272 |
+
* - f: means finite values only (no infinities)
|
273 |
+
* - n: means nans are supported (non-standard encoding)
|
274 |
+
* for integer types the scheme is:
|
275 |
+
* `[u]int<size_bits>[b<bias>]`
|
276 |
+
* - if bias is not present it means its zero
|
277 |
+
*/
|
278 |
+
if (is_floating_point()) {
|
279 |
+
auto ret = "float" + std::to_string(size_bits()) + "_e" +
|
280 |
+
std::to_string(exponent) + "m" + std::to_string(mantissa);
|
281 |
+
if (!is_ieee_754()) {
|
282 |
+
if (finite_values_only) {
|
283 |
+
ret += "f";
|
284 |
+
}
|
285 |
+
if (nan_repr != NAN_NONE) {
|
286 |
+
ret += "n";
|
287 |
+
}
|
288 |
+
}
|
289 |
+
return ret;
|
290 |
+
} else {
|
291 |
+
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
|
292 |
+
if (has_bias()) {
|
293 |
+
ret += "b" + std::to_string(bias);
|
294 |
+
}
|
295 |
+
return ret;
|
296 |
+
}
|
297 |
+
}
|
298 |
+
|
299 |
+
constexpr bool operator==(ScalarType const& other) const {
|
300 |
+
return mantissa == other.mantissa && exponent == other.exponent &&
|
301 |
+
bias == other.bias && signed_ == other.signed_ &&
|
302 |
+
finite_values_only == other.finite_values_only &&
|
303 |
+
nan_repr == other.nan_repr;
|
304 |
+
}
|
305 |
+
};
|
306 |
+
|
307 |
+
using ScalarTypeId = ScalarType::Id;
|
308 |
+
|
309 |
+
// "rust style" names generally following:
|
310 |
+
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
311 |
+
static inline constexpr auto kS4 = ScalarType::int_(4);
|
312 |
+
static inline constexpr auto kU4 = ScalarType::uint(4);
|
313 |
+
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
|
314 |
+
static inline constexpr auto kS8 = ScalarType::int_(8);
|
315 |
+
static inline constexpr auto kU8 = ScalarType::uint(8);
|
316 |
+
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
317 |
+
|
318 |
+
static inline constexpr auto kFE3M2f =
|
319 |
+
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
320 |
+
static inline constexpr auto kFE4M3fn =
|
321 |
+
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
322 |
+
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
323 |
+
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
324 |
+
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
325 |
+
|
326 |
+
// Fixed width style names, generally following:
|
327 |
+
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
328 |
+
static inline constexpr auto kInt4 = kS4;
|
329 |
+
static inline constexpr auto kUint4 = kU4;
|
330 |
+
static inline constexpr auto kUint4b8 = kU4B8;
|
331 |
+
static inline constexpr auto kInt8 = kS8;
|
332 |
+
static inline constexpr auto kUint8 = kU8;
|
333 |
+
static inline constexpr auto kUint8b128 = kU8B128;
|
334 |
+
|
335 |
+
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
336 |
+
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
337 |
+
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
338 |
+
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
|
339 |
+
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
|
340 |
+
|
341 |
+
// colloquial names
|
342 |
+
static inline constexpr auto kHalf = kFE5M10;
|
343 |
+
static inline constexpr auto kFloat16 = kHalf;
|
344 |
+
static inline constexpr auto kBFloat16 = kFE8M7;
|
345 |
+
|
346 |
+
static inline constexpr auto kFloat16Id = kFloat16.id();
|
347 |
+
}; // namespace vllm
|