File size: 2,446 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// Copyright (c) ONNX Project Contributors

/*

 * SPDX-License-Identifier: Apache-2.0

 */

#pragma once
#include "onnx/common/visitor.h"

namespace ONNX_NAMESPACE {
namespace internal { // internal/private API

using AttributeMap = std::unordered_map<std::string, const AttributeProto*>;

// Class for binding formal attribute-parameters (in a node or graph) to their values.

class AttributeBinder : public MutableVisitor {
 public:
  AttributeBinder(const AttributeMap& attr_map) : attr_map_(attr_map) {}

  // Binding a formal attribute-parameter to a value may, as a special case, also
  // remove the attribute from the list of attributes of a node (when the attribute
  // has no specified value). Hence, we need to do the processing at a Node level
  // rather than an attribute level.
  void VisitNode(NodeProto* node) override {
    auto& attributes = *node->mutable_attribute();
    for (auto attr_iter = attributes.begin(); attr_iter != attributes.end();) {
      auto& attr = *attr_iter;
      if (!attr.ref_attr_name().empty()) {
        // Attribute-references must be replaced by the corresponding attribute-value in the call-node
        // if the call-node contains the attribute. Otherwise, this attribute must be removed.
        auto it = attr_map_.find(attr.ref_attr_name());
        if (it != attr_map_.end()) {
          const AttributeProto* replacement = it->second;
          // Copy value of attribute, but retain original name:
          std::string name = attr.name();
          attr = *replacement;
          attr.set_name(name);
          ++attr_iter;
        } else {
          attr_iter = attributes.erase(attr_iter);
        }
      } else {
        // For regular attributes, we process subgraphs, if present, recursively.
        VisitAttribute(&attr);
        ++attr_iter;
      }
    }
  }

  // Updates a FunctionProto by replacing all attribute-references with the corresponding
  // attribute-values in the call-node, if present. Otherwise, the attribute is removed.
  static void BindAttributes(const NodeProto& callnode, FunctionProto& callee) {
    AttributeMap map;
    for (auto& attr : callnode.attribute()) {
      map[attr.name()] = &attr;
    }
    AttributeBinder attr_binder(map);
    attr_binder.VisitFunction(&callee);
  }

 private:
  const AttributeMap& attr_map_;
};

} // namespace internal
} // namespace ONNX_NAMESPACE