csukuangfj commited on
Commit
3c2bebf
·
1 Parent(s): ad41697

Add onnxruntime.xcframework 1.16.2

Browse files
1.16.2/onnxruntime.xcframework/Headers/coreml_provider_factory.h ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+ #pragma once
4
+
5
+ #include "onnxruntime_c_api.h"
6
+
7
+ // COREMLFlags are bool options we want to set for CoreML EP
8
+ // This enum is defined as bit flags, and cannot have negative value
9
+ // To generate an uint32_t coreml_flags for using with OrtSessionOptionsAppendExecutionProvider_CoreML below,
10
+ // uint32_t coreml_flags = 0;
11
+ // coreml_flags |= COREML_FLAG_USE_CPU_ONLY;
12
+ enum COREMLFlags {
13
+ COREML_FLAG_USE_NONE = 0x000,
14
+
15
+ // Using CPU only in CoreML EP, this may decrease the perf but will provide
16
+ // reference output value without precision loss, which is useful for validation
17
+ COREML_FLAG_USE_CPU_ONLY = 0x001,
18
+
19
+ // Enable CoreML EP on subgraph
20
+ COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002,
21
+
22
+ // By default CoreML Execution provider will be enabled for all compatible Apple devices
23
+ // Enable this option will only enable CoreML EP for Apple devices with ANE (Apple Neural Engine)
24
+ // Please note, enable this option does not guarantee the entire model to be executed using ANE only
25
+ COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004,
26
+
27
+ // Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also allow inputs with
28
+ // dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes.
29
+ COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008,
30
+
31
+ // Keep COREML_FLAG_LAST at the end of the enum definition
32
+ // And assign the last COREMLFlag to it
33
+ COREML_FLAG_LAST = COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES,
34
+ };
35
+
36
+ #ifdef __cplusplus
37
+ extern "C" {
38
+ #endif
39
+
40
+ ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CoreML,
41
+ _In_ OrtSessionOptions* options, uint32_t coreml_flags);
42
+
43
+ #ifdef __cplusplus
44
+ }
45
+ #endif
1.16.2/onnxruntime.xcframework/Headers/onnxruntime_c_api.h ADDED
The diff for this file is too large to render. See raw diff
 
1.16.2/onnxruntime.xcframework/Headers/onnxruntime_cxx_api.h ADDED
The diff for this file is too large to render. See raw diff
 
