Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
8.56 kB
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#pragma once
#include <stdint.h>
#include <string>
#include <unordered_map>
#include <vector>
namespace ONNX_NAMESPACE {
#define FORALL_BUILTIN_SYMBOLS(_) \
_(spatial) \
_(select_last_index) \
_(coordinate_transformation_mode) \
_(PythonOp) \
_(CppOp) \
_(Param) \
_(Select) \
_(Return) \
_(Eval) \
_(add) \
_(Add) \
_(Div) \
_(Mul) \
_(Neg) \
_(Sub) \
_(Pow) \
_(Sigmoid) \
_(ArgMax) \
_(Concat) \
_(Softmax) \
_(LogSoftmax) \
_(Dropout) \
_(Tanh) \
_(mul) \
_(neg) \
_(sigmoid) \
_(tanh) \
_(Constant) \
_(cat) \
_(Slice) \
_(Squeeze) \
_(Undefined) \
_(FusionGroup) \
_(MatMul) \
_(Gemm) \
_(Tile) \
_(SubConstant) \
_(Scale) \
_(Transpose) \
_(Pad) \
_(Reshape) \
_(split) \
_(chunk) \
_(Offset) \
_(value) \
_(Subgraph) \
_(BatchNormalization) \
_(Conv) \
_(ConvTranspose) \
_(is_test) \
_(epsilon) \
_(expand) \
_(Expand) \
_(order) \
_(momentum) \
_(consumed_inputs) \
_(kernels) \
_(kernel_shape) \
_(kernel) \
_(scale) \
_(strides) \
_(stride) \
_(pads) \
_(pad) \
_(beta) \
_(alpha) \
_(dilations) \
_(dilation) \
_(broadcast) \
_(axis) \
_(ratio) \
_(size) \
_(dim) \
_(keepdims) \
_(perm) \
_(shape) \
_(axes) \
_(group) \
_(inplace) \
_(transA) \
_(transB) \
_(other) \
_(__and__) \
_(__lshift__) \
_(__or__) \
_(__rshift__) \
_(__xor__) \
_(abs) \
_(acos) \
_(asin) \
_(atan) \
_(atan2) \
_(ceil) \
_(clamp) \
_(cos) \
_(cosh) \
_(div) \
_(eq) \
_(equal) \
_(Exp) \
_(ends) \
_(expm1) \
_(floor) \
_(fmod) \
_(frac) \
_(ge) \
_(gt) \
_(le) \
_(lerp) \
_(lgamma) \
_(Log) \
_(log1p) \
_(lt) \
_(max) \
_(min) \
_(ne) \
_(ones) \
_(pow) \
_(reciprocal) \
_(remainder) \
_(round) \
_(rsqrt) \
_(sin) \
_(sinh) \
_(Sqrt) \
_(sub) \
_(starts) \
_(tan) \
_(trunc) \
_(zeros) \
_(exponent) \
_(device) \
_(mode) \
_(Identity) \
_(Loop) \
_(If) \
_(body) \
_(then_branch) \
_(else_branch) \
_(Captured) \
_(__control_inputs) \
_(count_include_pad) \
_(storage_order) \
_(Unsqueeze) \
_(ReduceL1) \
_(ReduceL2) \
_(ReduceLogSum) \
_(ReduceLogSumExp) \
_(ReduceMax) \
_(ReduceMean) \
_(ReduceMin) \
_(ReduceProd) \
_(ReduceSum) \
_(ReduceSumSquare) \
_(Cast) \
_(to) \
_(PRelu) \
_(Greater) \
_(Less) \
_(scales) \
_(Upsample) \
_(RNN) \
_(layout) \
_(k) \
_(Flatten) \
_(ScatterElements) \
_(Resize) \
_(ceil_mode) \
_(num_outputs) \
_(start) \
_(end) \
_(num_groups) \
_(stash_type) \
_(block_size) \
_(output_dtype)
enum BuiltinSymbol {
#define DEFINE_SYMBOL(s) k##s,
FORALL_BUILTIN_SYMBOLS(DEFINE_SYMBOL)
#undef DEFINE_SYMBOL
kLastSymbol, // where we start counting for new symbols
};
struct Symbol {
Symbol() {}
/*implicit*/ Symbol(BuiltinSymbol value) : value(value) {}
explicit Symbol(const std::string& s);
explicit Symbol(uint32_t value) : value(value) {}
operator uint32_t() const {
return value;
}
const char* toString() const;
private:
uint32_t value;
};
static inline bool operator==(Symbol lhs, Symbol rhs) {
return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
}
// necessary to prevent ambiguous overload resolutions
static inline bool operator==(BuiltinSymbol lhs, Symbol rhs) {
return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
}
static inline bool operator==(Symbol lhs, BuiltinSymbol rhs) {
return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
}
inline Symbol operator"" _sym(const char* s, size_t) {
return Symbol(s);
}
} // namespace ONNX_NAMESPACE
// make symbol behave like an integer in hash tables
namespace std {
template <>
struct hash<ONNX_NAMESPACE::Symbol> {
std::size_t operator()(ONNX_NAMESPACE::Symbol s) const {
return std::hash<uint32_t>()(static_cast<uint32_t>(s));
}
};
} // namespace std