csukuangfj commited on
Commit
e212fe9
·
1 Parent(s): a93c67c

Add onnxruntime.xcframework 1.20.0

Browse files
1.20.0/onnxruntime.xcframework/Headers/coreml_provider_factory.h ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+ #pragma once
4
+
5
+ #include "onnxruntime_c_api.h"
6
+
7
+ // COREMLFlags are bool options we want to set for CoreML EP
8
+ // This enum is defined as bit flags, and cannot have negative value
9
+ // To generate an uint32_t coreml_flags for using with OrtSessionOptionsAppendExecutionProvider_CoreML below,
10
+ // uint32_t coreml_flags = 0;
11
+ // coreml_flags |= COREML_FLAG_USE_CPU_ONLY;
12
+ enum COREMLFlags {
13
+ COREML_FLAG_USE_NONE = 0x000,
14
+
15
+ // Using CPU only in CoreML EP, this may decrease the perf but will provide
16
+ // reference output value without precision loss, which is useful for validation
17
+ COREML_FLAG_USE_CPU_ONLY = 0x001,
18
+
19
+ // Enable CoreML EP on subgraph
20
+ COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002,
21
+
22
+ // By default CoreML Execution provider will be enabled for all compatible Apple devices
23
+ // Enable this option will only enable CoreML EP for Apple devices with ANE (Apple Neural Engine)
24
+ // Please note, enable this option does not guarantee the entire model to be executed using ANE only
25
+ COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004,
26
+
27
+ // Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also allow inputs with
28
+ // dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes.
29
+ COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008,
30
+
31
+ // Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later.
32
+ COREML_FLAG_CREATE_MLPROGRAM = 0x010,
33
+
34
+ // Exclude ANE as sometimes this decrease performance
35
+ // https://developer.apple.com/documentation/coreml/mlcomputeunits?language=objc
36
+ // there are four compute units:
37
+ // MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll
38
+ COREML_FLAG_USE_CPU_AND_GPU = 0x020,
39
+ // Keep COREML_FLAG_LAST at the end of the enum definition
40
+ // And assign the last COREMLFlag to it
41
+ COREML_FLAG_LAST = COREML_FLAG_USE_CPU_AND_GPU,
42
+ };
43
+
44
+ #ifdef __cplusplus
45
+ extern "C" {
46
+ #endif
47
+
48
+ ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CoreML,
49
+ _In_ OrtSessionOptions* options, uint32_t coreml_flags);
50
+
51
+ #ifdef __cplusplus
52
+ }
53
+ #endif
1.20.0/onnxruntime.xcframework/Headers/cpu_provider_factory.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #include "onnxruntime_c_api.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ /**
11
+ * \param use_arena zero: false. non-zero: true.
12
+ */
13
+ ORT_EXPORT
14
+ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
15
+ ORT_ALL_ARGS_NONNULL;
16
+
17
+ #ifdef __cplusplus
18
+ }
19
+ #endif
1.20.0/onnxruntime.xcframework/Headers/onnxruntime_c_api.h ADDED
The diff for this file is too large to render. See raw diff
 
1.20.0/onnxruntime.xcframework/Headers/onnxruntime_cxx_api.h ADDED
The diff for this file is too large to render. See raw diff
 
