csukuangfj commited on
Commit
f1e2018
·
1 Parent(s): 41584bb

Add onnxruntime.xcframework 1.14.0

Browse files
1.14.0/onnxruntime.xcframework/Headers/coreml_provider_factory.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+ #pragma once
4
+
5
+ #include "onnxruntime_c_api.h"
6
+
7
+ // COREMLFlags are bool options we want to set for CoreML EP
8
+ // This enum is defined as bit flags, and cannot have negative value
9
+ // To generate an uint32_t coreml_flags for using with OrtSessionOptionsAppendExecutionProvider_CoreML below,
10
+ // uint32_t coreml_flags = 0;
11
+ // coreml_flags |= COREML_FLAG_USE_CPU_ONLY;
12
+ enum COREMLFlags {
13
+ COREML_FLAG_USE_NONE = 0x000,
14
+
15
+ // Using CPU only in CoreML EP, this may decrease the perf but will provide
16
+ // reference output value without precision loss, which is useful for validation
17
+ COREML_FLAG_USE_CPU_ONLY = 0x001,
18
+
19
+ // Enable CoreML EP on subgraph
20
+ COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002,
21
+
22
+ // By default CoreML Execution provider will be enabled for all compatible Apple devices
23
+ // Enable this option will only enable CoreML EP for Apple devices with ANE (Apple Neural Engine)
24
+ // Please note, enable this option does not guarantee the entire model to be executed using ANE only
25
+ COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004,
26
+
27
+ // Keep COREML_FLAG_MAX at the end of the enum definition
28
+ // And assign the last COREMLFlag to it
29
+ COREML_FLAG_LAST = COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE,
30
+ };
31
+
32
+ #ifdef __cplusplus
33
+ extern "C" {
34
+ #endif
35
+
36
+ ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CoreML,
37
+ _In_ OrtSessionOptions* options, uint32_t coreml_flags);
38
+
39
+ #ifdef __cplusplus
40
+ }
41
+ #endif
1.14.0/onnxruntime.xcframework/Headers/cpu_provider_factory.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #include "onnxruntime_c_api.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ /**
11
+ * \param use_arena zero: false. non-zero: true.
12
+ */
13
+ ORT_EXPORT
14
+ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
15
+ ORT_ALL_ARGS_NONNULL;
16
+
17
+ #ifdef __cplusplus
18
+ }
19
+ #endif
1.14.0/onnxruntime.xcframework/Headers/onnxruntime_c_api.h ADDED
The diff for this file is too large to render. See raw diff
 
