Spaces:
Sleeping
Sleeping
// Copyright (c) ONNX Project Contributors | |
/* | |
* SPDX-License-Identifier: Apache-2.0 | |
*/ | |
// ATTENTION: The code in this file is highly EXPERIMENTAL. | |
// Adventurous users should note that the APIs will probably change. | |
TypeName(const TypeName&) = delete; \ | |
TypeName& operator=(const TypeName&) = delete | |
namespace ONNX_NAMESPACE { | |
namespace { // internal/private API | |
std::string toVarName(size_t i) { | |
std::ostringstream oss; | |
oss << "_v_" << i; | |
return oss.str(); | |
} | |
} // namespace | |
// Graph represents one "function" of computation. | |
// It uses a simple ownership model where the graph owns all the nodes inside it. | |
// All references inside the graph are raw pointers. | |
// Destroying the Graph will invalidate any pointers to nodes in the graph. | |
struct Graph; | |
// Node is the base class of the IR graph. It represents one computation | |
// and dependencies on a list of Values. The "prim-ops", so to speak. | |
struct Node; | |
// A Value represents an input or output to node that is either a | |
// Tensor or an opaque Handle object, as determined by type(). | |
struct Value; | |
class ResourceGuard final { | |
std::function<void()> destructor_; | |
bool released_; | |
public: | |
ONNX_DISALLOW_COPY_AND_ASSIGN(ResourceGuard); | |
explicit ResourceGuard(std::function<void()> destructor) : destructor_(std::move(destructor)), released_(false) {} | |
ResourceGuard(ResourceGuard&& other) = default; | |
ResourceGuard& operator=(ResourceGuard&& other) = default; | |
~ResourceGuard() { | |
if (!released_) | |
destructor_(); | |
} | |
void release() { | |
released_ = true; | |
} | |
}; | |
struct Dimension final { | |
Dimension() : is_unknown(true), is_int(false), dim(-1) {} | |
Dimension(std::string param) : is_unknown(false), is_int(false), dim(-1), param(std::move(param)) {} // NOLINT | |
Dimension(int64_t dim) : is_unknown(false), is_int(true), dim(dim) {} // NOLINT | |
bool is_unknown; | |
bool is_int; | |
int64_t dim; | |
std::string param; | |
}; | |
enum class AttributeKind : uint8_t { | |
// float, float list, int, int list, string, string list, | |
// tensor, tensor list, subgraph, subgraph list. type proto, type proto list | |
f, | |
fs, | |
i, | |
is, | |
s, | |
ss, | |
t, | |
ts, | |
g, | |
gs, | |
tp, | |
tps | |
}; | |
static inline const char* toString(AttributeKind kind) { | |
static constexpr const char* names[] = {"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "tp", "tps"}; | |
ONNX_ASSERT(size_t(kind) < sizeof(names) / sizeof(const char*)); | |
return names[int(kind)]; | |
} | |
struct AttributeValue { | |
explicit AttributeValue(Symbol name) : name(name) {} | |
using Ptr = std::unique_ptr<AttributeValue>; | |
Symbol name; | |
virtual AttributeKind kind() const = 0; | |
virtual Ptr clone() const = 0; | |
virtual ~AttributeValue() = default; | |
}; | |
template <typename T, AttributeKind Kind> | |
struct ScalarAttributeValue final : public AttributeValue { | |
using ConstructorType = const T&; | |
using ValueType = T; | |
ScalarAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(value_) {} | |
ValueType& value() { | |
return value_; | |
} | |
virtual Ptr clone() const override { | |
return Ptr(new ScalarAttributeValue(name, value_)); | |
} | |
virtual AttributeKind kind() const override { | |
return Kind; | |
} | |
private: | |
ValueType value_; | |
}; | |
template <typename T, AttributeKind Kind> | |
struct VectorAttributeValue final : public AttributeValue { | |
using ConstructorType = const std::vector<T>&&; | |
using ValueType = std::vector<T>; | |
VectorAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {} | |
ValueType& value() { | |
return value_; | |
} | |
virtual AttributeKind kind() const override { | |
return Kind; | |
} | |
virtual std::unique_ptr<AttributeValue> clone() const override { | |
auto copy = value_; | |
return Ptr(new VectorAttributeValue(name, std::move(copy))); | |
} | |
private: | |
ValueType value_; | |
}; | |
using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>; | |
using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>; | |
using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>; | |
using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>; | |
using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>; | |
using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>; | |
using TensorAttr = ScalarAttributeValue<Tensor, AttributeKind::t>; | |
using TensorsAttr = VectorAttributeValue<Tensor, AttributeKind::ts>; | |
using GraphAttr = ScalarAttributeValue<std::shared_ptr<Graph>, AttributeKind::g>; | |
using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>, AttributeKind::gs>; | |
using TypeProtoAttr = ScalarAttributeValue<TypeProto, AttributeKind::tp>; | |
using TypeProtosAttr = VectorAttributeValue<TypeProto, AttributeKind::tps>; | |
// CRTP so that Node which inherits Attributes can be return for | |
// method chaining e.g: | |
// Node * n = g->create(kSelect)->set_i(kOffset,3)->set_f(kValue,3.5); | |
// we return Derived* pointers because Nodes are normally held as pointers. | |
template <typename Derived> | |
struct Attributes { | |
Attributes() {} | |
void copyAttributes(const Attributes& rhs) { | |
values_.clear(); | |
values_.reserve(rhs.values_.size()); | |
for (auto& i : rhs.values_) { | |
values_.push_back(i->clone()); | |
} | |
} | |
bool hasAttribute(Symbol name) const { | |
return find(name, false) != values_.end(); | |
} | |
AttributeKind kindOf(Symbol name) const { | |
return (*find(name, true))->kind(); | |
} | |
Derived* removeAttribute(Symbol name) { | |
values_.erase(find(name, true)); | |
return This(); | |
} | |
bool hasAttributes() const { | |
return !values_.empty(); | |
} | |
// The names are returned in order, since name actually is the index. | |
std::vector<Symbol> attributeNames() const { | |
std::vector<Symbol> names; | |
names.reserve(values_.size()); | |
for (auto& a : values_) | |
names.push_back(a->name); | |
return names; | |
} | |
Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \ | |
return set<Kind##Attr>(name, std::forward<Kind##Attr::ConstructorType>(v)); \ | |
} \ | |
const Kind##Attr::ValueType& method(Symbol name) const { \ | |
return get<Kind##Attr>(name); \ | |
} | |
CREATE_ACCESSOR(Float, f) | |
CREATE_ACCESSOR(Floats, fs) | |
CREATE_ACCESSOR(String, s) | |
CREATE_ACCESSOR(Strings, ss) | |
CREATE_ACCESSOR(Int, i) | |
CREATE_ACCESSOR(Ints, is) | |
CREATE_ACCESSOR(Tensor, t) | |
CREATE_ACCESSOR(Tensors, ts) | |
CREATE_ACCESSOR(Graph, g) | |
CREATE_ACCESSOR(Graphs, gs) | |
CREATE_ACCESSOR(TypeProto, tp) | |
CREATE_ACCESSOR(TypeProtos, tps) | |
private: | |
Derived* This() { | |
return static_cast<Derived*>(this); | |
} | |
template <typename T> | |
Derived* set(Symbol name, typename T::ConstructorType v) { | |
auto it = find(name, false); | |
auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v))); | |
if (it == values_.end()) { | |
values_.push_back(std::move(nv)); | |
} else { | |
*it = std::move(nv); | |
} | |
return This(); | |
} | |
template <typename T> | |
typename T::ValueType& get(Symbol name) const { | |
auto it = find(name, true); | |
T* child = static_cast<T*>(it->get()); | |
return child->value(); | |
} | |
using AVPtr = AttributeValue::Ptr; | |
// NB: For determinism, we use a vector rather than a hash map. This does | |
// mean that lookups are O(n), so you shouldn't use Attributes to store | |
// a big pile of messages. | |
std::vector<AVPtr> values_; | |
using iterator = std::vector<AVPtr>::iterator; | |
iterator find(Symbol name, bool required) { | |
auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; }); | |
ONNX_ASSERT(!required || it != values_.end()); | |
return it; | |
} | |
using const_iterator = std::vector<AVPtr>::const_iterator; | |
const_iterator find(Symbol name, bool required) const { | |
auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; }); | |
ONNX_ASSERTM( | |
!required || it != values_.end(), | |
"%s:%u: %s: required undefined attribute '%s'", | |
__FILE__, | |
__LINE__, | |
__func__, | |
name.toString()); | |
return it; | |
} | |
}; | |
// Each use is represented by this type, see Node::uses() | |
// 'user' is the consumer of the value, offset is the index into | |
// 'user's input this where the produces will be found. | |
struct Use final { | |
Use(Node* user, size_t offset) : user(user), offset(offset) {} | |
Node* user; | |
size_t offset; | |
}; | |
static inline bool operator==(const Use& a, const Use& b) { | |
return a.user == b.user && a.offset == b.offset; | |
} | |
// the list types are intentionally simple, but we type-def | |
// them here so if we need to change them, refactoring will be easier | |
using node_list = std::vector<Node*>; | |
using value_list = std::vector<Value*>; | |
using use_list = std::vector<Use>; | |
using NodeKind = Symbol; | |
struct Value final { | |
ONNX_DISALLOW_COPY_AND_ASSIGN(Value); | |
Value(Node* node_, size_t offset_); | |
Value(Value&&) = default; | |
Value& operator=(Value&&) = default; | |
~Value() = default; | |
private: | |
friend struct Node; | |
friend struct Graph; | |
Node* node_; | |
size_t offset_; | |
size_t unique_ = 0; // unique id | |
size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,... | |
use_list uses_in_current_graph_; | |
bool has_unique_name_; | |
std::string unique_name_; | |
int32_t elem_type_; | |
bool has_sizes_; | |
std::vector<Dimension> sizes_; | |
public: | |
Value* setElemType(int32_t elem_type) { | |
elem_type_ = elem_type; | |
return this; | |
} | |
int32_t elemType() const { | |
return elem_type_; | |
} | |
bool has_sizes() const { | |
return has_sizes_; | |
} | |
Value* setSizes(std::vector<Dimension> sizes) { | |
has_sizes_ = true; | |
sizes_ = std::move(sizes); | |
return this; | |
} | |
Value* wipeSizes() { | |
has_sizes_ = false; | |
sizes_ = std::vector<Dimension>(); | |
return this; | |
} | |
const std::vector<Dimension>& sizes() const { | |
return sizes_; | |
} | |
size_t unique() const { | |
return unique_; | |
} | |
bool has_unique_name() const { | |
return has_unique_name_; | |
} | |
std::string uniqueName() const { | |
if (has_unique_name()) | |
return unique_name_; | |
return toVarName(unique()); | |
} | |
Value* setUniqueName(const std::string& name, bool rename_subgraph_captured_nodes = true); | |
Value* setStage(size_t s) { | |
stage_ = s; | |
return this; | |
} | |
size_t stage() const { | |
return stage_; | |
} | |
Node* node() { | |
return node_; | |
} | |
size_t offset() const { | |
return offset_; | |
} | |
const Node* node() const { | |
return node_; | |
} | |
Graph* owningGraph(); | |
const Graph* owningGraph() const; | |
// TODO: make this more const correct | |
const use_list uses() const; | |
// Replaces all uses of this node with 'newValue'. | |
// | |
// Given: %3 = f(%1, %2) | |
// %4 = g(%3) | |
// %5 = h(%3, %3) | |
// Execute: %3.replaceAllUsesWith(%6) | |
// Result: %3 = f(%1, %2) | |
// %4 = g(%6) | |
// %5 = h(%6, %6) | |
void replaceAllUsesWith(Value* newValue); | |
Value* copyMetadata(Value* from) { | |
setElemType(from->elemType()); | |
setSizes(from->sizes()); | |
if (from->has_unique_name()) { | |
setUniqueName(from->uniqueName()); | |
} | |
return this; | |
} | |
}; | |
struct Node : public Attributes<Node> { | |
ONNX_DISALLOW_COPY_AND_ASSIGN(Node); | |
friend struct Graph; | |
friend struct Value; | |
friend graph_node_list; | |
friend const_graph_node_list; | |
friend graph_node_list_iterator; | |
friend const_graph_node_list_iterator; | |
private: | |
// each node but Return/Param | |
// is associated with exactly one place in the node list... | |
// of the graph_ | |
// this circular is a doubly-linked list, the Return node is used as the sentinel for the beginning and end of the | |
// list such that the list never has null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev pointer | |
// using an array to allow the same iterator class for forward and reverse node lists | |
// This list represents a topological sort | |
Node* next_in_graph[2] = {nullptr, nullptr}; | |
Node*& next() { | |
return next_in_graph[kNextDirection]; | |
} | |
Node*& prev() { | |
return next_in_graph[kPrevDirection]; | |
} | |
Node* const& next() const { | |
return next_in_graph[kNextDirection]; | |
} | |
Node* const& prev() const { | |
return next_in_graph[kPrevDirection]; | |
} | |
const NodeKind kind_; | |
std::vector<Value*> inputs_; | |
std::vector<Value*> outputs_; | |
Graph* graph_; | |
size_t stage_; | |
bool has_name_; | |
std::string name_; | |
bool has_domain_; | |
std::string domain_; | |
bool has_doc_string_; | |
std::string doc_string_; | |
bool has_overload_; | |
std::string overload_; | |
protected: | |
Node(Graph* graph_, NodeKind kind_); // defined after graph | |
public: | |
bool has_name() const { | |
return has_name_; | |
} | |
const std::string& name() const { | |
return name_; | |
} | |
void setName(std::string name) { | |
has_name_ = true; | |
name_ = std::move(name); | |
} | |
bool has_domain() const { | |
return has_domain_; | |
} | |
const std::string& domain() const { | |
return domain_; | |
} | |
void setDomain(std::string domain) { | |
has_domain_ = true; | |
domain_ = std::move(domain); | |
} | |
bool has_overload() const { | |
return has_overload_; | |
} | |
const std::string& overload() const { | |
return overload_; | |
} | |
void setOverload(std::string overload) { | |
has_overload_ = true; | |
overload_ = std::move(overload); | |
} | |
bool has_doc_string() const { | |
return has_doc_string_; | |
} | |
const std::string& docString() const { | |
return doc_string_; | |
} | |
void setDocString(std::string doc_string) { | |
has_doc_string_ = true; | |
doc_string_ = std::move(doc_string); | |
} | |
NodeKind kind() const { | |
return kind_; | |
} | |
Graph* owningGraph() { | |
return graph_; | |
} | |
const Graph* owningGraph() const { | |
return graph_; | |
} | |
size_t stage() const { | |
return stage_; | |
} | |
Node* setStage(size_t s) { | |
stage_ = s; | |
return this; | |
} | |
// NB: This returns an ArrayRef; that means that it will | |
// get invalidated if you resize inputs (e.g., using addInput) | |
// We can't return a std::vector<Node*>& because there's no | |
// way to soundly cast to std::vector<const Node*> (an insane | |
// implementation of std::vector could make this representationally | |
// different.) | |
ArrayRef<Value*> inputs() { | |
return inputs_; | |
} | |
ArrayRef<const Value*> inputs() const { | |
// Vectors are not convertible in const-ness of elements, but | |
// raw pointers are. | |
return {inputs_.data(), inputs_.size()}; | |
} | |
// NB: This returns an ArrayRef; that means that it will | |
// get invalidated if you resize inputs (e.g., using addInput) | |
// We can't return a std::vector<Node*>& because there's no | |
// way to soundly cast to std::vector<const Node*> (an insane | |
// implementation of std::vector could make this representationally | |
// different.) | |
ArrayRef<Value*> outputs() { | |
return outputs_; | |
} | |
ArrayRef<const Value*> outputs() const { | |
// Vectors are not convertible in const-ness of elements, but | |
// raw pointers are. | |
return {outputs_.data(), outputs_.size()}; | |
} | |
bool hasUses() const { | |
for (auto o : outputs()) { | |
if (!o->uses().empty()) | |
return true; | |
} | |
return false; | |
} | |
void replaceAllUsesWith(Node* n) { | |
ONNX_ASSERT(outputs().size() == n->outputs().size()); | |
size_t nOutputs = outputs().size(); | |
for (size_t i = 0; i < nOutputs; i++) { | |
outputs()[i]->replaceAllUsesWith(n->outputs()[i]); | |
} | |
} | |
// lots of things like chunk have a single input or single output, so we have a | |
// helper to make accessing it easier | |
Value* input() { | |
ONNX_ASSERT(inputs_.size() == 1); | |
return inputs_.at(0); | |
} | |
Value* output() { | |
ONNX_ASSERT(outputs_.size() == 1); | |
return outputs_.at(0); | |
} | |
const Value* input() const { | |
ONNX_ASSERT(inputs_.size() == 1); | |
return inputs_.at(0); | |
} | |
Value* output() const { | |
ONNX_ASSERT(outputs_.size() == 1); | |
return outputs_.at(0); | |
} | |
// Access a particular input. This is a checked index. | |
Value* input(size_t i) { | |
return inputs_.at(i); | |
} | |
const Value* input(size_t i) const { | |
return inputs_.at(i); | |
} | |
// Graphs | |
// Note [Topological invariant] | |
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
// We always maintain an up-to-date topological ordering of all nodes via | |
// the next()/prev() links. All transformations to graphs must preserve | |
// this topological ordering: for example, it is only valid to 'addInput' | |
// with an input which is topologically before the current node. | |
// | |
// Usually, it is obvious whether or not topological order is maintained; | |
// for example, if you are adding nodes to the end of the topsort, it's | |
// impossible for them to refer to inputs that are not in the topsort. | |
// If it is not obvious, please comment accordingly. | |
// Add 'node' as an input to 'this' at the end of existing | |
// arguments. Returns the added node for ease of chaining. | |
// | |
// Given: %3 = f(%1, %2) | |
// Execute: %3.addInput(%4) | |
// Result: %3 = f(%1, %2, %4) | |
Value* addInput(Value* node) { | |
ONNX_ASSERT(graph_ == node->owningGraph()); | |
node->uses_in_current_graph_.emplace_back(this, inputs_.size()); | |
inputs_.push_back(node); | |
return node; | |
} | |
// Replace the input of 'this' at position 'i' with | |
// 'newValue', returning the old node. | |
// | |
// Given: %3 = f(%1, %2) | |
// Execute: %3.replaceInput(1, %4) | |
// Result: %3 = f(%1, %4) | |
Value* replaceInput(size_t i, Value* newValue) { | |
ONNX_ASSERT(newValue->owningGraph() == graph_); | |
Value* old = dropInput(i); | |
inputs_[i] = newValue; | |
newValue->uses_in_current_graph_.emplace_back(this, i); | |
return old; | |
} | |
// Replace all occurrences of 'from' in the inputs of this | |
// node with 'to'. Corresponds to llvm's replaceUsesOfWith. | |
// | |
// Given: %3 = f(%1, %2, %1) | |
// Execute: %3.replaceInputWith(%1, %4) | |
// Result: %3 = f(%4, %2, %4) | |
void replaceInputWith(Value* from, Value* to) { | |
ONNX_ASSERT(from->owningGraph() == graph_); | |
ONNX_ASSERT(to->owningGraph() == graph_); | |
size_t i = 0; | |
for (auto input : inputs()) { | |
if (input == from) | |
replaceInput(i, to); | |
i++; | |
} | |
} | |
Value* addOutput() { | |
outputs_.push_back(new Value(this, outputs_.size())); | |
return outputs_.back(); | |
} | |
void eraseOutput(size_t i); | |
// Insert unattached 'this' node after 'n' in the topological order. | |
// Returns this (for chaining). | |
// | |
// Given: %3 = f(%1, %2) | |
// %4 = g(%3) | |
// and unattached: %5 = h(%1) | |
// Execute: %5.insertBefore(%4) | |
// Result: %3 = f(%1, %2) | |
// %5 = h(%1) | |
// %4 = g(%3) | |
Node* insertBefore(Node* n) { | |
ONNX_ASSERT(n->inGraphList()); | |
insertAfter(n->prev()); | |
return this; | |
} | |
// Insert unattached 'this' node after 'n' in the topological order. | |
// Returns this (for chaining). | |
// | |
// Given: %3 = f(%1, %2) | |
// %4 = g(%3) | |
// and unattached: %5 = h(%1) | |
// Execute: %5.insertAfter(%4) | |
// Result: %3 = f(%1, %2) | |
// %4 = g(%3) | |
// %5 = h(%1) | |
Node* insertAfter(Node* n) { | |
ONNX_ASSERT(!inGraphList() && n->inGraphList()); | |
Node* next = n->next(); | |
n->next() = this; | |
this->prev() = n; | |
this->next() = next; | |
next->prev() = this; | |
return this; | |
} | |
// Move 'this' (already in the graph) after 'n' in the topological order. | |
// | |
// Given: %2 = f(%1) | |
// %3 = g(%1) | |
// Execute: %2.moveAfter(%3) | |
// Result: %3 = g(%1) | |
// %2 = f(%1) | |
// | |
void moveAfter(Node* n) { | |
removeFromList(); | |
insertAfter(n); | |
} | |
// Move a node 'n' (already in the graph) before 'this' in the topological order. | |
// | |
// Given: %2 = f(%1) | |
// %3 = g(%1) | |
// Execute: %3.moveBefore(%2) | |
// Result: %3 = g(%1) | |
// %2 = f(%1) | |
void moveBefore(Node* n) { | |
removeFromList(); | |
insertBefore(n); | |
} | |
// Remove the input at 'i' from this node. | |
// | |
// WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling | |
// removeInput. | |
// | |
// Given: %3 = f(%1, %2) | |
// Execute: %3.removeInput(1) | |
// Result: %3 = f(%1) | |
void removeInput(size_t i) { | |
dropInput(i); | |
// everything after this input shifts left, | |
// so we need to update their use offsets to match | |
for (size_t j = i + 1; j < inputs_.size(); j++) { | |
auto it = findUseForInput(j); | |
it->offset--; | |
} | |
inputs_.erase(inputs_.begin() + i); | |
} | |
// Remove all inputs from a node. | |
// | |
// Given: %3 = f(%1, %2) | |
// Execute: %3.removeAllInputs() | |
// Result: %3 = f() | |
void removeAllInputs() { | |
for (size_t i = 0; i < inputs().size(); ++i) | |
dropInput(i); | |
inputs_.clear(); | |
} | |
// Check whether this node is before node n in the graph. | |
bool isBefore(Node* n); | |
// iterators of the node list starting at this node | |
// useful for resuming a search starting at this node | |
graph_node_list_iterator iterator(); | |
graph_node_list_iterator reverseIterator(); | |
const_graph_node_list_iterator iterator() const; | |
const_graph_node_list_iterator reverseIterator() const; | |
// Remove 'this' from the instruction list and deallocate it. | |
// | |
// Invariant: no outputs of 'this' may have any uses. | |
// | |
// Given: %2 = f(%1) | |
// %3 = g(%1) | |
// Execute: %2.destroy() | |
// Result: %3 = g(%1) | |
void destroy(); | |
// Dynamically cast this node to the subclass indicated by the | |
// template variable, returning nullptr if the cast is invalid.. | |
// | |
// Example usage: if(auto s = n.cast<Select>()) { ... } | |
// | |
// TODO: Make this const correct | |
template <typename T> | |
T* cast() { | |
if (T::Kind == kind()) | |
return static_cast<T*>(this); | |
return nullptr; | |
} | |
template <typename T> | |
T* expect() { | |
ONNX_ASSERTM(T::Kind == kind(), "expected a %s but found a %s", T::Kind.toString(), kind().toString()); | |
return static_cast<T*>(this); | |
} | |
virtual ~Node() = default; | |
private: | |
// Lookup iterator in use list of _input i_ that corresponds to its use of _this_ | |
use_list::iterator findUseForInput(size_t i) { | |
auto& input_uses = inputs_[i]->uses_in_current_graph_; | |
// O(N) on the use list, but unless we get nodes with +100 uses | |
// vector traversal still is probably faster than linked list | |
auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i)); | |
ONNX_ASSERT(use_it != input_uses.end()); | |
return use_it; | |
} | |
// remove the use of input i, this sets input i to nullptr, but | |
// is only used internally to Node before setting it to a new value | |
// or erasing the entry from the list. | |
Value* dropInput(size_t i) { | |
ONNX_ASSERT(i < inputs_.size()); | |
auto input_node = inputs_[i]; | |
auto use_it = findUseForInput(i); | |
input_node->uses_in_current_graph_.erase(use_it); | |
inputs_[i] = nullptr; | |
return input_node; | |
} | |
bool inGraphList() const { | |
ONNX_ASSERT(next() != nullptr || prev() == nullptr); | |
return next() != nullptr; | |
} | |
void removeFromList() { | |
ONNX_ASSERT(inGraphList()); | |
Node* next = this->next(); | |
Node* prev = this->prev(); | |
prev->next() = next; | |
next->prev() = prev; | |
this->next() = nullptr; | |
this->prev() = nullptr; | |
} | |
protected: | |
// subclasses must override | |
// this function is used by createClone to initialize a new version | |
// of a node in another graph. It should allocate a new instance of the same | |
// concrete type as 'this', but in graph 'g' which might be different | |
// than graph_ | |
virtual Node* allocNewInstance(Graph* g) { | |
return new Node(g, kind()); | |
} | |
// create a copy of all properties of Node s into this. | |
// subclasses should extend if they have additional information to copy. | |
// 'this' will be allocated with s->allocNewInstance(g) so it should have | |
// the same concrete type as 's' | |
// | |
// NB: This does NOT clone stages. You're expected to set the stage correctly | |
// if you are going to preserve it. | |
virtual void cloneFrom(Node* s) { | |
copyAttributes(*s); | |
} | |
}; | |
// A class with the same properties as OperatorSetIdProto, but without protobuf | |
// overhead, resulting in a simpler and more readable workflow. | |
class OpSetID final { | |
private: | |
std::string domain_; | |
int64_t version_; | |
public: | |
explicit OpSetID(const OperatorSetIdProto& proto) : domain_(proto.domain()), version_(proto.version()) {} | |
// Default Domain Constructor | |
explicit OpSetID(const int64_t version) : domain_(""), version_(version) {} | |
explicit OpSetID(const std::string& domain, int64_t version) : domain_(domain), version_(version) {} | |
// target must be in the form "<domain>&<version>" | |
std::string toString() const { | |
return domain_ + "$" + ONNX_NAMESPACE::to_string(version_); | |
} | |
// target must be in the form "<domain>&<version>" | |
static OpSetID fromString(const std::string& target) { | |
ONNX_TRY { | |
std::string new_domain = target.substr(0, target.find("$")); | |
int new_version = ONNX_NAMESPACE::stoi(target.substr(target.find("$") + 1, target.length()).c_str()); | |
return OpSetID(new_domain, new_version); | |
} | |
ONNX_CATCH(const std::runtime_error& e) { | |
ONNX_HANDLE_EXCEPTION([&]() { ONNX_ASSERTM(false, "Error in fromString: %s", e.what()); }); | |
} | |
// The control will never reach here. | |
// In the default build where exceptions are turned on in case of any error | |
// the control will enter catch block where an exception will be thrown again. | |
// In case of "no exception build" the code aborts at the site of first exception. | |
// Adding this to appease the warning "control may reach end of non-void function" | |
// as the mac build fails when ONNX_WERROR==ON | |
return OpSetID("", 0); | |
} | |
const std::string& domain() const { | |
return domain_; | |
} | |
int64_t version() const { | |
return version_; | |
} | |
void incrementVersion(int64_t step) { | |
version_ += step; | |
} | |
void setVersion(int64_t newVal) { | |
version_ = newVal; | |
} | |
}; | |
struct Graph final { | |
ONNX_DISALLOW_COPY_AND_ASSIGN(Graph); | |
friend struct Node; | |
friend struct Value; | |
private: | |
// only used to keep track of allocated nodes | |
// actual representation of Graph is done with | |
// inputs, outputs, nodes | |
std::unordered_set<const Node*> all_nodes; | |
std::unordered_set<const Value*> all_values; | |
size_t next_unique_; | |
size_t new_node_stage_; | |
// holds outputs in a way that can be reflected | |
// as a Use object | |
// also used as the beginning/end of the circular node list to avoid | |
// having corner cases where the list is empty. | |
Node* const output_; | |
Node* const input_; | |
// Create an independent node list for those initializers do not exist in input | |
Node* const initializer_node_; | |
std::vector<Tensor> initializers_; | |
std::vector<std::string> initializer_names_; | |
bool has_name_; | |
std::string name_; | |
bool has_doc_string_; | |
std::string doc_string_; | |
std::vector<OpSetID> opset_versions_; | |
bool isNameUnique(const std::string& name) const { | |
if (std::find(initializer_names_.cbegin(), initializer_names_.cend(), name) != initializer_names_.cend()) { | |
return false; | |
} | |
const auto f = [&name](const Value* v) { return v->uniqueName() == name; }; | |
for (const Node* node : all_nodes) { | |
for (const auto& attr : node->attributeNames()) { | |
if (node->kindOf(attr) == AttributeKind::g) { | |
const auto& subgraph = node->g(attr); | |
if (!subgraph->isNameUnique(name)) { | |
return false; | |
} | |
} else if (node->kindOf(attr) == AttributeKind::gs) { | |
for (const auto& subgraph : node->gs(attr)) { | |
if (!subgraph->isNameUnique(name)) { | |
return false; | |
} | |
} | |
} | |
} | |
const auto found_in = std::find_if(node->inputs().begin(), node->inputs().end(), f); | |
if (found_in != node->inputs().end()) { | |
return false; | |
} | |
const auto found_out = std::find_if(node->outputs().begin(), node->outputs().end(), f); | |
if (found_out != node->outputs().end()) { | |
return false; | |
} | |
} | |
return true; | |
} | |
public: | |
Graph() | |
: next_unique_(0), | |
new_node_stage_(0), | |
output_(initOutput(create(kReturn, 0))), | |
input_(create(kParam, 0)), | |
initializer_node_(create(kParam, 0)), | |
has_name_(false), | |
has_doc_string_(false) {} | |
bool has_doc_string() const { | |
return has_doc_string_; | |
} | |
const std::string& docString() { | |
return doc_string_; | |
} | |
void setDocString(std::string doc_string) { | |
has_doc_string_ = true; | |
doc_string_ = std::move(doc_string); | |
} | |
void addInitializer(Tensor& initializer) { | |
if (initializer.name().empty()) { | |
initializer.setName(toVarName(getNextUnique())); | |
} | |
initializers_.push_back(initializer); | |
initializer_names_.push_back(initializer.name()); | |
} | |
// For IR >= 4, initializer is not required to exist in input | |
// Add initializer into initializer node list and return its Value | |
Value* addInitializerAndCreateValue(Tensor& initializer) { | |
addInitializer(initializer); | |
auto* init_value = initializer_node_->addOutput(); | |
std::vector<Dimension> dim_sizes{initializer.sizes().cbegin(), initializer.sizes().cend()}; | |
init_value->setUniqueName(initializer.name()); | |
init_value->setSizes(dim_sizes); | |
init_value->setElemType(initializer.elem_type()); | |
return init_value; | |
} | |
void eraseInitializer(const std::string& name) { | |
initializers_.erase( | |
std::remove_if( | |
initializers_.begin(), | |
initializers_.end(), | |
[&name](Tensor& initializer) { return initializer.name() == name; }), | |
initializers_.end()); | |
initializer_names_.erase( | |
std::remove(initializer_names_.begin(), initializer_names_.end(), name), initializer_names_.end()); | |
for (size_t i = 0; i < initializer_node_->outputs().size(); i++) { | |
if (initializer_node_->outputs()[i]->uniqueName() == name) { | |
initializer_node_->eraseOutput(i); | |
break; | |
} | |
} | |
} | |
void clearInitializers() { | |
initializers_.clear(); | |
initializer_names_.clear(); | |
} | |
const std::vector<Tensor>& initializers() const { | |
return initializers_; | |
} | |
const std::vector<std::string>& initializer_names() const { | |
return initializer_names_; | |
} | |
std::vector<Tensor>::const_iterator getInitializer(const std::string& name) const { | |
for (auto it = initializers_.cbegin(); it != initializers_.cend(); ++it) { | |
if (name == it->name()) { | |
return it; | |
} | |
} | |
return initializers_.end(); | |
} | |
bool is_constant_initializer(const Value* value) const { | |
return value->node() == initializer_node_; | |
} | |
ArrayRef<Value*> inputs() { | |
return input_->outputs(); | |
} | |
ArrayRef<const Value*> inputs() const { | |
const auto& inputs = input_->outputs(); | |
return {inputs.data(), inputs.size()}; | |
} | |
ArrayRef<Value*> outputs() { | |
return output_->inputs(); | |
} | |
ArrayRef<const Value*> outputs() const { | |
return static_cast<const Node*>(output_)->inputs(); | |
} | |
graph_node_list nodes() { | |
return graph_node_list(output_, kNextDirection); | |
} | |
const_graph_node_list nodes() const { | |
return const_graph_node_list(output_, kNextDirection); | |
} | |
std::vector<OpSetID>& opset_versions_mutable() { | |
return opset_versions_; | |
} | |
size_t getNextUnique() { | |
std::string next_unique_name = toVarName(++next_unique_); | |
while (!isNameUnique(next_unique_name)) { | |
next_unique_name = toVarName(++next_unique_); | |
} | |
return next_unique_; | |
} | |
// These invocations of begin() on output of function are OK | |
// because graph_node_list is non-owning, so it doesn't matter | |
// if it immediately dies after the invocation. | |
graph_node_list_iterator begin() { | |
return nodes().begin(); | |
} | |
const_graph_node_list_iterator begin() const { | |
return nodes().begin(); | |
} | |
graph_node_list_iterator end() { | |
return nodes().end(); | |
} | |
const_graph_node_list_iterator end() const { | |
return nodes().end(); | |
} | |
graph_node_list_iterator rbegin() { | |
return nodes().rbegin(); | |
} | |
const_graph_node_list_iterator rbegin() const { | |
return nodes().rbegin(); | |
} | |
graph_node_list_iterator rend() { | |
return nodes().rend(); | |
} | |
const_graph_node_list_iterator rend() const { | |
return nodes().rend(); | |
} | |
Node* return_node() { | |
return output_; | |
} | |
const Node* return_node() const { | |
return output_; | |
} | |
Value* addInput() { | |
return input_->addOutput(); | |
} | |
void eraseInput(size_t i) { | |
input_->eraseOutput(i); | |
} | |
void advanceStage() { | |
new_node_stage_++; | |
} | |
void setStage(size_t new_stage) { | |
new_node_stage_ = new_stage; | |
} | |
size_t stage() const { | |
return new_node_stage_; | |
} | |
ResourceGuard setStageTemporary(size_t s) { | |
auto prev_stage = new_node_stage_; | |
new_node_stage_ = s; | |
return ResourceGuard([prev_stage, this]() { this->new_node_stage_ = prev_stage; }); | |
} | |
size_t registerOutput(Value* n) { | |
output_->addInput(n); | |
return outputs().size() - 1; | |
} | |
Node* create(NodeKind kind, size_t num_outputs = 1) { | |
// NB: Node constructor adds node to all_nodes | |
auto n = new Node(this, kind); | |
for (size_t i = 0; i < num_outputs; i++) | |
n->addOutput(); | |
return n; | |
} | |
Node* create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs = 1) { | |
auto n = create(kind, num_outputs); | |
for (auto i : inputs) | |
n->addInput(i); | |
return n; | |
} | |
Node* appendNode(Node* n) { | |
ONNX_ASSERT(n->graph_ == this && !n->inGraphList()); | |
n->insertBefore(output_); | |
return n; | |
} | |
Node* prependNode(Node* n) { | |
ONNX_ASSERT(n->graph_ == this && !n->inGraphList()); | |
n->insertAfter(output_); | |
return n; | |
} | |
// Adds to graph initializer list, initializer names list, and as a graph input | |
// Also syncs the initializer name, tensor name, and value name | |
// Create an initializer whose value is stored in input | |
Value* addInitializerAndInput(const Tensor& initializer, const std::string& name) { | |
Tensor initializerCopy = initializer; | |
std::vector<Dimension> dim_sizes{initializerCopy.sizes().cbegin(), initializerCopy.sizes().cend()}; | |
Value* new_init = addInput(); | |
initializerCopy.setName(name); | |
new_init->setUniqueName(name); | |
new_init->setSizes(dim_sizes); | |
new_init->setElemType(initializerCopy.elem_type()); | |
addInitializer(initializerCopy); | |
return new_init; | |
} | |
Value* addInitializerAndInput(const Tensor& initializer) { | |
return addInitializerAndInput(initializer, toVarName(getNextUnique())); | |
} | |
// Erases from graph initializer list, initializer names list, and as a graph input | |
// Must have no uses | |
void eraseInitializerAndInput(Value* v) { | |
eraseInitializer(v->uniqueName()); | |
if (v->node() == input_) { | |
eraseInput(v->offset()); | |
} | |
} | |
~Graph() { | |
for (const Node* n : all_nodes) | |
delete n; | |
for (const Value* v : all_values) | |
delete v; | |
} | |
std::string toString() const { | |
std::ostringstream oss; | |
oss << *this; | |
return oss.str(); | |
} | |
bool has_name() const { | |
return has_name_; | |
} | |
const std::string& name() const { | |
return name_; | |
} | |
void setName(std::string name) { | |
has_name_ = true; | |
name_ = std::move(name); | |
} | |
friend std::ostream& operator<<(std::ostream& out, const Graph& g); | |
void forSelfAndEachSubGraph(const std::function<void(Graph*)>& fn) { | |
fn(this); | |
for (const Node* node : all_nodes) { | |
for (const auto& attr : node->attributeNames()) { | |
if (node->kindOf(attr) == AttributeKind::g) { | |
std::shared_ptr<Graph> subgraph = node->g(attr); | |
subgraph->forSelfAndEachSubGraph(fn); | |
} else if (node->kindOf(attr) == AttributeKind::gs) { | |
for (const auto& subgraph : node->gs(attr)) { | |
subgraph->forSelfAndEachSubGraph(fn); | |
} | |
} | |
} | |
} | |
} | |
void forSelfAndEachSubGraph(const std::function<void(const Graph*)>& fn) const { | |
std::function<void(Graph*)> tmp_fn = [fn](Graph* graph) { fn(graph); }; | |
const_cast<Graph*>(this)->forSelfAndEachSubGraph(tmp_fn); | |
} | |
void forEachNode(const std::function<void(Node*)>& fn) { | |
forSelfAndEachSubGraph([fn](Graph* graph) { | |
for (Node* node : graph->nodes()) { | |
fn(node); | |
} | |
}); | |
} | |
void forEachNode(const std::function<void(const Node*)>& fn) const { | |
std::function<void(Node*)> tmp_fn = [fn](Node* node) { fn(node); }; | |
const_cast<Graph*>(this)->forEachNode(tmp_fn); | |
} | |
private: | |
// should only be called in the constructor | |
Node* initOutput(Node* p) { | |
p->next() = p; | |
p->prev() = p; | |
p->setStage(std::numeric_limits<size_t>::max()); | |
return p; | |
} | |
void freeNode(Node* n) { | |
auto it = all_nodes.find(n); | |
ONNX_ASSERT(it != all_nodes.end()); | |
delete *it; | |
all_nodes.erase(it); | |
} | |
void freeValue(Value* v) { | |
auto it = all_values.find(v); | |
ONNX_ASSERT(it != all_values.end()); | |
delete *it; | |
all_values.erase(it); | |
} | |
}; | |
inline Value::Value(Node* node_, size_t offset_) | |
: node_(node_), | |
offset_(offset_), | |
unique_(node_->graph_->getNextUnique()), | |
stage_(node_->graph_->new_node_stage_), | |
has_unique_name_(false), | |
elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED), | |
has_sizes_(false) { | |
node_->graph_->all_values.emplace(this); | |
} | |
inline Graph* Value::owningGraph() { | |
return node()->owningGraph(); | |
} | |
inline const Graph* Value::owningGraph() const { | |
return node()->owningGraph(); | |
} | |
// `captured` nodes in subgraph determines which value it captures | |
// by storing the value's unique name, so old unique names in `captured` nodes | |
// should also be updated. | |
// Initializer names are also storaged in graph.initializer_names_, it should be | |
// updated too. | |
inline Value* Value::setUniqueName(const std::string& name, bool update_related_names) { | |
if (has_unique_name() && update_related_names) { | |
auto* graph = owningGraph(); | |
auto old_name = unique_name_; | |
for (size_t i = 0; i < owningGraph()->initializer_names_.size(); i++) { | |
auto& initializer_name = owningGraph()->initializer_names_[i]; | |
if (initializer_name == old_name) { | |
initializer_name = name; | |
owningGraph()->initializers_[i].setName(name); | |
} | |
} | |
graph->forEachNode([this, &name, &old_name](Node* node) { | |
if (node->owningGraph() == this->owningGraph()) { | |
// skip non-subgraph | |
return; | |
} | |
if (node->kind() == kCaptured) { | |
Value* output = node->output(); | |
if (output->uniqueName() == old_name) { | |
output->setUniqueName(name, false); | |
} | |
} | |
}); | |
} | |
unique_name_ = name; | |
has_unique_name_ = true; | |
return this; | |
} | |
inline void Value::replaceAllUsesWith(Value* newValue) { | |
auto* graph = owningGraph(); | |
ONNX_ASSERT(graph == newValue->owningGraph()); | |
// propagate sizes and elem type | |
if (this->has_sizes()) { | |
newValue->setSizes(this->sizes()); | |
} | |
if (this->elemType() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) { | |
newValue->setElemType(this->elemType()); | |
} | |
const auto unique_name = this->uniqueName(); | |
// We do not want the optimization to change the graph output name | |
if (std::find(graph->outputs().rbegin(), graph->outputs().rend(), this) != graph->outputs().rend()) { | |
newValue->setUniqueName(unique_name); | |
// The "unique" semantic of unique_name should be kept or uses() | |
// will return an incorrect result when the value is used in subgraph | |
this->setUniqueName(toVarName(graph->getNextUnique()), false); | |
} | |
newValue->uses_in_current_graph_.reserve(this->uses_in_current_graph_.size()); | |
for (auto u : uses_in_current_graph_) { | |
u.user->inputs_[u.offset] = newValue; | |
newValue->uses_in_current_graph_.push_back(u); | |
} | |
graph->forEachNode([this, &newValue, &unique_name](Node* node) { | |
if (node->owningGraph() == this->owningGraph()) { | |
// skip non-subgraph | |
return; | |
} | |
if (node->kind() == kCaptured) { | |
Value* output = node->output(); | |
if (output->uniqueName() == unique_name) { | |
output->setUniqueName(newValue->uniqueName()); | |
} | |
} | |
}); | |
uses_in_current_graph_.clear(); | |
assert(this->uses().empty()); | |
} | |
inline Node::Node(Graph* graph_, NodeKind kind_) | |
: kind_(kind_), | |
graph_(graph_), | |
stage_(graph_->new_node_stage_), | |
has_name_(false), | |
has_domain_(false), | |
has_doc_string_(false), | |
has_overload_(false) { | |
graph_->all_nodes.emplace(this); | |
} | |
inline void Node::eraseOutput(size_t i) { | |
ONNX_ASSERT(i < outputs_.size()); | |
ONNX_ASSERT(outputs_[i]->uses().empty()); | |
Value* n = outputs_[i]; | |
outputs_.erase(outputs_.begin() + i); | |
owningGraph()->freeValue(n); | |
for (size_t j = i; j < outputs_.size(); j++) { | |
outputs_[j]->offset_--; | |
} | |
} | |
inline bool Node::isBefore(Node* n) { | |
if (n == nullptr || this == n) { | |
// Bail out early. | |
return false; | |
} | |
// return true if node is Param (in initializers) | |
if (kind_ == kParam) { | |
return true; | |
} | |
// return false if target node is Param (in initializers) | |
if (n->kind() == kParam) { | |
return false; | |
} | |
ONNX_ASSERT(n->inGraphList()); | |
for (Node* p = next(); p != *graph_->end(); p = p->next()) { | |
if (p == n) { | |
return true; | |
} | |
} | |
return false; | |
} | |
inline void Node::destroy() { | |
ONNX_ASSERT(inGraphList()); | |
while (!outputs().empty()) | |
eraseOutput(outputs().size() - 1); | |
removeAllInputs(); | |
removeFromList(); | |
graph_->freeNode(this); | |
} | |
/************* All nodes not required to be defined before Graph **************/ | |
inline graph_node_list_iterator Node::iterator() { | |
return graph_node_list_iterator(this, 0); | |
} | |
inline graph_node_list_iterator Node::reverseIterator() { | |
return iterator().reverse(); | |
} | |
inline const_graph_node_list_iterator Node::iterator() const { | |
return const_graph_node_list_iterator(this, 0); | |
} | |
inline const_graph_node_list_iterator Node::reverseIterator() const { | |
return iterator().reverse(); | |
} | |
// Returns a list about which nodes are using this value, | |
// nodes in subgraph are also included. | |
// This method is usually used to check whether it is | |
// safe to delete a Value. | |
inline const use_list Value::uses() const { | |
use_list all_uses = uses_in_current_graph_; | |
owningGraph()->forEachNode([this, &all_uses](const Node* node) { | |
if (node->owningGraph() == this->owningGraph()) { | |
// skip non-subgraph | |
return; | |
} | |
if (node->kind() == kCaptured) { | |
const Value* output = node->outputs()[0]; | |
if (output->uniqueName() == this->uniqueName()) { | |
const auto output_uses = output->uses(); | |
all_uses.insert(all_uses.end(), output_uses.begin(), output_uses.end()); | |
} | |
} | |
}); | |
return all_uses; | |
} | |
} // namespace ONNX_NAMESPACE | |