Spaces:
Sleeping
Sleeping
File size: 2,446 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include "onnx/common/visitor.h"
namespace ONNX_NAMESPACE {
namespace internal { // internal/private API
using AttributeMap = std::unordered_map<std::string, const AttributeProto*>;
// Class for binding formal attribute-parameters (in a node or graph) to their values.
class AttributeBinder : public MutableVisitor {
public:
AttributeBinder(const AttributeMap& attr_map) : attr_map_(attr_map) {}
// Binding a formal attribute-parameter to a value may, as a special case, also
// remove the attribute from the list of attributes of a node (when the attribute
// has no specified value). Hence, we need to do the processing at a Node level
// rather than an attribute level.
void VisitNode(NodeProto* node) override {
auto& attributes = *node->mutable_attribute();
for (auto attr_iter = attributes.begin(); attr_iter != attributes.end();) {
auto& attr = *attr_iter;
if (!attr.ref_attr_name().empty()) {
// Attribute-references must be replaced by the corresponding attribute-value in the call-node
// if the call-node contains the attribute. Otherwise, this attribute must be removed.
auto it = attr_map_.find(attr.ref_attr_name());
if (it != attr_map_.end()) {
const AttributeProto* replacement = it->second;
// Copy value of attribute, but retain original name:
std::string name = attr.name();
attr = *replacement;
attr.set_name(name);
++attr_iter;
} else {
attr_iter = attributes.erase(attr_iter);
}
} else {
// For regular attributes, we process subgraphs, if present, recursively.
VisitAttribute(&attr);
++attr_iter;
}
}
}
// Updates a FunctionProto by replacing all attribute-references with the corresponding
// attribute-values in the call-node, if present. Otherwise, the attribute is removed.
static void BindAttributes(const NodeProto& callnode, FunctionProto& callee) {
AttributeMap map;
for (auto& attr : callnode.attribute()) {
map[attr.name()] = &attr;
}
AttributeBinder attr_binder(map);
attr_binder.VisitFunction(&callee);
}
private:
const AttributeMap& attr_map_;
};
} // namespace internal
} // namespace ONNX_NAMESPACE
|