# SPDX-License-Identifier: Apache-2.0 # Copyright (c) ONNX Project Contributors import unittest import onnx from onnx import checker, utils class TestFunction(unittest.TestCase): def _verify_function_set(self, extracted_model, function_set, func_domain): # type: ignore checker.check_model(extracted_model) self.assertEqual(len(extracted_model.functions), len(function_set)) for function in function_set: self.assertIsNotNone( next( ( f for f in extracted_model.functions if f.name == function and f.domain == func_domain ), None, ) ) def test_extract_model_with_local_function(self) -> None: r"""# 1. build a model with graph below. extract models with output combinations # 2. validate extracted models' local functions # # model graph: # i0 i1 i2 # | __________________|__________________/_________ # | | | | / | # | | | | / | # func_add func_identity add identity # | ___\___________\____________________|_________ | # | | \ \ | _______|___| # | | \ \ | | | | # add function_nested_identity_add add function_nested_identity_add # | | | | # | | | | # o_func_add o_all_func0 o_no_func o_all_func1 # # where function_nested_identity_add is a function that is defined with functions: # a b # | | # func_identity func_identity # \ / # func_add # | # c # """ # function common func_domain = "local" func_opset_imports = [onnx.helper.make_opsetid("", 14)] func_nested_opset_imports = [ onnx.helper.make_opsetid("", 14), onnx.helper.make_opsetid(func_domain, 1), ] # add function func_add_name = "func_add" func_add_inputs = ["a", "b"] func_add_outputs = ["c"] func_add_nodes = [onnx.helper.make_node("Add", ["a", "b"], ["c"])] func_add = onnx.helper.make_function( func_domain, func_add_name, func_add_inputs, func_add_outputs, func_add_nodes, func_opset_imports, ) # identity function func_identity_name = "func_identity" func_identity_inputs = ["a"] func_identity_outputs = ["b"] func_identity_nodes = [onnx.helper.make_node("Identity", ["a"], ["b"])] func_identity = onnx.helper.make_function( func_domain, func_identity_name, func_identity_inputs, func_identity_outputs, func_identity_nodes, func_opset_imports, ) # nested identity/add function func_nested_identity_add_name = "func_nested_identity_add" func_nested_identity_add_inputs = ["a", "b"] func_nested_identity_add_outputs = ["c"] func_nested_identity_add_nodes = [ onnx.helper.make_node("func_identity", ["a"], ["a1"], domain=func_domain), onnx.helper.make_node("func_identity", ["b"], ["b1"], domain=func_domain), onnx.helper.make_node("func_add", ["a1", "b1"], ["c"], domain=func_domain), ] func_nested_identity_add = onnx.helper.make_function( func_domain, func_nested_identity_add_name, func_nested_identity_add_inputs, func_nested_identity_add_outputs, func_nested_identity_add_nodes, func_nested_opset_imports, ) # create graph nodes node_func_add = onnx.helper.make_node( func_add_name, ["i0", "i1"], ["t0"], domain=func_domain ) node_add0 = onnx.helper.make_node("Add", ["i1", "i2"], ["t2"]) node_add1 = onnx.helper.make_node("Add", ["t0", "t2"], ["o_func_add"]) node_func_identity = onnx.helper.make_node( func_identity_name, ["i1"], ["t1"], domain=func_domain ) node_identity = onnx.helper.make_node("Identity", ["i1"], ["t3"]) node_add2 = onnx.helper.make_node("Add", ["t3", "t2"], ["o_no_func"]) node_func_nested0 = onnx.helper.make_node( func_nested_identity_add_name, ["t0", "t1"], ["o_all_func0"], domain=func_domain, ) node_func_nested1 = onnx.helper.make_node( func_nested_identity_add_name, ["t3", "t2"], ["o_all_func1"], domain=func_domain, ) graph_name = "graph_with_imbedded_functions" ir_version = 8 opset_imports = [ onnx.helper.make_opsetid("", 14), onnx.helper.make_opsetid("local", 1), ] tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=2, shape=[5]) graph = onnx.helper.make_graph( [ node_func_add, node_add0, node_add1, node_func_identity, node_identity, node_func_nested0, node_func_nested1, node_add2, ], graph_name, [ onnx.helper.make_value_info(name="i0", type_proto=tensor_type_proto), onnx.helper.make_value_info(name="i1", type_proto=tensor_type_proto), onnx.helper.make_value_info(name="i2", type_proto=tensor_type_proto), ], [ onnx.helper.make_value_info( name="o_no_func", type_proto=tensor_type_proto ), onnx.helper.make_value_info( name="o_func_add", type_proto=tensor_type_proto ), onnx.helper.make_value_info( name="o_all_func0", type_proto=tensor_type_proto ), onnx.helper.make_value_info( name="o_all_func1", type_proto=tensor_type_proto ), ], ) meta = { "ir_version": ir_version, "opset_imports": opset_imports, "producer_name": "test_extract_model_with_local_function", "functions": [func_identity, func_add, func_nested_identity_add], } model = onnx.helper.make_model(graph, **meta) checker.check_model(model) extracted_with_no_funcion = utils.Extractor(model).extract_model( ["i0", "i1", "i2"], ["o_no_func"] ) self._verify_function_set(extracted_with_no_funcion, {}, func_domain) extracted_with_add_funcion = utils.Extractor(model).extract_model( ["i0", "i1", "i2"], ["o_func_add"] ) self._verify_function_set( extracted_with_add_funcion, {func_add_name}, func_domain ) extracted_with_o_all_funcion0 = utils.Extractor(model).extract_model( ["i0", "i1", "i2"], ["o_all_func0"] ) self._verify_function_set( extracted_with_o_all_funcion0, {func_add_name, func_identity_name, func_nested_identity_add_name}, func_domain, ) extracted_with_o_all_funcion1 = utils.Extractor(model).extract_model( ["i0", "i1", "i2"], ["o_all_func1"] ) self._verify_function_set( extracted_with_o_all_funcion1, {func_add_name, func_identity_name, func_nested_identity_add_name}, func_domain, ) extracted_with_o_all_funcion2 = utils.Extractor(model).extract_model( ["i0", "i1", "i2"], ["o_no_func", "o_func_add", "o_all_func0", "o_all_func1"], ) self._verify_function_set( extracted_with_o_all_funcion2, {func_add_name, func_identity_name, func_nested_identity_add_name}, func_domain, ) if __name__ == "__main__": unittest.main()