1.14.0/onnxruntime.xcframework/Headers/onnxruntime_cxx_api.h ADDED
@@ -0,0 +1,1876 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5
+ //
6
+ // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7
+ // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
8
+ // all the resources follow RAII and do not leak memory.
9
+ //
10
+ // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
11
+ // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
12
+ // until you assign an instance that actually holds an underlying object.
13
+ //
14
+ // For Ort objects only move assignment between objects is allowed, there are no copy constructors.
15
+ // Some objects have explicit 'Clone' methods for this purpose.
16
+ //
17
+ // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
18
+ // by value or by reference. ConstXXXX types are restricted to const only interfaces.
19
+ //
20
+ // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
21
+ //
22
+ // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
23
+ // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
24
+
25
+ #pragma once
26
+ #include "onnxruntime_c_api.h"
27
+ #include <cstddef>
28
+ #include <array>
29
+ #include <memory>
30
+ #include <stdexcept>
31
+ #include <string>
32
+ #include <vector>
33
+ #include <unordered_map>
34
+ #include <utility>
35
+ #include <type_traits>
36
+
37
+ #ifdef ORT_NO_EXCEPTIONS
38
+ #include <iostream>
39
+ #endif
40
+
41
+ /** \brief All C++ Onnxruntime APIs are defined inside this namespace
42
+ *
43
+ */
44
+ namespace Ort {
45
+
46
+ /** \brief All C++ methods that can fail will throw an exception of this type
47
+ *
48
+ * If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
49
+ */
50
+ struct Exception : std::exception {
51
+ Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
52
+
53
+ OrtErrorCode GetOrtErrorCode() const { return code_; }
54
+ const char* what() const noexcept override { return message_.c_str(); }
55
+
56
+ private:
57
+ std::string message_;
58
+ OrtErrorCode code_;
59
+ };
60
+
61
+ #ifdef ORT_NO_EXCEPTIONS
62
+ // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
63
+ // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
64
+ #ifndef ORT_CXX_API_THROW
65
+ #define ORT_CXX_API_THROW(string, code) \
66
+ do { \
67
+ std::cerr << Ort::Exception(string, code) \
68
+ .what() \
69
+ << std::endl; \
70
+ abort(); \
71
+ } while (false)
72
+ #endif
73
+ #else
74
+ #define ORT_CXX_API_THROW(string, code) \
75
+ throw Ort::Exception(string, code)
76
+ #endif
77
+
78
+ // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
79
+ // it's in a template so that we can define a global variable in a header and make
80
+ // it transparent to the users of the API.
81
+ template <typename T>
82
+ struct Global {
83
+ static const OrtApi* api_;
84
+ };
85
+
86
+ // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
87
+ template <typename T>
88
+ #ifdef ORT_API_MANUAL_INIT
89
+ const OrtApi* Global<T>::api_{};
90
+ inline void InitApi() { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
91
+
92
+ // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
93
+ // required by C++ APIs.
94
+ //
95
+ // Example mycustomop.cc:
96
+ //
97
+ // #define ORT_API_MANUAL_INIT
98
+ // #include <onnxruntime_cxx_api.h>
99
+ // #undef ORT_API_MANUAL_INIT
100
+ //
101
+ // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
102
+ // Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
103
+ // // ...
104
+ // }
105
+ //
106
+ inline void InitApi(const OrtApi* api) { Global<void>::api_ = api; }
107
+ #else
108
+ #if defined(_MSC_VER) && !defined(__clang__)
109
+ #pragma warning(push)
110
+ // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
111
+ // Please define ORT_API_MANUAL_INIT if it conerns you.
112
+ #pragma warning(disable : 26426)
113
+ #endif
114
+ const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
115
+ #if defined(_MSC_VER) && !defined(__clang__)
116
+ #pragma warning(pop)
117
+ #endif
118
+ #endif
119
+
120
+ /// This returns a reference to the OrtApi interface in use
121
+ inline const OrtApi& GetApi() { return *Global<void>::api_; }
122
+
123
+ /// <summary>
124
+ /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
125
+ /// returns a vector of strings representing the available execution providers.
126
+ /// </summary>
127
+ /// <returns>vector of strings</returns>
128
+ std::vector<std::string> GetAvailableProviders();
129
+
130
+ /** \brief IEEE 754 half-precision floating point data type
131
+ * \details It is necessary for type dispatching to make use of C++ API
132
+ * The type is implicitly convertible to/from uint16_t.
133
+ * The size of the structure should align with uint16_t and one can freely cast
134
+ * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
135
+ *
136
+ * Generally, you can feed any of your types as float16/blfoat16 data to create a tensor
137
+ * on top of it, providing it can form a continuous buffer with 16-bit elements with no padding.
138
+ * And you can also feed a array of uint16_t elements directly. For example,
139
+ *
140
+ * \code{.unparsed}
141
+ * uint16_t values[] = { 15360, 16384, 16896, 17408, 17664};
142
+ * constexpr size_t values_length = sizeof(values) / sizeof(values[0]);
143
+ * std::vector<int64_t> dims = {values_length}; // one dimensional example
144
+ * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
145
+ * // Note we are passing bytes count in this api, not number of elements -> sizeof(values)
146
+ * auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values),
147
+ * dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
148
+ * \endcode
149
+ *
150
+ * Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use
151
+ * a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra
152
+ * template specialization.
153
+ *
154
+ * \code{.unparsed}
155
+ * namespace yours { struct half {}; } // assume this is your type, define this:
156
+ * namespace Ort {
157
+ * template<>
158
+ * struct TypeToTensorType<yours::half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
159
+ * } //namespace Ort
160
+ *
161
+ * std::vector<yours::half> values;
162
+ * std::vector<int64_t> dims = {values.size()}; // one dimensional example
163
+ * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
164
+ * // Here we are passing element count -> values.size()
165
+ * auto float16_tensor = Ort::Value::CreateTensor<yours::half>(info, values.data(), values.size(), dims.data(), dims.size());
166
+ *
167
+ * \endcode
168
+ */
169
+ struct Float16_t {
170
+ uint16_t value;
171
+ constexpr Float16_t() noexcept : value(0) {}
172
+ constexpr Float16_t(uint16_t v) noexcept : value(v) {}
173
+ constexpr operator uint16_t() const noexcept { return value; }
174
+ constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; };
175
+ constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; };
176
+ };
177
+
178
+ static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
179
+
180
+ /** \brief bfloat16 (Brain Floating Point) data type
181
+ * \details It is necessary for type dispatching to make use of C++ API
182
+ * The type is implicitly convertible to/from uint16_t.
183
+ * The size of the structure should align with uint16_t and one can freely cast
184
+ * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
185
+ *
186
+ * See also code examples for Float16_t above.
187
+ */
188
+ struct BFloat16_t {
189
+ uint16_t value;
190
+ constexpr BFloat16_t() noexcept : value(0) {}
191
+ constexpr BFloat16_t(uint16_t v) noexcept : value(v) {}
192
+ constexpr operator uint16_t() const noexcept { return value; }
193
+ constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; };
194
+ constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; };
195
+ };
196
+
197
+ static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
198
+
199
+ namespace detail {
200
+ // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
201
+ // This can't be done in the C API since C doesn't have function overloading.
202
+ #define ORT_DEFINE_RELEASE(NAME) \
203
+ inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
204
+
205
+ ORT_DEFINE_RELEASE(Allocator);
206
+ ORT_DEFINE_RELEASE(MemoryInfo);
207
+ ORT_DEFINE_RELEASE(CustomOpDomain);
208
+ ORT_DEFINE_RELEASE(ThreadingOptions);
209
+ ORT_DEFINE_RELEASE(Env);
210
+ ORT_DEFINE_RELEASE(RunOptions);
211
+ ORT_DEFINE_RELEASE(Session);
212
+ ORT_DEFINE_RELEASE(SessionOptions);
213
+ ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
214
+ ORT_DEFINE_RELEASE(SequenceTypeInfo);
215
+ ORT_DEFINE_RELEASE(MapTypeInfo);
216
+ ORT_DEFINE_RELEASE(TypeInfo);
217
+ ORT_DEFINE_RELEASE(Value);
218
+ ORT_DEFINE_RELEASE(ModelMetadata);
219
+ ORT_DEFINE_RELEASE(IoBinding);
220
+ ORT_DEFINE_RELEASE(ArenaCfg);
221
+ ORT_DEFINE_RELEASE(Status);
222
+ ORT_DEFINE_RELEASE(OpAttr);
223
+ ORT_DEFINE_RELEASE(Op);
224
+ ORT_DEFINE_RELEASE(KernelInfo);
225
+
226
+ #undef ORT_DEFINE_RELEASE
227
+
228
+ /** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object
229
+ * has no ownership of the underlying C object.
230
+ */
231
+ template <typename T>
232
+ struct Unowned {
233
+ using Type = T;
234
+ };
235
+
236
+ /** \brief Used internally by the C++ API. C++ wrapper types inherit from this.
237
+ * This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
238
+ *
239
+ * All of the C++ classes
240
+ * a) serve as containers for pointers to objects that are created by the underlying C API.
241
+ * Their size is just a pointer size, no need to dynamically allocate them. Use them by value.
242
+ * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects.
243
+ * they would release objects owned automatically when going out of scope, they are move-only.
244
+ * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers.
245
+ * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else
246
+ * such as Onnxruntime or instances of XXXX classes.
247
+ * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used
248
+ * in C++ code.
249
+ *
250
+ */
251
+
252
+ /// <summary>
253
+ /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction.
254
+ /// </summary>
255
+ template <typename T>
256
+ struct Base {
257
+ using contained_type = T;
258
+
259
+ constexpr Base() = default;
260
+ constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
261
+ ~Base() { OrtRelease(p_); }
262
+
263
+ Base(const Base&) = delete;
264
+ Base& operator=(const Base&) = delete;
265
+
266
+ Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
267
+ Base& operator=(Base&& v) noexcept {
268
+ OrtRelease(p_);
269
+ p_ = v.release();
270
+ return *this;
271
+ }
272
+
273
+ constexpr operator contained_type*() const noexcept { return p_; }
274
+
275
+ /// \brief Relinquishes ownership of the contained C object pointer
276
+ /// The underlying object is not destroyed
277
+ contained_type* release() {
278
+ T* p = p_;
279
+ p_ = nullptr;
280
+ return p;
281
+ }
282
+
283
+ protected:
284
+ contained_type* p_{};
285
+ };
286
+
287
+ // Undefined. For const types use Base<Unowned<const T>>
288
+ template <typename T>
289
+ struct Base<const T>;
290
+
291
+ /// <summary>
292
+ /// Covers unowned pointers owned by either the ORT
293
+ /// or some other instance of CPP wrappers.
294
+ /// Used for ConstXXX and UnownedXXXX types that are copyable.
295
+ /// Also convenient to wrap raw OrtXX pointers .
296
+ /// </summary>
297
+ /// <typeparam name="T"></typeparam>
298
+ template <typename T>
299
+ struct Base<Unowned<T>> {
300
+ using contained_type = typename Unowned<T>::Type;
301
+
302
+ constexpr Base() = default;
303
+ constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
304
+
305
+ ~Base() = default;
306
+
307
+ Base(const Base&) = default;
308
+ Base& operator=(const Base&) = default;
309
+
310
+ Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
311
+ Base& operator=(Base&& v) noexcept {
312
+ p_ = nullptr;
313
+ std::swap(p_, v.p_);
314
+ return *this;
315
+ }
316
+
317
+ constexpr operator contained_type*() const noexcept { return p_; }
318
+
319
+ protected:
320
+ contained_type* p_{};
321
+ };
322
+
323
+ // Light functor to release memory with OrtAllocator
324
+ struct AllocatedFree {
325
+ OrtAllocator* allocator_;
326
+ explicit AllocatedFree(OrtAllocator* allocator)
327
+ : allocator_(allocator) {}
328
+ void operator()(void* ptr) const {
329
+ if (ptr) allocator_->Free(allocator_, ptr);
330
+ }
331
+ };
332
+
333
+ } // namespace detail
334
+
335
+ struct AllocatorWithDefaultOptions;
336
+ struct Env;
337
+ struct TypeInfo;
338
+ struct Value;
339
+ struct ModelMetadata;
340
+
341
+ /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
342
+ * and release them at the end of the scope. The lifespan of the given allocator
343
+ * must eclipse the lifespan of AllocatedStringPtr instance
344
+ */
345
+ using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
346
+
347
+ /** \brief The Status that holds ownership of OrtStatus received from C API
348
+ * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate
349
+ * constructors to construct an instance of a Status object from exceptions.
350
+ */
351
+ struct Status : detail::Base<OrtStatus> {
352
+ explicit Status(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used
353
+ explicit Status(OrtStatus* status); ///< Takes ownership of OrtStatus instance returned from the C API. Must be non-null
354
+ explicit Status(const Exception&); ///< Creates status instance out of exception
355
+ explicit Status(const std::exception&); ///< Creates status instance out of exception
356
+ std::string GetErrorMessage() const;
357
+ OrtErrorCode GetErrorCode() const;
358
+ };
359
+
360
+ /** \brief The ThreadingOptions
361
+ *
362
+ * The ThreadingOptions used for set global threadpools' options of The Env.
363
+ */
364
+ struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
365
+ /// \brief Wraps OrtApi::CreateThreadingOptions
366
+ ThreadingOptions();
367
+
368
+ /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
369
+ ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
370
+
371
+ /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
372
+ ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
373
+
374
+ /// \brief Wraps OrtApi::SetGlobalSpinControl
375
+ ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
376
+
377
+ /// \brief Wraps OrtApi::SetGlobalDenormalAsZero
378
+ ThreadingOptions& SetGlobalDenormalAsZero();
379
+
380
+ /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
381
+ ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
382
+
383
+ /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
384
+ ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
385
+
386
+ /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
387
+ ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
388
+ };
389
+
390
+ /** \brief The Env (Environment)
391
+ *
392
+ * The Env holds the logging state used by all other objects.
393
+ * <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
394
+ */
395
+ struct Env : detail::Base<OrtEnv> {
396
+ explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used
397
+
398
+ /// \brief Wraps OrtApi::CreateEnv
399
+ Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
400
+
401
+ /// \brief Wraps OrtApi::CreateEnvWithCustomLogger
402
+ Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
403
+
404
+ /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
405
+ Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
406
+
407
+ /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
408
+ Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
409
+ OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
410
+
411
+ /// \brief C Interop Helper
412
+ explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
413
+
414
+ Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents
415
+ Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents
416
+
417
+ Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel
418
+
419
+ Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
420
+ };
421
+
422
+ /** \brief Custom Op Domain
423
+ *
424
+ */
425
+ struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
426
+ explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
427
+
428
+ /// \brief Wraps OrtApi::CreateCustomOpDomain
429
+ explicit CustomOpDomain(const char* domain);
430
+
431
+ // This does not take ownership of the op, simply registers it.
432
+ void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
433
+ };
434
+
435
+ /** \brief RunOptions
436
+ *
437
+ */
438
+ struct RunOptions : detail::Base<OrtRunOptions> {
439
+ explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
440
+ RunOptions(); ///< Wraps OrtApi::CreateRunOptions
441
+
442
+ RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
443
+ int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
444
+
445
+ RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
446
+ int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
447
+
448
+ RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
449
+ const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
450
+
451
+ RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
452
+
453
+ /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
454
+ *
455
+ * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
456
+ * Wraps OrtApi::RunOptionsSetTerminate
457
+ */
458
+ RunOptions& SetTerminate();
459
+
460
+ /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
461
+ *
462
+ * Wraps OrtApi::RunOptionsUnsetTerminate
463
+ */
464
+ RunOptions& UnsetTerminate();
465
+ };
466
+
467
+
468
+ namespace detail {
469
+ // Utility function that returns a SessionOption config entry key for a specific custom operator.
470
+ // Ex: custom_op.[custom_op_name].[config]
471
+ std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
472
+ } // namespace detail
473
+
474
+ /// <summary>
475
+ /// Class that represents session configuration entries for one or more custom operators.
476
+ ///
477
+ /// Example:
478
+ /// Ort::CustomOpConfigs op_configs;
479
+ /// op_configs.AddConfig("my_custom_op", "device_type", "CPU");
480
+ ///
481
+ /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary.
482
+ /// </summary>
483
+ struct CustomOpConfigs {
484
+ CustomOpConfigs() = default;
485
+ ~CustomOpConfigs() = default;
486
+ CustomOpConfigs(const CustomOpConfigs&) = default;
487
+ CustomOpConfigs& operator=(const CustomOpConfigs&) = default;
488
+ CustomOpConfigs(CustomOpConfigs&& o) = default;
489
+ CustomOpConfigs& operator=(CustomOpConfigs&& o) = default;
490
+
491
+ /** \brief Adds a session configuration entry/value for a specific custom operator.
492
+ *
493
+ * \param custom_op_name The name of the custom operator for which to add a configuration entry.
494
+ * Must match the name returned by the CustomOp's GetName() method.
495
+ * \param config_key The name of the configuration entry.
496
+ * \param config_value The value of the configuration entry.
497
+ * \return A reference to this object to enable call chaining.
498
+ */
499
+ CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
500
+
501
+ /** \brief Returns a flattened map of custom operator configuration entries and their values.
502
+ *
503
+ * The keys has been flattened to include both the custom operator name and the configuration entry key name.
504
+ * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair
505
+ * {"my_op.key", "value"}.
506
+ *
507
+ * \return An unordered map of flattened configurations.
508
+ */
509
+ const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
510
+
511
+ private:
512
+ std::unordered_map<std::string, std::string> flat_configs_;
513
+ };
514
+
515
+ /** \brief Options object used when creating a new Session object
516
+ *
517
+ * Wraps ::OrtSessionOptions object and methods
518
+ */
519
+
520
+ struct SessionOptions;
521
+
522
+ namespace detail {
523
+ // we separate const-only methods because passing const ptr to non-const methods
524
+ // is only discovered when inline methods are compiled which is counter-intuitive
525
+ template <typename T>
526
+ struct ConstSessionOptionsImpl : Base<T> {
527
+ using B = Base<T>;
528
+ using B::B;
529
+
530
+ SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
531
+
532
+ std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
533
+ bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
534
+ std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
535
+ };
536
+
537
+ template <typename T>
538
+ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
539
+ using B = ConstSessionOptionsImpl<T>;
540
+ using B::B;
541
+
542
+ SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
543
+ SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
544
+ SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
545
+
546
+ SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
547
+ SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
548
+
549
+ SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath
550
+
551
+ SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling
552
+ SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling
553
+
554
+ SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps
555
+
556
+ SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern
557
+ SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern
558
+
559
+ SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
560
+
561
+ SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
562
+ SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
563
+
564
+ SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain
565
+
566
+ SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
567
+
568
+ SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
569
+
570
+ SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
571
+ SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
572
+
573
+ SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
574
+ SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
575
+ SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
576
+ SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
577
+ SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
578
+ SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
579
+ SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
580
+ ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
581
+ SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
582
+ /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports SNPE and XNNPACK.
583
+ SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
584
+ const std::unordered_map<std::string, std::string>& provider_options = {});
585
+
586
+ SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
587
+ SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
588
+ SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
589
+
590
+ ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2.
591
+ ///< The custom operator configurations are optional. If provided, custom operator configs are set via
592
+ ///< OrtApi::AddSessionConfigEntry.
593
+ SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
594
+
595
+ SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction
596
+ };
597
+ } // namespace detail
598
+
599
+ using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>;
600
+ using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>;
601
+
602
+ /** \brief Wrapper around ::OrtSessionOptions
603
+ *
604
+ */
605
+ struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
606
+ explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used
607
+ SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions
608
+ explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {} ///< Used for interop with the C API
609
+ UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; }
610
+ ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
611
+ };
612
+
613
+ /** \brief Wrapper around ::OrtModelMetadata
614
+ *
615
+ */
616
+ struct ModelMetadata : detail::Base<OrtModelMetadata> {
617
+ explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
618
+ explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
619
+
620
+ /** \brief Returns a copy of the producer name.
621
+ *
622
+ * \param allocator to allocate memory for the copy of the name returned
623
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
624
+ * The OrtAllocator instances must be valid at the point of memory release.
625
+ */
626
+ AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
627
+
628
+ /** \brief Returns a copy of the graph name.
629
+ *
630
+ * \param allocator to allocate memory for the copy of the name returned
631
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
632
+ * The OrtAllocator instances must be valid at the point of memory release.
633
+ */
634
+ AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
635
+
636
+ /** \brief Returns a copy of the domain name.
637
+ *
638
+ * \param allocator to allocate memory for the copy of the name returned
639
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
640
+ * The OrtAllocator instances must be valid at the point of memory release.
641
+ */
642
+ AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
643
+
644
+ /** \brief Returns a copy of the description.
645
+ *
646
+ * \param allocator to allocate memory for the copy of the string returned
647
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
648
+ * The OrtAllocator instances must be valid at the point of memory release.
649
+ */
650
+ AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription
651
+
652
+ /** \brief Returns a copy of the graph description.
653
+ *
654
+ * \param allocator to allocate memory for the copy of the string returned
655
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
656
+ * The OrtAllocator instances must be valid at the point of memory release.
657
+ */
658
+ AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription
659
+
660
+ /** \brief Returns a vector of copies of the custom metadata keys.
661
+ *
662
+ * \param allocator to allocate memory for the copy of the string returned
663
+ * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope.
664
+ * The OrtAllocator instance must be valid at the point of memory release.
665
+ */
666
+ std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
667
+
668
+ /** \brief Looks up a value by a key in the Custom Metadata map
669
+ *
670
+ * \param key zero terminated string key to lookup
671
+ * \param allocator to allocate memory for the copy of the string returned
672
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
673
+ * maybe nullptr if key is not found.
674
+ *
675
+ * The OrtAllocator instances must be valid at the point of memory release.
676
+ */
677
+ AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
678
+
679
+ int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
680
+ };
681
+
682
+ struct IoBinding;
683
+
684
+ namespace detail {
685
+
686
+ // we separate const-only methods because passing const ptr to non-const methods
687
+ // is only discovered when inline methods are compiled which is counter-intuitive
688
+ template <typename T>
689
+ struct ConstSessionImpl : Base<T> {
690
+ using B = Base<T>;
691
+ using B::B;
692
+
693
+ size_t GetInputCount() const; ///< Returns the number of model inputs
694
+ size_t GetOutputCount() const; ///< Returns the number of model outputs
695
+ size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden
696
+
697
+ /** \brief Returns a copy of input name at the specified index.
698
+ *
699
+ * \param index must less than the value returned by GetInputCount()
700
+ * \param allocator to allocate memory for the copy of the name returned
701
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
702
+ * The OrtAllocator instances must be valid at the point of memory release.
703
+ */
704
+ AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;
705
+
706
+ /** \brief Returns a copy of output name at then specified index.
707
+ *
708
+ * \param index must less than the value returned by GetOutputCount()
709
+ * \param allocator to allocate memory for the copy of the name returned
710
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
711
+ * The OrtAllocator instances must be valid at the point of memory release.
712
+ */
713
+ AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;
714
+
715
+ /** \brief Returns a copy of the overridable initializer name at then specified index.
716
+ *
717
+ * \param index must less than the value returned by GetOverridableInitializerCount()
718
+ * \param allocator to allocate memory for the copy of the name returned
719
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
720
+ * The OrtAllocator instances must be valid at the point of memory release.
721
+ */
722
+ AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
723
+
724
+ uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
725
+ ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
726
+
727
+ TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
728
+ TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
729
+ TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
730
+ };
731
+
732
+ template <typename T>
733
+ struct SessionImpl : ConstSessionImpl<T> {
734
+ using B = ConstSessionImpl<T>;
735
+ using B::B;
736
+
737
+ /** \brief Run the model returning results in an Ort allocated vector.
738
+ *
739
+ * Wraps OrtApi::Run
740
+ *
741
+ * The caller provides a list of inputs and a list of the desired outputs to return.
742
+ *
743
+ * See the output logs for more information on warnings/errors that occur while processing the model.
744
+ * Common errors are.. (TODO)
745
+ *
746
+ * \param[in] run_options
747
+ * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
748
+ * \param[in] input_values Array of Value objects of length input_count that is the list of input values
749
+ * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
750
+ * \param[in] output_names Array of C style strings of length output_count that is the list of output names
751
+ * \param[in] output_count Number of outputs (the size of the output_names array)
752
+ * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
753
+ */
754
+ std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
755
+ const char* const* output_names, size_t output_count);
756
+
757
+ /** \brief Run the model returning results in user provided outputs
758
+ * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
759
+ */
760
+ void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
761
+ const char* const* output_names, Value* output_values, size_t output_count);
762
+
763
+ void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
764
+
765
+ /** \brief End profiling and return a copy of the profiling file name.
766
+ *
767
+ * \param allocator to allocate memory for the copy of the string returned
768
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
769
+ * The OrtAllocator instances must be valid at the point of memory release.
770
+ */
771
+ AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling
772
+ };
773
+
774
+ } // namespace detail
775
+
776
+ using ConstSession = detail::ConstSessionImpl<detail::Unowned<const OrtSession>>;
777
+ using UnownedSession = detail::SessionImpl<detail::Unowned<OrtSession>>;
778
+
779
+ /** \brief Wrapper around ::OrtSession
780
+ *
781
+ */
782
+ struct Session : detail::SessionImpl<OrtSession> {
783
+ explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
784
+ Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
785
+ Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
786
+ OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
787
+ Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
788
+ Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
789
+ OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
790
+
791
+ ConstSession GetConst() const { return ConstSession{this->p_}; }
792
+ UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
793
+ };
794
+
795
+ namespace detail {
796
+ template <typename T>
797
+ struct MemoryInfoImpl : Base<T> {
798
+ using B = Base<T>;
799
+ using B::B;
800
+
801
+ std::string GetAllocatorName() const;
802
+ OrtAllocatorType GetAllocatorType() const;
803
+ int GetDeviceId() const;
804
+ OrtMemoryInfoDeviceType GetDeviceType() const;
805
+ OrtMemType GetMemoryType() const;
806
+
807
+ template <typename U>
808
+ bool operator==(const MemoryInfoImpl<U>& o) const;
809
+ };
810
+ } // namespace detail
811
+
812
+ // Const object holder that does not own the underlying object
813
+ using ConstMemoryInfo = detail::MemoryInfoImpl<detail::Unowned<const OrtMemoryInfo>>;
814
+
815
+ /** \brief Wrapper around ::OrtMemoryInfo
816
+ *
817
+ */
818
+ struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
819
+ static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
820
+ explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created
821
+ explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {} ///< Take ownership of a pointer created by C Api
822
+ MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
823
+ ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
824
+ };
825
+
826
+ namespace detail {
827
+ template <typename T>
828
+ struct TensorTypeAndShapeInfoImpl : Base<T> {
829
+ using B = Base<T>;
830
+ using B::B;
831
+
832
+ ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
833
+ size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount
834
+
835
+ size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount
836
+
837
+ /** \deprecated use GetShape() returning std::vector
838
+ * [[deprecated]]
839
+ * This interface is unsafe to use
840
+ */
841
+ [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions
842
+
843
+ void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
844
+
845
+ std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
846
+ };
847
+
848
+ } // namespace detail
849
+
850
+ using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl<detail::Unowned<const OrtTensorTypeAndShapeInfo>>;
851
+
852
+ /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
853
+ *
854
+ */
855
+ struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
856
+ explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
857
+ explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
858
+ ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; }
859
+ };
860
+
861
+ namespace detail {
862
+ template <typename T>
863
+ struct SequenceTypeInfoImpl : Base<T> {
864
+ using B = Base<T>;
865
+ using B::B;
866
+ TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType
867
+ };
868
+
869
+ } // namespace detail
870
+
871
+ using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl<detail::Unowned<const OrtSequenceTypeInfo>>;
872
+
873
+ /** \brief Wrapper around ::OrtSequenceTypeInfo
874
+ *
875
+ */
876
+ struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
877
+ explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
878
+ explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
879
+ ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
880
+ };
881
+
882
+ namespace detail {
883
+ template <typename T>
884
+ struct MapTypeInfoImpl : detail::Base<T> {
885
+ using B = Base<T>;
886
+ using B::B;
887
+ ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType
888
+ TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType
889
+ };
890
+
891
+ } // namespace detail
892
+
893
+ using ConstMapTypeInfo = detail::MapTypeInfoImpl<detail::Unowned<const OrtMapTypeInfo>>;
894
+
895
+ /** \brief Wrapper around ::OrtMapTypeInfo
896
+ *
897
+ */
898
+ struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
899
+ explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
900
+ explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
901
+ ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
902
+ };
903
+
904
+ namespace detail {
905
+ template <typename T>
906
+ struct TypeInfoImpl : detail::Base<T> {
907
+ using B = Base<T>;
908
+ using B::B;
909
+
910
+ ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo
911
+ ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
912
+ ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
913
+
914
+ ONNXType GetONNXType() const;
915
+ };
916
+ } // namespace detail
917
+
918
+ /// <summary>
919
+ /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value.
920
+ /// Provides access to const OrtTypeInfo APIs.
921
+ /// </summary>
922
+ using ConstTypeInfo = detail::TypeInfoImpl<detail::Unowned<const OrtTypeInfo>>;
923
+
924
+ /// <summary>
925
+ /// Type information that may contain either TensorTypeAndShapeInfo or
926
+ /// the information about contained sequence or map depending on the ONNXType.
927
+ /// </summary>
928
+ struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
929
+ explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
930
+ explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop
931
+
932
+ ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
933
+ };
934
+
935
+ namespace detail {
936
+ // This structure is used to feed sparse tensor values
937
+ // information for use with FillSparseTensor<Format>() API
938
+ // if the data type for the sparse tensor values is numeric
939
+ // use data.p_data, otherwise, use data.str pointer to feed
940
+ // values. data.str is an array of const char* that are zero terminated.
941
+ // number of strings in the array must match shape size.
942
+ // For fully sparse tensors use shape {0} and set p_data/str
943
+ // to nullptr.
944
+ struct OrtSparseValuesParam {
945
+ const int64_t* values_shape;
946
+ size_t values_shape_len;
947
+ union {
948
+ const void* p_data;
949
+ const char** str;
950
+ } data;
951
+ };
952
+
953
+ // Provides a way to pass shape in a single
954
+ // argument
955
+ struct Shape {
956
+ const int64_t* shape;
957
+ size_t shape_len;
958
+ };
959
+
960
+ template <typename T>
961
+ struct ConstValueImpl : Base<T> {
962
+ using B = Base<T>;
963
+ using B::B;
964
+
965
+ /// <summary>
966
+ /// Obtains a pointer to a user defined data for experimental purposes
967
+ /// </summary>
968
+ template <typename R>
969
+ void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue
970
+
971
+ bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
972
+ bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None
973
+
974
+ size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
975
+ Value GetValue(int index, OrtAllocator* allocator) const;
976
+
977
+ /// <summary>
978
+ /// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
979
+ /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
980
+ /// for allocating necessary memory and calling GetStringTensorContent().
981
+ /// </summary>
982
+ /// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
983
+ size_t GetStringTensorDataLength() const;
984
+
985
+ /// <summary>
986
+ /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
987
+ /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
988
+ /// The user must also allocate offsets buffer with the number of entries equal to that of the contained
989
+ /// strings.
990
+ ///
991
+ /// Strings are always assumed to be on CPU, no X-device copy.
992
+ /// </summary>
993
+ /// <param name="buffer">user allocated buffer</param>
994
+ /// <param name="buffer_length">length in bytes of the allocated buffer</param>
995
+ /// <param name="offsets">a pointer to the offsets user allocated buffer</param>
996
+ /// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
997
+ /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
998
+ /// for sparse tensors</param>
999
+ void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
1000
+
1001
+ /// <summary>
1002
+ /// Returns a const typed pointer to the tensor contained data.
1003
+ /// No type checking is performed, the caller must ensure the type matches the tensor type.
1004
+ /// </summary>
1005
+ /// <typeparam name="T"></typeparam>
1006
+ /// <returns>const pointer to data, no copies made</returns>
1007
+ template <typename R>
1008
+ const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// <summary>
1009
+
1010
+ /// <summary>
1011
+ /// Returns a non-typed pointer to a tensor contained data.
1012
+ /// </summary>
1013
+ /// <returns>const pointer to data, no copies made</returns>
1014
+ const void* GetTensorRawData() const;
1015
+
1016
+ /// <summary>
1017
+ /// The API returns type information for data contained in a tensor. For sparse
1018
+ /// tensors it returns type information for contained non-zero values.
1019
+ /// It returns dense shape for sparse tensors.
1020
+ /// </summary>
1021
+ /// <returns>TypeInfo</returns>
1022
+ TypeInfo GetTypeInfo() const;
1023
+
1024
+ /// <summary>
1025
+ /// The API returns type information for data contained in a tensor. For sparse
1026
+ /// tensors it returns type information for contained non-zero values.
1027
+ /// It returns dense shape for sparse tensors.
1028
+ /// </summary>
1029
+ /// <returns>TensorTypeAndShapeInfo</returns>
1030
+ TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
1031
+
1032
+ /// <summary>
1033
+ /// This API returns information about the memory allocation used to hold data.
1034
+ /// </summary>
1035
+ /// <returns>Non owning instance of MemoryInfo</returns>
1036
+ ConstMemoryInfo GetTensorMemoryInfo() const;
1037
+
1038
+ /// <summary>
1039
+ /// The API copies UTF-8 encoded bytes for the requested string element
1040
+ /// contained within a tensor or a sparse tensor into a provided buffer.
1041
+ /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
1042
+ /// </summary>
1043
+ /// <param name="buffer_length"></param>
1044
+ /// <param name="element_index"></param>
1045
+ /// <param name="buffer"></param>
1046
+ void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
1047
+
1048
+ /// <summary>
1049
+ /// The API returns a byte length of UTF-8 encoded string element
1050
+ /// contained in either a tensor or a spare tensor values.
1051
+ /// </summary>
1052
+ /// <param name="element_index"></param>
1053
+ /// <returns>byte length for the specified string element</returns>
1054
+ size_t GetStringTensorElementLength(size_t element_index) const;
1055
+
1056
+ #if !defined(DISABLE_SPARSE_TENSORS)
1057
+ /// <summary>
1058
+ /// The API returns the sparse data format this OrtValue holds in a sparse tensor.
1059
+ /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
1060
+ /// the value returned is ORT_SPARSE_UNDEFINED.
1061
+ /// </summary>
1062
+ /// <returns>Format enum</returns>
1063
+ OrtSparseFormat GetSparseFormat() const;
1064
+
1065
+ /// <summary>
1066
+ /// The API returns type and shape information for stored non-zero values of the
1067
+ /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
1068
+ /// </summary>
1069
+ /// <returns>TensorTypeAndShapeInfo values information</returns>
1070
+ TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
1071
+
1072
+ /// <summary>
1073
+ /// The API returns type and shape information for the specified indices. Each supported
1074
+ /// indices have their own enum values even if a give format has more than one kind of indices.
1075
+ /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
1076
+ /// </summary>
1077
+ /// <param name="format">enum requested</param>
1078
+ /// <returns>type and shape information</returns>
1079
+ TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const;
1080
+
1081
+ /// <summary>
1082
+ /// The API retrieves a pointer to the internal indices buffer. The API merely performs
1083
+ /// a convenience data type casting on the return type pointer. Make sure you are requesting
1084
+ /// the right type, use GetSparseTensorIndicesTypeShapeInfo();
1085
+ /// </summary>
1086
+ /// <typeparam name="T">type to cast to</typeparam>
1087
+ /// <param name="indices_format">requested indices kind</param>
1088
+ /// <param name="num_indices">number of indices entries</param>
1089
+ /// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
1090
+ template <typename R>
1091
+ const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
1092
+
1093
+ /// <summary>
1094
+ /// Returns true if the OrtValue contains a sparse tensor
1095
+ /// </summary>
1096
+ /// <returns></returns>
1097
+ bool IsSparseTensor() const;
1098
+
1099
+ /// <summary>
1100
+ /// The API returns a pointer to an internal buffer of the sparse tensor
1101
+ /// containing non-zero values. The API merely does casting. Make sure you
1102
+ /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
1103
+ /// first.
1104
+ /// </summary>
1105
+ /// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
1106
+ /// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
1107
+ template <typename R>
1108
+ const R* GetSparseTensorValues() const;
1109
+
1110
+ #endif
1111
+ };
1112
+
1113
+ template <typename T>
1114
+ struct ValueImpl : ConstValueImpl<T> {
1115
+ using B = ConstValueImpl<T>;
1116
+ using B::B;
1117
+
1118
+ /// <summary>
1119
+ /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer
1120
+ /// No type checking is performed, the caller must ensure the type matches the tensor type.
1121
+ /// </summary>
1122
+ /// <returns>non-const pointer to data, no copies made</returns>
1123
+ template <typename R>
1124
+ R* GetTensorMutableData();
1125
+
1126
+ /// <summary>
1127
+ /// Returns a non-typed non-const pointer to a tensor contained data.
1128
+ /// </summary>
1129
+ /// <returns>pointer to data, no copies made</returns>
1130
+ void* GetTensorMutableRawData();
1131
+
1132
+ /// <summary>
1133
+ // Obtain a reference to an element of data at the location specified
1134
+ /// by the vector of dims.
1135
+ /// </summary>
1136
+ /// <typeparam name="R"></typeparam>
1137
+ /// <param name="location">[in] expressed by a vecotr of dimensions offsets</param>
1138
+ /// <returns></returns>
1139
+ template <typename R>
1140
+ R& At(const std::vector<int64_t>& location);
1141
+
1142
+ /// <summary>
1143
+ /// Set all strings at once in a string tensor
1144
+ /// </summary>
1145
+ /// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param>
1146
+ /// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param>
1147
+ void FillStringTensor(const char* const* s, size_t s_len);
1148
+
1149
+ /// <summary>
1150
+ /// Set a single string in a string tensor
1151
+ /// </summary>
1152
+ /// <param name="s">[in] A null terminated UTF-8 encoded string</param>
1153
+ /// <param name="index">[in] Index of the string in the tensor to set</param>
1154
+ void FillStringTensorElement(const char* s, size_t index);
1155
+
1156
+ #if !defined(DISABLE_SPARSE_TENSORS)
1157
+ /// <summary>
1158
+ /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
1159
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1160
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
1161
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1162
+ /// </summary>
1163
+ /// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
1164
+ /// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
1165
+ void UseCooIndices(int64_t* indices_data, size_t indices_num);
1166
+
1167
+ /// <summary>
1168
+ /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
1169
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1170
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
1171
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1172
+ /// </summary>
1173
+ /// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
1174
+ /// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
1175
+ /// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
1176
+ /// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
1177
+ void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
1178
+
1179
+ /// <summary>
1180
+ /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
1181
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1182
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
1183
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1184
+ /// </summary>
1185
+ /// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
1186
+ /// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
1187
+ void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
1188
+
1189
+ /// <summary>
1190
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1191
+ /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
1192
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
1193
+ /// </summary>
1194
+ /// <param name="data_mem_info">specified buffer memory description</param>
1195
+ /// <param name="values_param">values buffer information.</param>
1196
+ /// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
1197
+ /// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
1198
+ void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
1199
+ const int64_t* indices_data, size_t indices_num);
1200
+
1201
+ /// <summary>
1202
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1203
+ /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
1204
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
1205
+ /// </summary>
1206
+ /// <param name="data_mem_info">specified buffer memory description</param>
1207
+ /// <param name="values">values buffer information</param>
1208
+ /// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
1209
+ /// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
1210
+ /// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
1211
+ /// <param name="outer_indices_num">number of csr outer indices or 0</param>
1212
+ void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1213
+ const OrtSparseValuesParam& values,
1214
+ const int64_t* inner_indices_data, size_t inner_indices_num,
1215
+ const int64_t* outer_indices_data, size_t outer_indices_num);
1216
+
1217
+ /// <summary>
1218
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1219
+ /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
1220
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
1221
+ /// </summary>
1222
+ /// <param name="data_mem_info">specified buffer memory description</param>
1223
+ /// <param name="values">values buffer information</param>
1224
+ /// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
1225
+ /// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
1226
+ void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1227
+ const OrtSparseValuesParam& values,
1228
+ const Shape& indices_shape,
1229
+ const int32_t* indices_data);
1230
+
1231
+ #endif
1232
+ };
1233
+
1234
+ } // namespace detail
1235
+
1236
+ using ConstValue = detail::ConstValueImpl<detail::Unowned<const OrtValue>>;
1237
+ using UnownedValue = detail::ValueImpl<detail::Unowned<OrtValue>>;
1238
+
1239
+ /** \brief Wrapper around ::OrtValue
1240
+ *
1241
+ */
1242
+ struct Value : detail::ValueImpl<OrtValue> {
1243
+ using Base = detail::ValueImpl<OrtValue>;
1244
+ using OrtSparseValuesParam = detail::OrtSparseValuesParam;
1245
+ using Shape = detail::Shape;
1246
+
1247
+ explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
1248
+ explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
1249
+ Value(Value&&) = default;
1250
+ Value& operator=(Value&&) = default;
1251
+
1252
+ ConstValue GetConst() const { return ConstValue{this->p_}; }
1253
+ UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
1254
+
1255
+ /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1256
+ * \tparam T The numeric datatype. This API is not suitable for strings.
1257
+ * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1258
+ * \param p_data Pointer to the data buffer.
1259
+ * \param p_data_element_count The number of elements in the data buffer.
1260
+ * \param shape Pointer to the tensor shape dimensions.
1261
+ * \param shape_len The number of tensor shape dimensions.
1262
+ */
1263
+ template <typename T>
1264
+ static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
1265
+
1266
+ /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1267
+ * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1268
+ * \param p_data Pointer to the data buffer.
1269
+ * \param p_data_byte_count The number of bytes in the data buffer.
1270
+ * \param shape Pointer to the tensor shape dimensions.
1271
+ * \param shape_len The number of tensor shape dimensions.
1272
+ * \param type The data type.
1273
+ */
1274
+ static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1275
+ ONNXTensorElementDataType type);
1276
+
1277
+ /** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1278
+ * \tparam T The numeric datatype. This API is not suitable for strings.
1279
+ * \param allocator The allocator to use.
1280
+ * \param shape Pointer to the tensor shape dimensions.
1281
+ * \param shape_len The number of tensor shape dimensions.
1282
+ */
1283
+ template <typename T>
1284
+ static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
1285
+
1286
+ /** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1287
+ * \param allocator The allocator to use.
1288
+ * \param shape Pointer to the tensor shape dimensions.
1289
+ * \param shape_len The number of tensor shape dimensions.
1290
+ * \param type The data type.
1291
+ */
1292
+ static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
1293
+
1294
+ static Value CreateMap(Value& keys, Value& values); ///< Wraps OrtApi::CreateValue
1295
+ static Value CreateSequence(std::vector<Value>& values); ///< Wraps OrtApi::CreateValue
1296
+
1297
+ template <typename T>
1298
+ static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue
1299
+
1300
+ #if !defined(DISABLE_SPARSE_TENSORS)
1301
+ /// <summary>
1302
+ /// This is a simple forwarding method to the other overload that helps deducing
1303
+ /// data type enum value from the type of the buffer.
1304
+ /// </summary>
1305
+ /// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
1306
+ /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1307
+ /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1308
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1309
+ /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1310
+ /// <returns></returns>
1311
+ template <typename T>
1312
+ static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1313
+ const Shape& values_shape);
1314
+
1315
+ /// <summary>
1316
+ /// Creates an OrtValue instance containing SparseTensor. This constructs
1317
+ /// a sparse tensor that makes use of user allocated buffers. It does not make copies
1318
+ /// of the user provided data and does not modify it. The lifespan of user provided buffers should
1319
+ /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
1320
+ /// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
1321
+ /// to supply a sparse format specific indices.
1322
+ /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
1323
+ /// can be properly copied into the allocated buffer.
1324
+ /// </summary>
1325
+ /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1326
+ /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1327
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1328
+ /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1329
+ /// <param name="type">data type</param>
1330
+ /// <returns>Ort::Value instance containing SparseTensor</returns>
1331
+ static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1332
+ const Shape& values_shape, ONNXTensorElementDataType type);
1333
+
1334
+ /// <summary>
1335
+ /// This is a simple forwarding method to the below CreateSparseTensor.
1336
+ /// This helps to specify data type enum in terms of C++ data type.
1337
+ /// Use CreateSparseTensor<T>
1338
+ /// </summary>
1339
+ /// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
1340
+ /// <param name="allocator">allocator to use</param>
1341
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1342
+ /// <returns>Ort::Value</returns>
1343
+ template <typename T>
1344
+ static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
1345
+
1346
+ /// <summary>
1347
+ /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
1348
+ /// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
1349
+ /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
1350
+ /// Use this API to create OrtValues that contain sparse tensors with all supported data types including
1351
+ /// strings.
1352
+ /// </summary>
1353
+ /// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
1354
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1355
+ /// <param name="type">data type</param>
1356
+ /// <returns>an instance of Ort::Value</returns>
1357
+ static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
1358
+
1359
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1360
+ };
1361
+
1362
+ /// <summary>
1363
+ /// Represents native memory allocation coming from one of the
1364
+ /// OrtAllocators registered with OnnxRuntime.
1365
+ /// Use it to wrap an allocation made by an allocator
1366
+ /// so it can be automatically released when no longer needed.
1367
+ /// </summary>
1368
+ struct MemoryAllocation {
1369
+ MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1370
+ ~MemoryAllocation();
1371
+ MemoryAllocation(const MemoryAllocation&) = delete;
1372
+ MemoryAllocation& operator=(const MemoryAllocation&) = delete;
1373
+ MemoryAllocation(MemoryAllocation&&) noexcept;
1374
+ MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
1375
+
1376
+ void* get() { return p_; }
1377
+ size_t size() const { return size_; }
1378
+
1379
+ private:
1380
+ OrtAllocator* allocator_;
1381
+ void* p_;
1382
+ size_t size_;
1383
+ };
1384
+
1385
+ namespace detail {
1386
+ template <typename T>
1387
+ struct AllocatorImpl : Base<T> {
1388
+ using B = Base<T>;
1389
+ using B::B;
1390
+
1391
+ void* Alloc(size_t size);
1392
+ MemoryAllocation GetAllocation(size_t size);
1393
+ void Free(void* p);
1394
+ ConstMemoryInfo GetInfo() const;
1395
+ };
1396
+
1397
+ } // namespace detail
1398
+
1399
+ /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime
1400
+ *
1401
+ */
1402
+ struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1403
+ explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1404
+ AllocatorWithDefaultOptions();
1405
+ };
1406
+
1407
+ /** \brief Wrapper around ::OrtAllocator
1408
+ *
1409
+ */
1410
+ struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1411
+ explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1412
+ Allocator(const Session& session, const OrtMemoryInfo*);
1413
+ };
1414
+
1415
+ using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
1416
+
1417
+ namespace detail {
1418
+ namespace binding_utils {
1419
+ // Bring these out of template
1420
+ std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
1421
+ std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
1422
+ } // namespace binding_utils
1423
+
1424
+ template <typename T>
1425
+ struct ConstIoBindingImpl : Base<T> {
1426
+ using B = Base<T>;
1427
+ using B::B;
1428
+
1429
+ std::vector<std::string> GetOutputNames() const;
1430
+ std::vector<std::string> GetOutputNames(OrtAllocator*) const;
1431
+ std::vector<Value> GetOutputValues() const;
1432
+ std::vector<Value> GetOutputValues(OrtAllocator*) const;
1433
+ };
1434
+
1435
+ template <typename T>
1436
+ struct IoBindingImpl : ConstIoBindingImpl<T> {
1437
+ using B = ConstIoBindingImpl<T>;
1438
+ using B::B;
1439
+
1440
+ void BindInput(const char* name, const Value&);
1441
+ void BindOutput(const char* name, const Value&);
1442
+ void BindOutput(const char* name, const OrtMemoryInfo*);
1443
+ void ClearBoundInputs();
1444
+ void ClearBoundOutputs();
1445
+ void SynchronizeInputs();
1446
+ void SynchronizeOutputs();
1447
+ };
1448
+
1449
+ } // namespace detail
1450
+
1451
+ using ConstIoBinding = detail::ConstIoBindingImpl<detail::Unowned<const OrtIoBinding>>;
1452
+ using UnownedIoBinding = detail::IoBindingImpl<detail::Unowned<OrtIoBinding>>;
1453
+
1454
+ /** \brief Wrapper around ::OrtIoBinding
1455
+ *
1456
+ */
1457
+ struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
1458
+ explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later.
1459
+ explicit IoBinding(Session& session);
1460
+ ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
1461
+ UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
1462
+ };
1463
+
1464
+ /*! \struct Ort::ArenaCfg
1465
+ * \brief it is a structure that represents the configuration of an arena based allocator
1466
+ * \details Please see docs/C_API.md for details
1467
+ */
1468
+ struct ArenaCfg : detail::Base<OrtArenaCfg> {
1469
+ explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
1470
+ /**
1471
+ * Wraps OrtApi::CreateArenaCfg
1472
+ * \param max_mem - use 0 to allow ORT to choose the default
1473
+ * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
1474
+ * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
1475
+ * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
1476
+ * See docs/C_API.md for details on what the following parameters mean and how to choose these values
1477
+ */
1478
+ ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
1479
+ };
1480
+
1481
+ //
1482
+ // Custom OPs (only needed to implement custom OPs)
1483
+ //
1484
+
1485
+ /// <summary>
1486
+ /// This struct provides life time management for custom op attribute
1487
+ /// </summary>
1488
+ struct OpAttr : detail::Base<OrtOpAttr> {
1489
+ OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
1490
+ };
1491
+
1492
+ /// <summary>
1493
+ /// This class wraps a raw pointer OrtKernelContext* that is being passed
1494
+ /// to the custom kernel Compute() method. Use it to safely access context
1495
+ /// attributes, input and output parameters with exception safety guarantees.
1496
+ /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
1497
+ /// </summary>
1498
+ struct KernelContext {
1499
+ explicit KernelContext(OrtKernelContext* context);
1500
+ size_t GetInputCount() const;
1501
+ size_t GetOutputCount() const;
1502
+ ConstValue GetInput(size_t index) const;
1503
+ UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
1504
+ UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
1505
+ void* GetGPUComputeStream() const;
1506
+
1507
+ private:
1508
+ OrtKernelContext* ctx_;
1509
+ };
1510
+
1511
+ struct KernelInfo;
1512
+
1513
+ namespace detail {
1514
+ namespace attr_utils {
1515
+ void GetAttr(const OrtKernelInfo* p, const char* name, float&);
1516
+ void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
1517
+ void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
1518
+ void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
1519
+ void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
1520
+ } // namespace attr_utils
1521
+
1522
+ template <typename T>
1523
+ struct KernelInfoImpl : Base<T> {
1524
+ using B = Base<T>;
1525
+ using B::B;
1526
+
1527
+ KernelInfo Copy() const;
1528
+
1529
+ template <typename R> // R is only implemented for float, int64_t, and string
1530
+ R GetAttribute(const char* name) const {
1531
+ R val;
1532
+ attr_utils::GetAttr(this->p_, name, val);
1533
+ return val;
1534
+ }
1535
+
1536
+ template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
1537
+ std::vector<R> GetAttributes(const char* name) const {
1538
+ std::vector<R> result;
1539
+ attr_utils::GetAttrs(this->p_, name, result);
1540
+ return result;
1541
+ }
1542
+
1543
+ Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
1544
+
1545
+ size_t GetInputCount() const;
1546
+ size_t GetOutputCount() const;
1547
+
1548
+ std::string GetInputName(size_t index) const;
1549
+ std::string GetOutputName(size_t index) const;
1550
+
1551
+ TypeInfo GetInputTypeInfo(size_t index) const;
1552
+ TypeInfo GetOutputTypeInfo(size_t index) const;
1553
+ };
1554
+
1555
+ } // namespace detail
1556
+
1557
+ using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>;
1558
+
1559
+ /// <summary>
1560
+ /// This struct owns the OrtKernInfo* pointer when a copy is made.
1561
+ /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor
1562
+ /// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance
1563
+ /// so it does not destroy the pointer the kernel does not own.
1564
+ /// </summary>
1565
+ struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
1566
+ explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
1567
+ explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
1568
+ ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
1569
+ };
1570
+
1571
+ /// <summary>
1572
+ /// Create and own custom defined operation.
1573
+ /// </summary>
1574
+ struct Op : detail::Base<OrtOp> {
1575
+ explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used
1576
+
1577
+ explicit Op(OrtOp*); ///< Take ownership of the OrtOp
1578
+
1579
+ static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
1580
+ int version, const char** type_constraint_names,
1581
+ const ONNXTensorElementDataType* type_constraint_values,
1582
+ size_t type_constraint_count,
1583
+ const OpAttr* attr_values,
1584
+ size_t attr_count,
1585
+ size_t input_count, size_t output_count);
1586
+
1587
+ void Invoke(const OrtKernelContext* context,
1588
+ const Value* input_values,
1589
+ size_t input_count,
1590
+ Value* output_values,
1591
+ size_t output_count);
1592
+
1593
+ // For easier refactoring
1594
+ void Invoke(const OrtKernelContext* context,
1595
+ const OrtValue* const* input_values,
1596
+ size_t input_count,
1597
+ OrtValue* const* output_values,
1598
+ size_t output_count);
1599
+ };
1600
+
1601
+ /// <summary>
1602
+ /// This entire structure is deprecated, but we not marking
1603
+ /// it as a whole yet since we want to preserve for the next release.
1604
+ /// </summary>
1605
+ struct CustomOpApi {
1606
+ CustomOpApi(const OrtApi& api) : api_(api) {}
1607
+
1608
+ /** \deprecated use Ort::Value::GetTensorTypeAndShape()
1609
+ * [[deprecated]]
1610
+ * This interface produces a pointer that must be released. Not exception safe.
1611
+ */
1612
+ [[deprecated("use Ort::Value::GetTensorTypeAndShape()")]] OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
1613
+
1614
+ /** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementCount()
1615
+ * [[deprecated]]
1616
+ * This interface is redundant.
1617
+ */
1618
+ [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementCount()")]] size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info);
1619
+
1620
+ /** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementType()
1621
+ * [[deprecated]]
1622
+ * This interface is redundant.
1623
+ */
1624
+ [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementType()")]] ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info);
1625
+
1626
+ /** \deprecated use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()
1627
+ * [[deprecated]]
1628
+ * This interface is redundant.
1629
+ */
1630
+ [[deprecated("use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()")]] size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info);
1631
+
1632
+ /** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape()
1633
+ * [[deprecated]]
1634
+ * This interface is redundant.
1635
+ */
1636
+ [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
1637
+
1638
+ /** \deprecated
1639
+ * [[deprecated]]
1640
+ * This interface sets dimensions to TensorTypeAndShapeInfo, but has no effect on the OrtValue.
1641
+ */
1642
+ [[deprecated("Do not use")]] void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
1643
+
1644
+ /** \deprecated use Ort::Value::GetTensorMutableData()
1645
+ * [[deprecated]]
1646
+ * This interface is redundant.
1647
+ */
1648
+ template <typename T>
1649
+ [[deprecated("use Ort::Value::GetTensorMutableData()")]] T* GetTensorMutableData(_Inout_ OrtValue* value);
1650
+
1651
+ /** \deprecated use Ort::Value::GetTensorData()
1652
+ * [[deprecated]]
1653
+ * This interface is redundant.
1654
+ */
1655
+ template <typename T>
1656
+ [[deprecated("use Ort::Value::GetTensorData()")]] const T* GetTensorData(_Inout_ const OrtValue* value);
1657
+
1658
+ /** \deprecated use Ort::Value::GetTensorMemoryInfo()
1659
+ * [[deprecated]]
1660
+ * This interface is redundant.
1661
+ */
1662
+ [[deprecated("use Ort::Value::GetTensorMemoryInfo()")]] const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value);
1663
+
1664
+ /** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape()
1665
+ * [[deprecated]]
1666
+ * This interface is redundant.
1667
+ */
1668
+ [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
1669
+
1670
+ /** \deprecated use TensorTypeAndShapeInfo instances for automatic ownership.
1671
+ * [[deprecated]]
1672
+ * This interface is not exception safe.
1673
+ */
1674
+ [[deprecated("use TensorTypeAndShapeInfo")]] void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
1675
+
1676
+ /** \deprecated use Ort::KernelContext::GetInputCount
1677
+ * [[deprecated]]
1678
+ * This interface is redundant.
1679
+ */
1680
+ [[deprecated("use Ort::KernelContext::GetInputCount")]] size_t KernelContext_GetInputCount(const OrtKernelContext* context);
1681
+
1682
+ /** \deprecated use Ort::KernelContext::GetInput
1683
+ * [[deprecated]]
1684
+ * This interface is redundant.
1685
+ */
1686
+ [[deprecated("use Ort::KernelContext::GetInput")]] const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index);
1687
+
1688
+ /** \deprecated use Ort::KernelContext::GetOutputCount
1689
+ * [[deprecated]]
1690
+ * This interface is redundant.
1691
+ */
1692
+ [[deprecated("use Ort::KernelContext::GetOutputCount")]] size_t KernelContext_GetOutputCount(const OrtKernelContext* context);
1693
+
1694
+ /** \deprecated use Ort::KernelContext::GetOutput
1695
+ * [[deprecated]]
1696
+ * This interface is redundant.
1697
+ */
1698
+ [[deprecated("use Ort::KernelContext::GetOutput")]] OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
1699
+
1700
+ /** \deprecated use Ort::KernelContext::GetGPUComputeStream
1701
+ * [[deprecated]]
1702
+ * This interface is redundant.
1703
+ */
1704
+ [[deprecated("use Ort::KernelContext::GetGPUComputeStream")]] void* KernelContext_GetGPUComputeStream(const OrtKernelContext* context);
1705
+
1706
+ /** \deprecated use Ort::ThrowOnError()
1707
+ * [[deprecated]]
1708
+ * This interface is redundant.
1709
+ */
1710
+ [[deprecated("use Ort::ThrowOnError()")]] void ThrowOnError(OrtStatus* result);
1711
+
1712
+ /** \deprecated use Ort::OpAttr
1713
+ * [[deprecated]]
1714
+ * This interface is not exception safe.
1715
+ */
1716
+ [[deprecated("use Ort::OpAttr")]] OrtOpAttr* CreateOpAttr(_In_ const char* name,
1717
+ _In_ const void* data,
1718
+ _In_ int len,
1719
+ _In_ OrtOpAttrType type);
1720
+
1721
+ /** \deprecated use Ort::OpAttr
1722
+ * [[deprecated]]
1723
+ * This interface is not exception safe.
1724
+ */
1725
+ [[deprecated("use Ort::OpAttr")]] void ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr);
1726
+
1727
+ /** \deprecated use Ort::Op
1728
+ * [[deprecated]]
1729
+ * This interface is not exception safe.
1730
+ */
1731
+ [[deprecated("use Ort::Op")]] OrtOp* CreateOp(_In_ const OrtKernelInfo* info,
1732
+ _In_ const char* op_name,
1733
+ _In_ const char* domain,
1734
+ _In_ int version,
1735
+ _In_opt_ const char** type_constraint_names,
1736
+ _In_opt_ const ONNXTensorElementDataType* type_constraint_values,
1737
+ _In_opt_ int type_constraint_count,
1738
+ _In_opt_ const OrtOpAttr* const* attr_values,
1739
+ _In_opt_ int attr_count,
1740
+ _In_ int input_count,
1741
+ _In_ int output_count);
1742
+
1743
+ /** \deprecated use Ort::Op::Invoke
1744
+ * [[deprecated]]
1745
+ * This interface is redundant
1746
+ */
1747
+ [[deprecated("use Ort::Op::Invoke")]] void InvokeOp(_In_ const OrtKernelContext* context,
1748
+ _In_ const OrtOp* ort_op,
1749
+ _In_ const OrtValue* const* input_values,
1750
+ _In_ int input_count,
1751
+ _Inout_ OrtValue* const* output_values,
1752
+ _In_ int output_count);
1753
+
1754
+ /** \deprecated use Ort::Op for automatic lifespan management.
1755
+ * [[deprecated]]
1756
+ * This interface is not exception safe.
1757
+ */
1758
+ [[deprecated("use Ort::Op")]] void ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op);
1759
+
1760
+ /** \deprecated use Ort::KernelInfo for automatic lifespan management or for
1761
+ * querying attributes
1762
+ * [[deprecated]]
1763
+ * This interface is redundant
1764
+ */
1765
+ template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
1766
+ [[deprecated("use Ort::KernelInfo::GetAttribute")]] T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
1767
+
1768
+ /** \deprecated use Ort::KernelInfo::Copy
1769
+ * querying attributes
1770
+ * [[deprecated]]
1771
+ * This interface is not exception safe
1772
+ */
1773
+ [[deprecated("use Ort::KernelInfo::Copy")]] OrtKernelInfo* CopyKernelInfo(_In_ const OrtKernelInfo* info);
1774
+
1775
+ /** \deprecated use Ort::KernelInfo for lifespan management
1776
+ * querying attributes
1777
+ * [[deprecated]]
1778
+ * This interface is not exception safe
1779
+ */
1780
+ [[deprecated("use Ort::KernelInfo")]] void ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy);
1781
+
1782
+ private:
1783
+ const OrtApi& api_;
1784
+ };
1785
+
1786
+ template <typename TOp, typename TKernel>
1787
+ struct CustomOpBase : OrtCustomOp {
1788
+ CustomOpBase() {
1789
+ OrtCustomOp::version = ORT_API_VERSION;
1790
+ OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
1791
+ OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
1792
+
1793
+ OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
1794
+
1795
+ OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
1796
+ OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
1797
+ OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
1798
+
1799
+ OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
1800
+ OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
1801
+
1802
+ OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
1803
+ #if defined(_MSC_VER) && !defined(__clang__)
1804
+ #pragma warning(push)
1805
+ #pragma warning(disable : 26409)
1806
+ #endif
1807
+ OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
1808
+ #if defined(_MSC_VER) && !defined(__clang__)
1809
+ #pragma warning(pop)
1810
+ #endif
1811
+ OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
1812
+ OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
1813
+
1814
+ OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
1815
+ OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
1816
+ OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
1817
+ OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
1818
+ }
1819
+
1820
+ // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
1821
+ const char* GetExecutionProviderType() const { return nullptr; }
1822
+
1823
+ // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
1824
+ // (inputs and outputs are required by default)
1825
+ OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
1826
+ return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
1827
+ }
1828
+
1829
+ OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
1830
+ return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
1831
+ }
1832
+
1833
+ // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
1834
+ OrtMemType GetInputMemoryType(size_t /*index*/) const {
1835
+ return OrtMemTypeDefault;
1836
+ }
1837
+
1838
+ // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
1839
+ // should expect at least 1 argument.
1840
+ int GetVariadicInputMinArity() const {
1841
+ return 1;
1842
+ }
1843
+
1844
+ // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
1845
+ // to a variadic input should be of the same type.
1846
+ bool GetVariadicInputHomogeneity() const {
1847
+ return true;
1848
+ }
1849
+
1850
+ // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
1851
+ // should produce at least 1 output value.
1852
+ int GetVariadicOutputMinArity() const {
1853
+ return 1;
1854
+ }
1855
+
1856
+ // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
1857
+ // produced by a variadic output should be of the same type.
1858
+ bool GetVariadicOutputHomogeneity() const {
1859
+ return true;
1860
+ }
1861
+
1862
+ // Declare list of session config entries used by this Custom Op.
1863
+ // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
1864
+ // This default implementation returns an empty vector of config entries.
1865
+ std::vector<std::string> GetSessionConfigKeys() const {
1866
+ return std::vector<std::string>{};
1867
+ }
1868
+
1869
+ protected:
1870
+ // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
1871
+ void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
1872
+ };
1873
+
1874
+ } // namespace Ort
1875
+
1876
+ #include "onnxruntime_cxx_inline.h"
1.14.0/onnxruntime.xcframework/Headers/onnxruntime_cxx_inline.h ADDED
@@ -0,0 +1,1874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
5
+ // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
6
+ //
7
+ // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
8
+ // the main C++ file with implementation details.
9
+
10
+ namespace Ort {
11
+
12
+ namespace detail {
13
+ inline void ThrowStatus(const Status& st) {
14
+ std::string error_message = st.GetErrorMessage();
15
+ OrtErrorCode error_code = st.GetErrorCode();
16
+ ORT_CXX_API_THROW(std::move(error_message), error_code);
17
+ }
18
+ } // namespace detail
19
+
20
+ inline void ThrowOnError(OrtStatus* ort_status) {
21
+ if (ort_status) {
22
+ Ort::Status st(ort_status);
23
+ detail::ThrowStatus(st);
24
+ }
25
+ }
26
+
27
+ inline void ThrowOnError(const Status& st) {
28
+ if (st) {
29
+ detail::ThrowStatus(st);
30
+ }
31
+ }
32
+
33
+ inline Status::Status(OrtStatus* status) : Base<OrtStatus>{status} {
34
+ }
35
+
36
+ inline Status::Status(const std::exception& e) {
37
+ p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
38
+ }
39
+
40
+ inline Status::Status(const Exception& e) {
41
+ p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
42
+ }
43
+
44
+ inline std::string Status::GetErrorMessage() const {
45
+ std::string message(GetApi().GetErrorMessage(p_));
46
+ return message;
47
+ }
48
+
49
+ inline OrtErrorCode Status::GetErrorCode() const {
50
+ return GetApi().GetErrorCode(p_);
51
+ }
52
+
53
+ // This template converts a C++ type into it's ONNXTensorElementDataType
54
+ template <typename T>
55
+ struct TypeToTensorType;
56
+ template <>
57
+ struct TypeToTensorType<float> {
58
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
59
+ };
60
+ template <>
61
+ struct TypeToTensorType<Float16_t> {
62
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
63
+ };
64
+ template <>
65
+ struct TypeToTensorType<BFloat16_t> {
66
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
67
+ };
68
+ template <>
69
+ struct TypeToTensorType<double> {
70
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
71
+ };
72
+ template <>
73
+ struct TypeToTensorType<int8_t> {
74
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
75
+ };
76
+ template <>
77
+ struct TypeToTensorType<int16_t> {
78
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
79
+ };
80
+ template <>
81
+ struct TypeToTensorType<int32_t> {
82
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
83
+ };
84
+ template <>
85
+ struct TypeToTensorType<int64_t> {
86
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
87
+ };
88
+ template <>
89
+ struct TypeToTensorType<uint8_t> {
90
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
91
+ };
92
+ template <>
93
+ struct TypeToTensorType<uint16_t> {
94
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
95
+ };
96
+ template <>
97
+ struct TypeToTensorType<uint32_t> {
98
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
99
+ };
100
+ template <>
101
+ struct TypeToTensorType<uint64_t> {
102
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
103
+ };
104
+ template <>
105
+ struct TypeToTensorType<bool> {
106
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
107
+ };
108
+
109
+ inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
110
+ : allocator_(allocator), p_(p), size_(size) {
111
+ }
112
+
113
+ inline MemoryAllocation::~MemoryAllocation() {
114
+ if (p_ != nullptr) {
115
+ // We do not throw out of destructor
116
+ auto ret = GetApi().AllocatorFree(allocator_, p_);
117
+ static_cast<void>(ret);
118
+ }
119
+ }
120
+
121
+ inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
122
+ *this = std::move(o);
123
+ }
124
+
125
+ inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
126
+ OrtAllocator* alloc = nullptr;
127
+ void* p = nullptr;
128
+ size_t sz = 0;
129
+
130
+ // Swap out this
131
+ std::swap(alloc, allocator_);
132
+ std::swap(p, p_);
133
+ std::swap(sz, size_);
134
+
135
+ // Swap with incoming
136
+ std::swap(allocator_, o.allocator_);
137
+ std::swap(p_, o.p_);
138
+ std::swap(size_, o.size_);
139
+
140
+ // Destroy this instance if needed
141
+ MemoryAllocation this_alloc(alloc, p, sz);
142
+ return *this;
143
+ }
144
+
145
+ namespace detail {
146
+
147
+ template <typename T>
148
+ inline void* AllocatorImpl<T>::Alloc(size_t size) {
149
+ void* out;
150
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
151
+ return out;
152
+ }
153
+
154
+ template <typename T>
155
+ inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
156
+ void* out;
157
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
158
+ MemoryAllocation result(this->p_, out, size);
159
+ return result;
160
+ }
161
+
162
+ template <typename T>
163
+ inline void AllocatorImpl<T>::Free(void* p) {
164
+ ThrowOnError(GetApi().AllocatorFree(this->p_, p));
165
+ }
166
+
167
+ template <typename T>
168
+ inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
169
+ const OrtMemoryInfo* out;
170
+ ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
171
+ return ConstMemoryInfo{out};
172
+ }
173
+
174
+ } // namespace detail
175
+
176
+ inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
177
+ ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
178
+ }
179
+
180
+ inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
181
+ ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
182
+ }
183
+
184
+ namespace detail {
185
+
186
+ template <typename T>
187
+ inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
188
+ const char* name = nullptr;
189
+ ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
190
+ return std::string(name);
191
+ }
192
+
193
+ template <typename T>
194
+ inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
195
+ OrtAllocatorType type;
196
+ ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
197
+ return type;
198
+ }
199
+
200
+ template <typename T>
201
+ inline int MemoryInfoImpl<T>::GetDeviceId() const {
202
+ int id = 0;
203
+ ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
204
+ return id;
205
+ }
206
+
207
+ template <typename T>
208
+ inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
209
+ OrtMemoryInfoDeviceType type;
210
+ GetApi().MemoryInfoGetDeviceType(this->p_, &type);
211
+ return type;
212
+ }
213
+
214
+ template <typename T>
215
+ inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
216
+ OrtMemType type;
217
+ ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
218
+ return type;
219
+ }
220
+
221
+ template <typename T>
222
+ template <typename U>
223
+ inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
224
+ int comp_result = 0;
225
+ ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
226
+ return comp_result == 0;
227
+ }
228
+
229
+ } // namespace detail
230
+
231
+ inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
232
+ OrtMemoryInfo* p;
233
+ ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
234
+ return MemoryInfo(p);
235
+ }
236
+
237
+ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
238
+ ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
239
+ }
240
+
241
+ namespace detail {
242
+ template <typename T>
243
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
244
+ AllocatorWithDefaultOptions allocator;
245
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
246
+ }
247
+
248
+ template <typename T>
249
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
250
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
251
+ }
252
+
253
+ template <typename T>
254
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
255
+ AllocatorWithDefaultOptions allocator;
256
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
257
+ }
258
+
259
+ template <typename T>
260
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
261
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
262
+ }
263
+
264
+ template <typename T>
265
+ inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
266
+ ThrowOnError(GetApi().BindInput(this->p_, name, value));
267
+ }
268
+
269
+ template <typename T>
270
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
271
+ ThrowOnError(GetApi().BindOutput(this->p_, name, value));
272
+ }
273
+
274
+ template <typename T>
275
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
276
+ ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
277
+ }
278
+
279
+ template <typename T>
280
+ inline void IoBindingImpl<T>::ClearBoundInputs() {
281
+ GetApi().ClearBoundInputs(this->p_);
282
+ }
283
+
284
+ template <typename T>
285
+ inline void IoBindingImpl<T>::ClearBoundOutputs() {
286
+ GetApi().ClearBoundOutputs(this->p_);
287
+ }
288
+
289
+ template <typename T>
290
+ inline void IoBindingImpl<T>::SynchronizeInputs() {
291
+ ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
292
+ }
293
+
294
+ template <typename T>
295
+ inline void IoBindingImpl<T>::SynchronizeOutputs() {
296
+ ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
297
+ }
298
+
299
+ namespace binding_utils {
300
+ inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
301
+ std::vector<std::string> result;
302
+ auto free_fn = detail::AllocatedFree(allocator);
303
+ using Ptr = std::unique_ptr<void, decltype(free_fn)>;
304
+
305
+ char* buffer = nullptr;
306
+ size_t* lengths = nullptr;
307
+ size_t count = 0;
308
+ ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
309
+
310
+ if (count == 0) {
311
+ return result;
312
+ }
313
+
314
+ Ptr buffer_g(buffer, free_fn);
315
+ Ptr lengths_g(lengths, free_fn);
316
+
317
+ result.reserve(count);
318
+ for (size_t i = 0; i < count; ++i) {
319
+ auto sz = *lengths;
320
+ result.emplace_back(buffer, sz);
321
+ buffer += sz;
322
+ ++lengths;
323
+ }
324
+ return result;
325
+ }
326
+
327
+ inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
328
+ std::vector<Value> result;
329
+ size_t owned = 0;
330
+ size_t output_count = 0;
331
+ // Lambda to release the buffer when no longer needed and
332
+ // make sure that we destroy all instances on exception
333
+ auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
334
+ if (buffer) {
335
+ while (owned < output_count) {
336
+ auto* p = buffer + owned++;
337
+ GetApi().ReleaseValue(*p);
338
+ }
339
+ allocator->Free(allocator, buffer);
340
+ }
341
+ };
342
+ using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
343
+
344
+ OrtValue** output_buffer = nullptr;
345
+ ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
346
+ if (output_count == 0) {
347
+ return result;
348
+ }
349
+
350
+ Ptr buffer_g(output_buffer, free_fn);
351
+
352
+ result.reserve(output_count);
353
+ for (size_t i = 0; i < output_count; ++i) {
354
+ result.emplace_back(output_buffer[i]);
355
+ ++owned;
356
+ }
357
+ return result;
358
+ }
359
+
360
+ } // namespace binding_utils
361
+ } // namespace detail
362
+
363
+ inline IoBinding::IoBinding(Session& session) {
364
+ ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
365
+ }
366
+
367
+ inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
368
+ ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
369
+ }
370
+
371
+ inline ThreadingOptions::ThreadingOptions() {
372
+ ThrowOnError(GetApi().CreateThreadingOptions(&p_));
373
+ }
374
+
375
+ inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
376
+ ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
377
+ return *this;
378
+ }
379
+
380
+ inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
381
+ ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
382
+ return *this;
383
+ }
384
+
385
+ inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
386
+ ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
387
+ return *this;
388
+ }
389
+
390
+ inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
391
+ ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
392
+ return *this;
393
+ }
394
+
395
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
396
+ ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
397
+ return *this;
398
+ }
399
+
400
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
401
+ ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
402
+ return *this;
403
+ }
404
+
405
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
406
+ ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
407
+ return *this;
408
+ }
409
+
410
+ inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
411
+ ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
412
+ if (strcmp(logid, "onnxruntime-node") == 0) {
413
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
414
+ } else {
415
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
416
+ }
417
+ }
418
+
419
+ inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
420
+ ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
421
+ if (strcmp(logid, "onnxruntime-node") == 0) {
422
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
423
+ } else {
424
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
425
+ }
426
+ }
427
+
428
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
429
+ ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
430
+ if (strcmp(logid, "onnxruntime-node") == 0) {
431
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
432
+ } else {
433
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
434
+ }
435
+ }
436
+
437
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
438
+ OrtLoggingLevel logging_level, _In_ const char* logid) {
439
+ ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
440
+ if (strcmp(logid, "onnxruntime-node") == 0) {
441
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
442
+ } else {
443
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
444
+ }
445
+ }
446
+
447
+ inline Env& Env::EnableTelemetryEvents() {
448
+ ThrowOnError(GetApi().EnableTelemetryEvents(p_));
449
+ return *this;
450
+ }
451
+
452
+ inline Env& Env::DisableTelemetryEvents() {
453
+ ThrowOnError(GetApi().DisableTelemetryEvents(p_));
454
+ return *this;
455
+ }
456
+
457
+ inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
458
+ ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
459
+ return *this;
460
+ }
461
+
462
+ inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
463
+ ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
464
+ return *this;
465
+ }
466
+
467
+ inline CustomOpDomain::CustomOpDomain(const char* domain) {
468
+ ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
469
+ }
470
+
471
+ inline void CustomOpDomain::Add(const OrtCustomOp* op) {
472
+ ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
473
+ }
474
+
475
+ inline RunOptions::RunOptions() {
476
+ ThrowOnError(GetApi().CreateRunOptions(&p_));
477
+ }
478
+
479
+ inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
480
+ ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
481
+ return *this;
482
+ }
483
+
484
+ inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
485
+ ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
486
+ return *this;
487
+ }
488
+
489
+ inline int RunOptions::GetRunLogVerbosityLevel() const {
490
+ int out;
491
+ ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
492
+ return out;
493
+ }
494
+
495
+ inline int RunOptions::GetRunLogSeverityLevel() const {
496
+ int out;
497
+ ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
498
+ return out;
499
+ }
500
+
501
+ inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
502
+ ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
503
+ return *this;
504
+ }
505
+
506
+ inline const char* RunOptions::GetRunTag() const {
507
+ const char* out;
508
+ ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
509
+ return out;
510
+ }
511
+
512
+ inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
513
+ ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
514
+ return *this;
515
+ }
516
+
517
+ inline RunOptions& RunOptions::SetTerminate() {
518
+ ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
519
+ return *this;
520
+ }
521
+
522
+ inline RunOptions& RunOptions::UnsetTerminate() {
523
+ ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
524
+ return *this;
525
+ }
526
+
527
+ namespace detail {
528
+
529
+ template <typename T>
530
+ inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
531
+ OrtSessionOptions* out;
532
+ ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
533
+ return SessionOptions{out};
534
+ }
535
+
536
+ template <typename T>
537
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
538
+ size_t size = 0;
539
+ // Feed nullptr for the data buffer to query the true size of the string value
540
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
541
+
542
+ std::string out;
543
+ out.resize(size);
544
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
545
+ out.resize(size - 1); // remove the terminating character '\0'
546
+
547
+ return out;
548
+ }
549
+
550
+ template <typename T>
551
+ inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
552
+ int out = 0;
553
+ Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
554
+ return static_cast<bool>(out);
555
+ }
556
+
557
+ template <typename T>
558
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
559
+ if (!this->HasConfigEntry(config_key)) {
560
+ return def;
561
+ }
562
+
563
+ return this->GetConfigEntry(config_key);
564
+ }
565
+
566
+ template <typename T>
567
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
568
+ ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
569
+ return *this;
570
+ }
571
+
572
+ template <typename T>
573
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
574
+ ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
575
+ return *this;
576
+ }
577
+
578
+ template <typename T>
579
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
580
+ ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
581
+ return *this;
582
+ }
583
+
584
+ template <typename T>
585
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
586
+ ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
587
+ return *this;
588
+ }
589
+
590
+ template <typename T>
591
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
592
+ ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
593
+ return *this;
594
+ }
595
+
596
+ template <typename T>
597
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
598
+ ThrowOnError(GetApi().DisableProfiling(this->p_));
599
+ return *this;
600
+ }
601
+
602
+ template <typename T>
603
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
604
+ ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
605
+ return *this;
606
+ }
607
+
608
+ template <typename T>
609
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
610
+ ThrowOnError(GetApi().EnableMemPattern(this->p_));
611
+ return *this;
612
+ }
613
+
614
+ template <typename T>
615
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
616
+ ThrowOnError(GetApi().DisableMemPattern(this->p_));
617
+ return *this;
618
+ }
619
+
620
+ template <typename T>
621
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
622
+ ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
623
+ return *this;
624
+ }
625
+
626
+ template <typename T>
627
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
628
+ ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
629
+ return *this;
630
+ }
631
+
632
+ template <typename T>
633
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
634
+ ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
635
+ return *this;
636
+ }
637
+
638
+ template <typename T>
639
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
640
+ ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
641
+ return *this;
642
+ }
643
+
644
+ template <typename T>
645
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
646
+ ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
647
+ return *this;
648
+ }
649
+
650
+ template <typename T>
651
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
652
+ ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
653
+ return *this;
654
+ }
655
+
656
+ template <typename T>
657
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
658
+ ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
659
+ return *this;
660
+ }
661
+
662
+ template <typename T>
663
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
664
+ ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
665
+ return *this;
666
+ }
667
+
668
+ template <typename T>
669
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
670
+ ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
671
+ return *this;
672
+ }
673
+
674
+ template <typename T>
675
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
676
+ const std::vector<Value>& ort_values) {
677
+ const size_t inputs_num = names.size();
678
+ if (inputs_num != ort_values.size()) {
679
+ ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
680
+ }
681
+ std::vector<const char*> names_ptr;
682
+ std::vector<const OrtValue*> ort_values_ptrs;
683
+ names_ptr.reserve(inputs_num);
684
+ ort_values_ptrs.reserve(inputs_num);
685
+ for (size_t i = 0; i < inputs_num; ++i) {
686
+ names_ptr.push_back(names[i].c_str());
687
+ ort_values_ptrs.push_back(ort_values[i]);
688
+ }
689
+ ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
690
+ return *this;
691
+ }
692
+
693
+ template <typename T>
694
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
695
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
696
+ return *this;
697
+ }
698
+
699
+ template <typename T>
700
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
701
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
702
+ return *this;
703
+ }
704
+
705
+ template <typename T>
706
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
707
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
708
+ return *this;
709
+ }
710
+
711
+ template <typename T>
712
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
713
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
714
+ return *this;
715
+ }
716
+
717
+ template <typename T>
718
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
719
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
720
+ return *this;
721
+ }
722
+
723
+ template <typename T>
724
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
725
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
726
+ return *this;
727
+ }
728
+
729
+ template <typename T>
730
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
731
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
732
+ return *this;
733
+ }
734
+
735
+ template <typename T>
736
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
737
+ const std::string& provider_name,
738
+ const std::unordered_map<std::string, std::string>& provider_options) {
739
+ auto num_entries = provider_options.size();
740
+ std::vector<const char*> keys, values;
741
+ if (num_entries > 0) {
742
+ keys.reserve(num_entries);
743
+ values.reserve(num_entries);
744
+
745
+ for (const auto& entry : provider_options) {
746
+ keys.push_back(entry.first.c_str());
747
+ values.push_back(entry.second.c_str());
748
+ }
749
+ }
750
+
751
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
752
+ keys.data(), values.data(), num_entries));
753
+
754
+ return *this;
755
+ }
756
+
757
+ template <typename T>
758
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
759
+ ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
760
+ return *this;
761
+ }
762
+
763
+ template <typename T>
764
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
765
+ ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
766
+ return *this;
767
+ }
768
+
769
+ template <typename T>
770
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
771
+ ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
772
+ return *this;
773
+ }
774
+
775
+ template <typename T>
776
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
777
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
778
+ return *this;
779
+ }
780
+
781
+ template <typename T>
782
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
783
+ const CustomOpConfigs& custom_op_configs) {
784
+ // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
785
+ // the custom op library.
786
+ for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
787
+ AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
788
+ }
789
+
790
+ ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
791
+ return *this;
792
+ }
793
+
794
+ template <typename T>
795
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
796
+ ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
797
+ return *this;
798
+ }
799
+
800
+ /// Session
801
+ template <typename T>
802
+ inline size_t ConstSessionImpl<T>::GetInputCount() const {
803
+ size_t out;
804
+ ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
805
+ return out;
806
+ }
807
+
808
+ template <typename T>
809
+ inline size_t ConstSessionImpl<T>::GetOutputCount() const {
810
+ size_t out;
811
+ ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
812
+ return out;
813
+ }
814
+
815
+ template <typename T>
816
+ inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
817
+ size_t out;
818
+ ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
819
+ return out;
820
+ }
821
+
822
+ template <typename T>
823
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
824
+ char* out;
825
+ ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
826
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
827
+ }
828
+
829
+ template <typename T>
830
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
831
+ char* out;
832
+ ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
833
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
834
+ }
835
+
836
+ template <typename T>
837
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
838
+ char* out;
839
+ ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
840
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
841
+ }
842
+
843
+ template <typename T>
844
+ inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
845
+ uint64_t out;
846
+ ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
847
+ return out;
848
+ }
849
+
850
+ template <typename T>
851
+ inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
852
+ OrtModelMetadata* out;
853
+ ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
854
+ return ModelMetadata{out};
855
+ }
856
+
857
+ template <typename T>
858
+ inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
859
+ OrtTypeInfo* out;
860
+ ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
861
+ return TypeInfo{out};
862
+ }
863
+
864
+ template <typename T>
865
+ inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
866
+ OrtTypeInfo* out;
867
+ ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
868
+ return TypeInfo{out};
869
+ }
870
+
871
+ template <typename T>
872
+ inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
873
+ OrtTypeInfo* out;
874
+ ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
875
+ return TypeInfo{out};
876
+ }
877
+
878
+ template <typename T>
879
+ inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
880
+ const char* const* output_names, size_t output_count) {
881
+ std::vector<Value> output_values;
882
+ output_values.reserve(output_count);
883
+ for (size_t i = 0; i < output_count; i++)
884
+ output_values.emplace_back(nullptr);
885
+ Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
886
+ return output_values;
887
+ }
888
+
889
+ template <typename T>
890
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
891
+ const char* const* output_names, Value* output_values, size_t output_count) {
892
+ static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
893
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
894
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
895
+ ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
896
+ }
897
+
898
+ template <typename T>
899
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
900
+ ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
901
+ }
902
+
903
+ template <typename T>
904
+ inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
905
+ char* out = nullptr;
906
+ ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
907
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
908
+ }
909
+
910
+ } // namespace detail
911
+
912
+ inline SessionOptions::SessionOptions() {
913
+ ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
914
+ }
915
+
916
+ /// CustomOpConfigs
917
+ inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
918
+ std::string config_key = "custom_op.";
919
+
920
+ config_key += custom_op_name;
921
+ config_key += ".";
922
+ config_key += config;
923
+
924
+ return config_key;
925
+ }
926
+
927
+ inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
928
+ const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
929
+ flat_configs_[full_flat_key] = config_value;
930
+ return *this;
931
+ }
932
+
933
+ inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
934
+ return flat_configs_;
935
+ }
936
+
937
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
938
+ ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
939
+ }
940
+
941
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
942
+ OrtPrepackedWeightsContainer* prepacked_weights_container) {
943
+ ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
944
+ }
945
+
946
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
947
+ ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
948
+ }
949
+
950
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
951
+ const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
952
+ ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
953
+ prepacked_weights_container, &this->p_));
954
+ }
955
+
956
+ inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
957
+ char* out;
958
+ ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
959
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
960
+ }
961
+
962
+ inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
963
+ char* out;
964
+ ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
965
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
966
+ }
967
+
968
+ inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
969
+ char* out;
970
+ ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
971
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
972
+ }
973
+
974
+ inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
975
+ char* out;
976
+ ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
977
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
978
+ }
979
+
980
+ inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
981
+ char* out;
982
+ ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
983
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
984
+ }
985
+
986
+ inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
987
+ char* out;
988
+ ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
989
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
990
+ }
991
+
992
+ inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
993
+ auto deletor = detail::AllocatedFree(allocator);
994
+ std::vector<AllocatedStringPtr> result;
995
+
996
+ char** out = nullptr;
997
+ int64_t num_keys = 0;
998
+ ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
999
+ if (num_keys <= 0) {
1000
+ return result;
1001
+ }
1002
+
1003
+ // array of pointers will be freed
1004
+ std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1005
+ // reserve may throw
1006
+ auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1007
+ std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1008
+ result.reserve(static_cast<size_t>(num_keys));
1009
+ strings_guard.release();
1010
+ for (int64_t i = 0; i < num_keys; ++i) {
1011
+ result.push_back(AllocatedStringPtr(out[i], deletor));
1012
+ }
1013
+
1014
+ return result;
1015
+ }
1016
+
1017
+ inline int64_t ModelMetadata::GetVersion() const {
1018
+ int64_t out;
1019
+ ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
1020
+ return out;
1021
+ }
1022
+
1023
+ namespace detail {
1024
+
1025
+ template <typename T>
1026
+ inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
1027
+ ONNXTensorElementDataType out;
1028
+ ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
1029
+ return out;
1030
+ }
1031
+
1032
+ template <typename T>
1033
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
1034
+ size_t out;
1035
+ ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
1036
+ return static_cast<size_t>(out);
1037
+ }
1038
+
1039
+ template <typename T>
1040
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
1041
+ size_t out;
1042
+ ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
1043
+ return out;
1044
+ }
1045
+
1046
+ template <typename T>
1047
+ inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
1048
+ ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
1049
+ }
1050
+
1051
+ template <typename T>
1052
+ inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
1053
+ ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1054
+ }
1055
+
1056
+ template <typename T>
1057
+ inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1058
+ std::vector<int64_t> out(GetDimensionsCount(), 0);
1059
+ ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1060
+ return out;
1061
+ }
1062
+
1063
+ } // namespace detail
1064
+
1065
+ namespace detail {
1066
+ template <typename T>
1067
+ inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
1068
+ const OrtTensorTypeAndShapeInfo* out;
1069
+ ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
1070
+ return ConstTensorTypeAndShapeInfo{out};
1071
+ }
1072
+
1073
+ template <typename T>
1074
+ inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
1075
+ const OrtSequenceTypeInfo* out;
1076
+ ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
1077
+ return ConstSequenceTypeInfo{out};
1078
+ }
1079
+
1080
+ template <typename T>
1081
+ inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
1082
+ const OrtMapTypeInfo* out;
1083
+ ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
1084
+ return ConstMapTypeInfo{out};
1085
+ }
1086
+
1087
+ template <typename T>
1088
+ inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
1089
+ ONNXType out;
1090
+ ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
1091
+ return out;
1092
+ }
1093
+
1094
+ } // namespace detail
1095
+
1096
+ namespace detail {
1097
+ template <typename T>
1098
+ inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
1099
+ OrtTypeInfo* output;
1100
+ ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
1101
+ return TypeInfo{output};
1102
+ }
1103
+
1104
+ } // namespace detail
1105
+
1106
+ namespace detail {
1107
+ template <typename T>
1108
+ inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
1109
+ ONNXTensorElementDataType out;
1110
+ ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
1111
+ return out;
1112
+ }
1113
+
1114
+ template <typename T>
1115
+ inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
1116
+ OrtTypeInfo* output;
1117
+ ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
1118
+ return TypeInfo{output};
1119
+ }
1120
+ } // namespace detail
1121
+
1122
+ namespace detail {
1123
+
1124
+ template <typename T>
1125
+ template <typename R>
1126
+ inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
1127
+ ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
1128
+ }
1129
+
1130
+ template <typename T>
1131
+ inline bool ConstValueImpl<T>::IsTensor() const {
1132
+ int out;
1133
+ ThrowOnError(GetApi().IsTensor(this->p_, &out));
1134
+ return out != 0;
1135
+ }
1136
+
1137
+ template <typename T>
1138
+ inline bool ConstValueImpl<T>::HasValue() const {
1139
+ int out;
1140
+ ThrowOnError(GetApi().HasValue(this->p_, &out));
1141
+ return out != 0;
1142
+ }
1143
+
1144
+ template <typename T>
1145
+ inline size_t ConstValueImpl<T>::GetCount() const {
1146
+ size_t out;
1147
+ ThrowOnError(GetApi().GetValueCount(this->p_, &out));
1148
+ return out;
1149
+ }
1150
+
1151
+ template <typename T>
1152
+ inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
1153
+ OrtValue* out;
1154
+ ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
1155
+ return Value{out};
1156
+ }
1157
+
1158
+ template <typename T>
1159
+ inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
1160
+ size_t out;
1161
+ ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
1162
+ return out;
1163
+ }
1164
+
1165
+ template <typename T>
1166
+ inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
1167
+ size_t out;
1168
+ ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
1169
+ return out;
1170
+ }
1171
+
1172
+ template <typename T>
1173
+ template <typename R>
1174
+ inline const R* ConstValueImpl<T>::GetTensorData() const {
1175
+ R* out;
1176
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
1177
+ return out;
1178
+ }
1179
+
1180
+ template <typename T>
1181
+ inline const void* ConstValueImpl<T>::GetTensorRawData() const {
1182
+ void* out;
1183
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
1184
+ return out;
1185
+ }
1186
+
1187
+ template <typename T>
1188
+ inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
1189
+ OrtTypeInfo* output;
1190
+ ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
1191
+ return TypeInfo{output};
1192
+ }
1193
+
1194
+ template <typename T>
1195
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
1196
+ OrtTensorTypeAndShapeInfo* output;
1197
+ ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
1198
+ return TensorTypeAndShapeInfo{output};
1199
+ }
1200
+
1201
+ template <typename T>
1202
+ inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
1203
+ const OrtMemoryInfo* mem_info;
1204
+ ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
1205
+ return ConstMemoryInfo(mem_info);
1206
+ }
1207
+
1208
+ template <typename T>
1209
+ inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
1210
+ ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1211
+ }
1212
+
1213
+ template <typename T>
1214
+ inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
1215
+ ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1216
+ }
1217
+
1218
+ #if !defined(DISABLE_SPARSE_TENSORS)
1219
+ template <typename T>
1220
+ inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
1221
+ OrtSparseFormat format;
1222
+ ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
1223
+ return format;
1224
+ }
1225
+
1226
+ template <typename T>
1227
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
1228
+ OrtTensorTypeAndShapeInfo* output;
1229
+ ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
1230
+ return TensorTypeAndShapeInfo{output};
1231
+ }
1232
+
1233
+ template <typename T>
1234
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
1235
+ OrtTensorTypeAndShapeInfo* output;
1236
+ ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1237
+ return TensorTypeAndShapeInfo{output};
1238
+ }
1239
+
1240
+ template <typename T>
1241
+ template <typename R>
1242
+ inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
1243
+ const void* out;
1244
+ ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1245
+ return reinterpret_cast<const R*>(out);
1246
+ }
1247
+
1248
+ template <typename T>
1249
+ inline bool ConstValueImpl<T>::IsSparseTensor() const {
1250
+ int out;
1251
+ ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
1252
+ return out != 0;
1253
+ }
1254
+
1255
+ template <typename T>
1256
+ template <typename R>
1257
+ inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
1258
+ const void* out;
1259
+ ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
1260
+ return reinterpret_cast<const R*>(out);
1261
+ }
1262
+
1263
+ #endif
1264
+
1265
+ template <typename T>
1266
+ void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
1267
+ ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
1268
+ }
1269
+
1270
+ template <typename T>
1271
+ void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
1272
+ ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
1273
+ }
1274
+
1275
+ template <typename T>
1276
+ void* ValueImpl<T>::GetTensorMutableRawData() {
1277
+ void* out;
1278
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
1279
+ return out;
1280
+ }
1281
+
1282
+ template <typename T>
1283
+ template <typename R>
1284
+ R* ValueImpl<T>::GetTensorMutableData() {
1285
+ R* out;
1286
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
1287
+ return out;
1288
+ }
1289
+
1290
+ template <typename T>
1291
+ template <typename R>
1292
+ R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
1293
+ static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
1294
+ R* out;
1295
+ ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
1296
+ return *out;
1297
+ }
1298
+
1299
+ #if !defined(DISABLE_SPARSE_TENSORS)
1300
+ template <typename T>
1301
+ void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
1302
+ ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
1303
+ }
1304
+
1305
+ template <typename T>
1306
+ void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
1307
+ ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1308
+ }
1309
+
1310
+ template <typename T>
1311
+ void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
1312
+ ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
1313
+ }
1314
+
1315
+ template <typename T>
1316
+ void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
1317
+ const int64_t* indices_data, size_t indices_num) {
1318
+ ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
1319
+ values_param.values_shape_len, values_param.data.p_data,
1320
+ indices_data, indices_num));
1321
+ }
1322
+
1323
+ template <typename T>
1324
+ void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1325
+ const OrtSparseValuesParam& values,
1326
+ const int64_t* inner_indices_data, size_t inner_indices_num,
1327
+ const int64_t* outer_indices_data, size_t outer_indices_num) {
1328
+ ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1329
+ inner_indices_data, inner_indices_num,
1330
+ outer_indices_data, outer_indices_num));
1331
+ }
1332
+
1333
+ template <typename T>
1334
+ void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1335
+ const OrtSparseValuesParam& values,
1336
+ const Shape& indices_shape,
1337
+ const int32_t* indices_data) {
1338
+ ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1339
+ indices_shape.shape, indices_shape.shape_len,
1340
+ indices_data));
1341
+ }
1342
+
1343
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1344
+
1345
+ } // namespace detail
1346
+
1347
+ template <typename T>
1348
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
1349
+ return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1350
+ }
1351
+
1352
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1353
+ ONNXTensorElementDataType type) {
1354
+ OrtValue* out;
1355
+ ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1356
+ return Value{out};
1357
+ }
1358
+
1359
+ template <typename T>
1360
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1361
+ return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1362
+ }
1363
+
1364
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
1365
+ OrtValue* out;
1366
+ ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1367
+ return Value{out};
1368
+ }
1369
+
1370
+ #if !defined(DISABLE_SPARSE_TENSORS)
1371
+
1372
+ template <typename T>
1373
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1374
+ const Shape& values_shape) {
1375
+ return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
1376
+ }
1377
+
1378
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1379
+ const Shape& values_shape, ONNXTensorElementDataType type) {
1380
+ OrtValue* out;
1381
+ ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1382
+ values_shape.shape, values_shape.shape_len, type, &out));
1383
+ return Value{out};
1384
+ }
1385
+
1386
+ template <typename T>
1387
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
1388
+ return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
1389
+ }
1390
+
1391
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
1392
+ ONNXTensorElementDataType type) {
1393
+ OrtValue* out;
1394
+ ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
1395
+ return Value{out};
1396
+ }
1397
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1398
+
1399
+ inline Value Value::CreateMap(Value& keys, Value& values) {
1400
+ OrtValue* out;
1401
+ OrtValue* inputs[2] = {keys, values};
1402
+ ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
1403
+ return Value{out};
1404
+ }
1405
+
1406
+ inline Value Value::CreateSequence(std::vector<Value>& values) {
1407
+ OrtValue* out;
1408
+ std::vector<OrtValue*> values_ort{values.data(), values.data() + values.size()};
1409
+ ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1410
+ return Value{out};
1411
+ }
1412
+
1413
+ template <typename T>
1414
+ inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
1415
+ OrtValue* out;
1416
+ ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
1417
+ return Value{out};
1418
+ }
1419
+
1420
+ //
1421
+ // Custom OP Inlines
1422
+ //
1423
+ inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
1424
+ }
1425
+
1426
+ inline size_t KernelContext::GetInputCount() const {
1427
+ size_t out = 0;
1428
+ Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
1429
+ return out;
1430
+ }
1431
+
1432
+ inline size_t KernelContext::GetOutputCount() const {
1433
+ size_t out = 0;
1434
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
1435
+ return out;
1436
+ }
1437
+
1438
+ inline ConstValue KernelContext::GetInput(size_t index) const {
1439
+ const OrtValue* out = nullptr;
1440
+ Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
1441
+ return ConstValue{out};
1442
+ }
1443
+
1444
+ inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
1445
+ OrtValue* out = nullptr;
1446
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
1447
+ return UnownedValue(out);
1448
+ }
1449
+
1450
+ inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
1451
+ OrtValue* out = nullptr;
1452
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
1453
+ return UnownedValue(out);
1454
+ }
1455
+
1456
+ inline void* KernelContext::GetGPUComputeStream() const {
1457
+ void* out = nullptr;
1458
+ Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
1459
+ return out;
1460
+ }
1461
+
1462
+ inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
1463
+ Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
1464
+ }
1465
+
1466
+ namespace detail {
1467
+ template <typename T>
1468
+ inline KernelInfo KernelInfoImpl<T>::Copy() const {
1469
+ OrtKernelInfo* info_copy = nullptr;
1470
+ Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
1471
+ return KernelInfo{info_copy};
1472
+ }
1473
+
1474
+ template <typename T>
1475
+ inline size_t KernelInfoImpl<T>::GetInputCount() const {
1476
+ size_t out = 0;
1477
+ ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
1478
+ return out;
1479
+ }
1480
+
1481
+ template <typename T>
1482
+ inline size_t KernelInfoImpl<T>::GetOutputCount() const {
1483
+ size_t out = 0;
1484
+ ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
1485
+ return out;
1486
+ }
1487
+
1488
+ template <typename T>
1489
+ inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
1490
+ size_t size = 0;
1491
+
1492
+ // Feed nullptr for the data buffer to query the true size of the string value
1493
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
1494
+
1495
+ std::string out;
1496
+ out.resize(size);
1497
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
1498
+ out.resize(size - 1); // remove the terminating character '\0'
1499
+
1500
+ return out;
1501
+ }
1502
+
1503
+ template <typename T>
1504
+ inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
1505
+ size_t size = 0;
1506
+
1507
+ // Feed nullptr for the data buffer to query the true size of the string value
1508
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
1509
+
1510
+ std::string out;
1511
+ out.resize(size);
1512
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
1513
+ out.resize(size - 1); // remove the terminating character '\0'
1514
+
1515
+ return out;
1516
+ }
1517
+
1518
+ template <typename T>
1519
+ inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
1520
+ OrtTypeInfo* out = nullptr;
1521
+ ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
1522
+ return TypeInfo{out};
1523
+ }
1524
+
1525
+ template <typename T>
1526
+ inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
1527
+ OrtTypeInfo* out = nullptr;
1528
+ ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
1529
+ return TypeInfo{out};
1530
+ }
1531
+
1532
+ template <typename T>
1533
+ inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
1534
+ OrtValue* out = nullptr;
1535
+ ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1536
+ return Value{out};
1537
+ }
1538
+
1539
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
1540
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
1541
+ }
1542
+
1543
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
1544
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
1545
+ }
1546
+
1547
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
1548
+ size_t size = 0;
1549
+ // Feed nullptr for the data buffer to query the true size of the string attribute
1550
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
1551
+
1552
+ std::string out;
1553
+ out.resize(size);
1554
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
1555
+ out.resize(size - 1); // remove the terminating character '\0'
1556
+ out.swap(result);
1557
+ }
1558
+
1559
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
1560
+ size_t size = 0;
1561
+ // Feed nullptr for the data buffer to query the true size of the attribute
1562
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
1563
+
1564
+ std::vector<float> out;
1565
+ out.resize(size);
1566
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
1567
+ out.swap(result);
1568
+ }
1569
+
1570
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
1571
+ size_t size = 0;
1572
+
1573
+ // Feed nullptr for the data buffer to query the true size of the attribute
1574
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
1575
+
1576
+ std::vector<int64_t> out;
1577
+ out.resize(size);
1578
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
1579
+ out.swap(result);
1580
+ }
1581
+ } // namespace detail
1582
+
1583
+ inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
1584
+
1585
+ inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1586
+
1587
+ inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
1588
+ const char** type_constraint_names,
1589
+ const ONNXTensorElementDataType* type_constraint_values,
1590
+ size_t type_constraint_count,
1591
+ const OpAttr* attr_values, size_t attr_count,
1592
+ size_t input_count, size_t output_count) {
1593
+ static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
1594
+ "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1595
+ auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
1596
+ OrtOp* op;
1597
+ Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1598
+ static_cast<int>(type_constraint_count),
1599
+ attr_input_values,
1600
+ static_cast<int>(attr_count),
1601
+ static_cast<int>(input_count),
1602
+ static_cast<int>(output_count), &op));
1603
+ return Op{op};
1604
+ }
1605
+
1606
+ inline void Op::Invoke(const OrtKernelContext* context,
1607
+ const Value* input_values,
1608
+ size_t input_count,
1609
+ Value* output_values,
1610
+ size_t output_count) {
1611
+ static_assert(sizeof(Value) == sizeof(OrtValue*),
1612
+ "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1613
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1614
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1615
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
1616
+ ort_output_values, static_cast<int>(output_count)));
1617
+ }
1618
+
1619
+ inline void Op::Invoke(const OrtKernelContext* context,
1620
+ const OrtValue* const* input_values,
1621
+ size_t input_count,
1622
+ OrtValue* const* output_values,
1623
+ size_t output_count) {
1624
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
1625
+ output_values, static_cast<int>(output_count)));
1626
+ }
1627
+
1628
+ inline void CustomOpApi::ThrowOnError(OrtStatus* status) {
1629
+ Ort::ThrowOnError(status);
1630
+ }
1631
+
1632
+ template <>
1633
+ inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1634
+ float out;
1635
+ Ort::ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
1636
+ return out;
1637
+ }
1638
+
1639
+ template <>
1640
+ inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1641
+ int64_t out;
1642
+ Ort::ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
1643
+ return out;
1644
+ }
1645
+
1646
+ template <>
1647
+ inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1648
+ size_t size = 0;
1649
+ std::string out;
1650
+
1651
+ // Feed nullptr for the data buffer to query the true size of the string attribute
1652
+ OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
1653
+
1654
+ if (status == nullptr) {
1655
+ out.resize(size);
1656
+ Ort::ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
1657
+ out.resize(size - 1); // remove the terminating character '\0'
1658
+ } else {
1659
+ Ort::ThrowOnError(status);
1660
+ }
1661
+ return out;
1662
+ }
1663
+
1664
+ template <>
1665
+ inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1666
+ size_t size = 0;
1667
+ std::vector<float> out;
1668
+
1669
+ // Feed nullptr for the data buffer to query the true size of the attribute
1670
+ OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
1671
+
1672
+ if (status == nullptr) {
1673
+ out.resize(size);
1674
+ Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
1675
+ } else {
1676
+ Ort::ThrowOnError(status);
1677
+ }
1678
+ return out;
1679
+ }
1680
+
1681
+ template <>
1682
+ inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1683
+ size_t size = 0;
1684
+ std::vector<int64_t> out;
1685
+
1686
+ // Feed nullptr for the data buffer to query the true size of the attribute
1687
+ OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
1688
+
1689
+ if (status == nullptr) {
1690
+ out.resize(size);
1691
+ Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
1692
+ } else {
1693
+ Ort::ThrowOnError(status);
1694
+ }
1695
+ return out;
1696
+ }
1697
+ inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
1698
+ OrtTensorTypeAndShapeInfo* out;
1699
+ Ort::ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
1700
+ return out;
1701
+ }
1702
+
1703
+ inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
1704
+ size_t out;
1705
+ Ort::ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
1706
+ return out;
1707
+ }
1708
+
1709
+ inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) {
1710
+ ONNXTensorElementDataType out;
1711
+ Ort::ThrowOnError(api_.GetTensorElementType(info, &out));
1712
+ return out;
1713
+ }
1714
+
1715
+ inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
1716
+ size_t out;
1717
+ Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
1718
+ return out;
1719
+ }
1720
+
1721
+ inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) {
1722
+ Ort::ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
1723
+ }
1724
+
1725
+ inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) {
1726
+ Ort::ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
1727
+ }
1728
+
1729
+ template <typename T>
1730
+ inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) {
1731
+ T* data;
1732
+ Ort::ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
1733
+ return data;
1734
+ }
1735
+
1736
+ inline const OrtMemoryInfo* CustomOpApi::GetTensorMemoryInfo(_In_ const OrtValue* value) {
1737
+ const OrtMemoryInfo* mem_info;
1738
+ Ort::ThrowOnError(api_.GetTensorMemoryInfo(value, &mem_info));
1739
+ return mem_info;
1740
+ }
1741
+
1742
+ template <typename T>
1743
+ inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) {
1744
+ T* data = nullptr;
1745
+ Ort::ThrowOnError(api_.GetTensorMutableData(const_cast<OrtValue*>(value), reinterpret_cast<void**>(&data)));
1746
+ return data;
1747
+ }
1748
+
1749
+ inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) {
1750
+ size_t out;
1751
+ Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
1752
+ std::vector<int64_t> output(out);
1753
+ Ort::ThrowOnError(api_.GetDimensions(info, output.data(), out));
1754
+ return output;
1755
+ }
1756
+
1757
+ inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) {
1758
+ api_.ReleaseTensorTypeAndShapeInfo(input);
1759
+ }
1760
+
1761
+ inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) {
1762
+ size_t out;
1763
+ Ort::ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
1764
+ return out;
1765
+ }
1766
+
1767
+ inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) {
1768
+ const OrtValue* out;
1769
+ Ort::ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
1770
+ return out;
1771
+ }
1772
+
1773
+ inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) {
1774
+ size_t out;
1775
+ Ort::ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
1776
+ return out;
1777
+ }
1778
+
1779
+ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
1780
+ _In_ const int64_t* dim_values, size_t dim_count) {
1781
+ OrtValue* out;
1782
+ Ort::ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
1783
+ return out;
1784
+ }
1785
+
1786
+ inline void* CustomOpApi::KernelContext_GetGPUComputeStream(const OrtKernelContext* context) {
1787
+ void* out;
1788
+ Ort::ThrowOnError(api_.KernelContext_GetGPUComputeStream(context, &out));
1789
+ return out;
1790
+ }
1791
+
1792
+ inline OrtOpAttr* CustomOpApi::CreateOpAttr(_In_ const char* name,
1793
+ _In_ const void* data,
1794
+ _In_ int len,
1795
+ _In_ OrtOpAttrType type) {
1796
+ OrtOpAttr* op_attr{};
1797
+ Ort::ThrowOnError(api_.CreateOpAttr(name, data, len, type, &op_attr));
1798
+ return op_attr;
1799
+ }
1800
+
1801
+ inline void CustomOpApi::ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr) {
1802
+ api_.ReleaseOpAttr(op_attr);
1803
+ }
1804
+
1805
+ inline OrtOp* CustomOpApi::CreateOp(_In_ const OrtKernelInfo* info,
1806
+ _In_ const char* op_name,
1807
+ _In_ const char* domain,
1808
+ _In_ int version,
1809
+ _In_opt_ const char** type_constraint_names,
1810
+ _In_opt_ const ONNXTensorElementDataType* type_constraint_values,
1811
+ _In_opt_ int type_constraint_count,
1812
+ _In_opt_ const OrtOpAttr* const* attr_values,
1813
+ _In_opt_ int attr_count,
1814
+ _In_ int input_count,
1815
+ _In_ int output_count) {
1816
+ OrtOp* ort_op{};
1817
+ Ort::ThrowOnError(api_.CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1818
+ type_constraint_count, attr_values, attr_count, input_count, output_count, &ort_op));
1819
+ return ort_op;
1820
+ }
1821
+
1822
+ inline void CustomOpApi::InvokeOp(_In_ const OrtKernelContext* context,
1823
+ _In_ const OrtOp* ort_op,
1824
+ _In_ const OrtValue* const* input_values,
1825
+ _In_ int input_count,
1826
+ _Inout_ OrtValue* const* output_values,
1827
+ _In_ int output_count) {
1828
+ Ort::ThrowOnError(api_.InvokeOp(context, ort_op, input_values, input_count, output_values, output_count));
1829
+ }
1830
+
1831
+ inline void CustomOpApi::ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op) {
1832
+ api_.ReleaseOp(ort_op);
1833
+ }
1834
+
1835
+ inline OrtKernelInfo* CustomOpApi::CopyKernelInfo(_In_ const OrtKernelInfo* info) {
1836
+ OrtKernelInfo* info_copy{};
1837
+ Ort::ThrowOnError(api_.CopyKernelInfo(info, &info_copy));
1838
+ return info_copy;
1839
+ }
1840
+
1841
+ inline void CustomOpApi::ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy) {
1842
+ api_.ReleaseKernelInfo(info_copy);
1843
+ }
1844
+
1845
+ inline std::vector<std::string> GetAvailableProviders() {
1846
+ int len;
1847
+ char** providers;
1848
+ ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
1849
+ std::vector<std::string> available_providers(providers, providers + len);
1850
+ ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
1851
+ return available_providers;
1852
+ }
1853
+
1854
+ SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
1855
+
1856
+ template <typename TOp, typename TKernel>
1857
+ void CustomOpBase<TOp, TKernel>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
1858
+ ConstSessionOptions options) const {
1859
+ const TOp* derived = static_cast<const TOp*>(this);
1860
+ std::vector<std::string> keys = derived->GetSessionConfigKeys();
1861
+
1862
+ out.reserve(keys.size());
1863
+
1864
+ std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
1865
+ const size_t prefix_size = config_entry_key.length();
1866
+
1867
+ for (const auto& key : keys) {
1868
+ config_entry_key.resize(prefix_size);
1869
+ config_entry_key.append(key);
1870
+ out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
1871
+ }
1872
+ }
1873
+
1874
+ } // namespace Ort
1.14.0/onnxruntime.xcframework/Info.plist ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
3
+ <plist version="1.0">
4
+ <dict>
5
+ <key>AvailableLibraries</key>
6
+ <array>
7
+ <dict>
8
+ <key>LibraryIdentifier</key>
9
+ <string>ios-arm64_x86_64-simulator</string>
10
+ <key>LibraryPath</key>
11
+ <string>onnxruntime.a</string>
12
+ <key>SupportedArchitectures</key>
13
+ <array>
14
+ <string>arm64</string>
15
+ <string>x86_64</string>
16
+ </array>
17
+ <key>SupportedPlatform</key>
18
+ <string>ios</string>
19
+ <key>SupportedPlatformVariant</key>
20
+ <string>simulator</string>
21
+ </dict>
22
+ <dict>
23
+ <key>LibraryIdentifier</key>
24
+ <string>ios-arm64</string>
25
+ <key>LibraryPath</key>
26
+ <string>onnxruntime.a</string>
27
+ <key>SupportedArchitectures</key>
28
+ <array>
29
+ <string>arm64</string>
30
+ </array>
31
+ <key>SupportedPlatform</key>
32
+ <string>ios</string>
33
+ </dict>
34
+ </array>
35
+ <key>CFBundlePackageType</key>
36
+ <string>XFWK</string>
37
+ <key>XCFrameworkFormatVersion</key>
38
+ <string>1.0</string>
39
+ </dict>
40
+ </plist>
1.14.0/onnxruntime.xcframework/ios-arm64/libonnxruntime.a ADDED
@@ -0,0 +1 @@
 
 
1
+ onnxruntime.a
1.14.0/onnxruntime.xcframework/ios-arm64/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b052d5e5e605673827875492753a7fe58c46b4a6d68f40e28894a25f24f28493
3
+ size 56896832
1.14.0/onnxruntime.xcframework/ios-arm64_x86_64-simulator/libonnxruntime.a ADDED
@@ -0,0 +1 @@
 
 
1
+ onnxruntime.a
1.14.0/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9cbe4e1556cab85a01dfdebad14eeab2d80423b58278964c6f82ae44ec7c360
3
+ size 116205304