// Copyright (c) ONNX Project Contributors // // SPDX-License-Identifier: Apache-2.0 #include "onnx/shape_inference/implementation.h" #include #include #include #include #include #include #include #include "onnx/checker.h" #include "onnx/common/common.h" #include "onnx/common/file_utils.h" #include "onnx/defs/data_type_utils.h" #include "onnx/proto_utils.h" #include "onnx/shape_inference/attribute_binder.h" #include "onnx/string_utils.h" namespace ONNX_NAMESPACE { namespace shape_inference { namespace { std::string GetValueCaseString(const TypeProto& type) { switch (type.value_case()) { case TypeProto::ValueCase::kTensorType: return "tensor_type"; case TypeProto::ValueCase::kSequenceType: return "sequence_type"; case TypeProto::ValueCase::kMapType: return "map_type"; case TypeProto::ValueCase::kOptionalType: return "optional_type"; #ifdef ONNX_ML case TypeProto::ValueCase::kOpaqueType: return "opaque_type"; #endif case TypeProto::ValueCase::kSparseTensorType: return "sparse_tensor_type"; case TypeProto::ValueCase::VALUE_NOT_SET: return "NOT_SET"; } return ONNX_NAMESPACE::to_string(type.value_case()); } std::string GetElemTypeString(const TypeProto_Tensor& type) { #ifndef ONNX_USE_LITE_PROTO std::string type_str = TensorProto::DataType_Name(static_cast(type.elem_type())); if (!type_str.empty()) { return type_str; } #endif return ONNX_NAMESPACE::to_string(type.elem_type()); } std::string GetElemTypeString(const TypeProto_SparseTensor& type) { #ifndef ONNX_USE_LITE_PROTO std::string type_str = TensorProto::DataType_Name(static_cast(type.elem_type())); if (!type_str.empty()) { return type_str; } #endif return ONNX_NAMESPACE::to_string(type.elem_type()); } inline bool IsOnnxDomainOp(const NodeProto& node, const std::string& op_type) { return (IsOnnxDomain(node.domain()) && (node.op_type() == op_type)); } } // namespace template void CheckTensorShapesAndTypes(const T& inferred_type, const T& existing_type) { if (inferred_type.elem_type() != TensorProto::UNDEFINED && existing_type.elem_type() != TensorProto::UNDEFINED && existing_type.elem_type() != inferred_type.elem_type()) { std::stringstream ss; ss << "Inferred elem type differs from existing elem type: (" << GetElemTypeString(inferred_type) << ") vs (" << GetElemTypeString(existing_type) << ")"; fail_type_inference(ss.str()); } if (!inferred_type.has_shape() || !existing_type.has_shape()) { return; } if (inferred_type.shape().dim_size() != existing_type.shape().dim_size()) { std::stringstream ss; ss << "Inferred shape and existing shape differ in rank: (" << inferred_type.shape().dim_size() << ") vs (" << existing_type.shape().dim_size() << ")"; fail_shape_inference(ss.str()); } for (int i = 0; i < inferred_type.shape().dim_size(); ++i) { const auto& inferred_dim = inferred_type.shape().dim(i); const auto& existing_dim = existing_type.shape().dim(i); if (inferred_dim.has_dim_value() && existing_dim.has_dim_value() && inferred_dim.dim_value() != existing_dim.dim_value()) { std::stringstream ss; ss << "Inferred shape and existing shape differ in dimension " << i << ": (" << inferred_dim.dim_value() << ") vs (" << existing_dim.dim_value() << ")"; fail_shape_inference(ss.str()); } } } void checkShapesAndTypes(const TypeProto& inferred_type, const TypeProto& existing_type) { const auto inferred_value_case = inferred_type.value_case(); const auto existing_value_case = existing_type.value_case(); if (inferred_value_case == TypeProto::ValueCase::VALUE_NOT_SET || existing_value_case == TypeProto::ValueCase::VALUE_NOT_SET) { // nothing to check; will assign inferredType to undefined existingType return; } if (inferred_value_case != existing_value_case) { fail_type_inference( "type case mismatch. existing=", GetValueCaseString(existing_type), " inferred=", GetValueCaseString(inferred_type)); } if (inferred_value_case == TypeProto::kTensorType && existing_value_case == TypeProto::kTensorType) { CheckTensorShapesAndTypes(inferred_type.tensor_type(), existing_type.tensor_type()); } else if ( inferred_value_case == TypeProto::kSparseTensorType && existing_value_case == TypeProto::kSparseTensorType) { CheckTensorShapesAndTypes(inferred_type.sparse_tensor_type(), existing_type.sparse_tensor_type()); } else if (inferred_value_case == TypeProto::kSequenceType && existing_value_case == TypeProto::kSequenceType) { checkShapesAndTypes(inferred_type.sequence_type().elem_type(), existing_type.sequence_type().elem_type()); } else if (inferred_value_case == TypeProto::kOptionalType && existing_value_case == TypeProto::kOptionalType) { checkShapesAndTypes(inferred_type.optional_type().elem_type(), existing_type.optional_type().elem_type()); } else if ( inferred_value_case == TypeProto::TypeProto::kMapType && existing_value_case == TypeProto::TypeProto::kMapType) { if (inferred_type.map_type().key_type() != existing_type.map_type().key_type()) { fail_type_inference( "key type mismatch from MapProto. existing=", Utils::DataTypeUtils::ToDataTypeString(existing_type.map_type().key_type()), " inferred=", Utils::DataTypeUtils::ToDataTypeString(inferred_type.map_type().key_type())); } checkShapesAndTypes(inferred_type.map_type().value_type(), existing_type.map_type().value_type()); } else { fail_type_inference("type case unsupported. existing=", existing_value_case, " inferred=", inferred_value_case); } } void mergeShapesAndTypes(const TypeProto_Tensor& inferred_type, TypeProto_Tensor* existing_type) { if (existing_type->elem_type() == TensorProto::UNDEFINED) { existing_type->set_elem_type(inferred_type.elem_type()); } if (!inferred_type.has_shape()) { return; } if (!existing_type->has_shape()) { *existing_type->mutable_shape() = inferred_type.shape(); return; } for (int i = 0; i < inferred_type.shape().dim_size(); ++i) { const auto& inferred_dim = inferred_type.shape().dim(i); auto* existing_dim = existing_type->mutable_shape()->mutable_dim(i); if ((!existing_dim->has_dim_value() && !existing_dim->has_dim_param()) || inferred_dim.has_dim_value()) { *existing_dim = inferred_dim; } } } void mergeShapesAndTypes(const TypeProto_SparseTensor& inferred_type, TypeProto_SparseTensor* existing_type) { if (existing_type->elem_type() == TensorProto::UNDEFINED) { existing_type->set_elem_type(inferred_type.elem_type()); } if (!inferred_type.has_shape()) { return; } if (!existing_type->has_shape()) { *existing_type->mutable_shape() = inferred_type.shape(); return; } for (int i = 0; i < inferred_type.shape().dim_size(); ++i) { const auto& inferred_dim = inferred_type.shape().dim(i); auto* existing_dim = existing_type->mutable_shape()->mutable_dim(i); if ((!existing_dim->has_dim_value() && !existing_dim->has_dim_param()) || inferred_dim.has_dim_value()) { *existing_dim = inferred_dim; } } } void mergeShapesAndTypes(const TypeProto& inferred_type, TypeProto* existing_type) { // Check before merge checkShapesAndTypes(inferred_type, *existing_type); const auto inferred_val_case = inferred_type.value_case(); if (inferred_val_case == TypeProto::kTensorType) { mergeShapesAndTypes(inferred_type.tensor_type(), existing_type->mutable_tensor_type()); } else if (inferred_val_case == TypeProto::kSparseTensorType) { mergeShapesAndTypes(inferred_type.sparse_tensor_type(), existing_type->mutable_sparse_tensor_type()); } else if (inferred_val_case == TypeProto::kSequenceType) { mergeShapesAndTypes( inferred_type.sequence_type().elem_type(), existing_type->mutable_sequence_type()->mutable_elem_type()); } else if (inferred_val_case == TypeProto::kOptionalType) { mergeShapesAndTypes( inferred_type.optional_type().elem_type(), existing_type->mutable_optional_type()->mutable_elem_type()); } else if (inferred_val_case == TypeProto::kMapType) { if (existing_type->map_type().key_type() == TensorProto::UNDEFINED) { existing_type->mutable_map_type()->set_key_type(inferred_type.map_type().key_type()); } mergeShapesAndTypes(inferred_type.map_type().value_type(), existing_type->mutable_map_type()->mutable_value_type()); } } // TypeProto_Tensor or TypeProto_SparseTensor template void GenerateSymbolicShape(TensorTypeProto* inferred_type, SymbolTable& symbol_table) { if (!inferred_type->has_shape()) { return; } for (int i = 0; i < inferred_type->shape().dim_size(); ++i) { // set a symbol if it doesn't have dim_value and dim_param auto* dim = inferred_type->mutable_shape()->mutable_dim(i); if (!dim->has_dim_value() && !dim->has_dim_param()) { dim->set_dim_param(symbol_table.createNew()); } } } void MaterializeSymbolicShape(TypeProto* inferred_type, SymbolTable& symbol_table) { const auto inferred_val_case = inferred_type->value_case(); if (inferred_val_case == TypeProto::ValueCase::VALUE_NOT_SET) { return; } if (inferred_val_case == TypeProto::kTensorType) { GenerateSymbolicShape(inferred_type->mutable_tensor_type(), symbol_table); } else if (inferred_val_case == TypeProto::kSparseTensorType) { GenerateSymbolicShape(inferred_type->mutable_sparse_tensor_type(), symbol_table); } else if (inferred_val_case == TypeProto::kSequenceType) { MaterializeSymbolicShape(inferred_type->mutable_sequence_type()->mutable_elem_type(), symbol_table); } else if (inferred_val_case == TypeProto::kOptionalType) { MaterializeSymbolicShape(inferred_type->mutable_optional_type()->mutable_elem_type(), symbol_table); } else if (inferred_val_case == TypeProto::kMapType) { MaterializeSymbolicShape(inferred_type->mutable_map_type()->mutable_value_type(), symbol_table); } else { fail_shape_inference("type case unsupported for symbolic shape inference. inferred=", inferred_val_case); } } std::string GetFunctionIdentifier(const FunctionProto& function) { // Note: Models with IR version < 10 do not have the overload attribute. // However, that will be mapped to an empty identifier. std::string overload = function.overload(); if (overload.empty()) { return function.domain() + ":" + function.name(); } return function.domain() + ":" + function.name() + ":" + overload; } std::string GetFunctionIdentifier(const NodeProto& node) { // Note: Models with IR version < 10 do not have the overload attribute. // However, that will be mapped to an empty identifier. std::string overload = node.overload(); if (overload.empty()) { return node.domain() + ":" + node.op_type(); } return node.domain() + ":" + node.op_type() + ":" + overload; } // InferredTypes: abstracts the differences between FunctionProto and GraphProto // for inference. For GraphProto, inferred types are stored in the GraphProto // but FunctionProto does not have a place to store inferred types. So, we // use a temporary vector (for the duration of inference) to store these. class InferredTypes { public: explicit InferredTypes(GraphProto* graph = nullptr) : graph_ptr(graph) {} TypeProto* Add(const std::string& var_name, const TypeProto& type) { if (graph_ptr != nullptr) { auto* p = graph_ptr->add_value_info(); p->set_name(var_name); *p->mutable_type() = type; return p->mutable_type(); } else { auto* p = new TypeProto(type); types.push_back(p); return p; } } ~InferredTypes() { for (auto* p : types) { delete p; } } private: std::vector types; GraphProto* graph_ptr; ONNX_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferredTypes); }; // Initialize a DataValueMap for a called function from the DataValueMap of the caller void BindValuesOnCall( const DataValueMap& caller_map, const NodeProto& caller, DataValueMap& callee_map, const FunctionProto& callee) { auto num_inputs = (std::min)(caller.input_size(), callee.input_size()); for (int i = 0; i < num_inputs; ++i) { const std::string& actual = caller.input(i); const std::string& formal = callee.input(i); if (!actual.empty()) { auto it = caller_map.find(actual); if (it != caller_map.end()) { callee_map[formal] = it->second; } } } } // Update a DataValueMap for a calling function from the DataValueMap of the callee void BindValuesOnReturn( const DataValueMap& callee_map, const FunctionProto& callee, DataValueMap& caller_map, const NodeProto& caller) { auto num_outputs = (std::min)(caller.output_size(), callee.output_size()); for (int i = 0; i < num_outputs; ++i) { const std::string& actual = caller.output(i); const std::string& formal = callee.output(i); if (!actual.empty()) { auto it = callee_map.find(formal); if (it != callee_map.end()) { caller_map[actual] = it->second; } } } } class ShapeInferenceImplBase { public: void UpdateType(const std::string& name, TypeProto* inferred_type) { if (inferred_type->value_case() == TypeProto::ValueCase::VALUE_NOT_SET) { return; } if (symbol_table) { MaterializeSymbolicShape(inferred_type, *symbol_table); } // Find any pre-existing type and shape info. If there is such, // then check for compatibility with the inferred // information. Otherwise, initialize it in an empty state. auto iter = value_types_by_name.find(name); if (iter != value_types_by_name.end()) { mergeShapesAndTypes(*inferred_type, iter->second); } else { value_types_by_name[name] = inferred_types.Add(name, *inferred_type); // For undefined output type, update both value_info and output for now // Update existing output with undefined type: assign inferred type to it iter = undefined_value_types_by_name.find(name); if (iter != undefined_value_types_by_name.end()) { *iter->second = *inferred_type; } } } void UpdateType(ValueInfoProto& valueInfo) { if (valueInfo.has_type()) { value_types_by_name[valueInfo.name()] = valueInfo.mutable_type(); } else { undefined_value_types_by_name[valueInfo.name()] = valueInfo.mutable_type(); } } template void AddTemporaryConstant(const std::string& name, const T& vector) { input_data_by_name_holder[name] = ToTensor(vector); input_data_by_name[name] = &input_data_by_name_holder[name]; } void ProcessConstant(const NodeProto& n) { if (IsOnnxDomainOp(n, "Constant") && n.output().size() == 1) { const std::string& output_name = n.output(0); for (const auto& attr : n.attribute()) { if (attr.name() == "value") { if (attr.type() == AttributeProto::TENSOR && attr.has_t()) { if (reuse_constant_tensors) { input_data_by_name[output_name] = &attr.t(); } else { input_data_by_name_holder[output_name] = attr.t(); input_data_by_name[output_name] = &input_data_by_name_holder[output_name]; } } else if (attr.type() == AttributeProto::SPARSE_TENSOR && attr.has_sparse_tensor()) { if (reuse_constant_tensors) { input_sparse_data_by_name[output_name] = &attr.sparse_tensor(); } } } else { switch (attr.type()) { case AttributeProto::INTS: { std::vector ints{attr.ints().begin(), attr.ints().end()}; AddTemporaryConstant(output_name, ints); break; } case AttributeProto::INT: { std::vector ints({attr.i()}); AddTemporaryConstant(output_name, ints); break; } case AttributeProto::FLOATS: { std::vector floats{attr.floats().begin(), attr.floats().end()}; AddTemporaryConstant(output_name, floats); break; } case AttributeProto::FLOAT: { std::vector floats({attr.f()}); AddTemporaryConstant(output_name, floats); break; } default: break; } } } } } void ProcessCall(const NodeProto& caller, const FunctionProto& callee, InferenceContext& ctx) { DataValueMap callee_value_map; if (generated_shape_data_by_name != nullptr) { BindValuesOnCall(*generated_shape_data_by_name, caller, callee_value_map, callee); } InferShapeForFunctionNode( callee, schema_registry, ctx, options, model_local_functions_map, symbol_table, &callee_value_map); if (generated_shape_data_by_name != nullptr) { BindValuesOnReturn(callee_value_map, callee, *generated_shape_data_by_name, caller); } } void Process(NodeProto& n) { // Resolve domain for node auto dit = opset_imports.find(n.domain()); if (dit == opset_imports.end()) { // Both "" (ONNX_DOMAIN) and "ai.onnx" (AI_ONNX_DOMAIN) refer to the default ONNX domain if (n.domain() == ONNX_DOMAIN) { dit = opset_imports.find(AI_ONNX_DOMAIN); } if (dit == opset_imports.end()) { fail_type_inference( "Cannot infer type and shape for node name ", n.name(), ". No opset import for domain ", n.domain(), " optype ", n.op_type()); } } auto domain_version = dit->second; const auto schema = schema_registry->GetSchema(n.op_type(), domain_version, n.domain()); InferenceContextImpl ctx( n, value_types_by_name, input_data_by_name, input_sparse_data_by_name, options, generated_shape_data_by_name, &graph_inference_context); ONNX_TRY { if (schema) { if (schema->has_type_and_shape_inference_function()) { schema->GetTypeAndShapeInferenceFunction()(ctx); } else if (schema->HasFunction()) { ProcessCall(n, *(schema->GetFunction()), ctx); } // else: rely on schema->CheckInputOutputType() down below. // check type-constraints specified via type variables if (options.check_type) { schema->CheckInputOutputType(ctx); } } else if (model_local_functions_map.size() > 0) { auto iter = model_local_functions_map.find(GetFunctionIdentifier(n)); if (iter != model_local_functions_map.end()) { ProcessCall(n, *(iter->second), ctx); } else { has_unsupported_op = true; return; } } else { has_unsupported_op = true; return; } for (int i = 0; i < n.output_size(); ++i) { // skip type and shape propagation for missing optional outputs. if (!n.output(i).empty()) UpdateType(n.output(i), ctx.getOutputType(i)); } // Constant values are tracked to improve inference/checking for subsequent nodes. ProcessConstant(n); // If data-propagation is enabled, partial-evaluation (aka data-propagation) is performed // to improve inference/checking for subsequent nodes. if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) { if (generated_shape_data_by_name == nullptr) { fail_shape_inference( "Container for generated shape data cannot be nullptr when enable_data_propagation option is set."); } DataPropagationContextImpl data_propagation_ctx( n, value_types_by_name, input_data_by_name, *generated_shape_data_by_name); schema->GetDataPropagationFunction()(data_propagation_ctx); } } ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) { ONNX_HANDLE_EXCEPTION([&]() { // Note: The following special handling is to accommodate custom-ops. Ideally, custom-ops // should be registered with a schema in the schema registry, allowing inference to handle // them. As things stand, this special handling is somewhat fragile and is not fully // general either. Eg., a custom-op suppresses error-messages for subsequent nodes, but // this does not work across graphs. If special handling is required, a user-option may // be a better way to do it. The fragility comes from the fact that the types of the // returned-values of the custom-op are unknown, and subsequent node-level inference // may fail because of this. if (!has_unsupported_op) { inference_errors.push_back(GetErrorWithNodeInfo(n, ex)); } }); } ONNX_CATCH(const std::runtime_error& err) { // TODO: Fix this. Unclear if this should be remapped to a shape inference error. // Need to rationalize the different types of exceptions that can be thrown. // See: https://github.com/onnx/onnx/pull/5519 ONNX_HANDLE_EXCEPTION([&]() { fail_shape_inference(GetErrorWithNodeInfo(n, err)); }); } } // TypeProto_Tensor or TypeProto_SparseTensor template void ProcessInitializer( const std::string& name, const T& tensorValue, TypeProto& initializer_type, std::unordered_map& map) { map[name] = &tensorValue; auto iter = value_types_by_name.find(name); // If it already exists in input, check input and initializer is sync // use shape info from input (input has priority over initializer) if (iter != value_types_by_name.end()) { checkShapesAndTypes(initializer_type, *iter->second); // CheckTensorShapesAndTypes(*initializer_tensor_type, *iter->second->mutable_tensor_type()); } // Support IR>=4: some tensors can only exist in initializer and not in input // So shape_inference should make use of initializer shapes // Store initializer shape info in value_info as well else if (ir_version >= 4) { initializer_type_list.push_back(std::move(initializer_type)); value_types_by_name[name] = &initializer_type_list.back(); } } void Process(GraphProto& graph) { if (symbol_table) { TraverseGraphsToAddExistingSymbols(graph, *symbol_table); } for (auto& vi : *graph.mutable_value_info()) { UpdateType(vi); } for (auto& vi : *graph.mutable_input()) { UpdateType(vi); } for (auto& vi : *graph.mutable_output()) { UpdateType(vi); } for (const auto& tp : graph.initializer()) { TypeProto initializer_type; TypeProto_Tensor* initializer_tensor_type = initializer_type.mutable_tensor_type(); initializer_tensor_type->set_elem_type(tp.data_type()); // set the shape according to the initializer shape info auto* shape = initializer_tensor_type->mutable_shape(); for (int i = 0; i < tp.dims_size(); ++i) { shape->add_dim()->set_dim_value(tp.dims(i)); } ProcessInitializer(tp.name(), tp, initializer_type, input_data_by_name); } for (const auto& tp : graph.sparse_initializer()) { TypeProto initializer_type; auto* initializer_sparse_tensor_type = initializer_type.mutable_sparse_tensor_type(); initializer_sparse_tensor_type->set_elem_type(tp.values().data_type()); // set the shape according to the initializer shape info auto* shape = initializer_sparse_tensor_type->mutable_shape(); for (int i = 0; i < tp.dims_size(); ++i) { shape->add_dim()->set_dim_value(tp.dims(i)); } ProcessInitializer(tp.values().name(), tp, initializer_type, input_sparse_data_by_name); } for (auto& n : *graph.mutable_node()) { Process(n); } } void Process(const NodeProto& n, internal::AttributeBinder& attribute_binder) { NodeProto copy_n(n); attribute_binder.VisitNode(©_n); Process(copy_n); } void Process(const FunctionProto& func_proto, InferenceContext& ctx) { // Ensure Constant node tensor-attributes are copied bool old_reuse_constant_tensors = reuse_constant_tensors; reuse_constant_tensors = false; // Get a temporary tensor-shape map const int num_actual_inputs = static_cast(ctx.getNumInputs()); const auto num_func_inputs = func_proto.input_size(); std::vector types_cache(num_func_inputs); for (int i = 0; i < num_func_inputs; ++i) { auto& parameter_name = func_proto.input().Get(i); auto* type_ptr = (i < num_actual_inputs) ? ctx.getInputType(i) : nullptr; // nullptr is valid, and indicates a missing optional input if (type_ptr != nullptr) { // Use a temporary copy of original type. // TODO: investigate whether we can eliminate use of temporary copy types_cache[i] = *type_ptr; value_types_by_name[parameter_name] = &types_cache[i]; } else value_types_by_name[parameter_name] = nullptr; } // Create a temporary initializer value map for (int i = 0; i < num_actual_inputs && i < num_func_inputs; ++i) { const TypeProto* type = ctx.getInputType(i); if (type != nullptr) { if (type->value_case() == TypeProto::kTensorType && ctx.getInputData(i) != nullptr) { input_data_by_name[func_proto.input().Get(i)] = ctx.getInputData(i); } else if (type->value_case() == TypeProto::kSparseTensorType && ctx.getInputSparseData(i) != nullptr) { input_sparse_data_by_name[func_proto.input().Get(i)] = ctx.getInputSparseData(i); } } } std::unordered_map attr_map; for (auto& attr : func_proto.attribute()) { if (ctx.getAttribute(attr) != nullptr) { attr_map[attr] = ctx.getAttribute(attr); } } for (auto& default_value : func_proto.attribute_proto()) { const std::string& name = default_value.name(); const AttributeProto* value = ctx.getAttribute(name); attr_map[name] = (value != nullptr) ? value : &default_value; } internal::AttributeBinder attribute_binder(attr_map); for (auto& n : func_proto.node()) { Process(n, attribute_binder); } for (int i = 0; i < func_proto.output_size(); ++i) { const std::string& output_name = func_proto.output().Get(i); // Skip if no type inferred for the tensor auto iter = value_types_by_name.find(output_name); if (iter != value_types_by_name.cend()) { // Copy the type info to ctx // to pass back to main graph auto type_proto = ctx.getOutputType(i); type_proto->CopyFrom(*(iter->second)); } } reuse_constant_tensors = old_reuse_constant_tensors; } public: ShapeInferenceImplBase( GraphProto* graph, // nullptr for FunctionProto inference const std::unordered_map& outer_scope_value_types_by_name_in, const std::unordered_map& opset_imports_in, const ShapeInferenceOptions& options_in, SymbolTable* symbol_table_in, const ModelLocalFunctionsMap& model_local_functions_map_in, const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(), DataValueMap* generated_shape_data_by_name_in = nullptr, const int ir_version_in = IR_VERSION // default the latest one ) : inferred_types(graph), value_types_by_name(outer_scope_value_types_by_name_in), opset_imports(opset_imports_in), options(options_in), symbol_table(symbol_table_in), model_local_functions_map(model_local_functions_map_in), schema_registry(schema_registry_in), generated_shape_data_by_name(generated_shape_data_by_name_in), ir_version(ir_version_in), graph_inference_context{ value_types_by_name, opset_imports, symbol_table, model_local_functions_map, schema_registry, generated_shape_data_by_name, ir_version} { if (options.enable_data_propagation && generated_shape_data_by_name == nullptr) { fail_shape_inference( "Container for generated shape data cannot be nullptr when enable_data_propagation option is set."); } } void FinalizeShapeInference() { auto& errors = getErrors(); // Throw shape inference error if any. Error mode right now only supports 0 and 1. // When set to 0, any node level shape inference errors are not thrown. This is to support backward compatiblity // with 1.7 and earlier releases. When set to 1 it will throw all exceptions. // TODO: Add a more granular way for exception handling. if (!errors.empty() && options.error_mode > 0) { std::string full_errors = "Inference error(s): "; for (const std::string& error : inference_errors) { full_errors += error + "\n"; } fail_shape_inference(full_errors); } } const std::vector& getErrors() const { return inference_errors; } private: InferredTypes inferred_types; std::unordered_map value_types_by_name; const std::unordered_map& opset_imports; const ShapeInferenceOptions& options; SymbolTable* symbol_table; const ModelLocalFunctionsMap& model_local_functions_map; const ISchemaRegistry* schema_registry; DataValueMap* generated_shape_data_by_name; int ir_version; GraphInferenceContext graph_inference_context; std::unordered_map undefined_value_types_by_name; std::unordered_map input_data_by_name; std::unordered_map input_data_by_name_holder; std::unordered_map input_sparse_data_by_name; bool has_unsupported_op = false; std::vector inference_errors; std::list initializer_type_list; // reuse_constant_tensors: controls whether we need to copy tensors occurring as attributes // in Constant nodes. We avoid it for inference for graphs, but must make a copy for functions. bool reuse_constant_tensors = true; }; static void InferShapesImpl( GraphProto* g, const std::unordered_map& outer_scope_value_types_by_name, const std::unordered_map& opset_imports, const ShapeInferenceOptions& options, SymbolTable* symbol_table, const ModelLocalFunctionsMap& model_local_functions_map, const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(), DataValueMap* generated_shape_data_by_name = nullptr, const int ir_version = IR_VERSION // default the latest one ) { DataValueMap empty; if (generated_shape_data_by_name == nullptr) { generated_shape_data_by_name = ∅ } ShapeInferenceImplBase base( g, outer_scope_value_types_by_name, opset_imports, options, symbol_table, model_local_functions_map, schema_registry, generated_shape_data_by_name, ir_version); base.Process(*g); base.FinalizeShapeInference(); } // Either ModelProto or FunctionProto template std::unordered_map GetOpsetImportsFromProto(const T& proto) { std::unordered_map opset_imports; for (const auto& opset_import : proto.opset_import()) { opset_imports[opset_import.domain()] = static_cast(opset_import.version()); } return opset_imports; } void InferShapes( GraphProto* g, const std::unordered_map& opset_imports, const ISchemaRegistry* schema_registry, const ShapeInferenceOptions& options, const std::unordered_map& model_local_functions) { SymbolTableImpl symbol_table; InferShapesImpl( g, std::unordered_map(0), opset_imports, options, &symbol_table, model_local_functions, schema_registry); } void InferShapes( ModelProto& m, const ISchemaRegistry* schema_registry, const ShapeInferenceOptions& options, DataValueMap* generated_shape_data_by_name) { auto opset_imports = GetOpsetImportsFromProto(m); SymbolTableImpl symbol_table; ModelLocalFunctionsMap model_local_functions_by_id; for (const auto& function_proto : m.functions()) { model_local_functions_by_id.insert({GetFunctionIdentifier(function_proto), &function_proto}); } InferShapesImpl( m.mutable_graph(), std::unordered_map(0), opset_imports, options, &symbol_table, model_local_functions_by_id, schema_registry, generated_shape_data_by_name, m.ir_version()); } void InferShapes( const std::string& model_path, const std::string& save_path, const ISchemaRegistry* schema_registry, const ShapeInferenceOptions& options, DataValueMap* generated_shape_data_by_name) { ModelProto model; LoadProtoFromPath(model_path, model); InferShapes(model, schema_registry, options, generated_shape_data_by_name); // Save the inferred model to the original model path // Use SerializeToString instead of SerializeToOstream due to LITE_PROTO std::fstream output(save_path, std::ios::out | std::ios::trunc | std::ios::binary); std::string model_string; ONNX_TRY { model.SerializeToString(&model_string); output << model_string; } ONNX_CATCH(...) { fail_check("Unable to save inferred model to the target path:", save_path); } } // Infer shape for functions void InferShapeForFunctionNode( const FunctionProto& func_proto, const std::unordered_map& func_opset_imports, const ISchemaRegistry* schema_registry, InferenceContext& ctx, const ShapeInferenceOptions& options, const std::unordered_map& model_local_functions_map, SymbolTable* symbol_table, DataValueMap* generated_shape_data_by_name) { ShapeInferenceImplBase base( nullptr, // no graph {}, // outer_scope_value_types_by_name func_opset_imports, options, symbol_table, model_local_functions_map, schema_registry, generated_shape_data_by_name); base.Process(func_proto, ctx); base.FinalizeShapeInference(); } void InferShapeForFunctionNode( const FunctionProto& function_proto, const ISchemaRegistry* schema_registry, InferenceContext& ctx, const ShapeInferenceOptions& options, const std::unordered_map& model_local_functions_map, SymbolTable* symbol_table, DataValueMap* generated_shape_data_by_name) { auto opset_imports = GetOpsetImportsFromProto(function_proto); InferShapeForFunctionNode( function_proto, opset_imports, schema_registry, ctx, options, model_local_functions_map, symbol_table, generated_shape_data_by_name); } struct FunctionInferenceContext : public InferenceContext { FunctionInferenceContext( const FunctionProto& func_proto, const std::vector& input_types, const std::vector& attributes, const ShapeInferenceOptions& options) : input_types_(input_types), options_(options) { for (const auto& attr : attributes) { attributesByName_[attr.name()] = &attr; } auto num_outputs = func_proto.output_size(); for (int i = 0; i < num_outputs; i++) { output_types_.push_back(TypeProto()); } } const AttributeProto* getAttribute(const std::string& name) const override { auto iter = attributesByName_.find(name); if (iter == attributesByName_.end()) { return nullptr; } else { return iter->second; } } size_t getNumInputs() const override { return input_types_.size(); } size_t getNumOutputs() const override { return output_types_.size(); } const TypeProto* getInputType(size_t index) const override { // We should return nullptr for missing optional parameters. // An uninitialized TypeProto() is used for missing optional parameters, and // is mapped to a nullptr here. if (index >= input_types_.size()) return nullptr; if (input_types_[index].value_case() == TypeProto::ValueCase::VALUE_NOT_SET) return nullptr; return &input_types_[index]; } TypeProto* getOutputType(size_t index) override { return (index < output_types_.size()) ? &output_types_[index] : nullptr; } GraphInferencer* getGraphAttributeInferencer(const std::string& attribute_name) override { ONNX_UNUSED_PARAMETER(attribute_name); // This method is unused for function-type-inference. return nullptr; } const TensorProto* getInputData(size_t index) const override { ONNX_UNUSED_PARAMETER(index); // This inference doesn't take advantage of statically known input values. return nullptr; } const SparseTensorProto* getInputSparseData(size_t index) const override { ONNX_UNUSED_PARAMETER(index); // This inference doesn't take advantage of statically known input values. return nullptr; } const TensorShapeProto* getSymbolicInput(size_t index) const override { ONNX_UNUSED_PARAMETER(index); // This inference doesn't take advantage of data-propagation. return nullptr; } std::vector popOutputTypes() { return std::move(output_types_); } private: const std::vector& input_types_; std::vector output_types_; std::unordered_map attributesByName_; ShapeInferenceOptions options_; }; std::vector InferFunctionOutputTypes( const FunctionProto& function_proto, const std::vector& input_types, const std::vector& attributes) { // TODO: if it is desirable for infer_function_output_types to provide check_type, strict_mode, data_prop, // we can add them to the Python API. For now we just assume the default options. ShapeInferenceOptions options{true, 1, false}; FunctionInferenceContext ctx(function_proto, input_types, attributes, options); auto opset_imports = GetOpsetImportsFromProto(function_proto); ShapeInferenceImplBase base( nullptr, // no graph {}, // outer_scope_value_types_by_name opset_imports, options, /*symbol_table*/ nullptr, /*model_local_functions_map*/ {}, /*schema_registry*/ OpSchemaRegistry::Instance(), /*generated_shape_data_by_name*/ nullptr); base.Process(function_proto, ctx); base.FinalizeShapeInference(); return ctx.popOutputTypes(); } std::vector GraphInferencerImpl::doInferencing( const std::vector& input_types, const std::vector& input_data) { SymbolTable* symbol_table = context_->symbol_table; int num_inputs = int(input_types.size()); std::unordered_set initializer_name_set; for (const auto& tp : g_->initializer()) { initializer_name_set.insert(tp.name()); } if (context_->ir_version >= 4) { if (g_->input_size() != num_inputs) { fail_shape_inference("Graph has ", g_->input_size(), " inputs but ", num_inputs, " were provided"); } for (int i = 0; i < g_->input_size(); ++i) { if (initializer_name_set.count(g_->input(i).name()) > 0) { fail_shape_inference( "Cannot use the same name as both a subgraph initializer and subgraph input: ", g_->input(i).name()); } } } else { // IR < 4 requires all initializers to be optional inputs // So the number of graph input can be larger than the number of node input if (num_inputs > g_->input_size()) { fail_shape_inference( "Graph has ", g_->input_size(), " inputs but ", num_inputs, " were provided.", "The number of graph input cannot be smaller than the number of node input"); } else if (num_inputs < g_->input_size()) { for (int i = 0; i < g_->input_size(); ++i) { if (i < num_inputs && initializer_name_set.count(g_->input(i).name()) > 0) { fail_shape_inference("Graph initializer names must appear after the actual inputs: ", g_->input(i).name()); } else if (i >= num_inputs && initializer_name_set.count(g_->input(i).name()) == 0) { // Further check whether the additional input is in initializers fail_shape_inference("Cannot find missing input: ", g_->input(i).name(), "in initializers. "); } } } } for (int i = 0, end = num_inputs; i < end; ++i) { const TypeProto* inferred_input = input_types[i]; if (!inferred_input) continue; TypeProto* graph_input = g_->mutable_input(i)->mutable_type(); // Even if graphInput doesn't have defined type, it will assign inferredType to it mergeShapesAndTypes(*inferred_input, graph_input); if (symbol_table) { MaterializeSymbolicShape(graph_input, *symbol_table); } } // future: pass inputData into InferShapes either directly, or indirectly by // updating initializers that match subgraph inputs. (void)input_data; InferShapesImpl( g_, *context_->outer_scope_value_types_by_name, // never null context_->opset_imports, options_, symbol_table, context_->model_local_functions, context_->schema_registry, context_->generated_shape_data_by_name); std::vector graph_output_types; graph_output_types.reserve(g_->output().size()); for (const ValueInfoProto& output : g_->output()) { graph_output_types.push_back(&output.type()); } return graph_output_types; } std::string GetErrorWithNodeInfo(const NodeProto& n, const std::runtime_error& err) { std::string op_name = n.has_name() ? (", node name: " + n.name()) : ""; return "(op_type:" + n.op_type() + op_name + "): " + err.what(); } void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbol_table) { symbol_table.addFromGraph(g); for (const auto& n : g.node()) { for (auto& attr : n.attribute()) { if (attr.has_g()) { TraverseGraphsToAddExistingSymbols(attr.g(), symbol_table); } } } } } // namespace shape_inference } // namespace ONNX_NAMESPACE