danieldk HF staff commited on
Commit
9d045c3
·
1 Parent(s): 439141b

Add missing `scalar_type.hpp`

Browse files
Files changed (1) hide show
  1. core/scalar_type.hpp +347 -0
core/scalar_type.hpp ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // For TORCH_CHECK
4
+ #include <torch/library.h>
5
+
6
+ namespace vllm {
7
+
8
+ //
9
+ // ScalarType can represent a wide range of floating point and integer types,
10
+ // in particular it can be used to represent sub-byte data types (something
11
+ // that torch.dtype currently does not support).
12
+ //
13
+ // The type definitions on the Python side can be found in: vllm/scalar_type.py
14
+ // these type definitions should be kept up to date with any Python API changes
15
+ // here.
16
+ //
17
+ class ScalarType {
18
+ public:
19
+ enum NanRepr : uint8_t {
20
+ NAN_NONE = 0, // nans are not supported
21
+ NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
22
+ NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
23
+
24
+ NAN_REPR_ID_MAX
25
+ };
26
+
27
+ constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
28
+ int32_t bias, bool finite_values_only = false,
29
+ NanRepr nan_repr = NAN_IEEE_754)
30
+ : exponent(exponent),
31
+ mantissa(mantissa),
32
+ signed_(signed_),
33
+ bias(bias),
34
+ finite_values_only(finite_values_only),
35
+ nan_repr(nan_repr){};
36
+
37
+ static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
38
+ return ScalarType(0, size_bits - 1, true, bias);
39
+ }
40
+
41
+ static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
42
+ return ScalarType(0, size_bits, false, bias);
43
+ }
44
+
45
+ // IEEE 754 compliant floating point type
46
+ static constexpr ScalarType float_IEEE754(uint8_t exponent,
47
+ uint8_t mantissa) {
48
+ TORCH_CHECK(mantissa > 0 && exponent > 0);
49
+ return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
50
+ }
51
+
52
+ // IEEE 754 non-compliant floating point type
53
+ static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
54
+ bool finite_values_only,
55
+ NanRepr nan_repr) {
56
+ TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
57
+ TORCH_CHECK(mantissa > 0 && exponent > 0);
58
+ TORCH_CHECK(nan_repr != NAN_IEEE_754,
59
+ "use `float_IEEE754` constructor for floating point types that "
60
+ "follow IEEE 754 conventions");
61
+ return ScalarType(exponent, mantissa, true, 0, finite_values_only,
62
+ nan_repr);
63
+ }
64
+
65
+ uint8_t const exponent; // size of the exponent field (0 for integer types)
66
+ uint8_t const mantissa; // size of the mantissa field (size of the integer
67
+ // excluding the sign bit for integer types)
68
+ bool const signed_; // flag if the type supports negative numbers (i.e. has a
69
+ // sign bit)
70
+ int32_t const bias; // stored values equal value + bias,
71
+ // used for quantized type
72
+
73
+ // Extra Floating point info
74
+ bool const finite_values_only; // i.e. no +/-inf if true
75
+ NanRepr const nan_repr; // how NaNs are represented
76
+ // (not applicable for integer types)
77
+
78
+ using Id = int64_t;
79
+
80
+ private:
81
+ // Field size in id
82
+ template <typename T_>
83
+ static constexpr size_t member_id_field_width() {
84
+ using T = std::decay_t<T_>;
85
+ return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
86
+ }
87
+
88
+ template <typename Fn, typename Init, typename Member, typename... Rest>
89
+ static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
90
+ Rest... rest) {
91
+ auto new_val = f(val, member);
92
+ if constexpr (sizeof...(rest) > 0) {
93
+ return reduce_members_helper(f, new_val, rest...);
94
+ } else {
95
+ return new_val;
96
+ };
97
+ }
98
+
99
+ template <typename Fn, typename Init>
100
+ constexpr auto reduce_members(Fn f, Init init) const {
101
+ // Should be in constructor order for `from_id`
102
+ return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
103
+ finite_values_only, nan_repr);
104
+ };
105
+
106
+ template <typename Fn, typename Init>
107
+ static constexpr auto reduce_member_types(Fn f, Init init) {
108
+ constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
109
+ return dummy_type.reduce_members(f, init);
110
+ };
111
+
112
+ static constexpr auto id_size_bits() {
113
+ return reduce_member_types(
114
+ [](int acc, auto member) -> int {
115
+ return acc + member_id_field_width<decltype(member)>();
116
+ },
117
+ 0);
118
+ }
119
+
120
+ public:
121
+ // unique id for this scalar type that can be computed at compile time for
122
+ // c++17 template specialization this is not needed once we migrate to
123
+ // c++20 and can pass literal classes as template parameters
124
+ constexpr Id id() const {
125
+ static_assert(id_size_bits() <= sizeof(Id) * 8,
126
+ "ScalarType id is too large to be stored");
127
+
128
+ auto or_and_advance = [](std::pair<Id, uint32_t> result,
129
+ auto member) -> std::pair<Id, uint32_t> {
130
+ auto [id, bit_offset] = result;
131
+ auto constexpr bits = member_id_field_width<decltype(member)>();
132
+ return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
133
+ << bit_offset,
134
+ bit_offset + bits};
135
+ };
136
+ return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
137
+ }
138
+
139
+ // create a ScalarType from an id, for c++17 template specialization,
140
+ // this is not needed once we migrate to c++20 and can pass literal
141
+ // classes as template parameters
142
+ static constexpr ScalarType from_id(Id id) {
143
+ auto extract_and_advance = [id](auto result, auto member) {
144
+ using T = decltype(member);
145
+ auto [tuple, bit_offset] = result;
146
+ auto constexpr bits = member_id_field_width<T>();
147
+ auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
148
+ ((uint64_t(1) << bits) - 1));
149
+ auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
150
+ return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
151
+ };
152
+
153
+ auto [tuple_args, _] = reduce_member_types(extract_and_advance,
154
+ std::pair<std::tuple<>, int>{});
155
+ return std::apply([](auto... args) { return ScalarType(args...); },
156
+ tuple_args);
157
+ }
158
+
159
+ constexpr int64_t size_bits() const {
160
+ return mantissa + exponent + is_signed();
161
+ }
162
+ constexpr bool is_signed() const { return signed_; }
163
+ constexpr bool is_integer() const { return exponent == 0; }
164
+ constexpr bool is_floating_point() const { return exponent > 0; }
165
+ constexpr bool is_ieee_754() const {
166
+ return is_floating_point() && finite_values_only == false &&
167
+ nan_repr == NAN_IEEE_754;
168
+ }
169
+ constexpr bool has_nans() const {
170
+ return is_floating_point() && nan_repr != NAN_NONE;
171
+ }
172
+ constexpr bool has_infs() const {
173
+ return is_floating_point() && finite_values_only == false;
174
+ }
175
+ constexpr bool has_bias() const { return bias != 0; }
176
+
177
+ private:
178
+ double _floating_point_max() const {
179
+ TORCH_CHECK(mantissa <= 52 && exponent <= 11,
180
+ "Cannot represent max/min as a double for type ", str());
181
+
182
+ uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
183
+ if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
184
+ max_mantissa -= 1;
185
+ }
186
+
187
+ uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
188
+ if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
189
+ TORCH_CHECK(exponent < 11,
190
+ "Cannot represent max/min as a double for type ", str());
191
+ max_exponent += 1;
192
+ }
193
+
194
+ // adjust the exponent to match that of a double
195
+ // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
196
+ // is the exponent bits), there is some precedent for non-standard biases,
197
+ // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
198
+ // but to avoid premature over complication we are just assuming the
199
+ // standard exponent bias until there is a need to support non-standard
200
+ // biases
201
+ uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
202
+ uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
203
+
204
+ uint64_t max_exponent_double =
205
+ max_exponent - exponent_bias + exponent_bias_double;
206
+
207
+ // shift the mantissa into the position for a double and
208
+ // the exponent
209
+ uint64_t double_raw =
210
+ (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
211
+
212
+ return *reinterpret_cast<double*>(&double_raw);
213
+ }
214
+
215
+ constexpr std::variant<int64_t, double> _raw_max() const {
216
+ if (is_floating_point()) {
217
+ return {_floating_point_max()};
218
+ } else {
219
+ TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
220
+ "Cannot represent max as a int64_t");
221
+ return {(int64_t(1) << mantissa) - 1};
222
+ }
223
+ }
224
+
225
+ constexpr std::variant<int64_t, double> _raw_min() const {
226
+ if (is_floating_point()) {
227
+ TORCH_CHECK(is_signed(),
228
+ "We currently assume all floating point types are signed");
229
+ constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
230
+
231
+ double max = _floating_point_max();
232
+ uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
233
+ uint64_t min_raw = max_raw | sign_bit_double;
234
+ return {*reinterpret_cast<double*>(&min_raw)};
235
+ } else {
236
+ TORCH_CHECK(!is_signed() || size_bits() <= 64,
237
+ "Cannot represent min as a int64_t");
238
+ if (is_signed()) {
239
+ // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
240
+ // then perform an arithmetic shift right to set all the bits above
241
+ // (size_bits() - 1) to 1
242
+ return {INT64_MIN >> (64 - size_bits())};
243
+ } else {
244
+ return {int64_t(0)};
245
+ }
246
+ }
247
+ }
248
+
249
+ public:
250
+ // Max representable value for this scalar type.
251
+ // (accounting for bias if there is one)
252
+ constexpr std::variant<int64_t, double> max() const {
253
+ return std::visit(
254
+ [this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
255
+ _raw_max());
256
+ }
257
+
258
+ // Min representable value for this scalar type.
259
+ // (accounting for bias if there is one)
260
+ constexpr std::variant<int64_t, double> min() const {
261
+ return std::visit(
262
+ [this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
263
+ _raw_min());
264
+ }
265
+
266
+ std::string str() const {
267
+ /* naming generally follows: https://github.com/jax-ml/ml_dtypes
268
+ * for floating point types (leading f) the scheme is:
269
+ * `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
270
+ * flags:
271
+ * - no-flags: means it follows IEEE 754 conventions
272
+ * - f: means finite values only (no infinities)
273
+ * - n: means nans are supported (non-standard encoding)
274
+ * for integer types the scheme is:
275
+ * `[u]int<size_bits>[b<bias>]`
276
+ * - if bias is not present it means its zero
277
+ */
278
+ if (is_floating_point()) {
279
+ auto ret = "float" + std::to_string(size_bits()) + "_e" +
280
+ std::to_string(exponent) + "m" + std::to_string(mantissa);
281
+ if (!is_ieee_754()) {
282
+ if (finite_values_only) {
283
+ ret += "f";
284
+ }
285
+ if (nan_repr != NAN_NONE) {
286
+ ret += "n";
287
+ }
288
+ }
289
+ return ret;
290
+ } else {
291
+ auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
292
+ if (has_bias()) {
293
+ ret += "b" + std::to_string(bias);
294
+ }
295
+ return ret;
296
+ }
297
+ }
298
+
299
+ constexpr bool operator==(ScalarType const& other) const {
300
+ return mantissa == other.mantissa && exponent == other.exponent &&
301
+ bias == other.bias && signed_ == other.signed_ &&
302
+ finite_values_only == other.finite_values_only &&
303
+ nan_repr == other.nan_repr;
304
+ }
305
+ };
306
+
307
+ using ScalarTypeId = ScalarType::Id;
308
+
309
+ // "rust style" names generally following:
310
+ // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
311
+ static inline constexpr auto kS4 = ScalarType::int_(4);
312
+ static inline constexpr auto kU4 = ScalarType::uint(4);
313
+ static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
314
+ static inline constexpr auto kS8 = ScalarType::int_(8);
315
+ static inline constexpr auto kU8 = ScalarType::uint(8);
316
+ static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
317
+
318
+ static inline constexpr auto kFE3M2f =
319
+ ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
320
+ static inline constexpr auto kFE4M3fn =
321
+ ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
322
+ static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
323
+ static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
324
+ static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
325
+
326
+ // Fixed width style names, generally following:
327
+ // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
328
+ static inline constexpr auto kInt4 = kS4;
329
+ static inline constexpr auto kUint4 = kU4;
330
+ static inline constexpr auto kUint4b8 = kU4B8;
331
+ static inline constexpr auto kInt8 = kS8;
332
+ static inline constexpr auto kUint8 = kU8;
333
+ static inline constexpr auto kUint8b128 = kU8B128;
334
+
335
+ static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
336
+ static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
337
+ static inline constexpr auto kFloat8_e5m2 = kFE5M2;
338
+ static inline constexpr auto kFloat16_e8m7 = kFE8M7;
339
+ static inline constexpr auto kFloat16_e5m10 = kFE5M10;
340
+
341
+ // colloquial names
342
+ static inline constexpr auto kHalf = kFE5M10;
343
+ static inline constexpr auto kFloat16 = kHalf;
344
+ static inline constexpr auto kBFloat16 = kFE8M7;
345
+
346
+ static inline constexpr auto kFloat16Id = kFloat16.id();
347
+ }; // namespace vllm