csukuangfj commited on
Commit
b3907de
·
1 Parent(s): e7db88e

remove 1.16.0

Browse files
1.16.0/onnxruntime.xcframework/Headers/onnxruntime_c_api.h DELETED
The diff for this file is too large to render. See raw diff
 
1.16.0/onnxruntime.xcframework/Headers/onnxruntime_cxx_api.h DELETED
The diff for this file is too large to render. See raw diff
 
1.16.0/onnxruntime.xcframework/Headers/onnxruntime_cxx_inline.h DELETED
@@ -1,1886 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
5
- // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
6
- //
7
- // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
8
- // the main C++ file with implementation details.
9
-
10
- #include <cstring>
11
-
12
- namespace Ort {
13
-
14
- namespace detail {
15
- inline void ThrowStatus(const Status& st) {
16
- std::string error_message = st.GetErrorMessage();
17
- OrtErrorCode error_code = st.GetErrorCode();
18
- ORT_CXX_API_THROW(std::move(error_message), error_code);
19
- }
20
- } // namespace detail
21
-
22
- inline void ThrowOnError(OrtStatus* ort_status) {
23
- if (ort_status) {
24
- Ort::Status st(ort_status);
25
- detail::ThrowStatus(st);
26
- }
27
- }
28
-
29
- inline void ThrowOnError(const Status& st) {
30
- if (st) {
31
- detail::ThrowStatus(st);
32
- }
33
- }
34
-
35
- inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
36
- }
37
-
38
- inline Status::Status(const std::exception& e) noexcept {
39
- p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
40
- }
41
-
42
- inline Status::Status(const Exception& e) noexcept {
43
- p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
44
- }
45
-
46
- inline Status::Status(const char* message, OrtErrorCode code) noexcept {
47
- p_ = GetApi().CreateStatus(code, message);
48
- }
49
-
50
- inline std::string Status::GetErrorMessage() const {
51
- std::string message(GetApi().GetErrorMessage(p_));
52
- return message;
53
- }
54
-
55
- inline OrtErrorCode Status::GetErrorCode() const {
56
- return GetApi().GetErrorCode(p_);
57
- }
58
-
59
- inline bool Status::IsOK() const noexcept {
60
- return (p_ == nullptr);
61
- }
62
-
63
- // This template converts a C++ type into it's ONNXTensorElementDataType
64
- template <typename T>
65
- struct TypeToTensorType;
66
- template <>
67
- struct TypeToTensorType<float> {
68
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
69
- };
70
- template <>
71
- struct TypeToTensorType<Float16_t> {
72
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
73
- };
74
- template <>
75
- struct TypeToTensorType<BFloat16_t> {
76
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
77
- };
78
- template <>
79
- struct TypeToTensorType<double> {
80
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
81
- };
82
- template <>
83
- struct TypeToTensorType<int8_t> {
84
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
85
- };
86
- template <>
87
- struct TypeToTensorType<int16_t> {
88
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
89
- };
90
- template <>
91
- struct TypeToTensorType<int32_t> {
92
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
93
- };
94
- template <>
95
- struct TypeToTensorType<int64_t> {
96
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
97
- };
98
- template <>
99
- struct TypeToTensorType<uint8_t> {
100
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
101
- };
102
- template <>
103
- struct TypeToTensorType<uint16_t> {
104
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
105
- };
106
- template <>
107
- struct TypeToTensorType<uint32_t> {
108
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
109
- };
110
- template <>
111
- struct TypeToTensorType<uint64_t> {
112
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
113
- };
114
- template <>
115
- struct TypeToTensorType<bool> {
116
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
117
- };
118
-
119
- template <>
120
- struct TypeToTensorType<Float8E4M3FN_t> {
121
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN;
122
- };
123
- template <>
124
- struct TypeToTensorType<Float8E4M3FNUZ_t> {
125
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ;
126
- };
127
- template <>
128
- struct TypeToTensorType<Float8E5M2_t> {
129
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2;
130
- };
131
- template <>
132
- struct TypeToTensorType<Float8E5M2FNUZ_t> {
133
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
134
- };
135
-
136
- inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept {
137
- if (IsNaN() || rhs.IsNaN()) {
138
- // IEEE defines that NaN is not equal to anything, including itself.
139
- return false;
140
- }
141
- return val == rhs.val;
142
- }
143
-
144
- inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept {
145
- if (IsNaN() || rhs.IsNaN()) {
146
- // IEEE defines that NaN is unordered with respect to everything, including itself.
147
- return false;
148
- }
149
-
150
- const bool left_is_negative = IsNegative();
151
- if (left_is_negative != rhs.IsNegative()) {
152
- // When the signs of left and right differ, we know that left is less than right if it is
153
- // the negative value. The exception to this is if both values are zero, in which case IEEE
154
- // says they should be equal, even if the signs differ.
155
- return left_is_negative && !AreZero(*this, rhs);
156
- }
157
- return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
158
- }
159
-
160
- inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
161
- : allocator_(allocator), p_(p), size_(size) {
162
- }
163
-
164
- inline MemoryAllocation::~MemoryAllocation() {
165
- if (p_ != nullptr) {
166
- // We do not throw out of destructor
167
- auto ret = GetApi().AllocatorFree(allocator_, p_);
168
- static_cast<void>(ret);
169
- }
170
- }
171
-
172
- inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
173
- *this = std::move(o);
174
- }
175
-
176
- inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
177
- OrtAllocator* alloc = nullptr;
178
- void* p = nullptr;
179
- size_t sz = 0;
180
-
181
- // Swap out this
182
- std::swap(alloc, allocator_);
183
- std::swap(p, p_);
184
- std::swap(sz, size_);
185
-
186
- // Swap with incoming
187
- std::swap(allocator_, o.allocator_);
188
- std::swap(p_, o.p_);
189
- std::swap(size_, o.size_);
190
-
191
- // Destroy this instance if needed
192
- MemoryAllocation this_alloc(alloc, p, sz);
193
- return *this;
194
- }
195
-
196
- namespace detail {
197
-
198
- template <typename T>
199
- inline void* AllocatorImpl<T>::Alloc(size_t size) {
200
- void* out;
201
- ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
202
- return out;
203
- }
204
-
205
- template <typename T>
206
- inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
207
- void* out;
208
- ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
209
- MemoryAllocation result(this->p_, out, size);
210
- return result;
211
- }
212
-
213
- template <typename T>
214
- inline void AllocatorImpl<T>::Free(void* p) {
215
- ThrowOnError(GetApi().AllocatorFree(this->p_, p));
216
- }
217
-
218
- template <typename T>
219
- inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
220
- const OrtMemoryInfo* out;
221
- ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
222
- return ConstMemoryInfo{out};
223
- }
224
-
225
- } // namespace detail
226
-
227
- inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
228
- ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
229
- }
230
-
231
- inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
232
- ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
233
- }
234
-
235
- namespace detail {
236
-
237
- template <typename T>
238
- inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
239
- const char* name = nullptr;
240
- ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
241
- return std::string(name);
242
- }
243
-
244
- template <typename T>
245
- inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
246
- OrtAllocatorType type;
247
- ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
248
- return type;
249
- }
250
-
251
- template <typename T>
252
- inline int MemoryInfoImpl<T>::GetDeviceId() const {
253
- int id = 0;
254
- ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
255
- return id;
256
- }
257
-
258
- template <typename T>
259
- inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
260
- OrtMemoryInfoDeviceType type;
261
- GetApi().MemoryInfoGetDeviceType(this->p_, &type);
262
- return type;
263
- }
264
-
265
- template <typename T>
266
- inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
267
- OrtMemType type;
268
- ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
269
- return type;
270
- }
271
-
272
- template <typename T>
273
- template <typename U>
274
- inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
275
- int comp_result = 0;
276
- ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
277
- return comp_result == 0;
278
- }
279
-
280
- } // namespace detail
281
-
282
- inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
283
- OrtMemoryInfo* p;
284
- ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
285
- return MemoryInfo(p);
286
- }
287
-
288
- inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
289
- ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
290
- }
291
-
292
- namespace detail {
293
- template <typename T>
294
- inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
295
- AllocatorWithDefaultOptions allocator;
296
- return binding_utils::GetOutputNamesHelper(this->p_, allocator);
297
- }
298
-
299
- template <typename T>
300
- inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
301
- return binding_utils::GetOutputNamesHelper(this->p_, allocator);
302
- }
303
-
304
- template <typename T>
305
- inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
306
- AllocatorWithDefaultOptions allocator;
307
- return binding_utils::GetOutputValuesHelper(this->p_, allocator);
308
- }
309
-
310
- template <typename T>
311
- inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
312
- return binding_utils::GetOutputValuesHelper(this->p_, allocator);
313
- }
314
-
315
- template <typename T>
316
- inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
317
- ThrowOnError(GetApi().BindInput(this->p_, name, value));
318
- }
319
-
320
- template <typename T>
321
- inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
322
- ThrowOnError(GetApi().BindOutput(this->p_, name, value));
323
- }
324
-
325
- template <typename T>
326
- inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
327
- ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
328
- }
329
-
330
- template <typename T>
331
- inline void IoBindingImpl<T>::ClearBoundInputs() {
332
- GetApi().ClearBoundInputs(this->p_);
333
- }
334
-
335
- template <typename T>
336
- inline void IoBindingImpl<T>::ClearBoundOutputs() {
337
- GetApi().ClearBoundOutputs(this->p_);
338
- }
339
-
340
- template <typename T>
341
- inline void IoBindingImpl<T>::SynchronizeInputs() {
342
- ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
343
- }
344
-
345
- template <typename T>
346
- inline void IoBindingImpl<T>::SynchronizeOutputs() {
347
- ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
348
- }
349
-
350
- namespace binding_utils {
351
- inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
352
- std::vector<std::string> result;
353
- auto free_fn = detail::AllocatedFree(allocator);
354
- using Ptr = std::unique_ptr<void, decltype(free_fn)>;
355
-
356
- char* buffer = nullptr;
357
- size_t* lengths = nullptr;
358
- size_t count = 0;
359
- ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
360
-
361
- if (count == 0) {
362
- return result;
363
- }
364
-
365
- Ptr buffer_g(buffer, free_fn);
366
- Ptr lengths_g(lengths, free_fn);
367
-
368
- result.reserve(count);
369
- for (size_t i = 0; i < count; ++i) {
370
- auto sz = *lengths;
371
- result.emplace_back(buffer, sz);
372
- buffer += sz;
373
- ++lengths;
374
- }
375
- return result;
376
- }
377
-
378
- inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
379
- std::vector<Value> result;
380
- size_t owned = 0;
381
- size_t output_count = 0;
382
- // Lambda to release the buffer when no longer needed and
383
- // make sure that we destroy all instances on exception
384
- auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
385
- if (buffer) {
386
- while (owned < output_count) {
387
- auto* p = buffer + owned++;
388
- GetApi().ReleaseValue(*p);
389
- }
390
- allocator->Free(allocator, buffer);
391
- }
392
- };
393
- using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
394
-
395
- OrtValue** output_buffer = nullptr;
396
- ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
397
- if (output_count == 0) {
398
- return result;
399
- }
400
-
401
- Ptr buffer_g(output_buffer, free_fn);
402
-
403
- result.reserve(output_count);
404
- for (size_t i = 0; i < output_count; ++i) {
405
- result.emplace_back(output_buffer[i]);
406
- ++owned;
407
- }
408
- return result;
409
- }
410
-
411
- } // namespace binding_utils
412
- } // namespace detail
413
-
414
- inline IoBinding::IoBinding(Session& session) {
415
- ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
416
- }
417
-
418
- inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
419
- ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
420
- }
421
-
422
- inline ThreadingOptions::ThreadingOptions() {
423
- ThrowOnError(GetApi().CreateThreadingOptions(&p_));
424
- }
425
-
426
- inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
427
- ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
428
- return *this;
429
- }
430
-
431
- inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
432
- ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
433
- return *this;
434
- }
435
-
436
- inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
437
- ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
438
- return *this;
439
- }
440
-
441
- inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
442
- ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
443
- return *this;
444
- }
445
-
446
- inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
447
- ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
448
- return *this;
449
- }
450
-
451
- inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
452
- ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
453
- return *this;
454
- }
455
-
456
- inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
457
- ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
458
- return *this;
459
- }
460
-
461
- inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
462
- ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
463
- if (strcmp(logid, "onnxruntime-node") == 0) {
464
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
465
- } else {
466
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
467
- }
468
- }
469
-
470
- inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
471
- ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
472
- if (strcmp(logid, "onnxruntime-node") == 0) {
473
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
474
- } else {
475
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
476
- }
477
- }
478
-
479
- inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
480
- ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
481
- if (strcmp(logid, "onnxruntime-node") == 0) {
482
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
483
- } else {
484
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
485
- }
486
- }
487
-
488
- inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
489
- OrtLoggingLevel logging_level, _In_ const char* logid) {
490
- ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
491
- if (strcmp(logid, "onnxruntime-node") == 0) {
492
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
493
- } else {
494
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
495
- }
496
- }
497
-
498
- inline Env& Env::EnableTelemetryEvents() {
499
- ThrowOnError(GetApi().EnableTelemetryEvents(p_));
500
- return *this;
501
- }
502
-
503
- inline Env& Env::DisableTelemetryEvents() {
504
- ThrowOnError(GetApi().DisableTelemetryEvents(p_));
505
- return *this;
506
- }
507
-
508
- inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
509
- ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
510
- return *this;
511
- }
512
-
513
- inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
514
- ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
515
- return *this;
516
- }
517
-
518
- inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg) {
519
- std::vector<const char*> keys, values;
520
- auto num_entries = options.size();
521
- if (num_entries > 0) {
522
- keys.reserve(num_entries);
523
- values.reserve(num_entries);
524
- for (const auto& entry : options) {
525
- keys.push_back(entry.first.c_str());
526
- values.push_back(entry.second.c_str());
527
- }
528
- }
529
- ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries));
530
- return *this;
531
- }
532
-
533
- inline CustomOpDomain::CustomOpDomain(const char* domain) {
534
- ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
535
- }
536
-
537
- inline void CustomOpDomain::Add(const OrtCustomOp* op) {
538
- ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
539
- }
540
-
541
- inline RunOptions::RunOptions() {
542
- ThrowOnError(GetApi().CreateRunOptions(&p_));
543
- }
544
-
545
- inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
546
- ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
547
- return *this;
548
- }
549
-
550
- inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
551
- ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
552
- return *this;
553
- }
554
-
555
- inline int RunOptions::GetRunLogVerbosityLevel() const {
556
- int out;
557
- ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
558
- return out;
559
- }
560
-
561
- inline int RunOptions::GetRunLogSeverityLevel() const {
562
- int out;
563
- ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
564
- return out;
565
- }
566
-
567
- inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
568
- ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
569
- return *this;
570
- }
571
-
572
- inline const char* RunOptions::GetRunTag() const {
573
- const char* out;
574
- ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
575
- return out;
576
- }
577
-
578
- inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
579
- ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
580
- return *this;
581
- }
582
-
583
- inline RunOptions& RunOptions::SetTerminate() {
584
- ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
585
- return *this;
586
- }
587
-
588
- inline RunOptions& RunOptions::UnsetTerminate() {
589
- ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
590
- return *this;
591
- }
592
-
593
- namespace detail {
594
-
595
- template <typename T>
596
- inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
597
- OrtSessionOptions* out;
598
- ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
599
- return SessionOptions{out};
600
- }
601
-
602
- template <typename T>
603
- inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
604
- size_t size = 0;
605
- // Feed nullptr for the data buffer to query the true size of the string value
606
- Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
607
-
608
- std::string out;
609
- out.resize(size);
610
- Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
611
- out.resize(size - 1); // remove the terminating character '\0'
612
-
613
- return out;
614
- }
615
-
616
- template <typename T>
617
- inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
618
- int out = 0;
619
- Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
620
- return static_cast<bool>(out);
621
- }
622
-
623
- template <typename T>
624
- inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
625
- if (!this->HasConfigEntry(config_key)) {
626
- return def;
627
- }
628
-
629
- return this->GetConfigEntry(config_key);
630
- }
631
-
632
- template <typename T>
633
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
634
- ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
635
- return *this;
636
- }
637
-
638
- template <typename T>
639
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
640
- ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
641
- return *this;
642
- }
643
-
644
- template <typename T>
645
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
646
- ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
647
- return *this;
648
- }
649
-
650
- template <typename T>
651
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
652
- ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
653
- return *this;
654
- }
655
-
656
- template <typename T>
657
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
658
- ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
659
- return *this;
660
- }
661
-
662
- template <typename T>
663
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
664
- ThrowOnError(GetApi().DisableProfiling(this->p_));
665
- return *this;
666
- }
667
-
668
- template <typename T>
669
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
670
- ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
671
- return *this;
672
- }
673
-
674
- template <typename T>
675
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
676
- ThrowOnError(GetApi().EnableMemPattern(this->p_));
677
- return *this;
678
- }
679
-
680
- template <typename T>
681
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
682
- ThrowOnError(GetApi().DisableMemPattern(this->p_));
683
- return *this;
684
- }
685
-
686
- template <typename T>
687
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
688
- ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
689
- return *this;
690
- }
691
-
692
- template <typename T>
693
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
694
- ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
695
- return *this;
696
- }
697
-
698
- template <typename T>
699
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
700
- ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
701
- return *this;
702
- }
703
-
704
- template <typename T>
705
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
706
- ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
707
- return *this;
708
- }
709
-
710
- template <typename T>
711
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
712
- ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
713
- return *this;
714
- }
715
-
716
- template <typename T>
717
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
718
- ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
719
- return *this;
720
- }
721
-
722
- template <typename T>
723
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
724
- ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
725
- return *this;
726
- }
727
-
728
- template <typename T>
729
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
730
- ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
731
- return *this;
732
- }
733
-
734
- template <typename T>
735
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
736
- ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
737
- return *this;
738
- }
739
-
740
- template <typename T>
741
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
742
- const std::vector<Value>& ort_values) {
743
- const size_t inputs_num = names.size();
744
- if (inputs_num != ort_values.size()) {
745
- ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
746
- }
747
- std::vector<const char*> names_ptr;
748
- std::vector<const OrtValue*> ort_values_ptrs;
749
- names_ptr.reserve(inputs_num);
750
- ort_values_ptrs.reserve(inputs_num);
751
- for (size_t i = 0; i < inputs_num; ++i) {
752
- names_ptr.push_back(names[i].c_str());
753
- ort_values_ptrs.push_back(ort_values[i]);
754
- }
755
- ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
756
- return *this;
757
- }
758
-
759
- template <typename T>
760
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
761
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
762
- return *this;
763
- }
764
-
765
- template <typename T>
766
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
767
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
768
- return *this;
769
- }
770
-
771
- template <typename T>
772
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
773
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
774
- return *this;
775
- }
776
-
777
- template <typename T>
778
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
779
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
780
- return *this;
781
- }
782
-
783
- template <typename T>
784
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
785
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
786
- return *this;
787
- }
788
-
789
- template <typename T>
790
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
791
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
792
- return *this;
793
- }
794
-
795
- template <typename T>
796
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
797
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
798
- return *this;
799
- }
800
-
801
- template <typename T>
802
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) {
803
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
804
- return *this;
805
- }
806
-
807
- template <typename T>
808
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
809
- const std::string& provider_name,
810
- const std::unordered_map<std::string, std::string>& provider_options) {
811
- auto num_entries = provider_options.size();
812
- std::vector<const char*> keys, values;
813
- if (num_entries > 0) {
814
- keys.reserve(num_entries);
815
- values.reserve(num_entries);
816
-
817
- for (const auto& entry : provider_options) {
818
- keys.push_back(entry.first.c_str());
819
- values.push_back(entry.second.c_str());
820
- }
821
- }
822
-
823
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
824
- keys.data(), values.data(), num_entries));
825
-
826
- return *this;
827
- }
828
-
829
- template <typename T>
830
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
831
- ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
832
- return *this;
833
- }
834
-
835
- template <typename T>
836
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
837
- ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
838
- return *this;
839
- }
840
-
841
- template <typename T>
842
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
843
- ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
844
- return *this;
845
- }
846
-
847
- template <typename T>
848
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
849
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
850
- return *this;
851
- }
852
-
853
- template <typename T>
854
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
855
- const CustomOpConfigs& custom_op_configs) {
856
- // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
857
- // the custom op library.
858
- for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
859
- AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
860
- }
861
-
862
- ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
863
- return *this;
864
- }
865
-
866
- template <typename T>
867
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
868
- ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
869
- return *this;
870
- }
871
-
872
- /// Session
873
- template <typename T>
874
- inline size_t ConstSessionImpl<T>::GetInputCount() const {
875
- size_t out;
876
- ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
877
- return out;
878
- }
879
-
880
- template <typename T>
881
- inline size_t ConstSessionImpl<T>::GetOutputCount() const {
882
- size_t out;
883
- ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
884
- return out;
885
- }
886
-
887
- template <typename T>
888
- inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
889
- size_t out;
890
- ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
891
- return out;
892
- }
893
-
894
- template <typename T>
895
- inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
896
- char* out;
897
- ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
898
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
899
- }
900
-
901
- template <typename T>
902
- inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
903
- char* out;
904
- ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
905
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
906
- }
907
-
908
- template <typename T>
909
- inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
910
- char* out;
911
- ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
912
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
913
- }
914
-
915
- template <typename T>
916
- inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
917
- uint64_t out;
918
- ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
919
- return out;
920
- }
921
-
922
- template <typename T>
923
- inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
924
- OrtModelMetadata* out;
925
- ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
926
- return ModelMetadata{out};
927
- }
928
-
929
- template <typename T>
930
- inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
931
- OrtTypeInfo* out;
932
- ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
933
- return TypeInfo{out};
934
- }
935
-
936
- template <typename T>
937
- inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
938
- OrtTypeInfo* out;
939
- ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
940
- return TypeInfo{out};
941
- }
942
-
943
- template <typename T>
944
- inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
945
- OrtTypeInfo* out;
946
- ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
947
- return TypeInfo{out};
948
- }
949
-
950
- template <typename T>
951
- inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
952
- const char* const* output_names, size_t output_count) {
953
- std::vector<Value> output_values;
954
- output_values.reserve(output_count);
955
- for (size_t i = 0; i < output_count; i++)
956
- output_values.emplace_back(nullptr);
957
- Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
958
- return output_values;
959
- }
960
-
961
- template <typename T>
962
- inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
963
- const char* const* output_names, Value* output_values, size_t output_count) {
964
- static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
965
- auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
966
- auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
967
- ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
968
- }
969
-
970
- template <typename T>
971
- inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
972
- ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
973
- }
974
-
975
- template <typename T>
976
- inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
977
- const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
978
- auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
979
- auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
980
- ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
981
- ort_input_values, input_count, output_names, output_count,
982
- ort_output_values, callback, user_data));
983
- }
984
-
985
- template <typename T>
986
- inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
987
- char* out = nullptr;
988
- ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
989
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
990
- }
991
-
992
- } // namespace detail
993
-
994
- inline SessionOptions::SessionOptions() {
995
- ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
996
- }
997
-
998
- /// CustomOpConfigs
999
- inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
1000
- std::string config_key = "custom_op.";
1001
-
1002
- config_key += custom_op_name;
1003
- config_key += ".";
1004
- config_key += config;
1005
-
1006
- return config_key;
1007
- }
1008
-
1009
- inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
1010
- const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
1011
- flat_configs_[full_flat_key] = config_value;
1012
- return *this;
1013
- }
1014
-
1015
- inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
1016
- return flat_configs_;
1017
- }
1018
-
1019
- inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
1020
- ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
1021
- }
1022
-
1023
- inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1024
- OrtPrepackedWeightsContainer* prepacked_weights_container) {
1025
- ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
1026
- }
1027
-
1028
- inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
1029
- ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
1030
- }
1031
-
1032
- inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
1033
- const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
1034
- ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
1035
- prepacked_weights_container, &this->p_));
1036
- }
1037
-
1038
- inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
1039
- char* out;
1040
- ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
1041
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1042
- }
1043
-
1044
- inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
1045
- char* out;
1046
- ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
1047
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1048
- }
1049
-
1050
- inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
1051
- char* out;
1052
- ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
1053
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1054
- }
1055
-
1056
- inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
1057
- char* out;
1058
- ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
1059
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1060
- }
1061
-
1062
- inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
1063
- char* out;
1064
- ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
1065
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1066
- }
1067
-
1068
- inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
1069
- char* out;
1070
- ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
1071
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1072
- }
1073
-
1074
- inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
1075
- auto deletor = detail::AllocatedFree(allocator);
1076
- std::vector<AllocatedStringPtr> result;
1077
-
1078
- char** out = nullptr;
1079
- int64_t num_keys = 0;
1080
- ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
1081
- if (num_keys <= 0) {
1082
- return result;
1083
- }
1084
-
1085
- // array of pointers will be freed
1086
- std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1087
- // reserve may throw
1088
- auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1089
- std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1090
- result.reserve(static_cast<size_t>(num_keys));
1091
- strings_guard.release();
1092
- for (int64_t i = 0; i < num_keys; ++i) {
1093
- result.push_back(AllocatedStringPtr(out[i], deletor));
1094
- }
1095
-
1096
- return result;
1097
- }
1098
-
1099
- inline int64_t ModelMetadata::GetVersion() const {
1100
- int64_t out;
1101
- ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
1102
- return out;
1103
- }
1104
-
1105
- namespace detail {
1106
-
1107
- template <typename T>
1108
- inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
1109
- ONNXTensorElementDataType out;
1110
- ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
1111
- return out;
1112
- }
1113
-
1114
- template <typename T>
1115
- inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
1116
- size_t out;
1117
- ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
1118
- return static_cast<size_t>(out);
1119
- }
1120
-
1121
- template <typename T>
1122
- inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
1123
- size_t out;
1124
- ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
1125
- return out;
1126
- }
1127
-
1128
- template <typename T>
1129
- inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
1130
- ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
1131
- }
1132
-
1133
- template <typename T>
1134
- inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
1135
- ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1136
- }
1137
-
1138
- template <typename T>
1139
- inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1140
- std::vector<int64_t> out(GetDimensionsCount(), 0);
1141
- ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1142
- return out;
1143
- }
1144
-
1145
- template <typename T>
1146
- inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
1147
- const OrtTensorTypeAndShapeInfo* out;
1148
- ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
1149
- return ConstTensorTypeAndShapeInfo{out};
1150
- }
1151
-
1152
- template <typename T>
1153
- inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
1154
- const OrtSequenceTypeInfo* out;
1155
- ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
1156
- return ConstSequenceTypeInfo{out};
1157
- }
1158
-
1159
- template <typename T>
1160
- inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
1161
- const OrtMapTypeInfo* out;
1162
- ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
1163
- return ConstMapTypeInfo{out};
1164
- }
1165
-
1166
- template <typename T>
1167
- inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
1168
- ONNXType out;
1169
- ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
1170
- return out;
1171
- }
1172
-
1173
- template <typename T>
1174
- inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
1175
- OrtTypeInfo* output;
1176
- ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
1177
- return TypeInfo{output};
1178
- }
1179
-
1180
- template <typename T>
1181
- inline TypeInfo OptionalTypeInfoImpl<T>::GetOptionalElementType() const {
1182
- OrtTypeInfo* info;
1183
- ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info));
1184
- return TypeInfo{info};
1185
- }
1186
-
1187
- template <typename T>
1188
- inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
1189
- ONNXTensorElementDataType out;
1190
- ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
1191
- return out;
1192
- }
1193
-
1194
- template <typename T>
1195
- inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
1196
- OrtTypeInfo* output;
1197
- ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
1198
- return TypeInfo{output};
1199
- }
1200
-
1201
- template <typename T>
1202
- inline ConstOptionalTypeInfo TypeInfoImpl<T>::GetOptionalTypeInfo() const {
1203
- const OrtOptionalTypeInfo* info;
1204
- ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info));
1205
- return ConstOptionalTypeInfo{info};
1206
- }
1207
-
1208
- } // namespace detail
1209
-
1210
- namespace detail {
1211
-
1212
- template <typename T>
1213
- template <typename R>
1214
- inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
1215
- ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
1216
- }
1217
-
1218
- template <typename T>
1219
- inline bool ConstValueImpl<T>::IsTensor() const {
1220
- int out;
1221
- ThrowOnError(GetApi().IsTensor(this->p_, &out));
1222
- return out != 0;
1223
- }
1224
-
1225
- template <typename T>
1226
- inline bool ConstValueImpl<T>::HasValue() const {
1227
- int out;
1228
- ThrowOnError(GetApi().HasValue(this->p_, &out));
1229
- return out != 0;
1230
- }
1231
-
1232
- template <typename T>
1233
- inline size_t ConstValueImpl<T>::GetCount() const {
1234
- size_t out;
1235
- ThrowOnError(GetApi().GetValueCount(this->p_, &out));
1236
- return out;
1237
- }
1238
-
1239
- template <typename T>
1240
- inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
1241
- OrtValue* out;
1242
- ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
1243
- return Value{out};
1244
- }
1245
-
1246
- template <typename T>
1247
- inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
1248
- size_t out;
1249
- ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
1250
- return out;
1251
- }
1252
-
1253
- template <typename T>
1254
- inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
1255
- size_t out;
1256
- ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
1257
- return out;
1258
- }
1259
-
1260
- template <typename T>
1261
- template <typename R>
1262
- inline const R* ConstValueImpl<T>::GetTensorData() const {
1263
- R* out;
1264
- ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
1265
- return out;
1266
- }
1267
-
1268
- template <typename T>
1269
- inline const void* ConstValueImpl<T>::GetTensorRawData() const {
1270
- void* out;
1271
- ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
1272
- return out;
1273
- }
1274
-
1275
- template <typename T>
1276
- inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
1277
- OrtTypeInfo* output;
1278
- ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
1279
- return TypeInfo{output};
1280
- }
1281
-
1282
- template <typename T>
1283
- inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
1284
- OrtTensorTypeAndShapeInfo* output;
1285
- ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
1286
- return TensorTypeAndShapeInfo{output};
1287
- }
1288
-
1289
- template <typename T>
1290
- inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
1291
- const OrtMemoryInfo* mem_info;
1292
- ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
1293
- return ConstMemoryInfo(mem_info);
1294
- }
1295
-
1296
- template <typename T>
1297
- inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
1298
- ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1299
- }
1300
-
1301
- template <typename T>
1302
- inline std::string ConstValueImpl<T>::GetStringTensorElement(size_t element_index) const {
1303
- size_t buffer_length;
1304
- ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length));
1305
-
1306
- std::string s;
1307
- s.resize(buffer_length);
1308
- ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0]));
1309
- return s;
1310
- }
1311
-
1312
- template <typename T>
1313
- inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
1314
- ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1315
- }
1316
-
1317
- #if !defined(DISABLE_SPARSE_TENSORS)
1318
- template <typename T>
1319
- inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
1320
- OrtSparseFormat format;
1321
- ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
1322
- return format;
1323
- }
1324
-
1325
- template <typename T>
1326
- inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
1327
- OrtTensorTypeAndShapeInfo* output;
1328
- ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
1329
- return TensorTypeAndShapeInfo{output};
1330
- }
1331
-
1332
- template <typename T>
1333
- inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
1334
- OrtTensorTypeAndShapeInfo* output;
1335
- ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1336
- return TensorTypeAndShapeInfo{output};
1337
- }
1338
-
1339
- template <typename T>
1340
- template <typename R>
1341
- inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
1342
- const void* out;
1343
- ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1344
- return reinterpret_cast<const R*>(out);
1345
- }
1346
-
1347
- template <typename T>
1348
- inline bool ConstValueImpl<T>::IsSparseTensor() const {
1349
- int out;
1350
- ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
1351
- return out != 0;
1352
- }
1353
-
1354
- template <typename T>
1355
- template <typename R>
1356
- inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
1357
- const void* out;
1358
- ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
1359
- return reinterpret_cast<const R*>(out);
1360
- }
1361
-
1362
- #endif
1363
-
1364
- template <typename T>
1365
- void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
1366
- ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
1367
- }
1368
-
1369
- template <typename T>
1370
- void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
1371
- ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
1372
- }
1373
-
1374
- template <typename T>
1375
- inline char* ValueImpl<T>::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) {
1376
- char* result;
1377
- ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result));
1378
- return result;
1379
- }
1380
-
1381
- template <typename T>
1382
- void* ValueImpl<T>::GetTensorMutableRawData() {
1383
- void* out;
1384
- ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
1385
- return out;
1386
- }
1387
-
1388
- template <typename T>
1389
- template <typename R>
1390
- R* ValueImpl<T>::GetTensorMutableData() {
1391
- R* out;
1392
- ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
1393
- return out;
1394
- }
1395
-
1396
- template <typename T>
1397
- template <typename R>
1398
- R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
1399
- static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
1400
- R* out;
1401
- ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
1402
- return *out;
1403
- }
1404
-
1405
- #if !defined(DISABLE_SPARSE_TENSORS)
1406
- template <typename T>
1407
- void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
1408
- ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
1409
- }
1410
-
1411
- template <typename T>
1412
- void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
1413
- ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1414
- }
1415
-
1416
- template <typename T>
1417
- void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
1418
- ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
1419
- }
1420
-
1421
- template <typename T>
1422
- void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
1423
- const int64_t* indices_data, size_t indices_num) {
1424
- ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
1425
- values_param.values_shape_len, values_param.data.p_data,
1426
- indices_data, indices_num));
1427
- }
1428
-
1429
- template <typename T>
1430
- void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1431
- const OrtSparseValuesParam& values,
1432
- const int64_t* inner_indices_data, size_t inner_indices_num,
1433
- const int64_t* outer_indices_data, size_t outer_indices_num) {
1434
- ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1435
- inner_indices_data, inner_indices_num,
1436
- outer_indices_data, outer_indices_num));
1437
- }
1438
-
1439
- template <typename T>
1440
- void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1441
- const OrtSparseValuesParam& values,
1442
- const Shape& indices_shape,
1443
- const int32_t* indices_data) {
1444
- ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1445
- indices_shape.shape, indices_shape.shape_len,
1446
- indices_data));
1447
- }
1448
-
1449
- #endif // !defined(DISABLE_SPARSE_TENSORS)
1450
-
1451
- } // namespace detail
1452
-
1453
- template <typename T>
1454
- inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
1455
- return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1456
- }
1457
-
1458
- inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1459
- ONNXTensorElementDataType type) {
1460
- OrtValue* out;
1461
- ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1462
- return Value{out};
1463
- }
1464
-
1465
- template <typename T>
1466
- inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1467
- return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1468
- }
1469
-
1470
- inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
1471
- OrtValue* out;
1472
- ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1473
- return Value{out};
1474
- }
1475
-
1476
- #if !defined(DISABLE_SPARSE_TENSORS)
1477
-
1478
- template <typename T>
1479
- inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1480
- const Shape& values_shape) {
1481
- return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
1482
- }
1483
-
1484
- inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1485
- const Shape& values_shape, ONNXTensorElementDataType type) {
1486
- OrtValue* out;
1487
- ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1488
- values_shape.shape, values_shape.shape_len, type, &out));
1489
- return Value{out};
1490
- }
1491
-
1492
- template <typename T>
1493
- inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
1494
- return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
1495
- }
1496
-
1497
- inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
1498
- ONNXTensorElementDataType type) {
1499
- OrtValue* out;
1500
- ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
1501
- return Value{out};
1502
- }
1503
- #endif // !defined(DISABLE_SPARSE_TENSORS)
1504
-
1505
- inline Value Value::CreateMap(const Value& keys, const Value& values) {
1506
- OrtValue* out;
1507
- const OrtValue* inputs[2] = {keys, values};
1508
- ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
1509
- return Value{out};
1510
- }
1511
-
1512
- inline Value Value::CreateSequence(const std::vector<Value>& values) {
1513
- OrtValue* out;
1514
- std::vector<const OrtValue*> values_ort{values.data(), values.data() + values.size()};
1515
- ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1516
- return Value{out};
1517
- }
1518
-
1519
- template <typename T>
1520
- inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
1521
- OrtValue* out;
1522
- ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
1523
- return Value{out};
1524
- }
1525
-
1526
- //
1527
- // Custom OP Inlines
1528
- //
1529
- inline Logger::Logger(const OrtLogger* logger) : logger_(logger) {
1530
- Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_));
1531
- }
1532
-
1533
- inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept {
1534
- return cached_severity_level_;
1535
- }
1536
-
1537
- inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
1538
- const char* func_name, const char* message) const noexcept {
1539
- OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number,
1540
- func_name);
1541
- return Status{status};
1542
- }
1543
-
1544
- // Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security)
1545
- // for gcc and clang. The alternative is to use actual C-style variadic parameters and apply
1546
- // __attribute__(format(printf...)), which does not work with variadic templates.
1547
- #if defined(__GNUC__)
1548
- #pragma GCC diagnostic push
1549
- #pragma GCC diagnostic ignored "-Wformat-nonliteral"
1550
- #pragma GCC diagnostic ignored "-Wformat-security"
1551
- #elif defined(__clang__)
1552
- #pragma clang diagnostic push
1553
- #pragma clang diagnostic ignored "-Wformat-nonliteral"
1554
- #pragma clang diagnostic ignored "-Wformat-security"
1555
- #endif
1556
- template <typename... Args>
1557
- inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
1558
- int line_number, const char* func_name, const char* format,
1559
- Args&&... args) const noexcept {
1560
- int msg_len = std::snprintf(nullptr, 0U, format, std::forward<Args>(args)...);
1561
-
1562
- if (msg_len < 0) { // Formatting error
1563
- return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL);
1564
- }
1565
-
1566
- OrtStatus* status = nullptr;
1567
- const size_t buffer_size = static_cast<size_t>(msg_len) + 1U;
1568
-
1569
- constexpr size_t kStackBufferSize = 1024;
1570
-
1571
- if (buffer_size < kStackBufferSize) {
1572
- char buffer[kStackBufferSize];
1573
- snprintf(buffer, kStackBufferSize, format, std::forward<Args>(args)...);
1574
- status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name);
1575
- } else {
1576
- // std::make_unique is only supported starting at C++14.
1577
- #if (__cplusplus >= 201402L) || (_MSC_VER >= 1900)
1578
- auto buffer = std::make_unique<char[]>(buffer_size);
1579
- #else
1580
- std::unique_ptr<char[]> buffer(new char[buffer_size]);
1581
- #endif
1582
- std::snprintf(buffer.get(), buffer_size, format, std::forward<Args>(args)...);
1583
- status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name);
1584
- }
1585
-
1586
- return Status{status};
1587
- }
1588
- // Re-enable -Wformat-nonliteral and -Wformat-security
1589
- #if defined(__GNUC__)
1590
- #pragma GCC diagnostic pop
1591
- #elif defined(__clang__)
1592
- #pragma clang diagnostic pop
1593
- #endif
1594
-
1595
- inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
1596
- }
1597
-
1598
- inline size_t KernelContext::GetInputCount() const {
1599
- size_t out = 0;
1600
- Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
1601
- return out;
1602
- }
1603
-
1604
- inline size_t KernelContext::GetOutputCount() const {
1605
- size_t out = 0;
1606
- Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
1607
- return out;
1608
- }
1609
-
1610
- inline ConstValue KernelContext::GetInput(size_t index) const {
1611
- const OrtValue* out = nullptr;
1612
- Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
1613
- return ConstValue{out};
1614
- }
1615
-
1616
- inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
1617
- OrtValue* out = nullptr;
1618
- Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
1619
- return UnownedValue(out);
1620
- }
1621
-
1622
- inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
1623
- OrtValue* out = nullptr;
1624
- Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
1625
- return UnownedValue(out);
1626
- }
1627
-
1628
- inline void* KernelContext::GetGPUComputeStream() const {
1629
- void* out = nullptr;
1630
- Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
1631
- return out;
1632
- }
1633
-
1634
- inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const {
1635
- OrtAllocator* out = nullptr;
1636
- Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out));
1637
- return out;
1638
- }
1639
-
1640
- inline Logger KernelContext::GetLogger() const {
1641
- const OrtLogger* out = nullptr;
1642
- ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out));
1643
- return Logger{out};
1644
- }
1645
-
1646
- inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
1647
- Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
1648
- }
1649
-
1650
- namespace detail {
1651
- template <typename T>
1652
- inline KernelInfo KernelInfoImpl<T>::Copy() const {
1653
- OrtKernelInfo* info_copy = nullptr;
1654
- Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
1655
- return KernelInfo{info_copy};
1656
- }
1657
-
1658
- template <typename T>
1659
- inline size_t KernelInfoImpl<T>::GetInputCount() const {
1660
- size_t out = 0;
1661
- ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
1662
- return out;
1663
- }
1664
-
1665
- template <typename T>
1666
- inline size_t KernelInfoImpl<T>::GetOutputCount() const {
1667
- size_t out = 0;
1668
- ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
1669
- return out;
1670
- }
1671
-
1672
- template <typename T>
1673
- inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
1674
- size_t size = 0;
1675
-
1676
- // Feed nullptr for the data buffer to query the true size of the string value
1677
- Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
1678
-
1679
- std::string out;
1680
- out.resize(size);
1681
- Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
1682
- out.resize(size - 1); // remove the terminating character '\0'
1683
-
1684
- return out;
1685
- }
1686
-
1687
- template <typename T>
1688
- inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
1689
- size_t size = 0;
1690
-
1691
- // Feed nullptr for the data buffer to query the true size of the string value
1692
- Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
1693
-
1694
- std::string out;
1695
- out.resize(size);
1696
- Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
1697
- out.resize(size - 1); // remove the terminating character '\0'
1698
-
1699
- return out;
1700
- }
1701
-
1702
- template <typename T>
1703
- inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
1704
- OrtTypeInfo* out = nullptr;
1705
- ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
1706
- return TypeInfo{out};
1707
- }
1708
-
1709
- template <typename T>
1710
- inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
1711
- OrtTypeInfo* out = nullptr;
1712
- ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
1713
- return TypeInfo{out};
1714
- }
1715
-
1716
- template <typename T>
1717
- inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
1718
- OrtValue* out = nullptr;
1719
- ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1720
- return Value{out};
1721
- }
1722
-
1723
- template <typename T>
1724
- inline ConstValue KernelInfoImpl<T>::GetTensorConstantInput(size_t index, int* is_constant) const {
1725
- const OrtValue* out = nullptr;
1726
- ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out));
1727
- return ConstValue{out};
1728
- }
1729
-
1730
- template <typename T>
1731
- inline std::string KernelInfoImpl<T>::GetNodeName() const {
1732
- size_t size = 0;
1733
-
1734
- // Feed nullptr for the data buffer to query the true size of the string value
1735
- Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size));
1736
-
1737
- std::string out;
1738
- out.resize(size);
1739
- Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size));
1740
- out.resize(size - 1); // remove the terminating character '\0'
1741
-
1742
- return out;
1743
- }
1744
-
1745
- template <typename T>
1746
- inline Logger KernelInfoImpl<T>::GetLogger() const {
1747
- const OrtLogger* out = nullptr;
1748
- ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out));
1749
- return Logger{out};
1750
- }
1751
-
1752
- inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
1753
- Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
1754
- }
1755
-
1756
- inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
1757
- Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
1758
- }
1759
-
1760
- inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
1761
- size_t size = 0;
1762
- // Feed nullptr for the data buffer to query the true size of the string attribute
1763
- Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
1764
-
1765
- std::string out;
1766
- out.resize(size);
1767
- Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
1768
- out.resize(size - 1); // remove the terminating character '\0'
1769
- out.swap(result);
1770
- }
1771
-
1772
- inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
1773
- size_t size = 0;
1774
- // Feed nullptr for the data buffer to query the true size of the attribute
1775
- Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
1776
-
1777
- std::vector<float> out;
1778
- out.resize(size);
1779
- Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
1780
- out.swap(result);
1781
- }
1782
-
1783
- inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
1784
- size_t size = 0;
1785
-
1786
- // Feed nullptr for the data buffer to query the true size of the attribute
1787
- Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
1788
-
1789
- std::vector<int64_t> out;
1790
- out.resize(size);
1791
- Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
1792
- out.swap(result);
1793
- }
1794
- } // namespace detail
1795
-
1796
- inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
1797
-
1798
- inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1799
-
1800
- inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
1801
- const char** type_constraint_names,
1802
- const ONNXTensorElementDataType* type_constraint_values,
1803
- size_t type_constraint_count,
1804
- const OpAttr* attr_values, size_t attr_count,
1805
- size_t input_count, size_t output_count) {
1806
- static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
1807
- "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1808
- auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
1809
- OrtOp* op;
1810
- Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1811
- static_cast<int>(type_constraint_count),
1812
- attr_input_values,
1813
- static_cast<int>(attr_count),
1814
- static_cast<int>(input_count),
1815
- static_cast<int>(output_count), &op));
1816
- return Op{op};
1817
- }
1818
-
1819
- inline void Op::Invoke(const OrtKernelContext* context,
1820
- const Value* input_values,
1821
- size_t input_count,
1822
- Value* output_values,
1823
- size_t output_count) {
1824
- static_assert(sizeof(Value) == sizeof(OrtValue*),
1825
- "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1826
- auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1827
- auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1828
- Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
1829
- ort_output_values, static_cast<int>(output_count)));
1830
- }
1831
-
1832
- inline void Op::Invoke(const OrtKernelContext* context,
1833
- const OrtValue* const* input_values,
1834
- size_t input_count,
1835
- OrtValue* const* output_values,
1836
- size_t output_count) {
1837
- Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
1838
- output_values, static_cast<int>(output_count)));
1839
- }
1840
-
1841
- inline std::string GetVersionString() {
1842
- return OrtGetApiBase()->GetVersionString();
1843
- }
1844
-
1845
- inline std::string GetBuildInfoString() {
1846
- return GetApi().GetBuildInfoString();
1847
- }
1848
-
1849
- inline std::vector<std::string> GetAvailableProviders() {
1850
- char** providers;
1851
- int len;
1852
-
1853
- auto release_fn = [&len](char** providers) {
1854
- // This should always return nullptr.
1855
- ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
1856
- };
1857
-
1858
- ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
1859
- std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
1860
- std::vector<std::string> available_providers;
1861
- available_providers.reserve(static_cast<size_t>(len));
1862
- for (int i = 0; i < len; ++i) {
1863
- available_providers.emplace_back(providers[i]);
1864
- }
1865
- return available_providers;
1866
- }
1867
-
1868
- template <typename TOp, typename TKernel, bool WithStatus>
1869
- void CustomOpBase<TOp, TKernel, WithStatus>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
1870
- ConstSessionOptions options) const {
1871
- const TOp* derived = static_cast<const TOp*>(this);
1872
- std::vector<std::string> keys = derived->GetSessionConfigKeys();
1873
-
1874
- out.reserve(keys.size());
1875
-
1876
- std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
1877
- const size_t prefix_size = config_entry_key.length();
1878
-
1879
- for (const auto& key : keys) {
1880
- config_entry_key.resize(prefix_size);
1881
- config_entry_key.append(key);
1882
- out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
1883
- }
1884
- }
1885
-
1886
- } // namespace Ort
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1.16.0/onnxruntime.xcframework/Headers/onnxruntime_float16.h DELETED
@@ -1,540 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <stdint.h>
7
- #include <cmath>
8
- #include <cstring>
9
- #include <limits>
10
-
11
- namespace onnxruntime_float16 {
12
-
13
- namespace detail {
14
-
15
- enum class endian {
16
- #if defined(_WIN32)
17
- little = 0,
18
- big = 1,
19
- native = little,
20
- #elif defined(__GNUC__) || defined(__clang__)
21
- little = __ORDER_LITTLE_ENDIAN__,
22
- big = __ORDER_BIG_ENDIAN__,
23
- native = __BYTE_ORDER__,
24
- #else
25
- #error onnxruntime_float16::detail::endian is not implemented in this environment.
26
- #endif
27
- };
28
-
29
- static_assert(
30
- endian::native == endian::little || endian::native == endian::big,
31
- "Only little-endian or big-endian native byte orders are supported.");
32
-
33
- } // namespace detail
34
-
35
- /// <summary>
36
- /// Shared implementation between public and internal classes. CRTP pattern.
37
- /// </summary>
38
- template <class Derived>
39
- struct Float16Impl {
40
- protected:
41
- /// <summary>
42
- /// Converts from float to uint16_t float16 representation
43
- /// </summary>
44
- /// <param name="v"></param>
45
- /// <returns></returns>
46
- constexpr static uint16_t ToUint16Impl(float v) noexcept;
47
-
48
- /// <summary>
49
- /// Converts float16 to float
50
- /// </summary>
51
- /// <returns>float representation of float16 value</returns>
52
- float ToFloatImpl() const noexcept;
53
-
54
- /// <summary>
55
- /// Creates an instance that represents absolute value.
56
- /// </summary>
57
- /// <returns>Absolute value</returns>
58
- uint16_t AbsImpl() const noexcept {
59
- return static_cast<uint16_t>(val & ~kSignMask);
60
- }
61
-
62
- /// <summary>
63
- /// Creates a new instance with the sign flipped.
64
- /// </summary>
65
- /// <returns>Flipped sign instance</returns>
66
- uint16_t NegateImpl() const noexcept {
67
- return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
68
- }
69
-
70
- public:
71
- // uint16_t special values
72
- static constexpr uint16_t kSignMask = 0x8000U;
73
- static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
74
- static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
75
- static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
76
- static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
77
- static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
78
- static constexpr uint16_t kEpsilonBits = 0x4170U;
79
- static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number
80
- static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
81
- static constexpr uint16_t kOneBits = 0x3C00U;
82
- static constexpr uint16_t kMinusOneBits = 0xBC00U;
83
-
84
- uint16_t val{0};
85
-
86
- Float16Impl() = default;
87
-
88
- /// <summary>
89
- /// Checks if the value is negative
90
- /// </summary>
91
- /// <returns>true if negative</returns>
92
- bool IsNegative() const noexcept {
93
- return static_cast<int16_t>(val) < 0;
94
- }
95
-
96
- /// <summary>
97
- /// Tests if the value is NaN
98
- /// </summary>
99
- /// <returns>true if NaN</returns>
100
- bool IsNaN() const noexcept {
101
- return AbsImpl() > kPositiveInfinityBits;
102
- }
103
-
104
- /// <summary>
105
- /// Tests if the value is finite
106
- /// </summary>
107
- /// <returns>true if finite</returns>
108
- bool IsFinite() const noexcept {
109
- return AbsImpl() < kPositiveInfinityBits;
110
- }
111
-
112
- /// <summary>
113
- /// Tests if the value represents positive infinity.
114
- /// </summary>
115
- /// <returns>true if positive infinity</returns>
116
- bool IsPositiveInfinity() const noexcept {
117
- return val == kPositiveInfinityBits;
118
- }
119
-
120
- /// <summary>
121
- /// Tests if the value represents negative infinity
122
- /// </summary>
123
- /// <returns>true if negative infinity</returns>
124
- bool IsNegativeInfinity() const noexcept {
125
- return val == kNegativeInfinityBits;
126
- }
127
-
128
- /// <summary>
129
- /// Tests if the value is either positive or negative infinity.
130
- /// </summary>
131
- /// <returns>True if absolute value is infinity</returns>
132
- bool IsInfinity() const noexcept {
133
- return AbsImpl() == kPositiveInfinityBits;
134
- }
135
-
136
- /// <summary>
137
- /// Tests if the value is NaN or zero. Useful for comparisons.
138
- /// </summary>
139
- /// <returns>True if NaN or zero.</returns>
140
- bool IsNaNOrZero() const noexcept {
141
- auto abs = AbsImpl();
142
- return (abs == 0 || abs > kPositiveInfinityBits);
143
- }
144
-
145
- /// <summary>
146
- /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
147
- /// </summary>
148
- /// <returns>True if so</returns>
149
- bool IsNormal() const noexcept {
150
- auto abs = AbsImpl();
151
- return (abs < kPositiveInfinityBits) // is finite
152
- && (abs != 0) // is not zero
153
- && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
154
- }
155
-
156
- /// <summary>
157
- /// Tests if the value is subnormal (denormal).
158
- /// </summary>
159
- /// <returns>True if so</returns>
160
- bool IsSubnormal() const noexcept {
161
- auto abs = AbsImpl();
162
- return (abs < kPositiveInfinityBits) // is finite
163
- && (abs != 0) // is not zero
164
- && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
165
- }
166
-
167
- /// <summary>
168
- /// Creates an instance that represents absolute value.
169
- /// </summary>
170
- /// <returns>Absolute value</returns>
171
- Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
172
-
173
- /// <summary>
174
- /// Creates a new instance with the sign flipped.
175
- /// </summary>
176
- /// <returns>Flipped sign instance</returns>
177
- Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
178
-
179
- /// <summary>
180
- /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
181
- /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
182
- /// and therefore equivalent, if the resulting value is still zero.
183
- /// </summary>
184
- /// <param name="lhs">first value</param>
185
- /// <param name="rhs">second value</param>
186
- /// <returns>True if both arguments represent zero</returns>
187
- static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
188
- return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
189
- }
190
-
191
- bool operator==(const Float16Impl& rhs) const noexcept {
192
- if (IsNaN() || rhs.IsNaN()) {
193
- // IEEE defines that NaN is not equal to anything, including itself.
194
- return false;
195
- }
196
- return val == rhs.val;
197
- }
198
-
199
- bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
200
-
201
- bool operator<(const Float16Impl& rhs) const noexcept {
202
- if (IsNaN() || rhs.IsNaN()) {
203
- // IEEE defines that NaN is unordered with respect to everything, including itself.
204
- return false;
205
- }
206
-
207
- const bool left_is_negative = IsNegative();
208
- if (left_is_negative != rhs.IsNegative()) {
209
- // When the signs of left and right differ, we know that left is less than right if it is
210
- // the negative value. The exception to this is if both values are zero, in which case IEEE
211
- // says they should be equal, even if the signs differ.
212
- return left_is_negative && !AreZero(*this, rhs);
213
- }
214
- return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
215
- }
216
- };
217
-
218
- // The following Float16_t conversions are based on the code from
219
- // Eigen library.
220
-
221
- // The conversion routines are Copyright (c) Fabian Giesen, 2016.
222
- // The original license follows:
223
- //
224
- // Copyright (c) Fabian Giesen, 2016
225
- // All rights reserved.
226
- // Redistribution and use in source and binary forms, with or without
227
- // modification, are permitted.
228
- // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
229
- // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
230
- // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
231
- // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
232
- // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
233
- // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
234
- // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
235
- // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
236
- // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
237
- // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
238
- // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
239
-
240
- namespace detail {
241
- union float32_bits {
242
- unsigned int u;
243
- float f;
244
- };
245
- } // namespace detail
246
-
247
- template <class Derived>
248
- inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
249
- detail::float32_bits f{};
250
- f.f = v;
251
-
252
- constexpr detail::float32_bits f32infty = {255 << 23};
253
- constexpr detail::float32_bits f16max = {(127 + 16) << 23};
254
- constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
255
- constexpr unsigned int sign_mask = 0x80000000u;
256
- uint16_t val = static_cast<uint16_t>(0x0u);
257
-
258
- unsigned int sign = f.u & sign_mask;
259
- f.u ^= sign;
260
-
261
- // NOTE all the integer compares in this function can be safely
262
- // compiled into signed compares since all operands are below
263
- // 0x80000000. Important if you want fast straight SSE2 code
264
- // (since there's no unsigned PCMPGTD).
265
-
266
- if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
267
- val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
268
- } else { // (De)normalized number or zero
269
- if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
270
- // use a magic value to align our 10 mantissa bits at the bottom of
271
- // the float. as long as FP addition is round-to-nearest-even this
272
- // just works.
273
- f.f += denorm_magic.f;
274
-
275
- // and one integer subtract of the bias later, we have our final float!
276
- val = static_cast<uint16_t>(f.u - denorm_magic.u);
277
- } else {
278
- unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
279
-
280
- // update exponent, rounding bias part 1
281
- // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
282
- // without arithmetic overflow.
283
- f.u += 0xc8000fffU;
284
- // rounding bias part 2
285
- f.u += mant_odd;
286
- // take the bits!
287
- val = static_cast<uint16_t>(f.u >> 13);
288
- }
289
- }
290
-
291
- val |= static_cast<uint16_t>(sign >> 16);
292
- return val;
293
- }
294
-
295
- template <class Derived>
296
- inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
297
- constexpr detail::float32_bits magic = {113 << 23};
298
- constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
299
- detail::float32_bits o{};
300
-
301
- o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
302
- unsigned int exp = shifted_exp & o.u; // just the exponent
303
- o.u += (127 - 15) << 23; // exponent adjust
304
-
305
- // handle exponent special cases
306
- if (exp == shifted_exp) { // Inf/NaN?
307
- o.u += (128 - 16) << 23; // extra exp adjust
308
- } else if (exp == 0) { // Zero/Denormal?
309
- o.u += 1 << 23; // extra exp adjust
310
- o.f -= magic.f; // re-normalize
311
- }
312
-
313
- // Attempt to workaround the Internal Compiler Error on ARM64
314
- // for bitwise | operator, including std::bitset
315
- #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
316
- if (IsNegative()) {
317
- return -o.f;
318
- }
319
- #else
320
- // original code:
321
- o.u |= (val & 0x8000U) << 16U; // sign bit
322
- #endif
323
- return o.f;
324
- }
325
-
326
- /// Shared implementation between public and internal classes. CRTP pattern.
327
- template <class Derived>
328
- struct BFloat16Impl {
329
- protected:
330
- /// <summary>
331
- /// Converts from float to uint16_t float16 representation
332
- /// </summary>
333
- /// <param name="v"></param>
334
- /// <returns></returns>
335
- static uint16_t ToUint16Impl(float v) noexcept;
336
-
337
- /// <summary>
338
- /// Converts bfloat16 to float
339
- /// </summary>
340
- /// <returns>float representation of bfloat16 value</returns>
341
- float ToFloatImpl() const noexcept;
342
-
343
- /// <summary>
344
- /// Creates an instance that represents absolute value.
345
- /// </summary>
346
- /// <returns>Absolute value</returns>
347
- uint16_t AbsImpl() const noexcept {
348
- return static_cast<uint16_t>(val & ~kSignMask);
349
- }
350
-
351
- /// <summary>
352
- /// Creates a new instance with the sign flipped.
353
- /// </summary>
354
- /// <returns>Flipped sign instance</returns>
355
- uint16_t NegateImpl() const noexcept {
356
- return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
357
- }
358
-
359
- public:
360
- // uint16_t special values
361
- static constexpr uint16_t kSignMask = 0x8000U;
362
- static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
363
- static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
364
- static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
365
- static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
366
- static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
367
- static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
368
- static constexpr uint16_t kEpsilonBits = 0x0080U;
369
- static constexpr uint16_t kMinValueBits = 0xFF7FU;
370
- static constexpr uint16_t kMaxValueBits = 0x7F7FU;
371
- static constexpr uint16_t kRoundToNearest = 0x7FFFU;
372
- static constexpr uint16_t kOneBits = 0x3F80U;
373
- static constexpr uint16_t kMinusOneBits = 0xBF80U;
374
-
375
- uint16_t val{0};
376
-
377
- BFloat16Impl() = default;
378
-
379
- /// <summary>
380
- /// Checks if the value is negative
381
- /// </summary>
382
- /// <returns>true if negative</returns>
383
- bool IsNegative() const noexcept {
384
- return static_cast<int16_t>(val) < 0;
385
- }
386
-
387
- /// <summary>
388
- /// Tests if the value is NaN
389
- /// </summary>
390
- /// <returns>true if NaN</returns>
391
- bool IsNaN() const noexcept {
392
- return AbsImpl() > kPositiveInfinityBits;
393
- }
394
-
395
- /// <summary>
396
- /// Tests if the value is finite
397
- /// </summary>
398
- /// <returns>true if finite</returns>
399
- bool IsFinite() const noexcept {
400
- return AbsImpl() < kPositiveInfinityBits;
401
- }
402
-
403
- /// <summary>
404
- /// Tests if the value represents positive infinity.
405
- /// </summary>
406
- /// <returns>true if positive infinity</returns>
407
- bool IsPositiveInfinity() const noexcept {
408
- return val == kPositiveInfinityBits;
409
- }
410
-
411
- /// <summary>
412
- /// Tests if the value represents negative infinity
413
- /// </summary>
414
- /// <returns>true if negative infinity</returns>
415
- bool IsNegativeInfinity() const noexcept {
416
- return val == kNegativeInfinityBits;
417
- }
418
-
419
- /// <summary>
420
- /// Tests if the value is either positive or negative infinity.
421
- /// </summary>
422
- /// <returns>True if absolute value is infinity</returns>
423
- bool IsInfinity() const noexcept {
424
- return AbsImpl() == kPositiveInfinityBits;
425
- }
426
-
427
- /// <summary>
428
- /// Tests if the value is NaN or zero. Useful for comparisons.
429
- /// </summary>
430
- /// <returns>True if NaN or zero.</returns>
431
- bool IsNaNOrZero() const noexcept {
432
- auto abs = AbsImpl();
433
- return (abs == 0 || abs > kPositiveInfinityBits);
434
- }
435
-
436
- /// <summary>
437
- /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
438
- /// </summary>
439
- /// <returns>True if so</returns>
440
- bool IsNormal() const noexcept {
441
- auto abs = AbsImpl();
442
- return (abs < kPositiveInfinityBits) // is finite
443
- && (abs != 0) // is not zero
444
- && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
445
- }
446
-
447
- /// <summary>
448
- /// Tests if the value is subnormal (denormal).
449
- /// </summary>
450
- /// <returns>True if so</returns>
451
- bool IsSubnormal() const noexcept {
452
- auto abs = AbsImpl();
453
- return (abs < kPositiveInfinityBits) // is finite
454
- && (abs != 0) // is not zero
455
- && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
456
- }
457
-
458
- /// <summary>
459
- /// Creates an instance that represents absolute value.
460
- /// </summary>
461
- /// <returns>Absolute value</returns>
462
- Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
463
-
464
- /// <summary>
465
- /// Creates a new instance with the sign flipped.
466
- /// </summary>
467
- /// <returns>Flipped sign instance</returns>
468
- Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
469
-
470
- /// <summary>
471
- /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
472
- /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
473
- /// and therefore equivalent, if the resulting value is still zero.
474
- /// </summary>
475
- /// <param name="lhs">first value</param>
476
- /// <param name="rhs">second value</param>
477
- /// <returns>True if both arguments represent zero</returns>
478
- static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
479
- // IEEE defines that positive and negative zero are equal, this gives us a quick equality check
480
- // for two values by or'ing the private bits together and stripping the sign. They are both zero,
481
- // and therefore equivalent, if the resulting value is still zero.
482
- return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
483
- }
484
- };
485
-
486
- template <class Derived>
487
- inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
488
- uint16_t result;
489
- if (std::isnan(v)) {
490
- result = kPositiveQNaNBits;
491
- } else {
492
- auto get_msb_half = [](float fl) {
493
- uint16_t result;
494
- #ifdef __cpp_if_constexpr
495
- if constexpr (detail::endian::native == detail::endian::little) {
496
- #else
497
- if (detail::endian::native == detail::endian::little) {
498
- #endif
499
- std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
500
- } else {
501
- std::memcpy(&result, &fl, sizeof(uint16_t));
502
- }
503
- return result;
504
- };
505
-
506
- uint16_t upper_bits = get_msb_half(v);
507
- union {
508
- uint32_t U32;
509
- float F32;
510
- };
511
- F32 = v;
512
- U32 += (upper_bits & 1) + kRoundToNearest;
513
- result = get_msb_half(F32);
514
- }
515
- return result;
516
- }
517
-
518
- template <class Derived>
519
- inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
520
- if (IsNaN()) {
521
- return std::numeric_limits<float>::quiet_NaN();
522
- }
523
- float result;
524
- char* const first = reinterpret_cast<char*>(&result);
525
- char* const second = first + sizeof(uint16_t);
526
- #ifdef __cpp_if_constexpr
527
- if constexpr (detail::endian::native == detail::endian::little) {
528
- #else
529
- if (detail::endian::native == detail::endian::little) {
530
- #endif
531
- std::memset(first, 0, sizeof(uint16_t));
532
- std::memcpy(second, &val, sizeof(uint16_t));
533
- } else {
534
- std::memcpy(first, &val, sizeof(uint16_t));
535
- std::memset(second, 0, sizeof(uint16_t));
536
- }
537
- return result;
538
- }
539
-
540
- } // namespace onnxruntime_float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1.16.0/onnxruntime.xcframework/Info.plist DELETED
@@ -1,53 +0,0 @@
1
- <?xml version="1.0" encoding="UTF-8"?>
2
- <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
3
- <plist version="1.0">
4
- <dict>
5
- <key>AvailableLibraries</key>
6
- <array>
7
- <dict>
8
- <key>LibraryIdentifier</key>
9
- <string>ios-arm64_x86_64-simulator</string>
10
- <key>LibraryPath</key>
11
- <string>onnxruntime.a</string>
12
- <key>SupportedArchitectures</key>
13
- <array>
14
- <string>arm64</string>
15
- <string>x86_64</string>
16
- </array>
17
- <key>SupportedPlatform</key>
18
- <string>ios</string>
19
- <key>SupportedPlatformVariant</key>
20
- <string>simulator</string>
21
- </dict>
22
- <dict>
23
- <key>LibraryIdentifier</key>
24
- <string>ios-arm64</string>
25
- <key>LibraryPath</key>
26
- <string>onnxruntime.a</string>
27
- <key>SupportedArchitectures</key>
28
- <array>
29
- <string>arm64</string>
30
- </array>
31
- <key>SupportedPlatform</key>
32
- <string>ios</string>
33
- </dict>
34
- <dict>
35
- <key>LibraryIdentifier</key>
36
- <string>macos-arm64_x86_64</string>
37
- <key>LibraryPath</key>
38
- <string>onnxruntime.a</string>
39
- <key>SupportedArchitectures</key>
40
- <array>
41
- <string>arm64</string>
42
- <string>x86_64</string>
43
- </array>
44
- <key>SupportedPlatform</key>
45
- <string>macos</string>
46
- </dict>
47
- </array>
48
- <key>CFBundlePackageType</key>
49
- <string>XFWK</string>
50
- <key>XCFrameworkFormatVersion</key>
51
- <string>1.0</string>
52
- </dict>
53
- </plist>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1.16.0/onnxruntime.xcframework/ios-arm64/libonnxruntime.a DELETED
@@ -1 +0,0 @@
1
- onnxruntime.a
 
 
1.16.0/onnxruntime.xcframework/ios-arm64/onnxruntime.a DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a1c60c43be2f6f048da22c560ce8f1d720f39d5bf089fdaab3b524f22d248d53
3
- size 60835152
 
 
 
 
1.16.0/onnxruntime.xcframework/ios-arm64_x86_64-simulator/libonnxruntime.a DELETED
@@ -1 +0,0 @@
1
- onnxruntime.a
 
 
1.16.0/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.a DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6e7ff688aaf0cbe41178ea00e2b3c5f7fa7bf8a66fa0be9171e0513e475053b7
3
- size 124397272
 
 
 
 
1.16.0/onnxruntime.xcframework/macos-arm64_x86_64/libonnxruntime.a DELETED
@@ -1 +0,0 @@
1
- onnxruntime.a
 
 
1.16.0/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.a DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:103277f881185df9920cb6f933fd34534905013883b1756864414c0d0ab28f66
3
- size 126830032