1.16.2/onnxruntime.xcframework/Headers/onnxruntime_cxx_inline.h ADDED
@@ -0,0 +1,1886 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
5
+ // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
6
+ //
7
+ // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
8
+ // the main C++ file with implementation details.
9
+
10
+ #include <cstring>
11
+
12
+ namespace Ort {
13
+
14
+ namespace detail {
15
+ inline void ThrowStatus(const Status& st) {
16
+ std::string error_message = st.GetErrorMessage();
17
+ OrtErrorCode error_code = st.GetErrorCode();
18
+ ORT_CXX_API_THROW(std::move(error_message), error_code);
19
+ }
20
+ } // namespace detail
21
+
22
+ inline void ThrowOnError(OrtStatus* ort_status) {
23
+ if (ort_status) {
24
+ Ort::Status st(ort_status);
25
+ detail::ThrowStatus(st);
26
+ }
27
+ }
28
+
29
+ inline void ThrowOnError(const Status& st) {
30
+ if (st) {
31
+ detail::ThrowStatus(st);
32
+ }
33
+ }
34
+
35
+ inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
36
+ }
37
+
38
+ inline Status::Status(const std::exception& e) noexcept {
39
+ p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
40
+ }
41
+
42
+ inline Status::Status(const Exception& e) noexcept {
43
+ p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
44
+ }
45
+
46
+ inline Status::Status(const char* message, OrtErrorCode code) noexcept {
47
+ p_ = GetApi().CreateStatus(code, message);
48
+ }
49
+
50
+ inline std::string Status::GetErrorMessage() const {
51
+ std::string message(GetApi().GetErrorMessage(p_));
52
+ return message;
53
+ }
54
+
55
+ inline OrtErrorCode Status::GetErrorCode() const {
56
+ return GetApi().GetErrorCode(p_);
57
+ }
58
+
59
+ inline bool Status::IsOK() const noexcept {
60
+ return (p_ == nullptr);
61
+ }
62
+
63
+ // This template converts a C++ type into it's ONNXTensorElementDataType
64
+ template <typename T>
65
+ struct TypeToTensorType;
66
+ template <>
67
+ struct TypeToTensorType<float> {
68
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
69
+ };
70
+ template <>
71
+ struct TypeToTensorType<Float16_t> {
72
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
73
+ };
74
+ template <>
75
+ struct TypeToTensorType<BFloat16_t> {
76
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
77
+ };
78
+ template <>
79
+ struct TypeToTensorType<double> {
80
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
81
+ };
82
+ template <>
83
+ struct TypeToTensorType<int8_t> {
84
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
85
+ };
86
+ template <>
87
+ struct TypeToTensorType<int16_t> {
88
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
89
+ };
90
+ template <>
91
+ struct TypeToTensorType<int32_t> {
92
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
93
+ };
94
+ template <>
95
+ struct TypeToTensorType<int64_t> {
96
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
97
+ };
98
+ template <>
99
+ struct TypeToTensorType<uint8_t> {
100
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
101
+ };
102
+ template <>
103
+ struct TypeToTensorType<uint16_t> {
104
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
105
+ };
106
+ template <>
107
+ struct TypeToTensorType<uint32_t> {
108
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
109
+ };
110
+ template <>
111
+ struct TypeToTensorType<uint64_t> {
112
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
113
+ };
114
+ template <>
115
+ struct TypeToTensorType<bool> {
116
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
117
+ };
118
+
119
+ template <>
120
+ struct TypeToTensorType<Float8E4M3FN_t> {
121
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN;
122
+ };
123
+ template <>
124
+ struct TypeToTensorType<Float8E4M3FNUZ_t> {
125
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ;
126
+ };
127
+ template <>
128
+ struct TypeToTensorType<Float8E5M2_t> {
129
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2;
130
+ };
131
+ template <>
132
+ struct TypeToTensorType<Float8E5M2FNUZ_t> {
133
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
134
+ };
135
+
136
+ inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept {
137
+ if (IsNaN() || rhs.IsNaN()) {
138
+ // IEEE defines that NaN is not equal to anything, including itself.
139
+ return false;
140
+ }
141
+ return val == rhs.val;
142
+ }
143
+
144
+ inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept {
145
+ if (IsNaN() || rhs.IsNaN()) {
146
+ // IEEE defines that NaN is unordered with respect to everything, including itself.
147
+ return false;
148
+ }
149
+
150
+ const bool left_is_negative = IsNegative();
151
+ if (left_is_negative != rhs.IsNegative()) {
152
+ // When the signs of left and right differ, we know that left is less than right if it is
153
+ // the negative value. The exception to this is if both values are zero, in which case IEEE
154
+ // says they should be equal, even if the signs differ.
155
+ return left_is_negative && !AreZero(*this, rhs);
156
+ }
157
+ return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
158
+ }
159
+
160
+ inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
161
+ : allocator_(allocator), p_(p), size_(size) {
162
+ }
163
+
164
+ inline MemoryAllocation::~MemoryAllocation() {
165
+ if (p_ != nullptr) {
166
+ // We do not throw out of destructor
167
+ auto ret = GetApi().AllocatorFree(allocator_, p_);
168
+ static_cast<void>(ret);
169
+ }
170
+ }
171
+
172
+ inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
173
+ *this = std::move(o);
174
+ }
175
+
176
+ inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
177
+ OrtAllocator* alloc = nullptr;
178
+ void* p = nullptr;
179
+ size_t sz = 0;
180
+
181
+ // Swap out this
182
+ std::swap(alloc, allocator_);
183
+ std::swap(p, p_);
184
+ std::swap(sz, size_);
185
+
186
+ // Swap with incoming
187
+ std::swap(allocator_, o.allocator_);
188
+ std::swap(p_, o.p_);
189
+ std::swap(size_, o.size_);
190
+
191
+ // Destroy this instance if needed
192
+ MemoryAllocation this_alloc(alloc, p, sz);
193
+ return *this;
194
+ }
195
+
196
+ namespace detail {
197
+
198
+ template <typename T>
199
+ inline void* AllocatorImpl<T>::Alloc(size_t size) {
200
+ void* out;
201
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
202
+ return out;
203
+ }
204
+
205
+ template <typename T>
206
+ inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
207
+ void* out;
208
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
209
+ MemoryAllocation result(this->p_, out, size);
210
+ return result;
211
+ }
212
+
213
+ template <typename T>
214
+ inline void AllocatorImpl<T>::Free(void* p) {
215
+ ThrowOnError(GetApi().AllocatorFree(this->p_, p));
216
+ }
217
+
218
+ template <typename T>
219
+ inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
220
+ const OrtMemoryInfo* out;
221
+ ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
222
+ return ConstMemoryInfo{out};
223
+ }
224
+
225
+ } // namespace detail
226
+
227
+ inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
228
+ ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
229
+ }
230
+
231
+ inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
232
+ ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
233
+ }
234
+
235
+ namespace detail {
236
+
237
+ template <typename T>
238
+ inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
239
+ const char* name = nullptr;
240
+ ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
241
+ return std::string(name);
242
+ }
243
+
244
+ template <typename T>
245
+ inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
246
+ OrtAllocatorType type;
247
+ ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
248
+ return type;
249
+ }
250
+
251
+ template <typename T>
252
+ inline int MemoryInfoImpl<T>::GetDeviceId() const {
253
+ int id = 0;
254
+ ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
255
+ return id;
256
+ }
257
+
258
+ template <typename T>
259
+ inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
260
+ OrtMemoryInfoDeviceType type;
261
+ GetApi().MemoryInfoGetDeviceType(this->p_, &type);
262
+ return type;
263
+ }
264
+
265
+ template <typename T>
266
+ inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
267
+ OrtMemType type;
268
+ ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
269
+ return type;
270
+ }
271
+
272
+ template <typename T>
273
+ template <typename U>
274
+ inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
275
+ int comp_result = 0;
276
+ ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
277
+ return comp_result == 0;
278
+ }
279
+
280
+ } // namespace detail
281
+
282
+ inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
283
+ OrtMemoryInfo* p;
284
+ ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
285
+ return MemoryInfo(p);
286
+ }
287
+
288
+ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
289
+ ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
290
+ }
291
+
292
+ namespace detail {
293
+ template <typename T>
294
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
295
+ AllocatorWithDefaultOptions allocator;
296
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
297
+ }
298
+
299
+ template <typename T>
300
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
301
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
302
+ }
303
+
304
+ template <typename T>
305
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
306
+ AllocatorWithDefaultOptions allocator;
307
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
308
+ }
309
+
310
+ template <typename T>
311
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
312
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
313
+ }
314
+
315
+ template <typename T>
316
+ inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
317
+ ThrowOnError(GetApi().BindInput(this->p_, name, value));
318
+ }
319
+
320
+ template <typename T>
321
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
322
+ ThrowOnError(GetApi().BindOutput(this->p_, name, value));
323
+ }
324
+
325
+ template <typename T>
326
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
327
+ ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
328
+ }
329
+
330
+ template <typename T>
331
+ inline void IoBindingImpl<T>::ClearBoundInputs() {
332
+ GetApi().ClearBoundInputs(this->p_);
333
+ }
334
+
335
+ template <typename T>
336
+ inline void IoBindingImpl<T>::ClearBoundOutputs() {
337
+ GetApi().ClearBoundOutputs(this->p_);
338
+ }
339
+
340
+ template <typename T>
341
+ inline void IoBindingImpl<T>::SynchronizeInputs() {
342
+ ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
343
+ }
344
+
345
+ template <typename T>
346
+ inline void IoBindingImpl<T>::SynchronizeOutputs() {
347
+ ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
348
+ }
349
+
350
+ namespace binding_utils {
351
+ inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
352
+ std::vector<std::string> result;
353
+ auto free_fn = detail::AllocatedFree(allocator);
354
+ using Ptr = std::unique_ptr<void, decltype(free_fn)>;
355
+
356
+ char* buffer = nullptr;
357
+ size_t* lengths = nullptr;
358
+ size_t count = 0;
359
+ ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
360
+
361
+ if (count == 0) {
362
+ return result;
363
+ }
364
+
365
+ Ptr buffer_g(buffer, free_fn);
366
+ Ptr lengths_g(lengths, free_fn);
367
+
368
+ result.reserve(count);
369
+ for (size_t i = 0; i < count; ++i) {
370
+ auto sz = *lengths;
371
+ result.emplace_back(buffer, sz);
372
+ buffer += sz;
373
+ ++lengths;
374
+ }
375
+ return result;
376
+ }
377
+
378
+ inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
379
+ std::vector<Value> result;
380
+ size_t owned = 0;
381
+ size_t output_count = 0;
382
+ // Lambda to release the buffer when no longer needed and
383
+ // make sure that we destroy all instances on exception
384
+ auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
385
+ if (buffer) {
386
+ while (owned < output_count) {
387
+ auto* p = buffer + owned++;
388
+ GetApi().ReleaseValue(*p);
389
+ }
390
+ allocator->Free(allocator, buffer);
391
+ }
392
+ };
393
+ using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
394
+
395
+ OrtValue** output_buffer = nullptr;
396
+ ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
397
+ if (output_count == 0) {
398
+ return result;
399
+ }
400
+
401
+ Ptr buffer_g(output_buffer, free_fn);
402
+
403
+ result.reserve(output_count);
404
+ for (size_t i = 0; i < output_count; ++i) {
405
+ result.emplace_back(output_buffer[i]);
406
+ ++owned;
407
+ }
408
+ return result;
409
+ }
410
+
411
+ } // namespace binding_utils
412
+ } // namespace detail
413
+
414
+ inline IoBinding::IoBinding(Session& session) {
415
+ ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
416
+ }
417
+
418
+ inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
419
+ ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
420
+ }
421
+
422
+ inline ThreadingOptions::ThreadingOptions() {
423
+ ThrowOnError(GetApi().CreateThreadingOptions(&p_));
424
+ }
425
+
426
+ inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
427
+ ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
428
+ return *this;
429
+ }
430
+
431
+ inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
432
+ ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
433
+ return *this;
434
+ }
435
+
436
+ inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
437
+ ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
438
+ return *this;
439
+ }
440
+
441
+ inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
442
+ ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
443
+ return *this;
444
+ }
445
+
446
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
447
+ ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
448
+ return *this;
449
+ }
450
+
451
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
452
+ ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
453
+ return *this;
454
+ }
455
+
456
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
457
+ ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
458
+ return *this;
459
+ }
460
+
461
+ inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
462
+ ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
463
+ if (strcmp(logid, "onnxruntime-node") == 0) {
464
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
465
+ } else {
466
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
467
+ }
468
+ }
469
+
470
+ inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
471
+ ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
472
+ if (strcmp(logid, "onnxruntime-node") == 0) {
473
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
474
+ } else {
475
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
476
+ }
477
+ }
478
+
479
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
480
+ ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
481
+ if (strcmp(logid, "onnxruntime-node") == 0) {
482
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
483
+ } else {
484
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
485
+ }
486
+ }
487
+
488
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
489
+ OrtLoggingLevel logging_level, _In_ const char* logid) {
490
+ ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
491
+ if (strcmp(logid, "onnxruntime-node") == 0) {
492
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
493
+ } else {
494
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
495
+ }
496
+ }
497
+
498
+ inline Env& Env::EnableTelemetryEvents() {
499
+ ThrowOnError(GetApi().EnableTelemetryEvents(p_));
500
+ return *this;
501
+ }
502
+
503
+ inline Env& Env::DisableTelemetryEvents() {
504
+ ThrowOnError(GetApi().DisableTelemetryEvents(p_));
505
+ return *this;
506
+ }
507
+
508
+ inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
509
+ ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
510
+ return *this;
511
+ }
512
+
513
+ inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
514
+ ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
515
+ return *this;
516
+ }
517
+
518
+ inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg) {
519
+ std::vector<const char*> keys, values;
520
+ auto num_entries = options.size();
521
+ if (num_entries > 0) {
522
+ keys.reserve(num_entries);
523
+ values.reserve(num_entries);
524
+ for (const auto& entry : options) {
525
+ keys.push_back(entry.first.c_str());
526
+ values.push_back(entry.second.c_str());
527
+ }
528
+ }
529
+ ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries));
530
+ return *this;
531
+ }
532
+
533
+ inline CustomOpDomain::CustomOpDomain(const char* domain) {
534
+ ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
535
+ }
536
+
537
+ inline void CustomOpDomain::Add(const OrtCustomOp* op) {
538
+ ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
539
+ }
540
+
541
+ inline RunOptions::RunOptions() {
542
+ ThrowOnError(GetApi().CreateRunOptions(&p_));
543
+ }
544
+
545
+ inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
546
+ ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
547
+ return *this;
548
+ }
549
+
550
+ inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
551
+ ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
552
+ return *this;
553
+ }
554
+
555
+ inline int RunOptions::GetRunLogVerbosityLevel() const {
556
+ int out;
557
+ ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
558
+ return out;
559
+ }
560
+
561
+ inline int RunOptions::GetRunLogSeverityLevel() const {
562
+ int out;
563
+ ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
564
+ return out;
565
+ }
566
+
567
+ inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
568
+ ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
569
+ return *this;
570
+ }
571
+
572
+ inline const char* RunOptions::GetRunTag() const {
573
+ const char* out;
574
+ ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
575
+ return out;
576
+ }
577
+
578
+ inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
579
+ ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
580
+ return *this;
581
+ }
582
+
583
+ inline RunOptions& RunOptions::SetTerminate() {
584
+ ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
585
+ return *this;
586
+ }
587
+
588
+ inline RunOptions& RunOptions::UnsetTerminate() {
589
+ ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
590
+ return *this;
591
+ }
592
+
593
+ namespace detail {
594
+
595
+ template <typename T>
596
+ inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
597
+ OrtSessionOptions* out;
598
+ ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
599
+ return SessionOptions{out};
600
+ }
601
+
602
+ template <typename T>
603
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
604
+ size_t size = 0;
605
+ // Feed nullptr for the data buffer to query the true size of the string value
606
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
607
+
608
+ std::string out;
609
+ out.resize(size);
610
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
611
+ out.resize(size - 1); // remove the terminating character '\0'
612
+
613
+ return out;
614
+ }
615
+
616
+ template <typename T>
617
+ inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
618
+ int out = 0;
619
+ Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
620
+ return static_cast<bool>(out);
621
+ }
622
+
623
+ template <typename T>
624
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
625
+ if (!this->HasConfigEntry(config_key)) {
626
+ return def;
627
+ }
628
+
629
+ return this->GetConfigEntry(config_key);
630
+ }
631
+
632
+ template <typename T>
633
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
634
+ ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
635
+ return *this;
636
+ }
637
+
638
+ template <typename T>
639
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
640
+ ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
641
+ return *this;
642
+ }
643
+
644
+ template <typename T>
645
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
646
+ ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
647
+ return *this;
648
+ }
649
+
650
+ template <typename T>
651
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
652
+ ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
653
+ return *this;
654
+ }
655
+
656
+ template <typename T>
657
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
658
+ ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
659
+ return *this;
660
+ }
661
+
662
+ template <typename T>
663
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
664
+ ThrowOnError(GetApi().DisableProfiling(this->p_));
665
+ return *this;
666
+ }
667
+
668
+ template <typename T>
669
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
670
+ ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
671
+ return *this;
672
+ }
673
+
674
+ template <typename T>
675
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
676
+ ThrowOnError(GetApi().EnableMemPattern(this->p_));
677
+ return *this;
678
+ }
679
+
680
+ template <typename T>
681
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
682
+ ThrowOnError(GetApi().DisableMemPattern(this->p_));
683
+ return *this;
684
+ }
685
+
686
+ template <typename T>
687
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
688
+ ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
689
+ return *this;
690
+ }
691
+
692
+ template <typename T>
693
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
694
+ ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
695
+ return *this;
696
+ }
697
+
698
+ template <typename T>
699
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
700
+ ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
701
+ return *this;
702
+ }
703
+
704
+ template <typename T>
705
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
706
+ ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
707
+ return *this;
708
+ }
709
+
710
+ template <typename T>
711
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
712
+ ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
713
+ return *this;
714
+ }
715
+
716
+ template <typename T>
717
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
718
+ ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
719
+ return *this;
720
+ }
721
+
722
+ template <typename T>
723
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
724
+ ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
725
+ return *this;
726
+ }
727
+
728
+ template <typename T>
729
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
730
+ ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
731
+ return *this;
732
+ }
733
+
734
+ template <typename T>
735
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
736
+ ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
737
+ return *this;
738
+ }
739
+
740
+ template <typename T>
741
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
742
+ const std::vector<Value>& ort_values) {
743
+ const size_t inputs_num = names.size();
744
+ if (inputs_num != ort_values.size()) {
745
+ ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
746
+ }
747
+ std::vector<const char*> names_ptr;
748
+ std::vector<const OrtValue*> ort_values_ptrs;
749
+ names_ptr.reserve(inputs_num);
750
+ ort_values_ptrs.reserve(inputs_num);
751
+ for (size_t i = 0; i < inputs_num; ++i) {
752
+ names_ptr.push_back(names[i].c_str());
753
+ ort_values_ptrs.push_back(ort_values[i]);
754
+ }
755
+ ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
756
+ return *this;
757
+ }
758
+
759
+ template <typename T>
760
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
761
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
762
+ return *this;
763
+ }
764
+
765
+ template <typename T>
766
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
767
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
768
+ return *this;
769
+ }
770
+
771
+ template <typename T>
772
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
773
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
774
+ return *this;
775
+ }
776
+
777
+ template <typename T>
778
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
779
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
780
+ return *this;
781
+ }
782
+
783
+ template <typename T>
784
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
785
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
786
+ return *this;
787
+ }
788
+
789
+ template <typename T>
790
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
791
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
792
+ return *this;
793
+ }
794
+
795
+ template <typename T>
796
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
797
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
798
+ return *this;
799
+ }
800
+
801
+ template <typename T>
802
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) {
803
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
804
+ return *this;
805
+ }
806
+
807
+ template <typename T>
808
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
809
+ const std::string& provider_name,
810
+ const std::unordered_map<std::string, std::string>& provider_options) {
811
+ auto num_entries = provider_options.size();
812
+ std::vector<const char*> keys, values;
813
+ if (num_entries > 0) {
814
+ keys.reserve(num_entries);
815
+ values.reserve(num_entries);
816
+
817
+ for (const auto& entry : provider_options) {
818
+ keys.push_back(entry.first.c_str());
819
+ values.push_back(entry.second.c_str());
820
+ }
821
+ }
822
+
823
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
824
+ keys.data(), values.data(), num_entries));
825
+
826
+ return *this;
827
+ }
828
+
829
+ template <typename T>
830
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
831
+ ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
832
+ return *this;
833
+ }
834
+
835
+ template <typename T>
836
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
837
+ ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
838
+ return *this;
839
+ }
840
+
841
+ template <typename T>
842
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
843
+ ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
844
+ return *this;
845
+ }
846
+
847
+ template <typename T>
848
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
849
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
850
+ return *this;
851
+ }
852
+
853
+ template <typename T>
854
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
855
+ const CustomOpConfigs& custom_op_configs) {
856
+ // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
857
+ // the custom op library.
858
+ for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
859
+ AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
860
+ }
861
+
862
+ ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
863
+ return *this;
864
+ }
865
+
866
+ template <typename T>
867
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
868
+ ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
869
+ return *this;
870
+ }
871
+
872
+ /// Session
873
+ template <typename T>
874
+ inline size_t ConstSessionImpl<T>::GetInputCount() const {
875
+ size_t out;
876
+ ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
877
+ return out;
878
+ }
879
+
880
+ template <typename T>
881
+ inline size_t ConstSessionImpl<T>::GetOutputCount() const {
882
+ size_t out;
883
+ ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
884
+ return out;
885
+ }
886
+
887
+ template <typename T>
888
+ inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
889
+ size_t out;
890
+ ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
891
+ return out;
892
+ }
893
+
894
+ template <typename T>
895
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
896
+ char* out;
897
+ ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
898
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
899
+ }
900
+
901
+ template <typename T>
902
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
903
+ char* out;
904
+ ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
905
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
906
+ }
907
+
908
+ template <typename T>
909
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
910
+ char* out;
911
+ ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
912
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
913
+ }
914
+
915
+ template <typename T>
916
+ inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
917
+ uint64_t out;
918
+ ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
919
+ return out;
920
+ }
921
+
922
+ template <typename T>
923
+ inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
924
+ OrtModelMetadata* out;
925
+ ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
926
+ return ModelMetadata{out};
927
+ }
928
+
929
+ template <typename T>
930
+ inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
931
+ OrtTypeInfo* out;
932
+ ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
933
+ return TypeInfo{out};
934
+ }
935
+
936
+ template <typename T>
937
+ inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
938
+ OrtTypeInfo* out;
939
+ ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
940
+ return TypeInfo{out};
941
+ }
942
+
943
+ template <typename T>
944
+ inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
945
+ OrtTypeInfo* out;
946
+ ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
947
+ return TypeInfo{out};
948
+ }
949
+
950
+ template <typename T>
951
+ inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
952
+ const char* const* output_names, size_t output_count) {
953
+ std::vector<Value> output_values;
954
+ output_values.reserve(output_count);
955
+ for (size_t i = 0; i < output_count; i++)
956
+ output_values.emplace_back(nullptr);
957
+ Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
958
+ return output_values;
959
+ }
960
+
961
+ template <typename T>
962
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
963
+ const char* const* output_names, Value* output_values, size_t output_count) {
964
+ static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
965
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
966
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
967
+ ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
968
+ }
969
+
970
+ template <typename T>
971
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
972
+ ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
973
+ }
974
+
975
+ template <typename T>
976
+ inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
977
+ const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
978
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
979
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
980
+ ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
981
+ ort_input_values, input_count, output_names, output_count,
982
+ ort_output_values, callback, user_data));
983
+ }
984
+
985
+ template <typename T>
986
+ inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
987
+ char* out = nullptr;
988
+ ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
989
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
990
+ }
991
+
992
+ } // namespace detail
993
+
994
+ inline SessionOptions::SessionOptions() {
995
+ ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
996
+ }
997
+
998
+ /// CustomOpConfigs
999
+ inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
1000
+ std::string config_key = "custom_op.";
1001
+
1002
+ config_key += custom_op_name;
1003
+ config_key += ".";
1004
+ config_key += config;
1005
+
1006
+ return config_key;
1007
+ }
1008
+
1009
+ inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
1010
+ const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
1011
+ flat_configs_[full_flat_key] = config_value;
1012
+ return *this;
1013
+ }
1014
+
1015
+ inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
1016
+ return flat_configs_;
1017
+ }
1018
+
1019
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
1020
+ ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
1021
+ }
1022
+
1023
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1024
+ OrtPrepackedWeightsContainer* prepacked_weights_container) {
1025
+ ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
1026
+ }
1027
+
1028
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
1029
+ ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
1030
+ }
1031
+
1032
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
1033
+ const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
1034
+ ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
1035
+ prepacked_weights_container, &this->p_));
1036
+ }
1037
+
1038
+ inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
1039
+ char* out;
1040
+ ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
1041
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1042
+ }
1043
+
1044
+ inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
1045
+ char* out;
1046
+ ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
1047
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1048
+ }
1049
+
1050
+ inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
1051
+ char* out;
1052
+ ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
1053
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1054
+ }
1055
+
1056
+ inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
1057
+ char* out;
1058
+ ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
1059
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1060
+ }
1061
+
1062
+ inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
1063
+ char* out;
1064
+ ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
1065
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1066
+ }
1067
+
1068
+ inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
1069
+ char* out;
1070
+ ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
1071
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1072
+ }
1073
+
1074
+ inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
1075
+ auto deletor = detail::AllocatedFree(allocator);
1076
+ std::vector<AllocatedStringPtr> result;
1077
+
1078
+ char** out = nullptr;
1079
+ int64_t num_keys = 0;
1080
+ ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
1081
+ if (num_keys <= 0) {
1082
+ return result;
1083
+ }
1084
+
1085
+ // array of pointers will be freed
1086
+ std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1087
+ // reserve may throw
1088
+ auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1089
+ std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1090
+ result.reserve(static_cast<size_t>(num_keys));
1091
+ strings_guard.release();
1092
+ for (int64_t i = 0; i < num_keys; ++i) {
1093
+ result.push_back(AllocatedStringPtr(out[i], deletor));
1094
+ }
1095
+
1096
+ return result;
1097
+ }
1098
+
1099
+ inline int64_t ModelMetadata::GetVersion() const {
1100
+ int64_t out;
1101
+ ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
1102
+ return out;
1103
+ }
1104
+
1105
+ namespace detail {
1106
+
1107
+ template <typename T>
1108
+ inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
1109
+ ONNXTensorElementDataType out;
1110
+ ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
1111
+ return out;
1112
+ }
1113
+
1114
+ template <typename T>
1115
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
1116
+ size_t out;
1117
+ ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
1118
+ return static_cast<size_t>(out);
1119
+ }
1120
+
1121
+ template <typename T>
1122
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
1123
+ size_t out;
1124
+ ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
1125
+ return out;
1126
+ }
1127
+
1128
+ template <typename T>
1129
+ inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
1130
+ ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
1131
+ }
1132
+
1133
+ template <typename T>
1134
+ inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
1135
+ ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1136
+ }
1137
+
1138
+ template <typename T>
1139
+ inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1140
+ std::vector<int64_t> out(GetDimensionsCount(), 0);
1141
+ ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1142
+ return out;
1143
+ }
1144
+
1145
+ template <typename T>
1146
+ inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
1147
+ const OrtTensorTypeAndShapeInfo* out;
1148
+ ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
1149
+ return ConstTensorTypeAndShapeInfo{out};
1150
+ }
1151
+
1152
+ template <typename T>
1153
+ inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
1154
+ const OrtSequenceTypeInfo* out;
1155
+ ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
1156
+ return ConstSequenceTypeInfo{out};
1157
+ }
1158
+
1159
+ template <typename T>
1160
+ inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
1161
+ const OrtMapTypeInfo* out;
1162
+ ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
1163
+ return ConstMapTypeInfo{out};
1164
+ }
1165
+
1166
+ template <typename T>
1167
+ inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
1168
+ ONNXType out;
1169
+ ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
1170
+ return out;
1171
+ }
1172
+
1173
+ template <typename T>
1174
+ inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
1175
+ OrtTypeInfo* output;
1176
+ ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
1177
+ return TypeInfo{output};
1178
+ }
1179
+
1180
+ template <typename T>
1181
+ inline TypeInfo OptionalTypeInfoImpl<T>::GetOptionalElementType() const {
1182
+ OrtTypeInfo* info;
1183
+ ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info));
1184
+ return TypeInfo{info};
1185
+ }
1186
+
1187
+ template <typename T>
1188
+ inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
1189
+ ONNXTensorElementDataType out;
1190
+ ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
1191
+ return out;
1192
+ }
1193
+
1194
+ template <typename T>
1195
+ inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
1196
+ OrtTypeInfo* output;
1197
+ ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
1198
+ return TypeInfo{output};
1199
+ }
1200
+
1201
+ template <typename T>
1202
+ inline ConstOptionalTypeInfo TypeInfoImpl<T>::GetOptionalTypeInfo() const {
1203
+ const OrtOptionalTypeInfo* info;
1204
+ ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info));
1205
+ return ConstOptionalTypeInfo{info};
1206
+ }
1207
+
1208
+ } // namespace detail
1209
+
1210
+ namespace detail {
1211
+
1212
+ template <typename T>
1213
+ template <typename R>
1214
+ inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
1215
+ ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
1216
+ }
1217
+
1218
+ template <typename T>
1219
+ inline bool ConstValueImpl<T>::IsTensor() const {
1220
+ int out;
1221
+ ThrowOnError(GetApi().IsTensor(this->p_, &out));
1222
+ return out != 0;
1223
+ }
1224
+
1225
+ template <typename T>
1226
+ inline bool ConstValueImpl<T>::HasValue() const {
1227
+ int out;
1228
+ ThrowOnError(GetApi().HasValue(this->p_, &out));
1229
+ return out != 0;
1230
+ }
1231
+
1232
+ template <typename T>
1233
+ inline size_t ConstValueImpl<T>::GetCount() const {
1234
+ size_t out;
1235
+ ThrowOnError(GetApi().GetValueCount(this->p_, &out));
1236
+ return out;
1237
+ }
1238
+
1239
+ template <typename T>
1240
+ inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
1241
+ OrtValue* out;
1242
+ ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
1243
+ return Value{out};
1244
+ }
1245
+
1246
+ template <typename T>
1247
+ inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
1248
+ size_t out;
1249
+ ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
1250
+ return out;
1251
+ }
1252
+
1253
+ template <typename T>
1254
+ inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
1255
+ size_t out;
1256
+ ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
1257
+ return out;
1258
+ }
1259
+
1260
+ template <typename T>
1261
+ template <typename R>
1262
+ inline const R* ConstValueImpl<T>::GetTensorData() const {
1263
+ R* out;
1264
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
1265
+ return out;
1266
+ }
1267
+
1268
+ template <typename T>
1269
+ inline const void* ConstValueImpl<T>::GetTensorRawData() const {
1270
+ void* out;
1271
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
1272
+ return out;
1273
+ }
1274
+
1275
+ template <typename T>
1276
+ inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
1277
+ OrtTypeInfo* output;
1278
+ ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
1279
+ return TypeInfo{output};
1280
+ }
1281
+
1282
+ template <typename T>
1283
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
1284
+ OrtTensorTypeAndShapeInfo* output;
1285
+ ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
1286
+ return TensorTypeAndShapeInfo{output};
1287
+ }
1288
+
1289
+ template <typename T>
1290
+ inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
1291
+ const OrtMemoryInfo* mem_info;
1292
+ ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
1293
+ return ConstMemoryInfo(mem_info);
1294
+ }
1295
+
1296
+ template <typename T>
1297
+ inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
1298
+ ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1299
+ }
1300
+
1301
+ template <typename T>
1302
+ inline std::string ConstValueImpl<T>::GetStringTensorElement(size_t element_index) const {
1303
+ size_t buffer_length;
1304
+ ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length));
1305
+
1306
+ std::string s;
1307
+ s.resize(buffer_length);
1308
+ ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0]));
1309
+ return s;
1310
+ }
1311
+
1312
+ template <typename T>
1313
+ inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
1314
+ ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1315
+ }
1316
+
1317
+ #if !defined(DISABLE_SPARSE_TENSORS)
1318
+ template <typename T>
1319
+ inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
1320
+ OrtSparseFormat format;
1321
+ ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
1322
+ return format;
1323
+ }
1324
+
1325
+ template <typename T>
1326
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
1327
+ OrtTensorTypeAndShapeInfo* output;
1328
+ ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
1329
+ return TensorTypeAndShapeInfo{output};
1330
+ }
1331
+
1332
+ template <typename T>
1333
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
1334
+ OrtTensorTypeAndShapeInfo* output;
1335
+ ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1336
+ return TensorTypeAndShapeInfo{output};
1337
+ }
1338
+
1339
+ template <typename T>
1340
+ template <typename R>
1341
+ inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
1342
+ const void* out;
1343
+ ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1344
+ return reinterpret_cast<const R*>(out);
1345
+ }
1346
+
1347
+ template <typename T>
1348
+ inline bool ConstValueImpl<T>::IsSparseTensor() const {
1349
+ int out;
1350
+ ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
1351
+ return out != 0;
1352
+ }
1353
+
1354
+ template <typename T>
1355
+ template <typename R>
1356
+ inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
1357
+ const void* out;
1358
+ ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
1359
+ return reinterpret_cast<const R*>(out);
1360
+ }
1361
+
1362
+ #endif
1363
+
1364
+ template <typename T>
1365
+ void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
1366
+ ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
1367
+ }
1368
+
1369
+ template <typename T>
1370
+ void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
1371
+ ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
1372
+ }
1373
+
1374
+ template <typename T>
1375
+ inline char* ValueImpl<T>::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) {
1376
+ char* result;
1377
+ ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result));
1378
+ return result;
1379
+ }
1380
+
1381
+ template <typename T>
1382
+ void* ValueImpl<T>::GetTensorMutableRawData() {
1383
+ void* out;
1384
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
1385
+ return out;
1386
+ }
1387
+
1388
+ template <typename T>
1389
+ template <typename R>
1390
+ R* ValueImpl<T>::GetTensorMutableData() {
1391
+ R* out;
1392
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
1393
+ return out;
1394
+ }
1395
+
1396
+ template <typename T>
1397
+ template <typename R>
1398
+ R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
1399
+ static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
1400
+ R* out;
1401
+ ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
1402
+ return *out;
1403
+ }
1404
+
1405
+ #if !defined(DISABLE_SPARSE_TENSORS)
1406
+ template <typename T>
1407
+ void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
1408
+ ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
1409
+ }
1410
+
1411
+ template <typename T>
1412
+ void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
1413
+ ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1414
+ }
1415
+
1416
+ template <typename T>
1417
+ void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
1418
+ ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
1419
+ }
1420
+
1421
+ template <typename T>
1422
+ void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
1423
+ const int64_t* indices_data, size_t indices_num) {
1424
+ ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
1425
+ values_param.values_shape_len, values_param.data.p_data,
1426
+ indices_data, indices_num));
1427
+ }
1428
+
1429
+ template <typename T>
1430
+ void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1431
+ const OrtSparseValuesParam& values,
1432
+ const int64_t* inner_indices_data, size_t inner_indices_num,
1433
+ const int64_t* outer_indices_data, size_t outer_indices_num) {
1434
+ ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1435
+ inner_indices_data, inner_indices_num,
1436
+ outer_indices_data, outer_indices_num));
1437
+ }
1438
+
1439
+ template <typename T>
1440
+ void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1441
+ const OrtSparseValuesParam& values,
1442
+ const Shape& indices_shape,
1443
+ const int32_t* indices_data) {
1444
+ ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1445
+ indices_shape.shape, indices_shape.shape_len,
1446
+ indices_data));
1447
+ }
1448
+
1449
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1450
+
1451
+ } // namespace detail
1452
+
1453
+ template <typename T>
1454
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
1455
+ return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1456
+ }
1457
+
1458
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1459
+ ONNXTensorElementDataType type) {
1460
+ OrtValue* out;
1461
+ ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1462
+ return Value{out};
1463
+ }
1464
+
1465
+ template <typename T>
1466
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1467
+ return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1468
+ }
1469
+
1470
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
1471
+ OrtValue* out;
1472
+ ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1473
+ return Value{out};
1474
+ }
1475
+
1476
+ #if !defined(DISABLE_SPARSE_TENSORS)
1477
+
1478
+ template <typename T>
1479
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1480
+ const Shape& values_shape) {
1481
+ return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
1482
+ }
1483
+
1484
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1485
+ const Shape& values_shape, ONNXTensorElementDataType type) {
1486
+ OrtValue* out;
1487
+ ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1488
+ values_shape.shape, values_shape.shape_len, type, &out));
1489
+ return Value{out};
1490
+ }
1491
+
1492
+ template <typename T>
1493
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
1494
+ return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
1495
+ }
1496
+
1497
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
1498
+ ONNXTensorElementDataType type) {
1499
+ OrtValue* out;
1500
+ ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
1501
+ return Value{out};
1502
+ }
1503
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1504
+
1505
+ inline Value Value::CreateMap(const Value& keys, const Value& values) {
1506
+ OrtValue* out;
1507
+ const OrtValue* inputs[2] = {keys, values};
1508
+ ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
1509
+ return Value{out};
1510
+ }
1511
+
1512
+ inline Value Value::CreateSequence(const std::vector<Value>& values) {
1513
+ OrtValue* out;
1514
+ std::vector<const OrtValue*> values_ort{values.data(), values.data() + values.size()};
1515
+ ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1516
+ return Value{out};
1517
+ }
1518
+
1519
+ template <typename T>
1520
+ inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
1521
+ OrtValue* out;
1522
+ ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
1523
+ return Value{out};
1524
+ }
1525
+
1526
+ //
1527
+ // Custom OP Inlines
1528
+ //
1529
+ inline Logger::Logger(const OrtLogger* logger) : logger_(logger) {
1530
+ Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_));
1531
+ }
1532
+
1533
+ inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept {
1534
+ return cached_severity_level_;
1535
+ }
1536
+
1537
+ inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
1538
+ const char* func_name, const char* message) const noexcept {
1539
+ OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number,
1540
+ func_name);
1541
+ return Status{status};
1542
+ }
1543
+
1544
+ // Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security)
1545
+ // for gcc and clang. The alternative is to use actual C-style variadic parameters and apply
1546
+ // __attribute__(format(printf...)), which does not work with variadic templates.
1547
+ #if defined(__GNUC__)
1548
+ #pragma GCC diagnostic push
1549
+ #pragma GCC diagnostic ignored "-Wformat-nonliteral"
1550
+ #pragma GCC diagnostic ignored "-Wformat-security"
1551
+ #elif defined(__clang__)
1552
+ #pragma clang diagnostic push
1553
+ #pragma clang diagnostic ignored "-Wformat-nonliteral"
1554
+ #pragma clang diagnostic ignored "-Wformat-security"
1555
+ #endif
1556
+ template <typename... Args>
1557
+ inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
1558
+ int line_number, const char* func_name, const char* format,
1559
+ Args&&... args) const noexcept {
1560
+ int msg_len = std::snprintf(nullptr, 0U, format, std::forward<Args>(args)...);
1561
+
1562
+ if (msg_len < 0) { // Formatting error
1563
+ return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL);
1564
+ }
1565
+
1566
+ OrtStatus* status = nullptr;
1567
+ const size_t buffer_size = static_cast<size_t>(msg_len) + 1U;
1568
+
1569
+ constexpr size_t kStackBufferSize = 1024;
1570
+
1571
+ if (buffer_size < kStackBufferSize) {
1572
+ char buffer[kStackBufferSize];
1573
+ snprintf(buffer, kStackBufferSize, format, std::forward<Args>(args)...);
1574
+ status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name);
1575
+ } else {
1576
+ // std::make_unique is only supported starting at C++14.
1577
+ #if (__cplusplus >= 201402L) || (_MSC_VER >= 1900)
1578
+ auto buffer = std::make_unique<char[]>(buffer_size);
1579
+ #else
1580
+ std::unique_ptr<char[]> buffer(new char[buffer_size]);
1581
+ #endif
1582
+ std::snprintf(buffer.get(), buffer_size, format, std::forward<Args>(args)...);
1583
+ status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name);
1584
+ }
1585
+
1586
+ return Status{status};
1587
+ }
1588
+ // Re-enable -Wformat-nonliteral and -Wformat-security
1589
+ #if defined(__GNUC__)
1590
+ #pragma GCC diagnostic pop
1591
+ #elif defined(__clang__)
1592
+ #pragma clang diagnostic pop
1593
+ #endif
1594
+
1595
+ inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
1596
+ }
1597
+
1598
+ inline size_t KernelContext::GetInputCount() const {
1599
+ size_t out = 0;
1600
+ Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
1601
+ return out;
1602
+ }
1603
+
1604
+ inline size_t KernelContext::GetOutputCount() const {
1605
+ size_t out = 0;
1606
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
1607
+ return out;
1608
+ }
1609
+
1610
+ inline ConstValue KernelContext::GetInput(size_t index) const {
1611
+ const OrtValue* out = nullptr;
1612
+ Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
1613
+ return ConstValue{out};
1614
+ }
1615
+
1616
+ inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
1617
+ OrtValue* out = nullptr;
1618
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
1619
+ return UnownedValue(out);
1620
+ }
1621
+
1622
+ inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
1623
+ OrtValue* out = nullptr;
1624
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
1625
+ return UnownedValue(out);
1626
+ }
1627
+
1628
+ inline void* KernelContext::GetGPUComputeStream() const {
1629
+ void* out = nullptr;
1630
+ Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
1631
+ return out;
1632
+ }
1633
+
1634
+ inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const {
1635
+ OrtAllocator* out = nullptr;
1636
+ Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out));
1637
+ return out;
1638
+ }
1639
+
1640
+ inline Logger KernelContext::GetLogger() const {
1641
+ const OrtLogger* out = nullptr;
1642
+ ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out));
1643
+ return Logger{out};
1644
+ }
1645
+
1646
+ inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
1647
+ Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
1648
+ }
1649
+
1650
+ namespace detail {
1651
+ template <typename T>
1652
+ inline KernelInfo KernelInfoImpl<T>::Copy() const {
1653
+ OrtKernelInfo* info_copy = nullptr;
1654
+ Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
1655
+ return KernelInfo{info_copy};
1656
+ }
1657
+
1658
+ template <typename T>
1659
+ inline size_t KernelInfoImpl<T>::GetInputCount() const {
1660
+ size_t out = 0;
1661
+ ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
1662
+ return out;
1663
+ }
1664
+
1665
+ template <typename T>
1666
+ inline size_t KernelInfoImpl<T>::GetOutputCount() const {
1667
+ size_t out = 0;
1668
+ ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
1669
+ return out;
1670
+ }
1671
+
1672
+ template <typename T>
1673
+ inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
1674
+ size_t size = 0;
1675
+
1676
+ // Feed nullptr for the data buffer to query the true size of the string value
1677
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
1678
+
1679
+ std::string out;
1680
+ out.resize(size);
1681
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
1682
+ out.resize(size - 1); // remove the terminating character '\0'
1683
+
1684
+ return out;
1685
+ }
1686
+
1687
+ template <typename T>
1688
+ inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
1689
+ size_t size = 0;
1690
+
1691
+ // Feed nullptr for the data buffer to query the true size of the string value
1692
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
1693
+
1694
+ std::string out;
1695
+ out.resize(size);
1696
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
1697
+ out.resize(size - 1); // remove the terminating character '\0'
1698
+
1699
+ return out;
1700
+ }
1701
+
1702
+ template <typename T>
1703
+ inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
1704
+ OrtTypeInfo* out = nullptr;
1705
+ ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
1706
+ return TypeInfo{out};
1707
+ }
1708
+
1709
+ template <typename T>
1710
+ inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
1711
+ OrtTypeInfo* out = nullptr;
1712
+ ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
1713
+ return TypeInfo{out};
1714
+ }
1715
+
1716
+ template <typename T>
1717
+ inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
1718
+ OrtValue* out = nullptr;
1719
+ ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1720
+ return Value{out};
1721
+ }
1722
+
1723
+ template <typename T>
1724
+ inline ConstValue KernelInfoImpl<T>::GetTensorConstantInput(size_t index, int* is_constant) const {
1725
+ const OrtValue* out = nullptr;
1726
+ ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out));
1727
+ return ConstValue{out};
1728
+ }
1729
+
1730
+ template <typename T>
1731
+ inline std::string KernelInfoImpl<T>::GetNodeName() const {
1732
+ size_t size = 0;
1733
+
1734
+ // Feed nullptr for the data buffer to query the true size of the string value
1735
+ Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size));
1736
+
1737
+ std::string out;
1738
+ out.resize(size);
1739
+ Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size));
1740
+ out.resize(size - 1); // remove the terminating character '\0'
1741
+
1742
+ return out;
1743
+ }
1744
+
1745
+ template <typename T>
1746
+ inline Logger KernelInfoImpl<T>::GetLogger() const {
1747
+ const OrtLogger* out = nullptr;
1748
+ ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out));
1749
+ return Logger{out};
1750
+ }
1751
+
1752
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
1753
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
1754
+ }
1755
+
1756
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
1757
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
1758
+ }
1759
+
1760
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
1761
+ size_t size = 0;
1762
+ // Feed nullptr for the data buffer to query the true size of the string attribute
1763
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
1764
+
1765
+ std::string out;
1766
+ out.resize(size);
1767
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
1768
+ out.resize(size - 1); // remove the terminating character '\0'
1769
+ out.swap(result);
1770
+ }
1771
+
1772
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
1773
+ size_t size = 0;
1774
+ // Feed nullptr for the data buffer to query the true size of the attribute
1775
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
1776
+
1777
+ std::vector<float> out;
1778
+ out.resize(size);
1779
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
1780
+ out.swap(result);
1781
+ }
1782
+
1783
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
1784
+ size_t size = 0;
1785
+
1786
+ // Feed nullptr for the data buffer to query the true size of the attribute
1787
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
1788
+
1789
+ std::vector<int64_t> out;
1790
+ out.resize(size);
1791
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
1792
+ out.swap(result);
1793
+ }
1794
+ } // namespace detail
1795
+
1796
+ inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
1797
+
1798
+ inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1799
+
1800
+ inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
1801
+ const char** type_constraint_names,
1802
+ const ONNXTensorElementDataType* type_constraint_values,
1803
+ size_t type_constraint_count,
1804
+ const OpAttr* attr_values, size_t attr_count,
1805
+ size_t input_count, size_t output_count) {
1806
+ static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
1807
+ "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1808
+ auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
1809
+ OrtOp* op;
1810
+ Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1811
+ static_cast<int>(type_constraint_count),
1812
+ attr_input_values,
1813
+ static_cast<int>(attr_count),
1814
+ static_cast<int>(input_count),
1815
+ static_cast<int>(output_count), &op));
1816
+ return Op{op};
1817
+ }
1818
+
1819
+ inline void Op::Invoke(const OrtKernelContext* context,
1820
+ const Value* input_values,
1821
+ size_t input_count,
1822
+ Value* output_values,
1823
+ size_t output_count) {
1824
+ static_assert(sizeof(Value) == sizeof(OrtValue*),
1825
+ "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1826
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1827
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1828
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
1829
+ ort_output_values, static_cast<int>(output_count)));
1830
+ }
1831
+
1832
+ inline void Op::Invoke(const OrtKernelContext* context,
1833
+ const OrtValue* const* input_values,
1834
+ size_t input_count,
1835
+ OrtValue* const* output_values,
1836
+ size_t output_count) {
1837
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
1838
+ output_values, static_cast<int>(output_count)));
1839
+ }
1840
+
1841
+ inline std::string GetVersionString() {
1842
+ return OrtGetApiBase()->GetVersionString();
1843
+ }
1844
+
1845
+ inline std::string GetBuildInfoString() {
1846
+ return GetApi().GetBuildInfoString();
1847
+ }
1848
+
1849
+ inline std::vector<std::string> GetAvailableProviders() {
1850
+ char** providers;
1851
+ int len;
1852
+
1853
+ auto release_fn = [&len](char** providers) {
1854
+ // This should always return nullptr.
1855
+ ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
1856
+ };
1857
+
1858
+ ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
1859
+ std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
1860
+ std::vector<std::string> available_providers;
1861
+ available_providers.reserve(static_cast<size_t>(len));
1862
+ for (int i = 0; i < len; ++i) {
1863
+ available_providers.emplace_back(providers[i]);
1864
+ }
1865
+ return available_providers;
1866
+ }
1867
+
1868
+ template <typename TOp, typename TKernel, bool WithStatus>
1869
+ void CustomOpBase<TOp, TKernel, WithStatus>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
1870
+ ConstSessionOptions options) const {
1871
+ const TOp* derived = static_cast<const TOp*>(this);
1872
+ std::vector<std::string> keys = derived->GetSessionConfigKeys();
1873
+
1874
+ out.reserve(keys.size());
1875
+
1876
+ std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
1877
+ const size_t prefix_size = config_entry_key.length();
1878
+
1879
+ for (const auto& key : keys) {
1880
+ config_entry_key.resize(prefix_size);
1881
+ config_entry_key.append(key);
1882
+ out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
1883
+ }
1884
+ }
1885
+
1886
+ } // namespace Ort
1.16.2/onnxruntime.xcframework/Headers/onnxruntime_float16.h ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ #include <stdint.h>
7
+ #include <cmath>
8
+ #include <cstring>
9
+ #include <limits>
10
+
11
+ namespace onnxruntime_float16 {
12
+
13
+ namespace detail {
14
+
15
+ enum class endian {
16
+ #if defined(_WIN32)
17
+ little = 0,
18
+ big = 1,
19
+ native = little,
20
+ #elif defined(__GNUC__) || defined(__clang__)
21
+ little = __ORDER_LITTLE_ENDIAN__,
22
+ big = __ORDER_BIG_ENDIAN__,
23
+ native = __BYTE_ORDER__,
24
+ #else
25
+ #error onnxruntime_float16::detail::endian is not implemented in this environment.
26
+ #endif
27
+ };
28
+
29
+ static_assert(
30
+ endian::native == endian::little || endian::native == endian::big,
31
+ "Only little-endian or big-endian native byte orders are supported.");
32
+
33
+ } // namespace detail
34
+
35
+ /// <summary>
36
+ /// Shared implementation between public and internal classes. CRTP pattern.
37
+ /// </summary>
38
+ template <class Derived>
39
+ struct Float16Impl {
40
+ protected:
41
+ /// <summary>
42
+ /// Converts from float to uint16_t float16 representation
43
+ /// </summary>
44
+ /// <param name="v"></param>
45
+ /// <returns></returns>
46
+ constexpr static uint16_t ToUint16Impl(float v) noexcept;
47
+
48
+ /// <summary>
49
+ /// Converts float16 to float
50
+ /// </summary>
51
+ /// <returns>float representation of float16 value</returns>
52
+ float ToFloatImpl() const noexcept;
53
+
54
+ /// <summary>
55
+ /// Creates an instance that represents absolute value.
56
+ /// </summary>
57
+ /// <returns>Absolute value</returns>
58
+ uint16_t AbsImpl() const noexcept {
59
+ return static_cast<uint16_t>(val & ~kSignMask);
60
+ }
61
+
62
+ /// <summary>
63
+ /// Creates a new instance with the sign flipped.
64
+ /// </summary>
65
+ /// <returns>Flipped sign instance</returns>
66
+ uint16_t NegateImpl() const noexcept {
67
+ return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
68
+ }
69
+
70
+ public:
71
+ // uint16_t special values
72
+ static constexpr uint16_t kSignMask = 0x8000U;
73
+ static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
74
+ static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
75
+ static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
76
+ static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
77
+ static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
78
+ static constexpr uint16_t kEpsilonBits = 0x4170U;
79
+ static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number
80
+ static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
81
+ static constexpr uint16_t kOneBits = 0x3C00U;
82
+ static constexpr uint16_t kMinusOneBits = 0xBC00U;
83
+
84
+ uint16_t val{0};
85
+
86
+ Float16Impl() = default;
87
+
88
+ /// <summary>
89
+ /// Checks if the value is negative
90
+ /// </summary>
91
+ /// <returns>true if negative</returns>
92
+ bool IsNegative() const noexcept {
93
+ return static_cast<int16_t>(val) < 0;
94
+ }
95
+
96
+ /// <summary>
97
+ /// Tests if the value is NaN
98
+ /// </summary>
99
+ /// <returns>true if NaN</returns>
100
+ bool IsNaN() const noexcept {
101
+ return AbsImpl() > kPositiveInfinityBits;
102
+ }
103
+
104
+ /// <summary>
105
+ /// Tests if the value is finite
106
+ /// </summary>
107
+ /// <returns>true if finite</returns>
108
+ bool IsFinite() const noexcept {
109
+ return AbsImpl() < kPositiveInfinityBits;
110
+ }
111
+
112
+ /// <summary>
113
+ /// Tests if the value represents positive infinity.
114
+ /// </summary>
115
+ /// <returns>true if positive infinity</returns>
116
+ bool IsPositiveInfinity() const noexcept {
117
+ return val == kPositiveInfinityBits;
118
+ }
119
+
120
+ /// <summary>
121
+ /// Tests if the value represents negative infinity
122
+ /// </summary>
123
+ /// <returns>true if negative infinity</returns>
124
+ bool IsNegativeInfinity() const noexcept {
125
+ return val == kNegativeInfinityBits;
126
+ }
127
+
128
+ /// <summary>
129
+ /// Tests if the value is either positive or negative infinity.
130
+ /// </summary>
131
+ /// <returns>True if absolute value is infinity</returns>
132
+ bool IsInfinity() const noexcept {
133
+ return AbsImpl() == kPositiveInfinityBits;
134
+ }
135
+
136
+ /// <summary>
137
+ /// Tests if the value is NaN or zero. Useful for comparisons.
138
+ /// </summary>
139
+ /// <returns>True if NaN or zero.</returns>
140
+ bool IsNaNOrZero() const noexcept {
141
+ auto abs = AbsImpl();
142
+ return (abs == 0 || abs > kPositiveInfinityBits);
143
+ }
144
+
145
+ /// <summary>
146
+ /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
147
+ /// </summary>
148
+ /// <returns>True if so</returns>
149
+ bool IsNormal() const noexcept {
150
+ auto abs = AbsImpl();
151
+ return (abs < kPositiveInfinityBits) // is finite
152
+ && (abs != 0) // is not zero
153
+ && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
154
+ }
155
+
156
+ /// <summary>
157
+ /// Tests if the value is subnormal (denormal).
158
+ /// </summary>
159
+ /// <returns>True if so</returns>
160
+ bool IsSubnormal() const noexcept {
161
+ auto abs = AbsImpl();
162
+ return (abs < kPositiveInfinityBits) // is finite
163
+ && (abs != 0) // is not zero
164
+ && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
165
+ }
166
+
167
+ /// <summary>
168
+ /// Creates an instance that represents absolute value.
169
+ /// </summary>
170
+ /// <returns>Absolute value</returns>
171
+ Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
172
+
173
+ /// <summary>
174
+ /// Creates a new instance with the sign flipped.
175
+ /// </summary>
176
+ /// <returns>Flipped sign instance</returns>
177
+ Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
178
+
179
+ /// <summary>
180
+ /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
181
+ /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
182
+ /// and therefore equivalent, if the resulting value is still zero.
183
+ /// </summary>
184
+ /// <param name="lhs">first value</param>
185
+ /// <param name="rhs">second value</param>
186
+ /// <returns>True if both arguments represent zero</returns>
187
+ static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
188
+ return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
189
+ }
190
+
191
+ bool operator==(const Float16Impl& rhs) const noexcept {
192
+ if (IsNaN() || rhs.IsNaN()) {
193
+ // IEEE defines that NaN is not equal to anything, including itself.
194
+ return false;
195
+ }
196
+ return val == rhs.val;
197
+ }
198
+
199
+ bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
200
+
201
+ bool operator<(const Float16Impl& rhs) const noexcept {
202
+ if (IsNaN() || rhs.IsNaN()) {
203
+ // IEEE defines that NaN is unordered with respect to everything, including itself.
204
+ return false;
205
+ }
206
+
207
+ const bool left_is_negative = IsNegative();
208
+ if (left_is_negative != rhs.IsNegative()) {
209
+ // When the signs of left and right differ, we know that left is less than right if it is
210
+ // the negative value. The exception to this is if both values are zero, in which case IEEE
211
+ // says they should be equal, even if the signs differ.
212
+ return left_is_negative && !AreZero(*this, rhs);
213
+ }
214
+ return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
215
+ }
216
+ };
217
+
218
+ // The following Float16_t conversions are based on the code from
219
+ // Eigen library.
220
+
221
+ // The conversion routines are Copyright (c) Fabian Giesen, 2016.
222
+ // The original license follows:
223
+ //
224
+ // Copyright (c) Fabian Giesen, 2016
225
+ // All rights reserved.
226
+ // Redistribution and use in source and binary forms, with or without
227
+ // modification, are permitted.
228
+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
229
+ // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
230
+ // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
231
+ // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
232
+ // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
233
+ // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
234
+ // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
235
+ // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
236
+ // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
237
+ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
238
+ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
239
+
240
+ namespace detail {
241
+ union float32_bits {
242
+ unsigned int u;
243
+ float f;
244
+ };
245
+ } // namespace detail
246
+
247
+ template <class Derived>
248
+ inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
249
+ detail::float32_bits f{};
250
+ f.f = v;
251
+
252
+ constexpr detail::float32_bits f32infty = {255 << 23};
253
+ constexpr detail::float32_bits f16max = {(127 + 16) << 23};
254
+ constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
255
+ constexpr unsigned int sign_mask = 0x80000000u;
256
+ uint16_t val = static_cast<uint16_t>(0x0u);
257
+
258
+ unsigned int sign = f.u & sign_mask;
259
+ f.u ^= sign;
260
+
261
+ // NOTE all the integer compares in this function can be safely
262
+ // compiled into signed compares since all operands are below
263
+ // 0x80000000. Important if you want fast straight SSE2 code
264
+ // (since there's no unsigned PCMPGTD).
265
+
266
+ if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
267
+ val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
268
+ } else { // (De)normalized number or zero
269
+ if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
270
+ // use a magic value to align our 10 mantissa bits at the bottom of
271
+ // the float. as long as FP addition is round-to-nearest-even this
272
+ // just works.
273
+ f.f += denorm_magic.f;
274
+
275
+ // and one integer subtract of the bias later, we have our final float!
276
+ val = static_cast<uint16_t>(f.u - denorm_magic.u);
277
+ } else {
278
+ unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
279
+
280
+ // update exponent, rounding bias part 1
281
+ // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
282
+ // without arithmetic overflow.
283
+ f.u += 0xc8000fffU;
284
+ // rounding bias part 2
285
+ f.u += mant_odd;
286
+ // take the bits!
287
+ val = static_cast<uint16_t>(f.u >> 13);
288
+ }
289
+ }
290
+
291
+ val |= static_cast<uint16_t>(sign >> 16);
292
+ return val;
293
+ }
294
+
295
+ template <class Derived>
296
+ inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
297
+ constexpr detail::float32_bits magic = {113 << 23};
298
+ constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
299
+ detail::float32_bits o{};
300
+
301
+ o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
302
+ unsigned int exp = shifted_exp & o.u; // just the exponent
303
+ o.u += (127 - 15) << 23; // exponent adjust
304
+
305
+ // handle exponent special cases
306
+ if (exp == shifted_exp) { // Inf/NaN?
307
+ o.u += (128 - 16) << 23; // extra exp adjust
308
+ } else if (exp == 0) { // Zero/Denormal?
309
+ o.u += 1 << 23; // extra exp adjust
310
+ o.f -= magic.f; // re-normalize
311
+ }
312
+
313
+ // Attempt to workaround the Internal Compiler Error on ARM64
314
+ // for bitwise | operator, including std::bitset
315
+ #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
316
+ if (IsNegative()) {
317
+ return -o.f;
318
+ }
319
+ #else
320
+ // original code:
321
+ o.u |= (val & 0x8000U) << 16U; // sign bit
322
+ #endif
323
+ return o.f;
324
+ }
325
+
326
+ /// Shared implementation between public and internal classes. CRTP pattern.
327
+ template <class Derived>
328
+ struct BFloat16Impl {
329
+ protected:
330
+ /// <summary>
331
+ /// Converts from float to uint16_t float16 representation
332
+ /// </summary>
333
+ /// <param name="v"></param>
334
+ /// <returns></returns>
335
+ static uint16_t ToUint16Impl(float v) noexcept;
336
+
337
+ /// <summary>
338
+ /// Converts bfloat16 to float
339
+ /// </summary>
340
+ /// <returns>float representation of bfloat16 value</returns>
341
+ float ToFloatImpl() const noexcept;
342
+
343
+ /// <summary>
344
+ /// Creates an instance that represents absolute value.
345
+ /// </summary>
346
+ /// <returns>Absolute value</returns>
347
+ uint16_t AbsImpl() const noexcept {
348
+ return static_cast<uint16_t>(val & ~kSignMask);
349
+ }
350
+
351
+ /// <summary>
352
+ /// Creates a new instance with the sign flipped.
353
+ /// </summary>
354
+ /// <returns>Flipped sign instance</returns>
355
+ uint16_t NegateImpl() const noexcept {
356
+ return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
357
+ }
358
+
359
+ public:
360
+ // uint16_t special values
361
+ static constexpr uint16_t kSignMask = 0x8000U;
362
+ static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
363
+ static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
364
+ static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
365
+ static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
366
+ static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
367
+ static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
368
+ static constexpr uint16_t kEpsilonBits = 0x0080U;
369
+ static constexpr uint16_t kMinValueBits = 0xFF7FU;
370
+ static constexpr uint16_t kMaxValueBits = 0x7F7FU;
371
+ static constexpr uint16_t kRoundToNearest = 0x7FFFU;
372
+ static constexpr uint16_t kOneBits = 0x3F80U;
373
+ static constexpr uint16_t kMinusOneBits = 0xBF80U;
374
+
375
+ uint16_t val{0};
376
+
377
+ BFloat16Impl() = default;
378
+
379
+ /// <summary>
380
+ /// Checks if the value is negative
381
+ /// </summary>
382
+ /// <returns>true if negative</returns>
383
+ bool IsNegative() const noexcept {
384
+ return static_cast<int16_t>(val) < 0;
385
+ }
386
+
387
+ /// <summary>
388
+ /// Tests if the value is NaN
389
+ /// </summary>
390
+ /// <returns>true if NaN</returns>
391
+ bool IsNaN() const noexcept {
392
+ return AbsImpl() > kPositiveInfinityBits;
393
+ }
394
+
395
+ /// <summary>
396
+ /// Tests if the value is finite
397
+ /// </summary>
398
+ /// <returns>true if finite</returns>
399
+ bool IsFinite() const noexcept {
400
+ return AbsImpl() < kPositiveInfinityBits;
401
+ }
402
+
403
+ /// <summary>
404
+ /// Tests if the value represents positive infinity.
405
+ /// </summary>
406
+ /// <returns>true if positive infinity</returns>
407
+ bool IsPositiveInfinity() const noexcept {
408
+ return val == kPositiveInfinityBits;
409
+ }
410
+
411
+ /// <summary>
412
+ /// Tests if the value represents negative infinity
413
+ /// </summary>
414
+ /// <returns>true if negative infinity</returns>
415
+ bool IsNegativeInfinity() const noexcept {
416
+ return val == kNegativeInfinityBits;
417
+ }
418
+
419
+ /// <summary>
420
+ /// Tests if the value is either positive or negative infinity.
421
+ /// </summary>
422
+ /// <returns>True if absolute value is infinity</returns>
423
+ bool IsInfinity() const noexcept {
424
+ return AbsImpl() == kPositiveInfinityBits;
425
+ }
426
+
427
+ /// <summary>
428
+ /// Tests if the value is NaN or zero. Useful for comparisons.
429
+ /// </summary>
430
+ /// <returns>True if NaN or zero.</returns>
431
+ bool IsNaNOrZero() const noexcept {
432
+ auto abs = AbsImpl();
433
+ return (abs == 0 || abs > kPositiveInfinityBits);
434
+ }
435
+
436
+ /// <summary>
437
+ /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
438
+ /// </summary>
439
+ /// <returns>True if so</returns>
440
+ bool IsNormal() const noexcept {
441
+ auto abs = AbsImpl();
442
+ return (abs < kPositiveInfinityBits) // is finite
443
+ && (abs != 0) // is not zero
444
+ && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
445
+ }
446
+
447
+ /// <summary>
448
+ /// Tests if the value is subnormal (denormal).
449
+ /// </summary>
450
+ /// <returns>True if so</returns>
451
+ bool IsSubnormal() const noexcept {
452
+ auto abs = AbsImpl();
453
+ return (abs < kPositiveInfinityBits) // is finite
454
+ && (abs != 0) // is not zero
455
+ && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
456
+ }
457
+
458
+ /// <summary>
459
+ /// Creates an instance that represents absolute value.
460
+ /// </summary>
461
+ /// <returns>Absolute value</returns>
462
+ Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
463
+
464
+ /// <summary>
465
+ /// Creates a new instance with the sign flipped.
466
+ /// </summary>
467
+ /// <returns>Flipped sign instance</returns>
468
+ Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
469
+
470
+ /// <summary>
471
+ /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
472
+ /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
473
+ /// and therefore equivalent, if the resulting value is still zero.
474
+ /// </summary>
475
+ /// <param name="lhs">first value</param>
476
+ /// <param name="rhs">second value</param>
477
+ /// <returns>True if both arguments represent zero</returns>
478
+ static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
479
+ // IEEE defines that positive and negative zero are equal, this gives us a quick equality check
480
+ // for two values by or'ing the private bits together and stripping the sign. They are both zero,
481
+ // and therefore equivalent, if the resulting value is still zero.
482
+ return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
483
+ }
484
+ };
485
+
486
+ template <class Derived>
487
+ inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
488
+ uint16_t result;
489
+ if (std::isnan(v)) {
490
+ result = kPositiveQNaNBits;
491
+ } else {
492
+ auto get_msb_half = [](float fl) {
493
+ uint16_t result;
494
+ #ifdef __cpp_if_constexpr
495
+ if constexpr (detail::endian::native == detail::endian::little) {
496
+ #else
497
+ if (detail::endian::native == detail::endian::little) {
498
+ #endif
499
+ std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
500
+ } else {
501
+ std::memcpy(&result, &fl, sizeof(uint16_t));
502
+ }
503
+ return result;
504
+ };
505
+
506
+ uint16_t upper_bits = get_msb_half(v);
507
+ union {
508
+ uint32_t U32;
509
+ float F32;
510
+ };
511
+ F32 = v;
512
+ U32 += (upper_bits & 1) + kRoundToNearest;
513
+ result = get_msb_half(F32);
514
+ }
515
+ return result;
516
+ }
517
+
518
+ template <class Derived>
519
+ inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
520
+ if (IsNaN()) {
521
+ return std::numeric_limits<float>::quiet_NaN();
522
+ }
523
+ float result;
524
+ char* const first = reinterpret_cast<char*>(&result);
525
+ char* const second = first + sizeof(uint16_t);
526
+ #ifdef __cpp_if_constexpr
527
+ if constexpr (detail::endian::native == detail::endian::little) {
528
+ #else
529
+ if (detail::endian::native == detail::endian::little) {
530
+ #endif
531
+ std::memset(first, 0, sizeof(uint16_t));
532
+ std::memcpy(second, &val, sizeof(uint16_t));
533
+ } else {
534
+ std::memcpy(first, &val, sizeof(uint16_t));
535
+ std::memset(second, 0, sizeof(uint16_t));
536
+ }
537
+ return result;
538
+ }
539
+
540
+ } // namespace onnxruntime_float16
1.16.2/onnxruntime.xcframework/Info.plist ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
3
+ <plist version="1.0">
4
+ <dict>
5
+ <key>AvailableLibraries</key>
6
+ <array>
7
+ <dict>
8
+ <key>LibraryIdentifier</key>
9
+ <string>ios-arm64</string>
10
+ <key>LibraryPath</key>
11
+ <string>onnxruntime.a</string>
12
+ <key>SupportedArchitectures</key>
13
+ <array>
14
+ <string>arm64</string>
15
+ </array>
16
+ <key>SupportedPlatform</key>
17
+ <string>ios</string>
18
+ </dict>
19
+ <dict>
20
+ <key>LibraryIdentifier</key>
21
+ <string>ios-arm64_x86_64-simulator</string>
22
+ <key>LibraryPath</key>
23
+ <string>onnxruntime.a</string>
24
+ <key>SupportedArchitectures</key>
25
+ <array>
26
+ <string>arm64</string>
27
+ <string>x86_64</string>
28
+ </array>
29
+ <key>SupportedPlatform</key>
30
+ <string>ios</string>
31
+ <key>SupportedPlatformVariant</key>
32
+ <string>simulator</string>
33
+ </dict>
34
+ <dict>
35
+ <key>LibraryIdentifier</key>
36
+ <string>macos-arm64_x86_64</string>
37
+ <key>LibraryPath</key>
38
+ <string>onnxruntime.a</string>
39
+ <key>SupportedArchitectures</key>
40
+ <array>
41
+ <string>arm64</string>
42
+ <string>x86_64</string>
43
+ </array>
44
+ <key>SupportedPlatform</key>
45
+ <string>macos</string>
46
+ </dict>
47
+ </array>
48
+ <key>CFBundlePackageType</key>
49
+ <string>XFWK</string>
50
+ <key>XCFrameworkFormatVersion</key>
51
+ <string>1.0</string>
52
+ </dict>
53
+ </plist>
1.16.2/onnxruntime.xcframework/ios-arm64/libonnxruntime.a ADDED
@@ -0,0 +1 @@
 
 
1
+ onnxruntime.a
1.16.2/onnxruntime.xcframework/ios-arm64/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d23ba169a0449233dba898027509c23cfc530a29abe1f4ade79e188bdb7ab24d
3
+ size 61402536
1.16.2/onnxruntime.xcframework/ios-arm64_x86_64-simulator/libonnxruntime.a ADDED
@@ -0,0 +1 @@
 
 
1
+ onnxruntime.a
1.16.2/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72c7dbf49e6d5cb6e2c667c1a6afbcb33cfd41ec2e11db62dd8455c4e4cf9441
3
+ size 125536184
1.16.2/onnxruntime.xcframework/macos-arm64_x86_64/libonnxruntime.a ADDED
@@ -0,0 +1 @@
 
 
1
+ onnxruntime.a
1.16.2/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bba4bce6b4f7f92fa898c93aa3a4dae39553dd883e307e2d7e51fcf03e63d43
3
+ size 127981440