1.20.0/onnxruntime.xcframework/Headers/onnxruntime_cxx_inline.h ADDED
@@ -0,0 +1,2170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
5
+ // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
6
+ //
7
+ // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
8
+ // the main C++ file with implementation details.
9
+
10
+ #include <algorithm>
11
+ #include <functional>
12
+ #include <iterator>
13
+ #include <type_traits>
14
+
15
+ // Convert OrtStatus to Ort::Status and return
16
+ // instead of throwing
17
+ #define ORT_CXX_RETURN_ON_API_FAIL(expression) \
18
+ { \
19
+ auto ort_status = (expression); \
20
+ if (ort_status) { \
21
+ return Ort::Status(ort_status); \
22
+ } \
23
+ }
24
+
25
+ #ifdef __cpp_if_constexpr
26
+ #define ORT_CXX_IF_CONSTEXPR if constexpr
27
+ #else
28
+ #define ORT_CXX_IF_CONSTEXPR if
29
+ #endif
30
+
31
+ namespace Ort {
32
+
33
+ namespace detail {
34
+ inline void ThrowStatus(const Status& st) {
35
+ std::string error_message = st.GetErrorMessage();
36
+ OrtErrorCode error_code = st.GetErrorCode();
37
+ ORT_CXX_API_THROW(std::move(error_message), error_code);
38
+ }
39
+ } // namespace detail
40
+
41
+ inline void ThrowOnError(OrtStatus* ort_status) {
42
+ if (ort_status) {
43
+ Ort::Status st(ort_status);
44
+ detail::ThrowStatus(st);
45
+ }
46
+ }
47
+
48
+ inline void ThrowOnError(const Status& st) {
49
+ if (st) {
50
+ detail::ThrowStatus(st);
51
+ }
52
+ }
53
+
54
+ inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
55
+ }
56
+
57
+ inline Status::Status(const std::exception& e) noexcept {
58
+ p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
59
+ }
60
+
61
+ inline Status::Status(const Exception& e) noexcept {
62
+ p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
63
+ }
64
+
65
+ inline Status::Status(const char* message, OrtErrorCode code) noexcept {
66
+ p_ = GetApi().CreateStatus(code, message);
67
+ }
68
+
69
+ inline std::string Status::GetErrorMessage() const {
70
+ std::string message(GetApi().GetErrorMessage(p_));
71
+ return message;
72
+ }
73
+
74
+ inline OrtErrorCode Status::GetErrorCode() const {
75
+ return GetApi().GetErrorCode(p_);
76
+ }
77
+
78
+ inline bool Status::IsOK() const noexcept {
79
+ return (p_ == nullptr);
80
+ }
81
+
82
+ // This template converts a C++ type into it's ONNXTensorElementDataType
83
+ template <typename T>
84
+ struct TypeToTensorType;
85
+ template <>
86
+ struct TypeToTensorType<float> {
87
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
88
+ };
89
+ template <>
90
+ struct TypeToTensorType<Float16_t> {
91
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
92
+ };
93
+ template <>
94
+ struct TypeToTensorType<BFloat16_t> {
95
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
96
+ };
97
+ template <>
98
+ struct TypeToTensorType<double> {
99
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
100
+ };
101
+ template <>
102
+ struct TypeToTensorType<int8_t> {
103
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
104
+ };
105
+ template <>
106
+ struct TypeToTensorType<int16_t> {
107
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
108
+ };
109
+ template <>
110
+ struct TypeToTensorType<int32_t> {
111
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
112
+ };
113
+ template <>
114
+ struct TypeToTensorType<int64_t> {
115
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
116
+ };
117
+ template <>
118
+ struct TypeToTensorType<uint8_t> {
119
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
120
+ };
121
+ template <>
122
+ struct TypeToTensorType<uint16_t> {
123
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
124
+ };
125
+ template <>
126
+ struct TypeToTensorType<uint32_t> {
127
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
128
+ };
129
+ template <>
130
+ struct TypeToTensorType<uint64_t> {
131
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
132
+ };
133
+ template <>
134
+ struct TypeToTensorType<bool> {
135
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
136
+ };
137
+
138
+ template <>
139
+ struct TypeToTensorType<Float8E4M3FN_t> {
140
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN;
141
+ };
142
+ template <>
143
+ struct TypeToTensorType<Float8E4M3FNUZ_t> {
144
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ;
145
+ };
146
+ template <>
147
+ struct TypeToTensorType<Float8E5M2_t> {
148
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2;
149
+ };
150
+ template <>
151
+ struct TypeToTensorType<Float8E5M2FNUZ_t> {
152
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
153
+ };
154
+
155
+ inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept {
156
+ if (IsNaN() || rhs.IsNaN()) {
157
+ // IEEE defines that NaN is not equal to anything, including itself.
158
+ return false;
159
+ }
160
+ return val == rhs.val;
161
+ }
162
+
163
+ inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept {
164
+ if (IsNaN() || rhs.IsNaN()) {
165
+ // IEEE defines that NaN is unordered with respect to everything, including itself.
166
+ return false;
167
+ }
168
+
169
+ const bool left_is_negative = IsNegative();
170
+ if (left_is_negative != rhs.IsNegative()) {
171
+ // When the signs of left and right differ, we know that left is less than right if it is
172
+ // the negative value. The exception to this is if both values are zero, in which case IEEE
173
+ // says they should be equal, even if the signs differ.
174
+ return left_is_negative && !AreZero(*this, rhs);
175
+ }
176
+ return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
177
+ }
178
+
179
+ inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
180
+ : allocator_(allocator), p_(p), size_(size) {
181
+ }
182
+
183
+ inline MemoryAllocation::~MemoryAllocation() {
184
+ if (p_ != nullptr) {
185
+ // We do not throw out of destructor
186
+ auto ret = GetApi().AllocatorFree(allocator_, p_);
187
+ static_cast<void>(ret);
188
+ }
189
+ }
190
+
191
+ inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
192
+ *this = std::move(o);
193
+ }
194
+
195
+ inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
196
+ OrtAllocator* alloc = nullptr;
197
+ void* p = nullptr;
198
+ size_t sz = 0;
199
+
200
+ // Swap out this
201
+ std::swap(alloc, allocator_);
202
+ std::swap(p, p_);
203
+ std::swap(sz, size_);
204
+
205
+ // Swap with incoming
206
+ std::swap(allocator_, o.allocator_);
207
+ std::swap(p_, o.p_);
208
+ std::swap(size_, o.size_);
209
+
210
+ // Destroy this instance if needed
211
+ MemoryAllocation this_alloc(alloc, p, sz);
212
+ return *this;
213
+ }
214
+
215
+ namespace detail {
216
+
217
+ template <typename T>
218
+ inline void* AllocatorImpl<T>::Alloc(size_t size) {
219
+ void* out;
220
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
221
+ return out;
222
+ }
223
+
224
+ template <typename T>
225
+ inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
226
+ void* out;
227
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
228
+ MemoryAllocation result(this->p_, out, size);
229
+ return result;
230
+ }
231
+
232
+ template <typename T>
233
+ inline void AllocatorImpl<T>::Free(void* p) {
234
+ ThrowOnError(GetApi().AllocatorFree(this->p_, p));
235
+ }
236
+
237
+ template <typename T>
238
+ inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
239
+ const OrtMemoryInfo* out;
240
+ ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
241
+ return ConstMemoryInfo{out};
242
+ }
243
+
244
+ } // namespace detail
245
+
246
+ inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
247
+ ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
248
+ }
249
+
250
+ inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
251
+ ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
252
+ }
253
+
254
+ namespace detail {
255
+
256
+ template <typename T>
257
+ inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
258
+ const char* name = nullptr;
259
+ ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
260
+ return std::string(name);
261
+ }
262
+
263
+ template <typename T>
264
+ inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
265
+ OrtAllocatorType type;
266
+ ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
267
+ return type;
268
+ }
269
+
270
+ template <typename T>
271
+ inline int MemoryInfoImpl<T>::GetDeviceId() const {
272
+ int id = 0;
273
+ ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
274
+ return id;
275
+ }
276
+
277
+ template <typename T>
278
+ inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
279
+ OrtMemoryInfoDeviceType type;
280
+ GetApi().MemoryInfoGetDeviceType(this->p_, &type);
281
+ return type;
282
+ }
283
+
284
+ template <typename T>
285
+ inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
286
+ OrtMemType type;
287
+ ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
288
+ return type;
289
+ }
290
+
291
+ template <typename T>
292
+ template <typename U>
293
+ inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
294
+ int comp_result = 0;
295
+ ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
296
+ return comp_result == 0;
297
+ }
298
+
299
+ } // namespace detail
300
+
301
+ inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
302
+ OrtMemoryInfo* p;
303
+ ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
304
+ return MemoryInfo(p);
305
+ }
306
+
307
+ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
308
+ ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
309
+ }
310
+
311
+ namespace detail {
312
+ template <typename T>
313
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
314
+ AllocatorWithDefaultOptions allocator;
315
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
316
+ }
317
+
318
+ template <typename T>
319
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
320
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
321
+ }
322
+
323
+ template <typename T>
324
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
325
+ AllocatorWithDefaultOptions allocator;
326
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
327
+ }
328
+
329
+ template <typename T>
330
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
331
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
332
+ }
333
+
334
+ template <typename T>
335
+ inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
336
+ ThrowOnError(GetApi().BindInput(this->p_, name, value));
337
+ }
338
+
339
+ template <typename T>
340
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
341
+ ThrowOnError(GetApi().BindOutput(this->p_, name, value));
342
+ }
343
+
344
+ template <typename T>
345
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
346
+ ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
347
+ }
348
+
349
+ template <typename T>
350
+ inline void IoBindingImpl<T>::ClearBoundInputs() {
351
+ GetApi().ClearBoundInputs(this->p_);
352
+ }
353
+
354
+ template <typename T>
355
+ inline void IoBindingImpl<T>::ClearBoundOutputs() {
356
+ GetApi().ClearBoundOutputs(this->p_);
357
+ }
358
+
359
+ template <typename T>
360
+ inline void IoBindingImpl<T>::SynchronizeInputs() {
361
+ ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
362
+ }
363
+
364
+ template <typename T>
365
+ inline void IoBindingImpl<T>::SynchronizeOutputs() {
366
+ ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
367
+ }
368
+
369
+ namespace binding_utils {
370
+ inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
371
+ std::vector<std::string> result;
372
+ auto free_fn = detail::AllocatedFree(allocator);
373
+ using Ptr = std::unique_ptr<void, decltype(free_fn)>;
374
+
375
+ char* buffer = nullptr;
376
+ size_t* lengths = nullptr;
377
+ size_t count = 0;
378
+ ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
379
+
380
+ if (count == 0) {
381
+ return result;
382
+ }
383
+
384
+ Ptr buffer_g(buffer, free_fn);
385
+ Ptr lengths_g(lengths, free_fn);
386
+
387
+ result.reserve(count);
388
+ for (size_t i = 0; i < count; ++i) {
389
+ auto sz = *lengths;
390
+ result.emplace_back(buffer, sz);
391
+ buffer += sz;
392
+ ++lengths;
393
+ }
394
+ return result;
395
+ }
396
+
397
+ inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
398
+ std::vector<Value> result;
399
+ size_t owned = 0;
400
+ size_t output_count = 0;
401
+ // Lambda to release the buffer when no longer needed and
402
+ // make sure that we destroy all instances on exception
403
+ auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
404
+ if (buffer) {
405
+ while (owned < output_count) {
406
+ auto* p = buffer + owned++;
407
+ GetApi().ReleaseValue(*p);
408
+ }
409
+ allocator->Free(allocator, buffer);
410
+ }
411
+ };
412
+ using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
413
+
414
+ OrtValue** output_buffer = nullptr;
415
+ ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
416
+ if (output_count == 0) {
417
+ return result;
418
+ }
419
+
420
+ Ptr buffer_g(output_buffer, free_fn);
421
+
422
+ result.reserve(output_count);
423
+ for (size_t i = 0; i < output_count; ++i) {
424
+ result.emplace_back(output_buffer[i]);
425
+ ++owned;
426
+ }
427
+ return result;
428
+ }
429
+
430
+ } // namespace binding_utils
431
+ } // namespace detail
432
+
433
+ inline IoBinding::IoBinding(Session& session) {
434
+ ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
435
+ }
436
+
437
+ inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
438
+ ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
439
+ }
440
+
441
+ inline ThreadingOptions::ThreadingOptions() {
442
+ ThrowOnError(GetApi().CreateThreadingOptions(&p_));
443
+ }
444
+
445
+ inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
446
+ ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
447
+ return *this;
448
+ }
449
+
450
+ inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
451
+ ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
452
+ return *this;
453
+ }
454
+
455
+ inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
456
+ ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
457
+ return *this;
458
+ }
459
+
460
+ inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
461
+ ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
462
+ return *this;
463
+ }
464
+
465
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
466
+ ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
467
+ return *this;
468
+ }
469
+
470
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
471
+ ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
472
+ return *this;
473
+ }
474
+
475
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
476
+ ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
477
+ return *this;
478
+ }
479
+
480
+ inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
481
+ ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
482
+ if (strcmp(logid, "onnxruntime-node") == 0) {
483
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
484
+ } else {
485
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
486
+ }
487
+ }
488
+
489
+ inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
490
+ ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
491
+ if (strcmp(logid, "onnxruntime-node") == 0) {
492
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
493
+ } else {
494
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
495
+ }
496
+ }
497
+
498
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
499
+ ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
500
+ if (strcmp(logid, "onnxruntime-node") == 0) {
501
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
502
+ } else {
503
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
504
+ }
505
+ }
506
+
507
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
508
+ OrtLoggingLevel logging_level, _In_ const char* logid) {
509
+ ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
510
+ if (strcmp(logid, "onnxruntime-node") == 0) {
511
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
512
+ } else {
513
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
514
+ }
515
+ }
516
+
517
+ inline Env& Env::EnableTelemetryEvents() {
518
+ ThrowOnError(GetApi().EnableTelemetryEvents(p_));
519
+ return *this;
520
+ }
521
+
522
+ inline Env& Env::DisableTelemetryEvents() {
523
+ ThrowOnError(GetApi().DisableTelemetryEvents(p_));
524
+ return *this;
525
+ }
526
+
527
+ inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
528
+ ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
529
+ return *this;
530
+ }
531
+
532
+ inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
533
+ ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
534
+ return *this;
535
+ }
536
+
537
+ inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg) {
538
+ std::vector<const char*> keys, values;
539
+ auto num_entries = options.size();
540
+ if (num_entries > 0) {
541
+ keys.reserve(num_entries);
542
+ values.reserve(num_entries);
543
+ for (const auto& entry : options) {
544
+ keys.push_back(entry.first.c_str());
545
+ values.push_back(entry.second.c_str());
546
+ }
547
+ }
548
+ ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries));
549
+ return *this;
550
+ }
551
+
552
+ inline CustomOpDomain::CustomOpDomain(const char* domain) {
553
+ ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
554
+ }
555
+
556
+ inline void CustomOpDomain::Add(const OrtCustomOp* op) {
557
+ ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
558
+ }
559
+
560
+ inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string<ORTCHAR_T>& adapter_path,
561
+ OrtAllocator* allocator) {
562
+ OrtLoraAdapter* p;
563
+ ThrowOnError(GetApi().CreateLoraAdapter(adapter_path.c_str(), allocator, &p));
564
+ return LoraAdapter{p};
565
+ }
566
+
567
+ inline LoraAdapter LoraAdapter::CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes,
568
+ OrtAllocator* allocator) {
569
+ OrtLoraAdapter* p;
570
+ ThrowOnError(GetApi().CreateLoraAdapterFromArray(bytes, num_bytes, allocator, &p));
571
+ return LoraAdapter{p};
572
+ }
573
+
574
+ inline RunOptions::RunOptions() {
575
+ ThrowOnError(GetApi().CreateRunOptions(&p_));
576
+ }
577
+
578
+ inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
579
+ ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
580
+ return *this;
581
+ }
582
+
583
+ inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
584
+ ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
585
+ return *this;
586
+ }
587
+
588
+ inline int RunOptions::GetRunLogVerbosityLevel() const {
589
+ int out;
590
+ ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
591
+ return out;
592
+ }
593
+
594
+ inline int RunOptions::GetRunLogSeverityLevel() const {
595
+ int out;
596
+ ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
597
+ return out;
598
+ }
599
+
600
+ inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
601
+ ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
602
+ return *this;
603
+ }
604
+
605
+ inline const char* RunOptions::GetRunTag() const {
606
+ const char* out;
607
+ ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
608
+ return out;
609
+ }
610
+
611
+ inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
612
+ ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
613
+ return *this;
614
+ }
615
+
616
+ inline RunOptions& RunOptions::SetTerminate() {
617
+ ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
618
+ return *this;
619
+ }
620
+
621
+ inline RunOptions& RunOptions::UnsetTerminate() {
622
+ ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
623
+ return *this;
624
+ }
625
+
626
+ inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter) {
627
+ ThrowOnError(GetApi().RunOptionsAddActiveLoraAdapter(p_, adapter));
628
+ return *this;
629
+ }
630
+
631
+ namespace detail {
632
+
633
+ template <typename T>
634
+ inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
635
+ OrtSessionOptions* out;
636
+ ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
637
+ return SessionOptions{out};
638
+ }
639
+
640
+ template <typename T>
641
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
642
+ size_t size = 0;
643
+ // Feed nullptr for the data buffer to query the true size of the string value
644
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
645
+
646
+ std::string out;
647
+ out.resize(size);
648
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
649
+ out.resize(size - 1); // remove the terminating character '\0'
650
+
651
+ return out;
652
+ }
653
+
654
+ template <typename T>
655
+ inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
656
+ int out = 0;
657
+ Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
658
+ return static_cast<bool>(out);
659
+ }
660
+
661
+ template <typename T>
662
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
663
+ if (!this->HasConfigEntry(config_key)) {
664
+ return def;
665
+ }
666
+
667
+ return this->GetConfigEntry(config_key);
668
+ }
669
+
670
+ template <typename T>
671
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
672
+ ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
673
+ return *this;
674
+ }
675
+
676
+ template <typename T>
677
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
678
+ ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
679
+ return *this;
680
+ }
681
+
682
+ template <typename T>
683
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
684
+ ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
685
+ return *this;
686
+ }
687
+
688
+ template <typename T>
689
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetDeterministicCompute(bool value) {
690
+ ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value));
691
+ return *this;
692
+ }
693
+
694
+ template <typename T>
695
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
696
+ ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
697
+ return *this;
698
+ }
699
+
700
+ template <typename T>
701
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
702
+ ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
703
+ return *this;
704
+ }
705
+
706
+ template <typename T>
707
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
708
+ ThrowOnError(GetApi().DisableProfiling(this->p_));
709
+ return *this;
710
+ }
711
+
712
+ template <typename T>
713
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
714
+ ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
715
+ return *this;
716
+ }
717
+
718
+ template <typename T>
719
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
720
+ ThrowOnError(GetApi().EnableMemPattern(this->p_));
721
+ return *this;
722
+ }
723
+
724
+ template <typename T>
725
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
726
+ ThrowOnError(GetApi().DisableMemPattern(this->p_));
727
+ return *this;
728
+ }
729
+
730
+ template <typename T>
731
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
732
+ ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
733
+ return *this;
734
+ }
735
+
736
+ template <typename T>
737
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
738
+ ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
739
+ return *this;
740
+ }
741
+
742
+ template <typename T>
743
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
744
+ ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
745
+ return *this;
746
+ }
747
+
748
+ template <typename T>
749
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
750
+ ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
751
+ return *this;
752
+ }
753
+
754
+ template <typename T>
755
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
756
+ ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
757
+ return *this;
758
+ }
759
+
760
+ template <typename T>
761
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
762
+ ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
763
+ return *this;
764
+ }
765
+
766
+ template <typename T>
767
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
768
+ ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
769
+ return *this;
770
+ }
771
+
772
+ template <typename T>
773
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
774
+ ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
775
+ return *this;
776
+ }
777
+
778
+ template <typename T>
779
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
780
+ ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
781
+ return *this;
782
+ }
783
+
784
+ template <typename T>
785
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
786
+ const std::vector<Value>& ort_values) {
787
+ const size_t inputs_num = names.size();
788
+ if (inputs_num != ort_values.size()) {
789
+ ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
790
+ }
791
+ std::vector<const char*> names_ptr;
792
+ std::vector<const OrtValue*> ort_values_ptrs;
793
+ names_ptr.reserve(inputs_num);
794
+ ort_values_ptrs.reserve(inputs_num);
795
+ for (size_t i = 0; i < inputs_num; ++i) {
796
+ names_ptr.push_back(names[i].c_str());
797
+ ort_values_ptrs.push_back(ort_values[i]);
798
+ }
799
+ ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
800
+ return *this;
801
+ }
802
+
803
+ template <typename T>
804
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializersFromFilesInMemory(const std::vector<std::basic_string<ORTCHAR_T>>& file_names,
805
+ const std::vector<char*>& buffer_array,
806
+ const std::vector<size_t>& file_lengths) {
807
+ const size_t inputs_num = file_names.size();
808
+ if (inputs_num != buffer_array.size()) {
809
+ ORT_CXX_API_THROW("Expecting names and buffer_array to have the same length", ORT_INVALID_ARGUMENT);
810
+ }
811
+ if (inputs_num != file_lengths.size()) {
812
+ ORT_CXX_API_THROW("Expecting names and file_lengths to have the same length", ORT_INVALID_ARGUMENT);
813
+ }
814
+ std::vector<const ORTCHAR_T*> names_ptr;
815
+ names_ptr.reserve(inputs_num);
816
+ for (size_t i = 0; i < inputs_num; ++i) {
817
+ names_ptr.push_back(file_names[i].c_str());
818
+ }
819
+ ThrowOnError(GetApi().AddExternalInitializersFromFilesInMemory(this->p_, names_ptr.data(), buffer_array.data(),
820
+ file_lengths.data(), inputs_num));
821
+ return *this;
822
+ }
823
+
824
+ template <typename T>
825
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
826
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
827
+ return *this;
828
+ }
829
+
830
+ template <typename T>
831
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
832
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
833
+ return *this;
834
+ }
835
+
836
+ template <typename T>
837
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
838
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
839
+ return *this;
840
+ }
841
+
842
+ template <typename T>
843
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
844
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
845
+ return *this;
846
+ }
847
+
848
+ template <typename T>
849
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
850
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
851
+ return *this;
852
+ }
853
+
854
+ template <typename T>
855
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
856
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
857
+ return *this;
858
+ }
859
+
860
+ template <typename T>
861
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
862
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
863
+ return *this;
864
+ }
865
+
866
+ template <typename T>
867
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) {
868
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
869
+ return *this;
870
+ }
871
+
872
+ template <typename T>
873
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
874
+ const std::string& provider_name,
875
+ const std::unordered_map<std::string, std::string>& provider_options) {
876
+ auto num_entries = provider_options.size();
877
+ std::vector<const char*> keys, values;
878
+ if (num_entries > 0) {
879
+ keys.reserve(num_entries);
880
+ values.reserve(num_entries);
881
+
882
+ for (const auto& entry : provider_options) {
883
+ keys.push_back(entry.first.c_str());
884
+ values.push_back(entry.second.c_str());
885
+ }
886
+ }
887
+
888
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
889
+ keys.data(), values.data(), num_entries));
890
+
891
+ return *this;
892
+ }
893
+
894
+ template <typename T>
895
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
896
+ ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
897
+ return *this;
898
+ }
899
+
900
+ template <typename T>
901
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
902
+ ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
903
+ return *this;
904
+ }
905
+
906
+ template <typename T>
907
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
908
+ ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
909
+ return *this;
910
+ }
911
+
912
+ template <typename T>
913
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
914
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
915
+ return *this;
916
+ }
917
+
918
+ template <typename T>
919
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options) {
920
+ auto num_entries = provider_options.size();
921
+ std::vector<const char*> keys, values;
922
+ if (num_entries > 0) {
923
+ keys.reserve(num_entries);
924
+ values.reserve(num_entries);
925
+
926
+ for (const auto& entry : provider_options) {
927
+ keys.push_back(entry.first.c_str());
928
+ values.push_back(entry.second.c_str());
929
+ }
930
+ }
931
+
932
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO_V2(this->p_,
933
+ keys.data(), values.data(), num_entries));
934
+
935
+ return *this;
936
+ }
937
+
938
+ template <typename T>
939
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options) {
940
+ auto num_entries = provider_options.size();
941
+ std::vector<const char*> keys, values;
942
+ if (num_entries > 0) {
943
+ keys.reserve(num_entries);
944
+ values.reserve(num_entries);
945
+
946
+ for (const auto& entry : provider_options) {
947
+ keys.push_back(entry.first.c_str());
948
+ values.push_back(entry.second.c_str());
949
+ }
950
+ }
951
+
952
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries));
953
+
954
+ return *this;
955
+ }
956
+
957
+ template <typename T>
958
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
959
+ const CustomOpConfigs& custom_op_configs) {
960
+ // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
961
+ // the custom op library.
962
+ for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
963
+ AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
964
+ }
965
+
966
+ ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
967
+ return *this;
968
+ }
969
+
970
+ template <typename T>
971
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
972
+ ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
973
+ return *this;
974
+ }
975
+
976
+ /// Session
977
+ template <typename T>
978
+ inline size_t ConstSessionImpl<T>::GetInputCount() const {
979
+ size_t out;
980
+ ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
981
+ return out;
982
+ }
983
+
984
+ template <typename T>
985
+ inline size_t ConstSessionImpl<T>::GetOutputCount() const {
986
+ size_t out;
987
+ ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
988
+ return out;
989
+ }
990
+
991
+ template <typename T>
992
+ inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
993
+ size_t out;
994
+ ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
995
+ return out;
996
+ }
997
+
998
+ template <typename T>
999
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
1000
+ char* out;
1001
+ ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
1002
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1003
+ }
1004
+
1005
+ template <typename T>
1006
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
1007
+ char* out;
1008
+ ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
1009
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1010
+ }
1011
+
1012
+ template <typename T>
1013
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
1014
+ char* out;
1015
+ ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
1016
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1017
+ }
1018
+
1019
+ template <typename T>
1020
+ inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
1021
+ uint64_t out;
1022
+ ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
1023
+ return out;
1024
+ }
1025
+
1026
+ template <typename T>
1027
+ inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
1028
+ OrtModelMetadata* out;
1029
+ ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
1030
+ return ModelMetadata{out};
1031
+ }
1032
+
1033
+ template <typename T>
1034
+ inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
1035
+ OrtTypeInfo* out;
1036
+ ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
1037
+ return TypeInfo{out};
1038
+ }
1039
+
1040
+ template <typename T>
1041
+ inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
1042
+ OrtTypeInfo* out;
1043
+ ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
1044
+ return TypeInfo{out};
1045
+ }
1046
+
1047
+ template <typename T>
1048
+ inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
1049
+ OrtTypeInfo* out;
1050
+ ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
1051
+ return TypeInfo{out};
1052
+ }
1053
+
1054
+ template <typename T>
1055
+ inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1056
+ const char* const* output_names, size_t output_count) {
1057
+ std::vector<Value> output_values;
1058
+ output_values.reserve(output_count);
1059
+ for (size_t i = 0; i < output_count; i++)
1060
+ output_values.emplace_back(nullptr);
1061
+ Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
1062
+ return output_values;
1063
+ }
1064
+
1065
+ template <typename T>
1066
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1067
+ const char* const* output_names, Value* output_values, size_t output_count) {
1068
+ static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1069
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1070
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1071
+ ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
1072
+ }
1073
+
1074
+ template <typename T>
1075
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
1076
+ ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
1077
+ }
1078
+
1079
+ template <typename T>
1080
+ inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1081
+ const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
1082
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1083
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1084
+ ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
1085
+ ort_input_values, input_count, output_names, output_count,
1086
+ ort_output_values, callback, user_data));
1087
+ }
1088
+
1089
+ template <typename T>
1090
+ inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
1091
+ char* out = nullptr;
1092
+ ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
1093
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1094
+ }
1095
+
1096
+ template <typename T>
1097
+ inline void SessionImpl<T>::SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len) {
1098
+ ThrowOnError(GetApi().SetEpDynamicOptions(this->p_, keys, values, kv_len));
1099
+ }
1100
+
1101
+ } // namespace detail
1102
+
1103
+ inline SessionOptions::SessionOptions() {
1104
+ ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
1105
+ }
1106
+
1107
+ /// CustomOpConfigs
1108
+ inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
1109
+ std::string config_key = "custom_op.";
1110
+
1111
+ config_key += custom_op_name;
1112
+ config_key += ".";
1113
+ config_key += config;
1114
+
1115
+ return config_key;
1116
+ }
1117
+
1118
+ inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
1119
+ const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
1120
+ flat_configs_[full_flat_key] = config_value;
1121
+ return *this;
1122
+ }
1123
+
1124
+ inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
1125
+ return flat_configs_;
1126
+ }
1127
+
1128
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
1129
+ ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
1130
+ }
1131
+
1132
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1133
+ OrtPrepackedWeightsContainer* prepacked_weights_container) {
1134
+ ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
1135
+ }
1136
+
1137
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
1138
+ ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
1139
+ }
1140
+
1141
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
1142
+ const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
1143
+ ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
1144
+ prepacked_weights_container, &this->p_));
1145
+ }
1146
+
1147
+ inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
1148
+ char* out;
1149
+ ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
1150
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1151
+ }
1152
+
1153
+ inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
1154
+ char* out;
1155
+ ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
1156
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1157
+ }
1158
+
1159
+ inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
1160
+ char* out;
1161
+ ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
1162
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1163
+ }
1164
+
1165
+ inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
1166
+ char* out;
1167
+ ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
1168
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1169
+ }
1170
+
1171
+ inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
1172
+ char* out;
1173
+ ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
1174
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1175
+ }
1176
+
1177
+ inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
1178
+ char* out;
1179
+ ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
1180
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1181
+ }
1182
+
1183
+ inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
1184
+ auto deletor = detail::AllocatedFree(allocator);
1185
+ std::vector<AllocatedStringPtr> result;
1186
+
1187
+ char** out = nullptr;
1188
+ int64_t num_keys = 0;
1189
+ ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
1190
+ if (num_keys <= 0) {
1191
+ return result;
1192
+ }
1193
+
1194
+ // array of pointers will be freed
1195
+ std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1196
+ // reserve may throw
1197
+ auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1198
+ std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1199
+ result.reserve(static_cast<size_t>(num_keys));
1200
+ strings_guard.release();
1201
+ for (int64_t i = 0; i < num_keys; ++i) {
1202
+ result.push_back(AllocatedStringPtr(out[i], deletor));
1203
+ }
1204
+
1205
+ return result;
1206
+ }
1207
+
1208
+ inline int64_t ModelMetadata::GetVersion() const {
1209
+ int64_t out;
1210
+ ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
1211
+ return out;
1212
+ }
1213
+
1214
+ namespace detail {
1215
+
1216
+ template <typename T>
1217
+ inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
1218
+ ONNXTensorElementDataType out;
1219
+ ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
1220
+ return out;
1221
+ }
1222
+
1223
+ template <typename T>
1224
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
1225
+ size_t out;
1226
+ ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
1227
+ return static_cast<size_t>(out);
1228
+ }
1229
+
1230
+ template <typename T>
1231
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
1232
+ size_t out;
1233
+ ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
1234
+ return out;
1235
+ }
1236
+
1237
+ template <typename T>
1238
+ inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
1239
+ ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
1240
+ }
1241
+
1242
+ template <typename T>
1243
+ inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
1244
+ ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1245
+ }
1246
+
1247
+ template <typename T>
1248
+ inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1249
+ std::vector<int64_t> out(GetDimensionsCount(), 0);
1250
+ ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1251
+ return out;
1252
+ }
1253
+
1254
+ template <typename T>
1255
+ inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
1256
+ const OrtTensorTypeAndShapeInfo* out;
1257
+ ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
1258
+ return ConstTensorTypeAndShapeInfo{out};
1259
+ }
1260
+
1261
+ template <typename T>
1262
+ inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
1263
+ const OrtSequenceTypeInfo* out;
1264
+ ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
1265
+ return ConstSequenceTypeInfo{out};
1266
+ }
1267
+
1268
+ template <typename T>
1269
+ inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
1270
+ const OrtMapTypeInfo* out;
1271
+ ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
1272
+ return ConstMapTypeInfo{out};
1273
+ }
1274
+
1275
+ template <typename T>
1276
+ inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
1277
+ ONNXType out;
1278
+ ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
1279
+ return out;
1280
+ }
1281
+
1282
+ template <typename T>
1283
+ inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
1284
+ OrtTypeInfo* output;
1285
+ ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
1286
+ return TypeInfo{output};
1287
+ }
1288
+
1289
+ template <typename T>
1290
+ inline TypeInfo OptionalTypeInfoImpl<T>::GetOptionalElementType() const {
1291
+ OrtTypeInfo* info;
1292
+ ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info));
1293
+ return TypeInfo{info};
1294
+ }
1295
+
1296
+ template <typename T>
1297
+ inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
1298
+ ONNXTensorElementDataType out;
1299
+ ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
1300
+ return out;
1301
+ }
1302
+
1303
+ template <typename T>
1304
+ inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
1305
+ OrtTypeInfo* output;
1306
+ ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
1307
+ return TypeInfo{output};
1308
+ }
1309
+
1310
+ template <typename T>
1311
+ inline ConstOptionalTypeInfo TypeInfoImpl<T>::GetOptionalTypeInfo() const {
1312
+ const OrtOptionalTypeInfo* info;
1313
+ ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info));
1314
+ return ConstOptionalTypeInfo{info};
1315
+ }
1316
+
1317
+ } // namespace detail
1318
+
1319
+ namespace detail {
1320
+
1321
+ template <typename T>
1322
+ template <typename R>
1323
+ inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
1324
+ ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
1325
+ }
1326
+
1327
+ template <typename T>
1328
+ inline bool ConstValueImpl<T>::IsTensor() const {
1329
+ int out;
1330
+ ThrowOnError(GetApi().IsTensor(this->p_, &out));
1331
+ return out != 0;
1332
+ }
1333
+
1334
+ template <typename T>
1335
+ inline bool ConstValueImpl<T>::HasValue() const {
1336
+ int out;
1337
+ ThrowOnError(GetApi().HasValue(this->p_, &out));
1338
+ return out != 0;
1339
+ }
1340
+
1341
+ template <typename T>
1342
+ inline size_t ConstValueImpl<T>::GetCount() const {
1343
+ size_t out;
1344
+ ThrowOnError(GetApi().GetValueCount(this->p_, &out));
1345
+ return out;
1346
+ }
1347
+
1348
+ template <typename T>
1349
+ inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
1350
+ OrtValue* out;
1351
+ ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
1352
+ return Value{out};
1353
+ }
1354
+
1355
+ template <typename T>
1356
+ inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
1357
+ size_t out;
1358
+ ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
1359
+ return out;
1360
+ }
1361
+
1362
+ template <typename T>
1363
+ inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
1364
+ size_t out;
1365
+ ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
1366
+ return out;
1367
+ }
1368
+
1369
+ template <typename T>
1370
+ template <typename R>
1371
+ inline const R* ConstValueImpl<T>::GetTensorData() const {
1372
+ R* out;
1373
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
1374
+ return out;
1375
+ }
1376
+
1377
+ template <typename T>
1378
+ inline const void* ConstValueImpl<T>::GetTensorRawData() const {
1379
+ void* out;
1380
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
1381
+ return out;
1382
+ }
1383
+
1384
+ template <typename T>
1385
+ inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
1386
+ OrtTypeInfo* output;
1387
+ ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
1388
+ return TypeInfo{output};
1389
+ }
1390
+
1391
+ template <typename T>
1392
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
1393
+ OrtTensorTypeAndShapeInfo* output;
1394
+ ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
1395
+ return TensorTypeAndShapeInfo{output};
1396
+ }
1397
+
1398
+ template <typename T>
1399
+ inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
1400
+ const OrtMemoryInfo* mem_info;
1401
+ ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
1402
+ return ConstMemoryInfo(mem_info);
1403
+ }
1404
+
1405
+ template <typename T>
1406
+ inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
1407
+ ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1408
+ }
1409
+
1410
+ template <typename T>
1411
+ inline std::string ConstValueImpl<T>::GetStringTensorElement(size_t element_index) const {
1412
+ size_t buffer_length;
1413
+ ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length));
1414
+
1415
+ std::string s;
1416
+ s.resize(buffer_length);
1417
+ ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0]));
1418
+ return s;
1419
+ }
1420
+
1421
+ template <typename T>
1422
+ inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
1423
+ ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1424
+ }
1425
+
1426
+ #if !defined(DISABLE_SPARSE_TENSORS)
1427
+ template <typename T>
1428
+ inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
1429
+ OrtSparseFormat format;
1430
+ ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
1431
+ return format;
1432
+ }
1433
+
1434
+ template <typename T>
1435
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
1436
+ OrtTensorTypeAndShapeInfo* output;
1437
+ ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
1438
+ return TensorTypeAndShapeInfo{output};
1439
+ }
1440
+
1441
+ template <typename T>
1442
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
1443
+ OrtTensorTypeAndShapeInfo* output;
1444
+ ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1445
+ return TensorTypeAndShapeInfo{output};
1446
+ }
1447
+
1448
+ template <typename T>
1449
+ template <typename R>
1450
+ inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
1451
+ const void* out;
1452
+ ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1453
+ return reinterpret_cast<const R*>(out);
1454
+ }
1455
+
1456
+ template <typename T>
1457
+ inline bool ConstValueImpl<T>::IsSparseTensor() const {
1458
+ int out;
1459
+ ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
1460
+ return out != 0;
1461
+ }
1462
+
1463
+ template <typename T>
1464
+ template <typename R>
1465
+ inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
1466
+ const void* out;
1467
+ ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
1468
+ return reinterpret_cast<const R*>(out);
1469
+ }
1470
+
1471
+ #endif
1472
+
1473
+ template <typename T>
1474
+ void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
1475
+ ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
1476
+ }
1477
+
1478
+ template <typename T>
1479
+ void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
1480
+ ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
1481
+ }
1482
+
1483
+ template <typename T>
1484
+ inline char* ValueImpl<T>::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) {
1485
+ char* result;
1486
+ ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result));
1487
+ return result;
1488
+ }
1489
+
1490
+ template <typename T>
1491
+ void* ValueImpl<T>::GetTensorMutableRawData() {
1492
+ void* out;
1493
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
1494
+ return out;
1495
+ }
1496
+
1497
+ template <typename T>
1498
+ template <typename R>
1499
+ R* ValueImpl<T>::GetTensorMutableData() {
1500
+ R* out;
1501
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
1502
+ return out;
1503
+ }
1504
+
1505
+ template <typename T>
1506
+ template <typename R>
1507
+ R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
1508
+ static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
1509
+ R* out;
1510
+ ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
1511
+ return *out;
1512
+ }
1513
+
1514
+ #if !defined(DISABLE_SPARSE_TENSORS)
1515
+ template <typename T>
1516
+ void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
1517
+ ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
1518
+ }
1519
+
1520
+ template <typename T>
1521
+ void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
1522
+ ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1523
+ }
1524
+
1525
+ template <typename T>
1526
+ void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
1527
+ ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
1528
+ }
1529
+
1530
+ template <typename T>
1531
+ void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
1532
+ const int64_t* indices_data, size_t indices_num) {
1533
+ ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
1534
+ values_param.values_shape_len, values_param.data.p_data,
1535
+ indices_data, indices_num));
1536
+ }
1537
+
1538
+ template <typename T>
1539
+ void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1540
+ const OrtSparseValuesParam& values,
1541
+ const int64_t* inner_indices_data, size_t inner_indices_num,
1542
+ const int64_t* outer_indices_data, size_t outer_indices_num) {
1543
+ ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1544
+ inner_indices_data, inner_indices_num,
1545
+ outer_indices_data, outer_indices_num));
1546
+ }
1547
+
1548
+ template <typename T>
1549
+ void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1550
+ const OrtSparseValuesParam& values,
1551
+ const Shape& indices_shape,
1552
+ const int32_t* indices_data) {
1553
+ ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1554
+ indices_shape.shape, indices_shape.shape_len,
1555
+ indices_data));
1556
+ }
1557
+
1558
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1559
+
1560
+ } // namespace detail
1561
+
1562
+ template <typename T>
1563
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
1564
+ return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1565
+ }
1566
+
1567
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1568
+ ONNXTensorElementDataType type) {
1569
+ OrtValue* out;
1570
+ ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1571
+ return Value{out};
1572
+ }
1573
+
1574
+ template <typename T>
1575
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1576
+ return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1577
+ }
1578
+
1579
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
1580
+ OrtValue* out;
1581
+ ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1582
+ return Value{out};
1583
+ }
1584
+
1585
+ #if !defined(DISABLE_SPARSE_TENSORS)
1586
+
1587
+ template <typename T>
1588
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1589
+ const Shape& values_shape) {
1590
+ return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
1591
+ }
1592
+
1593
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1594
+ const Shape& values_shape, ONNXTensorElementDataType type) {
1595
+ OrtValue* out;
1596
+ ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1597
+ values_shape.shape, values_shape.shape_len, type, &out));
1598
+ return Value{out};
1599
+ }
1600
+
1601
+ template <typename T>
1602
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
1603
+ return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
1604
+ }
1605
+
1606
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
1607
+ ONNXTensorElementDataType type) {
1608
+ OrtValue* out;
1609
+ ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
1610
+ return Value{out};
1611
+ }
1612
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1613
+
1614
+ inline Value Value::CreateMap(const Value& keys, const Value& values) {
1615
+ OrtValue* out;
1616
+ const OrtValue* inputs[2] = {keys, values};
1617
+ ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
1618
+ return Value{out};
1619
+ }
1620
+
1621
+ inline Value Value::CreateSequence(const std::vector<Value>& values) {
1622
+ OrtValue* out;
1623
+ std::vector<const OrtValue*> values_ort{values.data(), values.data() + values.size()};
1624
+ ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1625
+ return Value{out};
1626
+ }
1627
+
1628
+ template <typename T>
1629
+ inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
1630
+ OrtValue* out;
1631
+ ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
1632
+ return Value{out};
1633
+ }
1634
+
1635
+ //
1636
+ // Custom OP Inlines
1637
+ //
1638
+ inline Logger::Logger(const OrtLogger* logger) : logger_(logger) {
1639
+ Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_));
1640
+ }
1641
+
1642
+ inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept {
1643
+ return cached_severity_level_;
1644
+ }
1645
+
1646
+ inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
1647
+ const char* func_name, const char* message) const noexcept {
1648
+ OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number,
1649
+ func_name);
1650
+ return Status{status};
1651
+ }
1652
+
1653
+ // Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security)
1654
+ // for gcc and clang. The alternative is to use actual C-style variadic parameters and apply
1655
+ // __attribute__(format(printf...)), which does not work with variadic templates.
1656
+ #if defined(__GNUC__)
1657
+ #pragma GCC diagnostic push
1658
+ #pragma GCC diagnostic ignored "-Wformat-nonliteral"
1659
+ #pragma GCC diagnostic ignored "-Wformat-security"
1660
+ #elif defined(__clang__)
1661
+ #pragma clang diagnostic push
1662
+ #pragma clang diagnostic ignored "-Wformat-nonliteral"
1663
+ #pragma clang diagnostic ignored "-Wformat-security"
1664
+ #endif
1665
+ template <typename... Args>
1666
+ inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
1667
+ int line_number, const char* func_name, const char* format,
1668
+ Args&&... args) const noexcept {
1669
+ int msg_len = std::snprintf(nullptr, 0U, format, std::forward<Args>(args)...);
1670
+
1671
+ if (msg_len < 0) { // Formatting error
1672
+ return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL);
1673
+ }
1674
+
1675
+ OrtStatus* status = nullptr;
1676
+ const size_t buffer_size = static_cast<size_t>(msg_len) + 1U;
1677
+
1678
+ constexpr size_t kStackBufferSize = 1024;
1679
+
1680
+ if (buffer_size < kStackBufferSize) {
1681
+ char buffer[kStackBufferSize];
1682
+ snprintf(buffer, kStackBufferSize, format, std::forward<Args>(args)...);
1683
+ status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name);
1684
+ } else {
1685
+ // std::make_unique is only supported starting at C++14.
1686
+ #if (__cplusplus >= 201402L) || (_MSC_VER >= 1900)
1687
+ auto buffer = std::make_unique<char[]>(buffer_size);
1688
+ #else
1689
+ std::unique_ptr<char[]> buffer(new char[buffer_size]);
1690
+ #endif
1691
+ std::snprintf(buffer.get(), buffer_size, format, std::forward<Args>(args)...);
1692
+ status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name);
1693
+ }
1694
+
1695
+ return Status{status};
1696
+ }
1697
+ // Re-enable -Wformat-nonliteral and -Wformat-security
1698
+ #if defined(__GNUC__)
1699
+ #pragma GCC diagnostic pop
1700
+ #elif defined(__clang__)
1701
+ #pragma clang diagnostic pop
1702
+ #endif
1703
+
1704
+ inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
1705
+ }
1706
+
1707
+ inline size_t KernelContext::GetInputCount() const {
1708
+ size_t out = 0;
1709
+ Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
1710
+ return out;
1711
+ }
1712
+
1713
+ inline size_t KernelContext::GetOutputCount() const {
1714
+ size_t out = 0;
1715
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
1716
+ return out;
1717
+ }
1718
+
1719
+ inline ConstValue KernelContext::GetInput(size_t index) const {
1720
+ const OrtValue* out = nullptr;
1721
+ Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
1722
+ return ConstValue{out};
1723
+ }
1724
+
1725
+ inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
1726
+ OrtValue* out = nullptr;
1727
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
1728
+ return UnownedValue(out);
1729
+ }
1730
+
1731
+ inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
1732
+ OrtValue* out = nullptr;
1733
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
1734
+ return UnownedValue(out);
1735
+ }
1736
+
1737
+ inline void* KernelContext::GetGPUComputeStream() const {
1738
+ void* out = nullptr;
1739
+ Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
1740
+ return out;
1741
+ }
1742
+
1743
+ inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const {
1744
+ OrtAllocator* out = nullptr;
1745
+ Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out));
1746
+ return out;
1747
+ }
1748
+
1749
+ inline Logger KernelContext::GetLogger() const {
1750
+ const OrtLogger* out = nullptr;
1751
+ ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out));
1752
+ return Logger{out};
1753
+ }
1754
+
1755
+ inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const {
1756
+ ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data));
1757
+ }
1758
+
1759
+ inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
1760
+ Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
1761
+ }
1762
+
1763
+ namespace detail {
1764
+ template <typename T>
1765
+ inline KernelInfo KernelInfoImpl<T>::Copy() const {
1766
+ OrtKernelInfo* info_copy = nullptr;
1767
+ Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
1768
+ return KernelInfo{info_copy};
1769
+ }
1770
+
1771
+ template <typename T>
1772
+ inline size_t KernelInfoImpl<T>::GetInputCount() const {
1773
+ size_t out = 0;
1774
+ ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
1775
+ return out;
1776
+ }
1777
+
1778
+ template <typename T>
1779
+ inline size_t KernelInfoImpl<T>::GetOutputCount() const {
1780
+ size_t out = 0;
1781
+ ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
1782
+ return out;
1783
+ }
1784
+
1785
+ template <typename T>
1786
+ inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
1787
+ size_t size = 0;
1788
+
1789
+ // Feed nullptr for the data buffer to query the true size of the string value
1790
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
1791
+
1792
+ std::string out;
1793
+ out.resize(size);
1794
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
1795
+ out.resize(size - 1); // remove the terminating character '\0'
1796
+
1797
+ return out;
1798
+ }
1799
+
1800
+ template <typename T>
1801
+ inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
1802
+ size_t size = 0;
1803
+
1804
+ // Feed nullptr for the data buffer to query the true size of the string value
1805
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
1806
+
1807
+ std::string out;
1808
+ out.resize(size);
1809
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
1810
+ out.resize(size - 1); // remove the terminating character '\0'
1811
+
1812
+ return out;
1813
+ }
1814
+
1815
+ template <typename T>
1816
+ inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
1817
+ OrtTypeInfo* out = nullptr;
1818
+ ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
1819
+ return TypeInfo{out};
1820
+ }
1821
+
1822
+ template <typename T>
1823
+ inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
1824
+ OrtTypeInfo* out = nullptr;
1825
+ ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
1826
+ return TypeInfo{out};
1827
+ }
1828
+
1829
+ template <typename T>
1830
+ inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
1831
+ OrtValue* out = nullptr;
1832
+ ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1833
+ return Value{out};
1834
+ }
1835
+
1836
+ template <typename T>
1837
+ inline ConstValue KernelInfoImpl<T>::GetTensorConstantInput(size_t index, int* is_constant) const {
1838
+ const OrtValue* out = nullptr;
1839
+ ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out));
1840
+ return ConstValue{out};
1841
+ }
1842
+
1843
+ template <typename T>
1844
+ inline std::string KernelInfoImpl<T>::GetNodeName() const {
1845
+ size_t size = 0;
1846
+
1847
+ // Feed nullptr for the data buffer to query the true size of the string value
1848
+ Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size));
1849
+
1850
+ std::string out;
1851
+ out.resize(size);
1852
+ Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size));
1853
+ out.resize(size - 1); // remove the terminating character '\0'
1854
+
1855
+ return out;
1856
+ }
1857
+
1858
+ template <typename T>
1859
+ inline Logger KernelInfoImpl<T>::GetLogger() const {
1860
+ const OrtLogger* out = nullptr;
1861
+ ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out));
1862
+ return Logger{out};
1863
+ }
1864
+
1865
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
1866
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
1867
+ }
1868
+
1869
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
1870
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
1871
+ }
1872
+
1873
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
1874
+ size_t size = 0;
1875
+ // Feed nullptr for the data buffer to query the true size of the string attribute
1876
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
1877
+
1878
+ std::string out;
1879
+ out.resize(size);
1880
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
1881
+ out.resize(size - 1); // remove the terminating character '\0'
1882
+ out.swap(result);
1883
+ }
1884
+
1885
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
1886
+ size_t size = 0;
1887
+ // Feed nullptr for the data buffer to query the true size of the attribute
1888
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
1889
+
1890
+ std::vector<float> out;
1891
+ out.resize(size);
1892
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
1893
+ out.swap(result);
1894
+ }
1895
+
1896
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
1897
+ size_t size = 0;
1898
+
1899
+ // Feed nullptr for the data buffer to query the true size of the attribute
1900
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
1901
+
1902
+ std::vector<int64_t> out;
1903
+ out.resize(size);
1904
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
1905
+ out.swap(result);
1906
+ }
1907
+ } // namespace detail
1908
+
1909
+ inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
1910
+
1911
+ inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1912
+
1913
+ inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
1914
+ const char** type_constraint_names,
1915
+ const ONNXTensorElementDataType* type_constraint_values,
1916
+ size_t type_constraint_count,
1917
+ const OpAttr* attr_values, size_t attr_count,
1918
+ size_t input_count, size_t output_count) {
1919
+ static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
1920
+ "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1921
+ auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
1922
+ OrtOp* op;
1923
+ Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1924
+ static_cast<int>(type_constraint_count),
1925
+ attr_input_values,
1926
+ static_cast<int>(attr_count),
1927
+ static_cast<int>(input_count),
1928
+ static_cast<int>(output_count), &op));
1929
+ return Op{op};
1930
+ }
1931
+
1932
+ inline void Op::Invoke(const OrtKernelContext* context,
1933
+ const Value* input_values,
1934
+ size_t input_count,
1935
+ Value* output_values,
1936
+ size_t output_count) {
1937
+ static_assert(sizeof(Value) == sizeof(OrtValue*),
1938
+ "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1939
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1940
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1941
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
1942
+ ort_output_values, static_cast<int>(output_count)));
1943
+ }
1944
+
1945
+ inline void Op::Invoke(const OrtKernelContext* context,
1946
+ const OrtValue* const* input_values,
1947
+ size_t input_count,
1948
+ OrtValue* const* output_values,
1949
+ size_t output_count) {
1950
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
1951
+ output_values, static_cast<int>(output_count)));
1952
+ }
1953
+
1954
+ inline std::string GetVersionString() {
1955
+ return OrtGetApiBase()->GetVersionString();
1956
+ }
1957
+
1958
+ inline std::string GetBuildInfoString() {
1959
+ return GetApi().GetBuildInfoString();
1960
+ }
1961
+
1962
+ inline std::vector<std::string> GetAvailableProviders() {
1963
+ char** providers;
1964
+ int len;
1965
+
1966
+ auto release_fn = [&len](char** providers) {
1967
+ // This should always return nullptr.
1968
+ ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
1969
+ };
1970
+
1971
+ ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
1972
+ std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
1973
+ std::vector<std::string> available_providers;
1974
+ available_providers.reserve(static_cast<size_t>(len));
1975
+ for (int i = 0; i < len; ++i) {
1976
+ available_providers.emplace_back(providers[i]);
1977
+ }
1978
+ return available_providers;
1979
+ }
1980
+
1981
+ template <typename TOp, typename TKernel, bool WithStatus>
1982
+ void CustomOpBase<TOp, TKernel, WithStatus>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
1983
+ ConstSessionOptions options) const {
1984
+ const TOp* derived = static_cast<const TOp*>(this);
1985
+ std::vector<std::string> keys = derived->GetSessionConfigKeys();
1986
+
1987
+ out.reserve(keys.size());
1988
+
1989
+ std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
1990
+ const size_t prefix_size = config_entry_key.length();
1991
+
1992
+ for (const auto& key : keys) {
1993
+ config_entry_key.resize(prefix_size);
1994
+ config_entry_key.append(key);
1995
+ out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
1996
+ }
1997
+ }
1998
+
1999
+ inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api,
2000
+ OrtShapeInferContext* ctx) : ort_api_(ort_api), ctx_(ctx) {
2001
+ size_t input_count = 0;
2002
+ Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputCount(ctx_, &input_count));
2003
+ for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
2004
+ OrtTensorTypeAndShapeInfo* info{};
2005
+ Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info));
2006
+ TensorTypeAndShapeInfo type_shape_info(info);
2007
+ auto integer_shape = type_shape_info.GetShape();
2008
+ std::vector<const char*> symbolic_shape(integer_shape.size(), {});
2009
+ if (!integer_shape.empty()) {
2010
+ type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size());
2011
+ }
2012
+ Shape shape;
2013
+ for (size_t ith = 0; ith < integer_shape.size(); ++ith) {
2014
+ if (symbolic_shape[ith] && std::string{symbolic_shape[ith]}.size() > 0) {
2015
+ shape.emplace_back(symbolic_shape[ith]);
2016
+ } else {
2017
+ shape.emplace_back(integer_shape[ith]);
2018
+ }
2019
+ }
2020
+ input_shapes_.push_back(std::move(shape));
2021
+ type_shape_info.release();
2022
+ }
2023
+ }
2024
+
2025
+ inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type) {
2026
+ OrtTensorTypeAndShapeInfo* info = {};
2027
+ ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info));
2028
+ ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetTensorElementType(info, type));
2029
+
2030
+ using InfoPtr = std::unique_ptr<OrtTensorTypeAndShapeInfo, std::function<void(OrtTensorTypeAndShapeInfo*)>>;
2031
+
2032
+ InfoPtr info_ptr(info, [this](OrtTensorTypeAndShapeInfo* obj) {
2033
+ ort_api_->ReleaseTensorTypeAndShapeInfo(obj);
2034
+ });
2035
+
2036
+ std::vector<int64_t> integer_dims;
2037
+ std::vector<const char*> symbolic_dims;
2038
+
2039
+ for (const auto dim : shape) {
2040
+ if (dim.IsInt()) {
2041
+ integer_dims.push_back(dim.AsInt());
2042
+ symbolic_dims.push_back("");
2043
+ } else {
2044
+ if (!dim.AsSym() || std::string{dim.AsSym()}.empty()) {
2045
+ ORT_CXX_API_THROW("Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT);
2046
+ }
2047
+ integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM);
2048
+ symbolic_dims.push_back(dim.AsSym());
2049
+ }
2050
+ }
2051
+
2052
+ ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size()));
2053
+ ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size()));
2054
+ ORT_CXX_RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info));
2055
+ return Status{nullptr};
2056
+ }
2057
+
2058
+ inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) {
2059
+ const auto* attr = GetAttrHdl(attr_name);
2060
+ int64_t i = {};
2061
+ size_t out = {};
2062
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out));
2063
+ return i;
2064
+ }
2065
+
2066
+ inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) {
2067
+ const auto* attr = GetAttrHdl(attr_name);
2068
+ int64_t i = {};
2069
+ size_t out = {};
2070
+ // first call to get the bytes needed
2071
+ // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure.
2072
+ // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success).
2073
+ // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}.
2074
+ auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out);
2075
+ if (status) {
2076
+ size_t num_i = out / sizeof(int64_t);
2077
+ ShapeInferContext::Ints ints(num_i, 0);
2078
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out));
2079
+ return ints;
2080
+ } else {
2081
+ if (out == 0u) {
2082
+ return {};
2083
+ }
2084
+ return {i};
2085
+ }
2086
+ }
2087
+
2088
+ inline float ShapeInferContext::GetAttrFloat(const char* attr_name) {
2089
+ const auto* attr = GetAttrHdl(attr_name);
2090
+ float f = {};
2091
+ size_t out = {};
2092
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out));
2093
+ return f;
2094
+ }
2095
+
2096
+ inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) {
2097
+ const auto* attr = GetAttrHdl(attr_name);
2098
+ float f = {};
2099
+ size_t out = {};
2100
+ // first call to get the bytes needed
2101
+ // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure.
2102
+ // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success).
2103
+ // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}.
2104
+ auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out);
2105
+ if (status) {
2106
+ size_t num_f = out / sizeof(float);
2107
+ ShapeInferContext::Floats floats(num_f, 0);
2108
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out));
2109
+ return floats;
2110
+ } else {
2111
+ if (out == 0u) {
2112
+ return {};
2113
+ }
2114
+ return {f};
2115
+ }
2116
+ }
2117
+
2118
+ inline std::string ShapeInferContext::GetAttrString(const char* attr_name) {
2119
+ const auto* attr = GetAttrHdl(attr_name);
2120
+ char c = {};
2121
+ size_t out = {};
2122
+ // first call to get the bytes needed
2123
+ auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out);
2124
+ if (status) {
2125
+ std::vector<char> chars(out, '\0');
2126
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out));
2127
+ return {chars.data()};
2128
+ } else {
2129
+ return {c};
2130
+ }
2131
+ }
2132
+
2133
+ inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) {
2134
+ const auto* attr = GetAttrHdl(attr_name);
2135
+ char c = {};
2136
+ size_t out = {};
2137
+ // first call to get the bytes needed
2138
+ // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure.
2139
+ // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success).
2140
+ // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}.
2141
+ auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out);
2142
+ if (status) {
2143
+ std::vector<char> chars(out, '\0');
2144
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out));
2145
+ ShapeInferContext::Strings strings;
2146
+ char* char_st = chars.data();
2147
+ char* char_ed = char_st + out;
2148
+ while (char_st < char_ed) {
2149
+ strings.emplace_back(char_st);
2150
+ while (*char_st != '\0') {
2151
+ char_st++;
2152
+ }
2153
+ char_st++;
2154
+ }
2155
+ return strings;
2156
+ } else {
2157
+ if (out == 0u) {
2158
+ return {};
2159
+ }
2160
+ return {std::string{c}};
2161
+ }
2162
+ }
2163
+
2164
+ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const {
2165
+ const OrtOpAttr* attr_hdl = {};
2166
+ Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl));
2167
+ return attr_hdl;
2168
+ }
2169
+
2170
+ } // namespace Ort
1.20.0/onnxruntime.xcframework/Headers/onnxruntime_float16.h ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ #include <stdint.h>
7
+ #include <cmath>
8
+ #include <cstring>
9
+ #include <limits>
10
+
11
+ namespace onnxruntime_float16 {
12
+
13
+ namespace detail {
14
+
15
+ enum class endian {
16
+ #if defined(_WIN32)
17
+ little = 0,
18
+ big = 1,
19
+ native = little,
20
+ #elif defined(__GNUC__) || defined(__clang__)
21
+ little = __ORDER_LITTLE_ENDIAN__,
22
+ big = __ORDER_BIG_ENDIAN__,
23
+ native = __BYTE_ORDER__,
24
+ #else
25
+ #error onnxruntime_float16::detail::endian is not implemented in this environment.
26
+ #endif
27
+ };
28
+
29
+ static_assert(
30
+ endian::native == endian::little || endian::native == endian::big,
31
+ "Only little-endian or big-endian native byte orders are supported.");
32
+
33
+ } // namespace detail
34
+
35
+ /// <summary>
36
+ /// Shared implementation between public and internal classes. CRTP pattern.
37
+ /// </summary>
38
+ template <class Derived>
39
+ struct Float16Impl {
40
+ protected:
41
+ /// <summary>
42
+ /// Converts from float to uint16_t float16 representation
43
+ /// </summary>
44
+ /// <param name="v"></param>
45
+ /// <returns></returns>
46
+ constexpr static uint16_t ToUint16Impl(float v) noexcept;
47
+
48
+ /// <summary>
49
+ /// Converts float16 to float
50
+ /// </summary>
51
+ /// <returns>float representation of float16 value</returns>
52
+ float ToFloatImpl() const noexcept;
53
+
54
+ /// <summary>
55
+ /// Creates an instance that represents absolute value.
56
+ /// </summary>
57
+ /// <returns>Absolute value</returns>
58
+ uint16_t AbsImpl() const noexcept {
59
+ return static_cast<uint16_t>(val & ~kSignMask);
60
+ }
61
+
62
+ /// <summary>
63
+ /// Creates a new instance with the sign flipped.
64
+ /// </summary>
65
+ /// <returns>Flipped sign instance</returns>
66
+ uint16_t NegateImpl() const noexcept {
67
+ return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
68
+ }
69
+
70
+ public:
71
+ // uint16_t special values
72
+ static constexpr uint16_t kSignMask = 0x8000U;
73
+ static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
74
+ static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
75
+ static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
76
+ static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
77
+ static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
78
+ static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
79
+ static constexpr uint16_t kOneBits = 0x3C00U;
80
+ static constexpr uint16_t kMinusOneBits = 0xBC00U;
81
+
82
+ uint16_t val{0};
83
+
84
+ Float16Impl() = default;
85
+
86
+ /// <summary>
87
+ /// Checks if the value is negative
88
+ /// </summary>
89
+ /// <returns>true if negative</returns>
90
+ bool IsNegative() const noexcept {
91
+ return static_cast<int16_t>(val) < 0;
92
+ }
93
+
94
+ /// <summary>
95
+ /// Tests if the value is NaN
96
+ /// </summary>
97
+ /// <returns>true if NaN</returns>
98
+ bool IsNaN() const noexcept {
99
+ return AbsImpl() > kPositiveInfinityBits;
100
+ }
101
+
102
+ /// <summary>
103
+ /// Tests if the value is finite
104
+ /// </summary>
105
+ /// <returns>true if finite</returns>
106
+ bool IsFinite() const noexcept {
107
+ return AbsImpl() < kPositiveInfinityBits;
108
+ }
109
+
110
+ /// <summary>
111
+ /// Tests if the value represents positive infinity.
112
+ /// </summary>
113
+ /// <returns>true if positive infinity</returns>
114
+ bool IsPositiveInfinity() const noexcept {
115
+ return val == kPositiveInfinityBits;
116
+ }
117
+
118
+ /// <summary>
119
+ /// Tests if the value represents negative infinity
120
+ /// </summary>
121
+ /// <returns>true if negative infinity</returns>
122
+ bool IsNegativeInfinity() const noexcept {
123
+ return val == kNegativeInfinityBits;
124
+ }
125
+
126
+ /// <summary>
127
+ /// Tests if the value is either positive or negative infinity.
128
+ /// </summary>
129
+ /// <returns>True if absolute value is infinity</returns>
130
+ bool IsInfinity() const noexcept {
131
+ return AbsImpl() == kPositiveInfinityBits;
132
+ }
133
+
134
+ /// <summary>
135
+ /// Tests if the value is NaN or zero. Useful for comparisons.
136
+ /// </summary>
137
+ /// <returns>True if NaN or zero.</returns>
138
+ bool IsNaNOrZero() const noexcept {
139
+ auto abs = AbsImpl();
140
+ return (abs == 0 || abs > kPositiveInfinityBits);
141
+ }
142
+
143
+ /// <summary>
144
+ /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
145
+ /// </summary>
146
+ /// <returns>True if so</returns>
147
+ bool IsNormal() const noexcept {
148
+ auto abs = AbsImpl();
149
+ return (abs < kPositiveInfinityBits) // is finite
150
+ && (abs != 0) // is not zero
151
+ && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
152
+ }
153
+
154
+ /// <summary>
155
+ /// Tests if the value is subnormal (denormal).
156
+ /// </summary>
157
+ /// <returns>True if so</returns>
158
+ bool IsSubnormal() const noexcept {
159
+ auto abs = AbsImpl();
160
+ return (abs < kPositiveInfinityBits) // is finite
161
+ && (abs != 0) // is not zero
162
+ && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
163
+ }
164
+
165
+ /// <summary>
166
+ /// Creates an instance that represents absolute value.
167
+ /// </summary>
168
+ /// <returns>Absolute value</returns>
169
+ Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
170
+
171
+ /// <summary>
172
+ /// Creates a new instance with the sign flipped.
173
+ /// </summary>
174
+ /// <returns>Flipped sign instance</returns>
175
+ Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
176
+
177
+ /// <summary>
178
+ /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
179
+ /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
180
+ /// and therefore equivalent, if the resulting value is still zero.
181
+ /// </summary>
182
+ /// <param name="lhs">first value</param>
183
+ /// <param name="rhs">second value</param>
184
+ /// <returns>True if both arguments represent zero</returns>
185
+ static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
186
+ return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
187
+ }
188
+
189
+ bool operator==(const Float16Impl& rhs) const noexcept {
190
+ if (IsNaN() || rhs.IsNaN()) {
191
+ // IEEE defines that NaN is not equal to anything, including itself.
192
+ return false;
193
+ }
194
+ return val == rhs.val;
195
+ }
196
+
197
+ bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
198
+
199
+ bool operator<(const Float16Impl& rhs) const noexcept {
200
+ if (IsNaN() || rhs.IsNaN()) {
201
+ // IEEE defines that NaN is unordered with respect to everything, including itself.
202
+ return false;
203
+ }
204
+
205
+ const bool left_is_negative = IsNegative();
206
+ if (left_is_negative != rhs.IsNegative()) {
207
+ // When the signs of left and right differ, we know that left is less than right if it is
208
+ // the negative value. The exception to this is if both values are zero, in which case IEEE
209
+ // says they should be equal, even if the signs differ.
210
+ return left_is_negative && !AreZero(*this, rhs);
211
+ }
212
+ return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
213
+ }
214
+ };
215
+
216
+ // The following Float16_t conversions are based on the code from
217
+ // Eigen library.
218
+
219
+ // The conversion routines are Copyright (c) Fabian Giesen, 2016.
220
+ // The original license follows:
221
+ //
222
+ // Copyright (c) Fabian Giesen, 2016
223
+ // All rights reserved.
224
+ // Redistribution and use in source and binary forms, with or without
225
+ // modification, are permitted.
226
+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
227
+ // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
228
+ // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
229
+ // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
230
+ // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
231
+ // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
232
+ // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
233
+ // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
234
+ // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
235
+ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
236
+ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
237
+
238
+ namespace detail {
239
+ union float32_bits {
240
+ unsigned int u;
241
+ float f;
242
+ };
243
+ } // namespace detail
244
+
245
+ template <class Derived>
246
+ inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
247
+ detail::float32_bits f{};
248
+ f.f = v;
249
+
250
+ constexpr detail::float32_bits f32infty = {255 << 23};
251
+ constexpr detail::float32_bits f16max = {(127 + 16) << 23};
252
+ constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
253
+ constexpr unsigned int sign_mask = 0x80000000u;
254
+ uint16_t val = static_cast<uint16_t>(0x0u);
255
+
256
+ unsigned int sign = f.u & sign_mask;
257
+ f.u ^= sign;
258
+
259
+ // NOTE all the integer compares in this function can be safely
260
+ // compiled into signed compares since all operands are below
261
+ // 0x80000000. Important if you want fast straight SSE2 code
262
+ // (since there's no unsigned PCMPGTD).
263
+
264
+ if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
265
+ val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
266
+ } else { // (De)normalized number or zero
267
+ if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
268
+ // use a magic value to align our 10 mantissa bits at the bottom of
269
+ // the float. as long as FP addition is round-to-nearest-even this
270
+ // just works.
271
+ f.f += denorm_magic.f;
272
+
273
+ // and one integer subtract of the bias later, we have our final float!
274
+ val = static_cast<uint16_t>(f.u - denorm_magic.u);
275
+ } else {
276
+ unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
277
+
278
+ // update exponent, rounding bias part 1
279
+ // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
280
+ // without arithmetic overflow.
281
+ f.u += 0xc8000fffU;
282
+ // rounding bias part 2
283
+ f.u += mant_odd;
284
+ // take the bits!
285
+ val = static_cast<uint16_t>(f.u >> 13);
286
+ }
287
+ }
288
+
289
+ val |= static_cast<uint16_t>(sign >> 16);
290
+ return val;
291
+ }
292
+
293
+ template <class Derived>
294
+ inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
295
+ constexpr detail::float32_bits magic = {113 << 23};
296
+ constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
297
+ detail::float32_bits o{};
298
+
299
+ o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
300
+ unsigned int exp = shifted_exp & o.u; // just the exponent
301
+ o.u += (127 - 15) << 23; // exponent adjust
302
+
303
+ // handle exponent special cases
304
+ if (exp == shifted_exp) { // Inf/NaN?
305
+ o.u += (128 - 16) << 23; // extra exp adjust
306
+ } else if (exp == 0) { // Zero/Denormal?
307
+ o.u += 1 << 23; // extra exp adjust
308
+ o.f -= magic.f; // re-normalize
309
+ }
310
+
311
+ // Attempt to workaround the Internal Compiler Error on ARM64
312
+ // for bitwise | operator, including std::bitset
313
+ #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
314
+ if (IsNegative()) {
315
+ return -o.f;
316
+ }
317
+ #else
318
+ // original code:
319
+ o.u |= (val & 0x8000U) << 16U; // sign bit
320
+ #endif
321
+ return o.f;
322
+ }
323
+
324
+ /// Shared implementation between public and internal classes. CRTP pattern.
325
+ template <class Derived>
326
+ struct BFloat16Impl {
327
+ protected:
328
+ /// <summary>
329
+ /// Converts from float to uint16_t float16 representation
330
+ /// </summary>
331
+ /// <param name="v"></param>
332
+ /// <returns></returns>
333
+ static uint16_t ToUint16Impl(float v) noexcept;
334
+
335
+ /// <summary>
336
+ /// Converts bfloat16 to float
337
+ /// </summary>
338
+ /// <returns>float representation of bfloat16 value</returns>
339
+ float ToFloatImpl() const noexcept;
340
+
341
+ /// <summary>
342
+ /// Creates an instance that represents absolute value.
343
+ /// </summary>
344
+ /// <returns>Absolute value</returns>
345
+ uint16_t AbsImpl() const noexcept {
346
+ return static_cast<uint16_t>(val & ~kSignMask);
347
+ }
348
+
349
+ /// <summary>
350
+ /// Creates a new instance with the sign flipped.
351
+ /// </summary>
352
+ /// <returns>Flipped sign instance</returns>
353
+ uint16_t NegateImpl() const noexcept {
354
+ return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
355
+ }
356
+
357
+ public:
358
+ // uint16_t special values
359
+ static constexpr uint16_t kSignMask = 0x8000U;
360
+ static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
361
+ static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
362
+ static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
363
+ static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
364
+ static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
365
+ static constexpr uint16_t kMaxValueBits = 0x7F7FU;
366
+ static constexpr uint16_t kRoundToNearest = 0x7FFFU;
367
+ static constexpr uint16_t kOneBits = 0x3F80U;
368
+ static constexpr uint16_t kMinusOneBits = 0xBF80U;
369
+
370
+ uint16_t val{0};
371
+
372
+ BFloat16Impl() = default;
373
+
374
+ /// <summary>
375
+ /// Checks if the value is negative
376
+ /// </summary>
377
+ /// <returns>true if negative</returns>
378
+ bool IsNegative() const noexcept {
379
+ return static_cast<int16_t>(val) < 0;
380
+ }
381
+
382
+ /// <summary>
383
+ /// Tests if the value is NaN
384
+ /// </summary>
385
+ /// <returns>true if NaN</returns>
386
+ bool IsNaN() const noexcept {
387
+ return AbsImpl() > kPositiveInfinityBits;
388
+ }
389
+
390
+ /// <summary>
391
+ /// Tests if the value is finite
392
+ /// </summary>
393
+ /// <returns>true if finite</returns>
394
+ bool IsFinite() const noexcept {
395
+ return AbsImpl() < kPositiveInfinityBits;
396
+ }
397
+
398
+ /// <summary>
399
+ /// Tests if the value represents positive infinity.
400
+ /// </summary>
401
+ /// <returns>true if positive infinity</returns>
402
+ bool IsPositiveInfinity() const noexcept {
403
+ return val == kPositiveInfinityBits;
404
+ }
405
+
406
+ /// <summary>
407
+ /// Tests if the value represents negative infinity
408
+ /// </summary>
409
+ /// <returns>true if negative infinity</returns>
410
+ bool IsNegativeInfinity() const noexcept {
411
+ return val == kNegativeInfinityBits;
412
+ }
413
+
414
+ /// <summary>
415
+ /// Tests if the value is either positive or negative infinity.
416
+ /// </summary>
417
+ /// <returns>True if absolute value is infinity</returns>
418
+ bool IsInfinity() const noexcept {
419
+ return AbsImpl() == kPositiveInfinityBits;
420
+ }
421
+
422
+ /// <summary>
423
+ /// Tests if the value is NaN or zero. Useful for comparisons.
424
+ /// </summary>
425
+ /// <returns>True if NaN or zero.</returns>
426
+ bool IsNaNOrZero() const noexcept {
427
+ auto abs = AbsImpl();
428
+ return (abs == 0 || abs > kPositiveInfinityBits);
429
+ }
430
+
431
+ /// <summary>
432
+ /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
433
+ /// </summary>
434
+ /// <returns>True if so</returns>
435
+ bool IsNormal() const noexcept {
436
+ auto abs = AbsImpl();
437
+ return (abs < kPositiveInfinityBits) // is finite
438
+ && (abs != 0) // is not zero
439
+ && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
440
+ }
441
+
442
+ /// <summary>
443
+ /// Tests if the value is subnormal (denormal).
444
+ /// </summary>
445
+ /// <returns>True if so</returns>
446
+ bool IsSubnormal() const noexcept {
447
+ auto abs = AbsImpl();
448
+ return (abs < kPositiveInfinityBits) // is finite
449
+ && (abs != 0) // is not zero
450
+ && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
451
+ }
452
+
453
+ /// <summary>
454
+ /// Creates an instance that represents absolute value.
455
+ /// </summary>
456
+ /// <returns>Absolute value</returns>
457
+ Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
458
+
459
+ /// <summary>
460
+ /// Creates a new instance with the sign flipped.
461
+ /// </summary>
462
+ /// <returns>Flipped sign instance</returns>
463
+ Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
464
+
465
+ /// <summary>
466
+ /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
467
+ /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
468
+ /// and therefore equivalent, if the resulting value is still zero.
469
+ /// </summary>
470
+ /// <param name="lhs">first value</param>
471
+ /// <param name="rhs">second value</param>
472
+ /// <returns>True if both arguments represent zero</returns>
473
+ static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
474
+ // IEEE defines that positive and negative zero are equal, this gives us a quick equality check
475
+ // for two values by or'ing the private bits together and stripping the sign. They are both zero,
476
+ // and therefore equivalent, if the resulting value is still zero.
477
+ return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
478
+ }
479
+ };
480
+
481
+ template <class Derived>
482
+ inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
483
+ uint16_t result;
484
+ if (std::isnan(v)) {
485
+ result = kPositiveQNaNBits;
486
+ } else {
487
+ auto get_msb_half = [](float fl) {
488
+ uint16_t result;
489
+ #ifdef __cpp_if_constexpr
490
+ if constexpr (detail::endian::native == detail::endian::little) {
491
+ #else
492
+ if (detail::endian::native == detail::endian::little) {
493
+ #endif
494
+ std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
495
+ } else {
496
+ std::memcpy(&result, &fl, sizeof(uint16_t));
497
+ }
498
+ return result;
499
+ };
500
+
501
+ uint16_t upper_bits = get_msb_half(v);
502
+ union {
503
+ uint32_t U32;
504
+ float F32;
505
+ };
506
+ F32 = v;
507
+ U32 += (upper_bits & 1) + kRoundToNearest;
508
+ result = get_msb_half(F32);
509
+ }
510
+ return result;
511
+ }
512
+
513
+ template <class Derived>
514
+ inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
515
+ if (IsNaN()) {
516
+ return std::numeric_limits<float>::quiet_NaN();
517
+ }
518
+ float result;
519
+ char* const first = reinterpret_cast<char*>(&result);
520
+ char* const second = first + sizeof(uint16_t);
521
+ #ifdef __cpp_if_constexpr
522
+ if constexpr (detail::endian::native == detail::endian::little) {
523
+ #else
524
+ if (detail::endian::native == detail::endian::little) {
525
+ #endif
526
+ std::memset(first, 0, sizeof(uint16_t));
527
+ std::memcpy(second, &val, sizeof(uint16_t));
528
+ } else {
529
+ std::memcpy(first, &val, sizeof(uint16_t));
530
+ std::memset(second, 0, sizeof(uint16_t));
531
+ }
532
+ return result;
533
+ }
534
+
535
+ } // namespace onnxruntime_float16
1.20.0/onnxruntime.xcframework/Headers/onnxruntime_lite_custom_op.h ADDED
@@ -0,0 +1,1119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Summary
5
+ // The header has APIs to save custom op authors the trouble of defining schemas,
6
+ // which will be inferred by functions' signature, as long as their argument list has types supported here.
7
+ // Input could be:
8
+ // 1. Tensor of onnx data types.
9
+ // 2. Span of onnx data types.
10
+ // 3. Scalar of onnx data types.
11
+ // A input could be optional if indicated as std::optional<...>.
12
+ // For an output, it must be a tensor of onnx data types.
13
+ // Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op.
14
+ // For concrete examples, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
15
+ // Note - all APIs in this header are ABI.
16
+
17
+ #pragma once
18
+ #include "onnxruntime_cxx_api.h"
19
+ #include <optional>
20
+ #include <numeric>
21
+ #include <functional>
22
+ #include <unordered_set>
23
+
24
+ namespace Ort {
25
+ namespace Custom {
26
+
27
+ class ArgBase {
28
+ public:
29
+ ArgBase(OrtKernelContext* ctx,
30
+ size_t indice,
31
+ bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {}
32
+ virtual ~ArgBase() {};
33
+
34
+ protected:
35
+ struct KernelContext ctx_;
36
+ size_t indice_;
37
+ bool is_input_;
38
+ };
39
+
40
+ using ArgPtr = std::unique_ptr<Custom::ArgBase>;
41
+ using ArgPtrs = std::vector<ArgPtr>;
42
+
43
+ class TensorBase : public ArgBase {
44
+ public:
45
+ TensorBase(OrtKernelContext* ctx,
46
+ size_t indice,
47
+ bool is_input) : ArgBase(ctx, indice, is_input) {}
48
+
49
+ operator bool() const {
50
+ return shape_.has_value();
51
+ }
52
+
53
+ const std::vector<int64_t>& Shape() const {
54
+ if (!shape_.has_value()) {
55
+ ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
56
+ }
57
+ return shape_.value();
58
+ }
59
+
60
+ ONNXTensorElementDataType Type() const {
61
+ return type_;
62
+ }
63
+
64
+ int64_t NumberOfElement() const {
65
+ if (shape_.has_value()) {
66
+ return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
67
+ } else {
68
+ return 0;
69
+ }
70
+ }
71
+
72
+ std::string Shape2Str() const {
73
+ if (shape_.has_value()) {
74
+ std::string shape_str;
75
+ for (const auto& dim : *shape_) {
76
+ shape_str.append(std::to_string(dim));
77
+ shape_str.append(", ");
78
+ }
79
+ return shape_str;
80
+ } else {
81
+ return "empty";
82
+ }
83
+ }
84
+
85
+ bool IsCpuTensor() const {
86
+ return strcmp("Cpu", mem_type_) == 0;
87
+ }
88
+
89
+ virtual const void* DataRaw() const = 0;
90
+ virtual size_t SizeInBytes() const = 0;
91
+
92
+ protected:
93
+ std::optional<std::vector<int64_t>> shape_;
94
+ ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
95
+ const char* mem_type_ = "Cpu";
96
+ };
97
+
98
+ template <typename T>
99
+ struct Span {
100
+ const T* data_ = {};
101
+ size_t size_ = {};
102
+ void Assign(const T* data, size_t size) {
103
+ data_ = data;
104
+ size_ = size;
105
+ }
106
+ size_t size() const { return size_; }
107
+ T operator[](size_t indice) const {
108
+ return data_[indice];
109
+ }
110
+ const T* data() const { return data_; }
111
+ };
112
+
113
+ template <typename T>
114
+ class Tensor : public TensorBase {
115
+ public:
116
+ using TT = typename std::remove_reference<T>::type;
117
+ Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
118
+ if (is_input_) {
119
+ if (indice >= ctx_.GetInputCount()) {
120
+ ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
121
+ }
122
+ const_value_ = ctx_.GetInput(indice);
123
+ auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo();
124
+ shape_ = type_shape_info.GetShape();
125
+ }
126
+ }
127
+ const TT* Data() const {
128
+ return reinterpret_cast<const TT*>(const_value_.GetTensorRawData());
129
+ }
130
+ TT* Allocate(const std::vector<int64_t>& shape) {
131
+ shape_ = shape;
132
+ if (!data_) {
133
+ shape_ = shape;
134
+ data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData<TT>();
135
+ }
136
+ return data_;
137
+ }
138
+ static TT GetT() { return (TT)0; }
139
+ const Span<T>& AsSpan() {
140
+ if (!shape_.has_value() || shape_->size() != 1) {
141
+ ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor",
142
+ OrtErrorCode::ORT_RUNTIME_EXCEPTION);
143
+ }
144
+ span_.Assign(Data(), static_cast<size_t>((*shape_)[0]));
145
+ return span_;
146
+ }
147
+ const T& AsScalar() {
148
+ if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) {
149
+ ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor",
150
+ OrtErrorCode::ORT_RUNTIME_EXCEPTION);
151
+ }
152
+ return *Data();
153
+ }
154
+ const void* DataRaw() const override {
155
+ return reinterpret_cast<const void*>(Data());
156
+ }
157
+
158
+ size_t SizeInBytes() const override {
159
+ return sizeof(TT) * static_cast<size_t>(NumberOfElement());
160
+ }
161
+
162
+ private:
163
+ ConstValue const_value_; // for input
164
+ TT* data_{}; // for output
165
+ Span<T> span_;
166
+ };
167
+
168
+ template <>
169
+ class Tensor<std::string> : public TensorBase {
170
+ public:
171
+ using strings = std::vector<std::string>;
172
+
173
+ Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
174
+ if (is_input_) {
175
+ if (indice >= ctx_.GetInputCount()) {
176
+ ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
177
+ }
178
+ auto const_value = ctx_.GetInput(indice);
179
+ auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
180
+ shape_ = type_shape_info.GetShape();
181
+ auto num_chars = const_value.GetStringTensorDataLength();
182
+ // note - there will be copy ...
183
+ auto num_strings = static_cast<size_t>(NumberOfElement());
184
+ if (num_strings) {
185
+ std::vector<char> chars(num_chars + 1, '\0');
186
+ std::vector<size_t> offsets(num_strings);
187
+ const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
188
+ auto upper_bound = num_strings - 1;
189
+ input_strings_.resize(num_strings);
190
+ for (size_t i = upper_bound;; --i) {
191
+ if (i < upper_bound) {
192
+ chars[offsets[i + 1]] = '\0';
193
+ }
194
+ input_strings_[i] = chars.data() + offsets[i];
195
+ if (0 == i) {
196
+ break;
197
+ }
198
+ }
199
+ }
200
+ }
201
+ }
202
+ const strings& Data() const {
203
+ return input_strings_;
204
+ }
205
+ const void* DataRaw() const override {
206
+ if (input_strings_.size() != 1) {
207
+ ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
208
+ }
209
+ return reinterpret_cast<const void*>(input_strings_[0].c_str());
210
+ }
211
+ size_t SizeInBytes() const override {
212
+ if (input_strings_.size() != 1) {
213
+ ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
214
+ }
215
+ return input_strings_[0].size();
216
+ }
217
+ void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
218
+ shape_ = dims;
219
+ std::vector<const char*> raw;
220
+ for (const auto& s : ss) {
221
+ raw.push_back(s.data());
222
+ }
223
+ auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
224
+ // note - there will be copy ...
225
+ output.FillStringTensor(raw.data(), raw.size());
226
+ }
227
+ const Span<std::string>& AsSpan() {
228
+ ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
229
+ }
230
+ const std::string& AsScalar() {
231
+ if (input_strings_.size() != 1) {
232
+ ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor",
233
+ OrtErrorCode::ORT_RUNTIME_EXCEPTION);
234
+ }
235
+ return input_strings_[0];
236
+ }
237
+
238
+ private:
239
+ std::vector<std::string> input_strings_; // for input
240
+ };
241
+
242
+ template <>
243
+ class Tensor<std::string_view> : public TensorBase {
244
+ public:
245
+ using strings = std::vector<std::string>;
246
+ using string_views = std::vector<std::string_view>;
247
+
248
+ Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
249
+ if (is_input_) {
250
+ if (indice >= ctx_.GetInputCount()) {
251
+ ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
252
+ }
253
+ auto const_value = ctx_.GetInput(indice);
254
+ auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
255
+ shape_ = type_shape_info.GetShape();
256
+ auto num_chars = const_value.GetStringTensorDataLength();
257
+ chars_.resize(num_chars + 1, '\0');
258
+ auto num_strings = static_cast<size_t>(NumberOfElement());
259
+ if (num_strings) {
260
+ std::vector<size_t> offsets(num_strings);
261
+ const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size());
262
+ offsets.push_back(num_chars);
263
+ for (size_t i = 0; i < num_strings; ++i) {
264
+ input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
265
+ }
266
+ }
267
+ }
268
+ }
269
+ const string_views& Data() const {
270
+ return input_string_views_;
271
+ }
272
+ const void* DataRaw() const override {
273
+ if (input_string_views_.size() != 1) {
274
+ ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
275
+ }
276
+ return reinterpret_cast<const void*>(input_string_views_[0].data());
277
+ }
278
+ size_t SizeInBytes() const override {
279
+ if (input_string_views_.size() != 1) {
280
+ ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
281
+ }
282
+ return input_string_views_[0].size();
283
+ }
284
+ void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
285
+ shape_ = dims;
286
+ std::vector<const char*> raw;
287
+ for (const auto& s : ss) {
288
+ raw.push_back(s.data());
289
+ }
290
+ auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
291
+ // note - there will be copy ...
292
+ output.FillStringTensor(raw.data(), raw.size());
293
+ }
294
+ const Span<std::string_view>& AsSpan() {
295
+ ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
296
+ }
297
+ std::string_view AsScalar() {
298
+ if (input_string_views_.size() != 1) {
299
+ ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor",
300
+ OrtErrorCode::ORT_RUNTIME_EXCEPTION);
301
+ }
302
+ return input_string_views_[0];
303
+ }
304
+
305
+ private:
306
+ std::vector<char> chars_; // for input
307
+ std::vector<std::string_view> input_string_views_; // for input
308
+ };
309
+
310
+ using TensorPtr = std::unique_ptr<Custom::TensorBase>;
311
+ using TensorPtrs = std::vector<TensorPtr>;
312
+
313
+ struct TensorArray : public ArgBase {
314
+ TensorArray(OrtKernelContext* ctx,
315
+ size_t start_indice,
316
+ bool is_input) : ArgBase(ctx,
317
+ start_indice,
318
+ is_input) {
319
+ if (is_input) {
320
+ auto input_count = ctx_.GetInputCount();
321
+ for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) {
322
+ auto const_value = ctx_.GetInput(start_indice);
323
+ auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
324
+ auto type = type_shape_info.GetElementType();
325
+ TensorPtr tensor;
326
+ switch (type) {
327
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
328
+ tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input, true);
329
+ break;
330
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
331
+ tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input, true);
332
+ break;
333
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
334
+ tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input, true);
335
+ break;
336
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
337
+ tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input, true);
338
+ break;
339
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
340
+ tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input, true);
341
+ break;
342
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
343
+ tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input, true);
344
+ break;
345
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
346
+ tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input, true);
347
+ break;
348
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
349
+ tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input, true);
350
+ break;
351
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
352
+ tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input, true);
353
+ break;
354
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
355
+ tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input, true);
356
+ break;
357
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
358
+ tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input, true);
359
+ break;
360
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
361
+ tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input, true);
362
+ break;
363
+ default:
364
+ ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
365
+ break;
366
+ }
367
+ tensors_.emplace_back(tensor.release());
368
+ } // for
369
+ }
370
+ }
371
+ template <typename T>
372
+ T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
373
+ // ith_output is the indice of output relative to the tensor array
374
+ // indice_ + ith_output is the indice relative to context
375
+ auto tensor = std::make_unique<Tensor<T>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
376
+ auto raw_output = tensor.get()->Allocate(shape);
377
+ tensors_.emplace_back(tensor.release());
378
+ return raw_output;
379
+ }
380
+ Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
381
+ // ith_output is the indice of output relative to the tensor array
382
+ // indice_ + ith_output is the indice relative to context
383
+ auto tensor = std::make_unique<Tensor<std::string>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
384
+ Tensor<std::string>& output = *tensor;
385
+ tensors_.emplace_back(tensor.release());
386
+ return output;
387
+ }
388
+ size_t Size() const {
389
+ return tensors_.size();
390
+ }
391
+ const TensorPtr& operator[](size_t ith_input) const {
392
+ // ith_input is the indice of output relative to the tensor array
393
+ return tensors_.at(ith_input);
394
+ }
395
+
396
+ private:
397
+ TensorPtrs tensors_;
398
+ };
399
+
400
+ using Variadic = TensorArray;
401
+
402
+ /*
403
+ Note:
404
+ OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core.
405
+ The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so:
406
+ 1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierarchy.
407
+ 2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp,
408
+ hence memory could still be recycled properly.
409
+ Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety.
410
+ */
411
+ struct OrtLiteCustomOp : public OrtCustomOp {
412
+ using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
413
+ using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
414
+
415
+ // CreateTuple
416
+ template <size_t ith_input, size_t ith_output, typename... Ts>
417
+ static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
418
+ CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) {
419
+ return std::make_tuple();
420
+ }
421
+
422
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
423
+ static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
424
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
425
+ std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
426
+ auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
427
+ return std::tuple_cat(current, next);
428
+ }
429
+
430
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
431
+ static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
432
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
433
+ std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
434
+ auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
435
+ return std::tuple_cat(current, next);
436
+ }
437
+
438
+ #ifdef ORT_CUDA_CTX
439
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
440
+ static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
441
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
442
+ thread_local CudaContext cuda_context;
443
+ cuda_context.Init(*context);
444
+ std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
445
+ auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
446
+ return std::tuple_cat(current, next);
447
+ }
448
+ #endif
449
+
450
+ #ifdef ORT_ROCM_CTX
451
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
452
+ static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
453
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
454
+ thread_local RocmContext rocm_context;
455
+ rocm_context.Init(*context);
456
+ std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
457
+ auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
458
+ return std::tuple_cat(current, next);
459
+ }
460
+ #endif
461
+
462
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
463
+ static typename std::enable_if<std::is_same<T, const TensorArray*>::value, std::tuple<T, Ts...>>::type
464
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
465
+ args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
466
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
467
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
468
+ return std::tuple_cat(current, next);
469
+ }
470
+
471
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
472
+ static typename std::enable_if<std::is_same<T, const TensorArray&>::value, std::tuple<T, Ts...>>::type
473
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
474
+ args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
475
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
476
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
477
+ return std::tuple_cat(current, next);
478
+ }
479
+
480
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
481
+ static typename std::enable_if<std::is_same<T, TensorArray*>::value, std::tuple<T, Ts...>>::type
482
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
483
+ args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
484
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
485
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
486
+ return std::tuple_cat(current, next);
487
+ }
488
+
489
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
490
+ static typename std::enable_if<std::is_same<T, TensorArray&>::value, std::tuple<T, Ts...>>::type
491
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
492
+ args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
493
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
494
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
495
+ return std::tuple_cat(current, next);
496
+ }
497
+
498
+ #define CREATE_TUPLE_INPUT(data_type) \
499
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
500
+ static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
501
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
502
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
503
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
504
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
505
+ return std::tuple_cat(current, next); \
506
+ } \
507
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
508
+ static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
509
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
510
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
511
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
512
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
513
+ return std::tuple_cat(current, next); \
514
+ } \
515
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
516
+ static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
517
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
518
+ if (ith_input < num_input) { \
519
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
520
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
521
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
522
+ return std::tuple_cat(current, next); \
523
+ } else { \
524
+ std::tuple<T> current = std::tuple<T>{}; \
525
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
526
+ return std::tuple_cat(current, next); \
527
+ } \
528
+ } \
529
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
530
+ static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
531
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
532
+ if ("CPUExecutionProvider" != ep) { \
533
+ ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
534
+ } \
535
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
536
+ std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
537
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
538
+ return std::tuple_cat(current, next); \
539
+ } \
540
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
541
+ static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
542
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
543
+ if ("CPUExecutionProvider" != ep) { \
544
+ ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
545
+ } \
546
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
547
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
548
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
549
+ return std::tuple_cat(current, next); \
550
+ } \
551
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
552
+ static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
553
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
554
+ if (ith_input < num_input) { \
555
+ if ("CPUExecutionProvider" != ep) { \
556
+ ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
557
+ } \
558
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
559
+ std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
560
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
561
+ return std::tuple_cat(current, next); \
562
+ } else { \
563
+ std::tuple<T> current = std::tuple<T>{}; \
564
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
565
+ return std::tuple_cat(current, next); \
566
+ } \
567
+ } \
568
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
569
+ static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
570
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
571
+ if ("CPUExecutionProvider" != ep) { \
572
+ ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
573
+ } \
574
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
575
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
576
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
577
+ return std::tuple_cat(current, next); \
578
+ } \
579
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
580
+ static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
581
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
582
+ if (ith_input < num_input) { \
583
+ if ("CPUExecutionProvider" != ep) { \
584
+ ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
585
+ } \
586
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
587
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
588
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
589
+ return std::tuple_cat(current, next); \
590
+ } else { \
591
+ std::tuple<T> current = std::tuple<T>{}; \
592
+ auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
593
+ return std::tuple_cat(current, next); \
594
+ } \
595
+ }
596
+ #define CREATE_TUPLE_OUTPUT(data_type) \
597
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
598
+ static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
599
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
600
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
601
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
602
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
603
+ return std::tuple_cat(current, next); \
604
+ } \
605
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
606
+ static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
607
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
608
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
609
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
610
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
611
+ return std::tuple_cat(current, next); \
612
+ } \
613
+ template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
614
+ static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
615
+ CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
616
+ if (ith_output < num_output) { \
617
+ args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
618
+ std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
619
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
620
+ return std::tuple_cat(current, next); \
621
+ } else { \
622
+ std::tuple<T> current = std::tuple<T>{}; \
623
+ auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
624
+ return std::tuple_cat(current, next); \
625
+ } \
626
+ }
627
+ #define CREATE_TUPLE(data_type) \
628
+ CREATE_TUPLE_INPUT(data_type) \
629
+ CREATE_TUPLE_OUTPUT(data_type)
630
+
631
+ CREATE_TUPLE(bool)
632
+ CREATE_TUPLE(float)
633
+ CREATE_TUPLE(Ort::Float16_t)
634
+ CREATE_TUPLE(Ort::BFloat16_t)
635
+ CREATE_TUPLE(double)
636
+ CREATE_TUPLE(int8_t)
637
+ CREATE_TUPLE(int16_t)
638
+ CREATE_TUPLE(int32_t)
639
+ CREATE_TUPLE(int64_t)
640
+ CREATE_TUPLE(uint8_t)
641
+ CREATE_TUPLE(uint16_t)
642
+ CREATE_TUPLE(uint32_t)
643
+ CREATE_TUPLE(uint64_t)
644
+ CREATE_TUPLE(std::string)
645
+ CREATE_TUPLE_INPUT(std::string_view)
646
+ CREATE_TUPLE(Ort::Float8E4M3FN_t)
647
+ CREATE_TUPLE(Ort::Float8E4M3FNUZ_t)
648
+ CREATE_TUPLE(Ort::Float8E5M2_t)
649
+ CREATE_TUPLE(Ort::Float8E5M2FNUZ_t)
650
+
651
+ // ParseArgs ...
652
+ template <typename... Ts>
653
+ static typename std::enable_if<0 == sizeof...(Ts)>::type
654
+ ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
655
+ }
656
+
657
+ template <typename T, typename... Ts>
658
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
659
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
660
+ ParseArgs<Ts...>(input_types, output_types);
661
+ }
662
+
663
+ template <typename T, typename... Ts>
664
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
665
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
666
+ ParseArgs<Ts...>(input_types, output_types);
667
+ }
668
+
669
+ #ifdef ORT_CUDA_CTX
670
+ template <typename T, typename... Ts>
671
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
672
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
673
+ ParseArgs<Ts...>(input_types, output_types);
674
+ }
675
+ #endif
676
+
677
+ #ifdef ORT_ROCM_CTX
678
+ template <typename T, typename... Ts>
679
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
680
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
681
+ ParseArgs<Ts...>(input_types, output_types);
682
+ }
683
+ #endif
684
+
685
+ template <typename T, typename... Ts>
686
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray&>::value>::type
687
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
688
+ input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
689
+ ParseArgs<Ts...>(input_types, output_types);
690
+ }
691
+
692
+ template <typename T, typename... Ts>
693
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray*>::value>::type
694
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
695
+ input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
696
+ ParseArgs<Ts...>(input_types, output_types);
697
+ }
698
+
699
+ template <typename T, typename... Ts>
700
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray&>::value>::type
701
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
702
+ output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
703
+ ParseArgs<Ts...>(input_types, output_types);
704
+ }
705
+
706
+ template <typename T, typename... Ts>
707
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray*>::value>::type
708
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
709
+ output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
710
+ ParseArgs<Ts...>(input_types, output_types);
711
+ }
712
+
713
+ #define PARSE_INPUT_BASE(pack_type, onnx_type) \
714
+ template <typename T, typename... Ts> \
715
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
716
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
717
+ input_types.push_back(onnx_type); \
718
+ ParseArgs<Ts...>(input_types, output_types); \
719
+ } \
720
+ template <typename T, typename... Ts> \
721
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const std::optional<pack_type>>::value>::type \
722
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
723
+ input_types.push_back(onnx_type); \
724
+ ParseArgs<Ts...>(input_types, output_types); \
725
+ } \
726
+ template <typename T, typename... Ts> \
727
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
728
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
729
+ input_types.push_back(onnx_type); \
730
+ ParseArgs<Ts...>(input_types, output_types); \
731
+ }
732
+
733
+ #define PARSE_INPUT(data_type, onnx_type) \
734
+ PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
735
+ PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
736
+ PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
737
+ PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
738
+ PARSE_INPUT_BASE(data_type, onnx_type)
739
+
740
+ #define PARSE_OUTPUT(data_type, onnx_type) \
741
+ template <typename T, typename... Ts> \
742
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
743
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
744
+ output_types.push_back(onnx_type); \
745
+ ParseArgs<Ts...>(input_types, output_types); \
746
+ } \
747
+ template <typename T, typename... Ts> \
748
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
749
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
750
+ output_types.push_back(onnx_type); \
751
+ ParseArgs<Ts...>(input_types, output_types); \
752
+ } \
753
+ template <typename T, typename... Ts> \
754
+ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
755
+ ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
756
+ output_types.push_back(onnx_type); \
757
+ ParseArgs<Ts...>(input_types, output_types); \
758
+ }
759
+
760
+ #define PARSE_ARGS(data_type, onnx_type) \
761
+ PARSE_INPUT(data_type, onnx_type) \
762
+ PARSE_OUTPUT(data_type, onnx_type)
763
+
764
+ PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
765
+ PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
766
+ PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
767
+ PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
768
+ PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
769
+ PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
770
+ PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
771
+ PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
772
+ PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
773
+ PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
774
+ PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
775
+ PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
776
+ PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
777
+ PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
778
+ PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
779
+ PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN)
780
+ PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ)
781
+ PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2)
782
+ PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
783
+
784
+ OrtLiteCustomOp(const char* op_name,
785
+ const char* execution_provider,
786
+ ShapeInferFn shape_infer_fn,
787
+ int start_ver = 1,
788
+ int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
789
+ execution_provider_(execution_provider),
790
+ shape_infer_fn_(shape_infer_fn),
791
+ start_ver_(start_ver),
792
+ end_ver_(end_ver) {
793
+ OrtCustomOp::version = ORT_API_VERSION;
794
+
795
+ OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
796
+ OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
797
+ OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; };
798
+
799
+ OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
800
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
801
+ return self->input_types_.size();
802
+ };
803
+
804
+ OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
805
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
806
+ return self->input_types_[indice];
807
+ };
808
+
809
+ OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
810
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
811
+ return self->output_types_.size();
812
+ };
813
+
814
+ OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
815
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
816
+ return self->output_types_[indice];
817
+ };
818
+
819
+ OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
820
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
821
+ return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
822
+ };
823
+
824
+ OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
825
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
826
+ return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
827
+ };
828
+
829
+ OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
830
+ return 1;
831
+ };
832
+
833
+ OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
834
+ return 0;
835
+ };
836
+
837
+ OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
838
+ return 1;
839
+ };
840
+
841
+ OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
842
+ return 0;
843
+ };
844
+
845
+ OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; };
846
+ OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; };
847
+ OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; };
848
+ OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; };
849
+
850
+ OrtCustomOp::CreateKernelV2 = {};
851
+ OrtCustomOp::KernelComputeV2 = {};
852
+ OrtCustomOp::KernelCompute = {};
853
+
854
+ OrtCustomOp::InferOutputShapeFn = {};
855
+
856
+ OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
857
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
858
+ return self->start_ver_;
859
+ };
860
+
861
+ OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
862
+ auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
863
+ return self->end_ver_;
864
+ };
865
+
866
+ OrtCustomOp::GetMayInplace = {};
867
+ OrtCustomOp::ReleaseMayInplace = {};
868
+ OrtCustomOp::GetAliasMap = {};
869
+ OrtCustomOp::ReleaseAliasMap = {};
870
+ }
871
+
872
+ const std::string op_name_;
873
+ const std::string execution_provider_;
874
+
875
+ std::vector<ONNXTensorElementDataType> input_types_;
876
+ std::vector<ONNXTensorElementDataType> output_types_;
877
+
878
+ ShapeInferFn shape_infer_fn_ = {};
879
+
880
+ int start_ver_ = 1;
881
+ int end_ver_ = MAX_CUSTOM_OP_END_VER;
882
+
883
+ void* compute_fn_ = {};
884
+ void* compute_fn_return_status_ = {};
885
+ };
886
+
887
+ //////////////////////////// OrtLiteCustomFunc ////////////////////////////////
888
+ // The struct is to implement function-as-op.
889
+ // E.g. a function might be defined as:
890
+ // void Filter(const Ort::Custom::Tensor<float>& floats_in, Ort::Custom::Tensor<float>& floats_out) { ... }
891
+ // It could be registered this way:
892
+ // Ort::CustomOpDomain v2_domain{"v2"};
893
+ // std::unique_ptr<OrtLiteCustomOp> fil_op_ptr{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)};
894
+ // v2_domain.Add(fil_op_ptr.get());
895
+ // session_options.Add(v2_domain);
896
+ // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
897
+ template <typename... Args>
898
+ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
899
+ using ComputeFn = void (*)(Args...);
900
+ using ComputeFnReturnStatus = Status (*)(Args...);
901
+ using MyType = OrtLiteCustomFunc<Args...>;
902
+
903
+ struct Kernel {
904
+ size_t num_input_{};
905
+ size_t num_output_{};
906
+ ComputeFn compute_fn_{};
907
+ ComputeFnReturnStatus compute_fn_return_status_{};
908
+ std::string ep_{};
909
+ };
910
+
911
+ OrtLiteCustomFunc(const char* op_name,
912
+ const char* execution_provider,
913
+ ComputeFn compute_fn,
914
+ ShapeInferFn shape_infer_fn = {},
915
+ int start_ver = 1,
916
+ int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
917
+ compute_fn_ = reinterpret_cast<void*>(compute_fn);
918
+ ParseArgs<Args...>(input_types_, output_types_);
919
+
920
+ OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
921
+ auto kernel = reinterpret_cast<Kernel*>(op_kernel);
922
+ std::vector<ArgPtr> args;
923
+ auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
924
+ std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
925
+ };
926
+
927
+ OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
928
+ auto kernel = std::make_unique<Kernel>();
929
+ auto me = static_cast<const MyType*>(this_);
930
+ kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
931
+ Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
932
+ Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
933
+ auto self = static_cast<const OrtLiteCustomFunc*>(this_);
934
+ kernel->ep_ = self->execution_provider_;
935
+ return reinterpret_cast<void*>(kernel.release());
936
+ };
937
+
938
+ OrtCustomOp::KernelDestroy = [](void* op_kernel) {
939
+ delete reinterpret_cast<Kernel*>(op_kernel);
940
+ };
941
+
942
+ if (shape_infer_fn_) {
943
+ OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
944
+ auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
945
+ ShapeInferContext ctx(&GetApi(), ort_ctx);
946
+ return shape_info_fn(ctx);
947
+ };
948
+ }
949
+ }
950
+
951
+ OrtLiteCustomFunc(const char* op_name,
952
+ const char* execution_provider,
953
+ ComputeFnReturnStatus compute_fn_return_status,
954
+ ShapeInferFn shape_infer_fn = {},
955
+ int start_ver = 1,
956
+ int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
957
+ compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
958
+ ParseArgs<Args...>(input_types_, output_types_);
959
+
960
+ OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
961
+ auto kernel = reinterpret_cast<Kernel*>(op_kernel);
962
+ std::vector<ArgPtr> args;
963
+ auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
964
+ return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t);
965
+ };
966
+
967
+ OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
968
+ auto kernel = std::make_unique<Kernel>();
969
+ auto me = static_cast<const MyType*>(this_);
970
+ kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
971
+ Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
972
+ Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
973
+ auto self = static_cast<const OrtLiteCustomFunc*>(this_);
974
+ kernel->ep_ = self->execution_provider_;
975
+ return reinterpret_cast<void*>(kernel.release());
976
+ };
977
+
978
+ OrtCustomOp::KernelDestroy = [](void* op_kernel) {
979
+ delete reinterpret_cast<Kernel*>(op_kernel);
980
+ };
981
+
982
+ if (shape_infer_fn_) {
983
+ OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
984
+ auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
985
+ ShapeInferContext ctx(&GetApi(), ort_ctx);
986
+ return shape_info_fn(ctx);
987
+ };
988
+ }
989
+ }
990
+ }; // struct OrtLiteCustomFunc
991
+
992
+ /////////////////////////// OrtLiteCustomStruct ///////////////////////////
993
+ // The struct is to implement struct-as-op.
994
+ // E.g. a struct might be defined as:
995
+ // struct Merge {
996
+ // Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...}
997
+ // void Compute(const Ort::Custom::Tensor<std::string_view>& strings_in,
998
+ // std::string_view string_in,
999
+ // Ort::Custom::Tensor<std::string>* strings_out) {...}
1000
+ // bool reverse_ = false;
1001
+ // };
1002
+ // It could be registered this way:
1003
+ // Ort::CustomOpDomain v2_domain{"v2"};
1004
+ // std::unique_ptr<OrtLiteCustomOp> mrg_op_ptr{Ort::Custom::CreateLiteCustomOp<Merge>("Merge", "CPUExecutionProvider")};
1005
+ // v2_domain.Add(mrg_op_ptr.get());
1006
+ // session_options.Add(v2_domain);
1007
+ // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
1008
+ template <typename CustomOp>
1009
+ struct OrtLiteCustomStruct : public OrtLiteCustomOp {
1010
+ template <typename... Args>
1011
+ using CustomComputeFn = void (CustomOp::*)(Args...);
1012
+
1013
+ template <typename... Args>
1014
+ using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...);
1015
+
1016
+ using MyType = OrtLiteCustomStruct<CustomOp>;
1017
+
1018
+ struct Kernel {
1019
+ size_t num_input_{};
1020
+ size_t num_output_{};
1021
+ std::unique_ptr<CustomOp> custom_op_;
1022
+ std::string ep_{};
1023
+ };
1024
+
1025
+ OrtLiteCustomStruct(const char* op_name,
1026
+ const char* execution_provider,
1027
+ int start_ver = 1,
1028
+ int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {
1029
+ SetCompute(&CustomOp::Compute);
1030
+
1031
+ OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
1032
+ auto kernel = std::make_unique<Kernel>();
1033
+ Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
1034
+ Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
1035
+ kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info);
1036
+ auto self = static_cast<const OrtLiteCustomStruct*>(this_);
1037
+ kernel->ep_ = self->execution_provider_;
1038
+ return reinterpret_cast<void*>(kernel.release());
1039
+ };
1040
+
1041
+ OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1042
+ delete reinterpret_cast<Kernel*>(op_kernel);
1043
+ };
1044
+
1045
+ SetShapeInfer<CustomOp>(0);
1046
+ }
1047
+
1048
+ template <typename... Args>
1049
+ void SetCompute(CustomComputeFn<Args...>) {
1050
+ ParseArgs<Args...>(input_types_, output_types_);
1051
+ OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
1052
+ auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1053
+ ArgPtrs args;
1054
+ auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1055
+ std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
1056
+ };
1057
+ }
1058
+
1059
+ template <typename... Args>
1060
+ void SetCompute(CustomComputeFnReturnStatus<Args...>) {
1061
+ ParseArgs<Args...>(input_types_, output_types_);
1062
+ OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
1063
+ auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1064
+ ArgPtrs args;
1065
+ auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1066
+ return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t);
1067
+ };
1068
+ }
1069
+
1070
+ template <typename C>
1071
+ decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) {
1072
+ OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
1073
+ ShapeInferContext ctx(&GetApi(), ort_ctx);
1074
+ return C::InferOutputShape(ctx);
1075
+ };
1076
+ return {};
1077
+ }
1078
+
1079
+ template <typename C>
1080
+ void SetShapeInfer(...) {
1081
+ OrtCustomOp::InferOutputShapeFn = {};
1082
+ }
1083
+ }; // struct OrtLiteCustomStruct
1084
+
1085
+ /////////////////////////// CreateLiteCustomOp ////////////////////////////
1086
+
1087
+ template <typename... Args>
1088
+ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1089
+ const char* execution_provider,
1090
+ void (*custom_compute_fn)(Args...),
1091
+ Status (*shape_infer_fn)(ShapeInferContext&) = {},
1092
+ int start_ver = 1,
1093
+ int end_ver = MAX_CUSTOM_OP_END_VER) {
1094
+ using LiteOp = OrtLiteCustomFunc<Args...>;
1095
+ return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
1096
+ }
1097
+
1098
+ template <typename... Args>
1099
+ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1100
+ const char* execution_provider,
1101
+ Status (*custom_compute_fn_v2)(Args...),
1102
+ Status (*shape_infer_fn)(ShapeInferContext&) = {},
1103
+ int start_ver = 1,
1104
+ int end_ver = MAX_CUSTOM_OP_END_VER) {
1105
+ using LiteOp = OrtLiteCustomFunc<Args...>;
1106
+ return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
1107
+ }
1108
+
1109
+ template <typename CustomOp>
1110
+ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1111
+ const char* execution_provider,
1112
+ int start_ver = 1,
1113
+ int end_ver = MAX_CUSTOM_OP_END_VER) {
1114
+ using LiteOp = OrtLiteCustomStruct<CustomOp>;
1115
+ return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
1116
+ }
1117
+
1118
+ } // namespace Custom
1119
+ } // namespace Ort
1.20.0/onnxruntime.xcframework/Headers/onnxruntime_run_options_config_keys.h ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ /*
7
+ * This file defines RunOptions Config Keys and format of the Config Values.
8
+ *
9
+ * The Naming Convention for a RunOptions Config Key,
10
+ * "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
11
+ * Such as "ep.cuda.use_arena"
12
+ * The Config Key cannot be empty
13
+ * The maximum length of the Config Key is 128
14
+ *
15
+ * The string format of a RunOptions Config Value is defined individually for each Config.
16
+ * The maximum length of the Config Value is 1024
17
+ */
18
+
19
+ // Key for enabling shrinkages of user listed device memory arenas.
20
+ // Expects a list of semi-colon separated key value pairs separated by colon in the following format:
21
+ // "device_0:device_id_0;device_1:device_id_1"
22
+ // No white-spaces allowed in the provided list string.
23
+ // Currently, the only supported devices are : "cpu", "gpu" (case sensitive).
24
+ // If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled.
25
+ // Example usage: "cpu:0;gpu:0" (or) "gpu:0"
26
+ // By default, the value for this key is empty (i.e.) no memory arenas are shrunk
27
+ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage";
28
+
29
+ // Set to '1' to not synchronize execution providers with CPU at the end of session run.
30
+ // Per default it will be set to '0'
31
+ // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
32
+ static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";
33
+
34
+ // Set HTP performance mode for QNN HTP backend before session run.
35
+ // options for HTP performance mode: "burst", "balanced", "default", "high_performance",
36
+ // "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
37
+ // "sustained_high_performance". Default to "default".
38
+ static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode";
39
+
40
+ // Set HTP performance mode for QNN HTP backend post session run.
41
+ static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run";
42
+
43
+ // Set RPC control latency for QNN HTP backend
44
+ static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";
45
+
46
+ // Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true.
47
+ // The value should be an integer. If the value is not set, the default value is 0 and
48
+ // ORT session only captures one cuda graph before another capture is requested.
49
+ // If the value is set to -1, cuda graph capture/replay is disabled in that run.
50
+ // User are not expected to set the value to 0 as it is reserved for internal use.
51
+ static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id";
1.20.0/onnxruntime.xcframework/Headers/onnxruntime_session_options_config_keys.h ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ /*
7
+ * This file defines SessionOptions Config Keys and format of the Config Values.
8
+ *
9
+ * The Naming Convention for a SessionOptions Config Key,
10
+ * "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
11
+ * Such as "ep.cuda.use_arena"
12
+ * The Config Key cannot be empty
13
+ * The maximum length of the Config Key is 128
14
+ *
15
+ * The string format of a SessionOptions Config Value is defined individually for each Config.
16
+ * The maximum length of the Config Value is 1024
17
+ */
18
+
19
+ // Key for disable PrePacking,
20
+ // If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value)
21
+ static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking";
22
+
23
+ // A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session
24
+ // will be used. Use this to override the usage of env allocators on a per session level.
25
+ static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators";
26
+
27
+ // Set to 'ORT' (case sensitive) to load an ORT format model.
28
+ // If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT
29
+ static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format";
30
+
31
+ // Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set.
32
+ // If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'.
33
+ static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format";
34
+
35
+ // If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0".
36
+ // When multiple sessions are created, a main thread doesn't override changes from succeeding session options,
37
+ // but threads in session thread pools follow option changes.
38
+ // When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and
39
+ // denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool.
40
+ // Note that an alternative way not using this option at runtime is to train and export a model without denormals
41
+ // and that's recommended because turning this option on may hurt model accuracy.
42
+ static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero";
43
+
44
+ // It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not.
45
+ // "0": enable. ORT does fusion logic for QDQ format.
46
+ // "1": disable. ORT doesn't do fusion logic for QDQ format.
47
+ // Its default value is "0" unless the DirectML execution provider is registered, in which case it defaults to "1".
48
+ static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq";
49
+
50
+ // It controls whether to enable Double QDQ remover and Identical Children Consolidation
51
+ // "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
52
+ // "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
53
+ // Its default value is "0"
54
+ static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover";
55
+
56
+ // If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been
57
+ // completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the
58
+ // Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to
59
+ // 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on
60
+ // other factors like whether the model was created using Quantization Aware Training or Post Training Quantization.
61
+ // As such, it's best to test to determine if enabling this works well for your scenario.
62
+ // The default value is "0"
63
+ // Available since version 1.11.
64
+ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup";
65
+
66
+ // Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0".
67
+ // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
68
+ static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
69
+
70
+ // This setting controls whether to enable AheadOfTime function inlining.
71
+ // AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
72
+ // as possible with the help of enabled execution providers.
73
+ // This can reduce the number of function calls and improve performance because it is done before
74
+ // Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available,
75
+ // one can disable the AOT inlining, produce an optimized model and postpone AOT until run time.
76
+ // "0": enable; "1": disable.
77
+ // Its default value is "0".
78
+ static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";
79
+
80
+ #ifdef ENABLE_TRAINING
81
+ // Specifies a path of the file containing a list of memory optimization configurations.
82
+ // The value should be a string indicating the file path of the config file.
83
+ // The content of the config file is a JSON struct like this:
84
+ // [
85
+ // "Gelu+Cast+:1:0",
86
+ // "Dropout+:1:1"
87
+ // ]
88
+ // Taking the example of "Gelu+Cast+:1:0",
89
+ // > "Gelu+Cast+" is the subgraph string, a valid "subgraph string" should be one subgraph representation
90
+ // output by ORT graph transformations.
91
+ // > "1" is "optimization strategy", valid values: 0 - disabled, 1 - recompute.
92
+ // > "0" is "number of subgraph to apply" which is used to control how many subgraphs to apply optimization,
93
+ // to avoid "oversaving" the memory.
94
+ static const char* const kOrtSessionOptionsMemoryOptimizerApplyConfig = "optimization.memory_optimizer_config";
95
+
96
+ // Specifies the config for detecting subgraphs for memory footprint reduction.
97
+ // The value should be a string contains int separated using commas. The default value is "0:0".
98
+ static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config";
99
+ #endif
100
+
101
+ // This setting if set should contain a comma separated list of optimizers names that should be disabled.
102
+ // Optimizers may take time to execute and affect model loading time. If you feel that a specific optimizer
103
+ // does not provider runtime benefits, but affects your model loading time you may disable it using this config
104
+ // entry. This option is not enabled in ORT_MINIMAL_BUILD build.
105
+ // A list of optimizes is available in onnxruntime/core/optimizer/graph_transformer_utils.cc
106
+ //
107
+ // Default is an empty string which means no optimizers are disabled.
108
+ static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers";
109
+
110
+ // Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
111
+ // Using device allocators means the memory allocation is made using malloc/new.
112
+ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";
113
+
114
+ // Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking
115
+ // "0": thread will block if found no job to run
116
+ // "1": default, thread will spin a number of times before blocking
117
+ static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
118
+ static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";
119
+
120
+ // Key for using model bytes directly for ORT format
121
+ // If a session is created using an input byte array contains the ORT format model data,
122
+ // By default we will copy the model bytes at the time of session creation to ensure the model bytes
123
+ // buffer is valid.
124
+ // Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller
125
+ // has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed.
126
+ static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly";
127
+
128
+ /// <summary>
129
+ /// Key for using the ORT format model flatbuffer bytes directly for initializers.
130
+ /// This avoids copying the bytes and reduces peak memory usage during model loading and initialization.
131
+ /// Requires `session.use_ort_model_bytes_directly` to be true.
132
+ /// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire
133
+ /// duration of the InferenceSession.
134
+ /// </summary>
135
+ static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers =
136
+ "session.use_ort_model_bytes_for_initializers";
137
+
138
+ // This should only be specified when exporting an ORT format model for use on a different platform.
139
+ // If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0"
140
+ // Available since version 1.11.
141
+ static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed";
142
+
143
+ // x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8.
144
+ // To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if
145
+ // turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512
146
+ // platforms.
147
+ static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision";
148
+
149
+ // Specifies how minimal build graph optimizations are handled in a full build.
150
+ // These optimizations are at the extended level or higher.
151
+ // Possible values and their effects are:
152
+ // "save": Save runtime optimizations when saving an ORT format model.
153
+ // "apply": Only apply optimizations available in a minimal build.
154
+ // ""/<unspecified>: Apply optimizations available in a full build.
155
+ // Available since version 1.11.
156
+ static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations =
157
+ "optimization.minimal_build_optimizations";
158
+
159
+ // Note: The options specific to an EP should be specified prior to appending that EP to the session options object in
160
+ // order for them to take effect.
161
+
162
+ // Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be
163
+ // run by the NNAPI EP.
164
+ // The value should be a ","-delimited list of op types. For example, "Add,Sub".
165
+ // If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
166
+ // exclusion, set the value to "".
167
+ static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
168
+
169
+ // Enabling dynamic block-sizing for multithreading.
170
+ // With a positive value, thread pool will split a task of N iterations to blocks of size starting from:
171
+ // N / (num_of_threads * dynamic_block_base)
172
+ // As execution progresses, the size will decrease according to the diminishing residual of N,
173
+ // meaning the task will be distributed in smaller granularity for better parallelism.
174
+ // For some models, it helps to reduce the variance of E2E inference latency and boost performance.
175
+ // The feature will not function by default, specify any positive integer, e.g. "4", to enable it.
176
+ // Available since version 1.11.
177
+ static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base";
178
+
179
+ // This option allows to decrease CPU usage between infrequent
180
+ // requests and forces any TP threads spinning stop immediately when the last of
181
+ // concurrent Run() call returns.
182
+ // Spinning is restarted on the next Run() call.
183
+ // Applies only to internal thread-pools
184
+ static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop";
185
+
186
+ // "1": all inconsistencies encountered during shape and type inference
187
+ // will result in failures.
188
+ // "0": in some cases warnings will be logged but processing will continue. The default.
189
+ // May be useful to expose bugs in models.
190
+ static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference";
191
+
192
+ // "1": every model using a more recent opset than the latest released one will fail
193
+ // "0": the model may or may not work if onnxruntime cannot find an implementation, this option
194
+ // is used for development purpose.
195
+ static const char* const kOrtSessionOptionsConfigStrictAllowReleasedOpsetsOnly = "session.allow_released_opsets_only";
196
+
197
+ // The file saves configuration for partitioning node among logic streams
198
+ static const char* const kNodePartitionConfigFile = "session.node_partition_config_file";
199
+
200
+ // This Option allows setting affinities for intra op threads.
201
+ // Affinity string follows format:
202
+ // logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id
203
+ // Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to.
204
+ // e.g.1,2,3;4,5
205
+ // specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th.
206
+ // To ease the configuration, an "interval" is also allowed:
207
+ // e.g. 1-8;8-16;17-24
208
+ // orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth.
209
+ // Note:
210
+ // 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which
211
+ // is started and managed by the calling app;
212
+ // 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors,
213
+ // an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group.
214
+ // Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary.
215
+ static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities";
216
+
217
+ // This option will dump out the model to assist debugging any issues with layout transformation,
218
+ // and is primarily intended for developer usage. It is only relevant if an execution provider that requests
219
+ // NHWC layout is enabled such as NNAPI, XNNPACK or QNN.
220
+ //
221
+ // Default is off. Set to "1" to enable.
222
+ //
223
+ // If modified by layout transformation the model will be dumped after these steps:
224
+ // 1) insertion of the layout transformation Transpose nodes
225
+ // 2) after those are optimized using the transpose optimizer,
226
+ // 3) after the L1 transformers are applied to the updated graph.
227
+ // The model will be saved to filename post_layout_transform_step_<step_number>.onnx.
228
+ static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation";
229
+
230
+ // Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are
231
+ // assigned (i.e., "fallback") to the CPU EP by default.
232
+ //
233
+ // This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP.
234
+ // If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot
235
+ // fully support all of the nodes in the graph.
236
+ //
237
+ // It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation
238
+ // will also fail with an error.
239
+ //
240
+ // Option values:
241
+ // - "0": CPU EP fallback is not disabled. [DEFAULT]
242
+ // - "1": CPU EP fallback is disabled.
243
+ static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback";
244
+
245
+ // Use this config when serializing a large model after optimization to specify an external initializers file
246
+ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName =
247
+ "session.optimized_model_external_initializers_file_name";
248
+
249
+ // Use this config to control the minimum size of the initializer when externalizing it during serialization
250
+ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
251
+ "session.optimized_model_external_initializers_min_size_in_bytes";
252
+
253
+ // Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
254
+ // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
255
+ // "0": disable. (default)
256
+ // "1": enable.
257
+ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";
258
+
259
+ // Specify the file path for the Onnx model which has EP context.
260
+ // Default to original_file_name_ctx.onnx if not specified
261
+ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";
262
+
263
+ // Flag to specify whether to dump the EP context into the Onnx model.
264
+ // "0": dump the EP context into separate file, keep the file name in the Onnx model.
265
+ // "1": dump the EP context into the Onnx model. (default).
266
+ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
267
+
268
+ // Specify the EPContext node name prefix to make it unique
269
+ // in case user need to merge/connect multiple EPContext nodes in one model
270
+ static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix";
271
+
272
+ // Share EP related resources across EPs
273
+ static const char* const kOrtSessionOptionShareEpContexts = "ep.share_ep_contexts";
274
+
275
+ // Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
276
+ // Option values:
277
+ // - "0": Gemm FastMath mode is not enabled. [DEFAULT]
278
+ // - "1": Gemm FastMath mode is enabled.
279
+ static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
280
+
281
+ // When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option.
282
+ // Refer to MatMulNBits op schema for more details.
283
+ // If not provided, default is 4.
284
+ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";
285
+
286
+ // THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME
287
+ // Meant to be used with SetEpDynamicOptions
288
+ // Specify the type of workload for this session.
289
+ // “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default]
290
+ // “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
291
+ static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type";
1.20.0/onnxruntime.xcframework/Info.plist ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
3
+ <plist version="1.0">
4
+ <dict>
5
+ <key>AvailableLibraries</key>
6
+ <array>
7
+ <dict>
8
+ <key>BinaryPath</key>
9
+ <string>onnxruntime.a</string>
10
+ <key>LibraryIdentifier</key>
11
+ <string>ios-arm64</string>
12
+ <key>LibraryPath</key>
13
+ <string>onnxruntime.a</string>
14
+ <key>SupportedArchitectures</key>
15
+ <array>
16
+ <string>arm64</string>
17
+ </array>
18
+ <key>SupportedPlatform</key>
19
+ <string>ios</string>
20
+ </dict>
21
+ <dict>
22
+ <key>BinaryPath</key>
23
+ <string>onnxruntime.a</string>
24
+ <key>LibraryIdentifier</key>
25
+ <string>ios-arm64_x86_64-simulator</string>
26
+ <key>LibraryPath</key>
27
+ <string>onnxruntime.a</string>
28
+ <key>SupportedArchitectures</key>
29
+ <array>
30
+ <string>arm64</string>
31
+ <string>x86_64</string>
32
+ </array>
33
+ <key>SupportedPlatform</key>
34
+ <string>ios</string>
35
+ <key>SupportedPlatformVariant</key>
36
+ <string>simulator</string>
37
+ </dict>
38
+ <dict>
39
+ <key>BinaryPath</key>
40
+ <string>onnxruntime.a</string>
41
+ <key>LibraryIdentifier</key>
42
+ <string>macos-arm64_x86_64</string>
43
+ <key>LibraryPath</key>
44
+ <string>onnxruntime.a</string>
45
+ <key>SupportedArchitectures</key>
46
+ <array>
47
+ <string>arm64</string>
48
+ <string>x86_64</string>
49
+ </array>
50
+ <key>SupportedPlatform</key>
51
+ <string>macos</string>
52
+ </dict>
53
+ </array>
54
+ <key>CFBundlePackageType</key>
55
+ <string>XFWK</string>
56
+ <key>XCFrameworkFormatVersion</key>
57
+ <string>1.0</string>
58
+ </dict>
59
+ </plist>
1.20.0/onnxruntime.xcframework/ios-arm64/libonnxruntime.a ADDED
@@ -0,0 +1 @@
 
 
1
+ onnxruntime.a
1.20.0/onnxruntime.xcframework/ios-arm64/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c035dc64372bb649dcb6dbbd1f2617f8443243f97296aa961e08a3b4b29a155
3
+ size 72140680
1.20.0/onnxruntime.xcframework/ios-arm64_x86_64-simulator/libonnxruntime.a ADDED
@@ -0,0 +1 @@
 
 
1
+ onnxruntime.a
1.20.0/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f62d10daab2e8e93a73804c6370e6bb5ed2ff8fccd1b64a57b35b47144d301d
3
+ size 147749872
1.20.0/onnxruntime.xcframework/macos-arm64_x86_64/libonnxruntime.a ADDED
@@ -0,0 +1 @@
 
 
1
+ onnxruntime.a
1.20.0/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db30a783941fd6f1abc2a7be47face4b511d56290fce79b01d96ef4cd80eba6b
3
+ size 140554576