// Copyright (c) ONNX Project Contributors // // SPDX-License-Identifier: Apache-2.0 #pragma once #include #include #include #include #include #include #include #include "onnx/defs/function.h" #include "onnx/defs/schema.h" #include "onnx/proto_utils.h" #include "onnx/string_utils.h" namespace ONNX_NAMESPACE { namespace shape_inference { using ModelLocalFunctionsMap = std::unordered_map; // We reuse TensorShapeProto to propagate statically known (partial) information about // the values of tensors. It is intended for tensors used to store shape information // (the return values of ops like Shape and input values of ops like Reshape/Expand). // A DataValueMap is used to store the statically known (partial) values of variables. using DataValueMap = std::unordered_map; class SymbolTableImpl : public SymbolTable { public: SymbolTableImpl() : index_(0) {} void addFromGraph(const GraphProto& g) override { AddExistingSymbolicDims(g.input()); AddExistingSymbolicDims(g.output()); AddExistingSymbolicDims(g.value_info()); } // Creates a new unique symbol with the given prefix and adds it to the SymbolTable // Returns the newly created symbol std::string createNew(const std::string& symbol_prefix) override { std::string newSymbol; do { newSymbol = symbol_prefix + std::to_string(index_++); } while (existing_symbols.count(newSymbol) > 0); existing_symbols.insert(newSymbol); return newSymbol; } private: unsigned int index_; std::unordered_set existing_symbols; // TypeProto_Tensor or TypeProto_SparseTensor template void AddExistingSymbolicDims(const TensorTypeProto& tensorType) { if (tensorType.has_shape()) { for (int i = 0; i < tensorType.shape().dim_size(); ++i) { if (tensorType.shape().dim(i).has_dim_param()) { existing_symbols.insert(tensorType.shape().dim(i).dim_param()); } } } } void AddExistingSymbolicDims(const TypeProto& typeProto) { const auto val_case = typeProto.value_case(); switch (val_case) { case TypeProto::kTensorType: AddExistingSymbolicDims(typeProto.tensor_type()); break; case TypeProto::kSparseTensorType: AddExistingSymbolicDims(typeProto.sparse_tensor_type()); break; case TypeProto::kSequenceType: AddExistingSymbolicDims(typeProto.sequence_type().elem_type()); break; case TypeProto::kOptionalType: AddExistingSymbolicDims(typeProto.optional_type().elem_type()); break; case TypeProto::kMapType: AddExistingSymbolicDims(typeProto.map_type().value_type()); break; default: break; } } void AddExistingSymbolicDims(const google::protobuf::RepeatedPtrField& protos) { for (const auto& proto : protos) { AddExistingSymbolicDims(proto.type()); } } }; struct GraphInferenceContext { GraphInferenceContext( const std::unordered_map& outer_scope_value_types_by_name_in, const std::unordered_map opset_imports_in, SymbolTable* symbol_table_in = nullptr, const ModelLocalFunctionsMap& model_local_functions_in = {}, const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(), DataValueMap* generated_shape_data_by_name_in = nullptr, const int ir_version_in = IR_VERSION) : outer_scope_value_types_by_name{&outer_scope_value_types_by_name_in}, opset_imports{opset_imports_in}, symbol_table{symbol_table_in}, model_local_functions{model_local_functions_in}, schema_registry{schema_registry_in}, generated_shape_data_by_name{generated_shape_data_by_name_in}, ir_version{ir_version_in} {} const std::unordered_map* outer_scope_value_types_by_name; const std::unordered_map opset_imports; SymbolTable* symbol_table; const ModelLocalFunctionsMap& model_local_functions; const ISchemaRegistry* schema_registry; DataValueMap* generated_shape_data_by_name; const int ir_version; }; class GraphInferencerImpl : public GraphInferencer { public: GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context) : g_{&g}, context_{&context}, options_() {} GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context, const ShapeInferenceOptions& options) : g_{&g}, context_{&context}, options_(options) {} std::vector doInferencing( const std::vector& inputTypes, const std::vector& inputData) override; private: GraphProto* g_; GraphInferenceContext* context_; ShapeInferenceOptions options_; }; struct InferenceContextImpl : public InferenceContext { InferenceContextImpl( NodeProto& n, const std::unordered_map& valueTypesByName, const std::unordered_map& inputDataByName, const std::unordered_map& inputSparseDataByName, const ShapeInferenceOptions& options, DataValueMap* generatedShapeData = nullptr, GraphInferenceContext* graphInferenceContext = nullptr) : graphInferenceContext_{graphInferenceContext}, options_(options) { for (auto& attr : *n.mutable_attribute()) { attributesByName_[attr.name()] = &attr; if (attr.has_g()) { // need a mutable GraphProto to run inferencing on this attribute graphProtoAttributesByName_[attr.name()] = attr.mutable_g(); } } for (const auto& input : n.input()) { auto valueTypesIter = valueTypesByName.find(input); if (valueTypesIter != valueTypesByName.end()) { allInputTypes_.push_back(valueTypesIter->second); } else { allInputTypes_.push_back(nullptr); } // input data can be in 1 of the 3 containers // inputDataByName - this is when input is TensorProto // inputSparseDataByName - this is when input is SparseTensorProto // generatedShapeData - this is when input was generated as part of partial data propagation const auto inputDataIter = inputDataByName.find(input); if (inputDataIter != inputDataByName.cend()) { allInputData_.push_back(inputDataIter->second); allInputSparseData_.push_back(nullptr); allShapeInputData_.push_back(nullptr); } else { allInputData_.push_back(nullptr); const auto inputSparseDataIter = inputSparseDataByName.find(input); if (inputSparseDataIter != inputSparseDataByName.cend()) { allInputSparseData_.push_back(inputSparseDataIter->second); allShapeInputData_.push_back(nullptr); } else { allInputSparseData_.push_back(nullptr); if (generatedShapeData != nullptr) { const auto inputShapeDataIter = generatedShapeData->find(input); if (inputShapeDataIter != generatedShapeData->cend()) { allShapeInputData_.push_back(&inputShapeDataIter->second); } else { allShapeInputData_.push_back(nullptr); } } else { allShapeInputData_.push_back(nullptr); } } } } allOutputTypes_.resize(n.output_size()); } const AttributeProto* getAttribute(const std::string& name) const override { auto iter = attributesByName_.find(name); if (iter == attributesByName_.end()) { return nullptr; } else { return iter->second; } } size_t getNumInputs() const override { return allInputTypes_.size(); } const TypeProto* getInputType(size_t index) const override { if (index >= allInputTypes_.size()) { ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds."); } return allInputTypes_[index]; } const TensorProto* getInputData(size_t index) const override { if (index >= allInputData_.size()) { ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds."); } return allInputData_[index]; } const TensorShapeProto* getSymbolicInput(size_t index) const override { if (index >= allShapeInputData_.size()) { ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds."); } return allShapeInputData_[index]; } const SparseTensorProto* getInputSparseData(size_t index) const override { if (index >= allInputSparseData_.size()) { ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds."); } return allInputSparseData_[index]; } size_t getNumOutputs() const override { return allOutputTypes_.size(); } TypeProto* getOutputType(size_t index) override { if (index >= allOutputTypes_.size()) { ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds."); } return &allOutputTypes_[index]; } GraphInferencer* getGraphAttributeInferencer(const std::string& attr_name) override { if (!graphInferenceContext_) { fail_type_inference("GraphProto attribute inferencing is not enabled in this InferenceContextImpl instance."); } GraphInferencer* inferencer = nullptr; auto entry = graphAttributeInferencers_.find(attr_name); if (entry == graphAttributeInferencers_.cend()) { // create GraphInferencer instance auto attrNameToGraphProto = graphProtoAttributesByName_.find(attr_name); if (attrNameToGraphProto == graphProtoAttributesByName_.cend()) { fail_type_inference("Attribute ", attr_name, " does not contain a graph."); } std::unique_ptr new_inferencer{ new GraphInferencerImpl(*attrNameToGraphProto->second, *graphInferenceContext_, options_)}; inferencer = new_inferencer.get(); graphAttributeInferencers_.emplace(attr_name, std::move(new_inferencer)); } else { inferencer = entry->second.get(); } return inferencer; } std::vector allInputData_; std::vector allInputSparseData_; std::vector allShapeInputData_; std::unordered_map attributesByName_; std::unordered_map graphProtoAttributesByName_; std::vector allInputTypes_; std::vector allOutputTypes_; GraphInferenceContext* graphInferenceContext_; // mutable as internal cache of GraphInferencer instances mutable std::unordered_map> graphAttributeInferencers_; ShapeInferenceOptions options_; }; struct DataPropagationContextImpl : public DataPropagationContext { DataPropagationContextImpl( NodeProto& n, const std::unordered_map& valueTypesByName, const std::unordered_map& inputDataByName, DataValueMap& generatedShapeData) : generatedShapeData_(generatedShapeData) { size_t input_idx = 0; for (auto& attr : *n.mutable_attribute()) { attributesByName_[attr.name()] = &attr; } for (const auto& input : n.input()) { inputIndexToNameMap_.insert({input_idx++, input}); auto valueTypesIter = valueTypesByName.find(input); if (valueTypesIter != valueTypesByName.end()) { allInputTypes_.push_back(valueTypesIter->second); } else { allInputTypes_.push_back(nullptr); } const auto inputDataIter = inputDataByName.find(input); if (inputDataIter != inputDataByName.cend()) { allInputData_.push_back(inputDataIter->second); } else { allInputData_.push_back(nullptr); } } size_t output_idx = 0; for (const auto& output : n.output()) { outputIndexToNameMap_.insert({output_idx++, output}); } allOutputTypes_.resize(n.output_size()); } const AttributeProto* getAttribute(const std::string& name) const override { auto iter = attributesByName_.find(name); if (iter == attributesByName_.end()) { return nullptr; } else { return iter->second; } } size_t getNumInputs() const override { return allInputTypes_.size(); } const TypeProto* getInputType(size_t index) const override { if (index >= allInputTypes_.size()) { ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds."); } return allInputTypes_[index]; } size_t getNumOutputs() const override { return allOutputTypes_.size(); } const TypeProto* getOutputType(size_t index) const override { if (index >= allOutputTypes_.size()) { ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds."); } return &allOutputTypes_[index]; } // Convert integer vector into TensorShapeProto template void vectorToTensorShapeProto(const std::vector& input_vals, TensorShapeProto& converted_tsp) const { for (unsigned int i = 0; i < input_vals.size(); ++i) { converted_tsp.mutable_dim()->Add()->set_dim_value(input_vals[i]); } } const TensorShapeProto* getInputData(size_t index) override { if (index >= allInputData_.size()) { ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds."); } const std::string input_name = inputIndexToNameMap_.at(index); // Gets it from previous data propagation auto iter = generatedShapeData_.find(input_name); if (iter != generatedShapeData_.end()) { return &iter->second; } // Otherwise, gets it from initializer if it exists const auto* input_data = allInputData_[index]; // Only scalar (0D tensor) or 1D tensor can be converted for now // TODO: It should support tensors with more dimension on demand if (input_data != nullptr && (input_data->dims_size() == 0 || input_data->dims_size() == 1)) { TensorShapeProto tsp; if (input_data->data_type() == TensorProto_DataType_INT64) { vectorToTensorShapeProto(ParseData(input_data), tsp); } else if (input_data->data_type() == TensorProto_DataType_INT32) { vectorToTensorShapeProto(ParseData(input_data), tsp); } else { // Only supports integer type to form a shape return nullptr; } // Adds this TensorShapeProto from initializer into generatedShapeData // for future use auto result = generatedShapeData_.insert({input_name, std::move(tsp)}); if (result.second) { return &(result.first->second); } } return nullptr; } void addOutputData(size_t index, TensorShapeProto&& tsp) override { if (index >= outputIndexToNameMap_.size()) { ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds."); } auto result = generatedShapeData_.insert({outputIndexToNameMap_.at(index), std::move(tsp)}); if (!result.second) { fail_shape_inference("Data for input " + ONNX_NAMESPACE::to_string(index) + " already exists."); } } std::vector allInputData_; std::unordered_map inputIndexToNameMap_; std::unordered_map outputIndexToNameMap_; std::vector allInputTypes_; std::vector allOutputTypes_; DataValueMap& generatedShapeData_; std::unordered_map attributesByName_; }; void checkShapesAndTypes(const TypeProto_Sequence& inferredType, const TypeProto_Sequence& existingType); void checkShapesAndTypes(const TypeProto& inferredType, const TypeProto& existingType); template void GenerateSymbolicShape(TensorTypeProto* inferredType, SymbolTable& symbolTable); void MaterializeSymbolicShape(TypeProto* inferredType, SymbolTable& symbolTable); void mergeShapesAndTypes(const TypeProto_Tensor& inferredType, TypeProto_Tensor* existingType); void mergeShapesAndTypes(const TypeProto_SparseTensor& inferredType, TypeProto_SparseTensor* existingType); void mergeShapesAndTypes(const TypeProto_Sequence& inferredType, TypeProto_Tensor* existingType); void mergeShapesAndTypes(const TypeProto& inferredType, TypeProto* existingType); /// /// ModelLocalFunctionsMap is a map of function id -> model local function proto /// All the ONNX helper utilities expect the function id == : /// void InferShapes( GraphProto* g, const std::unordered_map& opset_imports, const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(), const ShapeInferenceOptions& options = {}, const ModelLocalFunctionsMap& in_model_functions = {}); void InferShapes( ModelProto& m, const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(), const ShapeInferenceOptions& options = {}, DataValueMap* generated_shape_data_by_name = nullptr); void InferShapes( const std::string& model_path, const std::string& save_path = "", const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(), const ShapeInferenceOptions& options = {}, DataValueMap* generated_shape_data_by_name = nullptr); /// /// ModelLocalFunctionsMap is a map of function id -> model local function proto /// All the ONNX helper utilities expect the function id == : /// void InferShapeForFunctionNode( const FunctionProto& func, const ISchemaRegistry* schema_registry, InferenceContext& ctx, const ShapeInferenceOptions& options = {}, const ModelLocalFunctionsMap& model_local_functions_map = {}, SymbolTable* symbolTable = nullptr, DataValueMap* generated_shape_data_by_name = nullptr); /// /// ModelLocalFunctionsMap is a map of function id -> model local function proto /// All the ONNX helper utilities expect the function id == : /// void InferShapeForFunctionNode( const FunctionProto& func_proto, const std::unordered_map& func_opset_imports, const ISchemaRegistry* schema_registry, InferenceContext& ctx, const ShapeInferenceOptions& options = {}, const ModelLocalFunctionsMap& model_local_functions_map = {}, SymbolTable* symbolTable = nullptr, DataValueMap* generated_shape_data_by_name = nullptr); /// /// Apply type-and-shape-inference based checks to a Function body. /// Returns the inferred types of the outputs of the function. /// Inference depends on the types of the inputs of the function as well as /// the attribute values supplied. /// A TypeProto with value_case() == TypeProto::ValueCase::VALUE_NOT_SET is used /// for missing optional parameters. /// std::vector InferFunctionOutputTypes( const FunctionProto& func_proto, const std::vector& input_types, const std::vector& attributes); std::string GetErrorWithNodeInfo(const NodeProto& n, const std::runtime_error& err); void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbolTable); } // namespace shape_inference } // namespace ONNX_NAMESPACE