Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
8.8 kB
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) ONNX Project Contributors
import unittest
import onnx
from onnx import checker, utils
class TestFunction(unittest.TestCase):
def _verify_function_set(self, extracted_model, function_set, func_domain): # type: ignore
checker.check_model(extracted_model)
self.assertEqual(len(extracted_model.functions), len(function_set))
for function in function_set:
self.assertIsNotNone(
next(
(
f
for f in extracted_model.functions
if f.name == function and f.domain == func_domain
),
None,
)
)
def test_extract_model_with_local_function(self) -> None:
r"""# 1. build a model with graph below. extract models with output combinations
# 2. validate extracted models' local functions
#
# model graph:
# i0 i1 i2
# | __________________|__________________/_________
# | | | | / |
# | | | | / |
# func_add func_identity add identity
# | ___\___________\____________________|_________ |
# | | \ \ | _______|___|
# | | \ \ | | | |
# add function_nested_identity_add add function_nested_identity_add
# | | | |
# | | | |
# o_func_add o_all_func0 o_no_func o_all_func1
#
# where function_nested_identity_add is a function that is defined with functions:
# a b
# | |
# func_identity func_identity
# \ /
# func_add
# |
# c
#
"""
# function common
func_domain = "local"
func_opset_imports = [onnx.helper.make_opsetid("", 14)]
func_nested_opset_imports = [
onnx.helper.make_opsetid("", 14),
onnx.helper.make_opsetid(func_domain, 1),
]
# add function
func_add_name = "func_add"
func_add_inputs = ["a", "b"]
func_add_outputs = ["c"]
func_add_nodes = [onnx.helper.make_node("Add", ["a", "b"], ["c"])]
func_add = onnx.helper.make_function(
func_domain,
func_add_name,
func_add_inputs,
func_add_outputs,
func_add_nodes,
func_opset_imports,
)
# identity function
func_identity_name = "func_identity"
func_identity_inputs = ["a"]
func_identity_outputs = ["b"]
func_identity_nodes = [onnx.helper.make_node("Identity", ["a"], ["b"])]
func_identity = onnx.helper.make_function(
func_domain,
func_identity_name,
func_identity_inputs,
func_identity_outputs,
func_identity_nodes,
func_opset_imports,
)
# nested identity/add function
func_nested_identity_add_name = "func_nested_identity_add"
func_nested_identity_add_inputs = ["a", "b"]
func_nested_identity_add_outputs = ["c"]
func_nested_identity_add_nodes = [
onnx.helper.make_node("func_identity", ["a"], ["a1"], domain=func_domain),
onnx.helper.make_node("func_identity", ["b"], ["b1"], domain=func_domain),
onnx.helper.make_node("func_add", ["a1", "b1"], ["c"], domain=func_domain),
]
func_nested_identity_add = onnx.helper.make_function(
func_domain,
func_nested_identity_add_name,
func_nested_identity_add_inputs,
func_nested_identity_add_outputs,
func_nested_identity_add_nodes,
func_nested_opset_imports,
)
# create graph nodes
node_func_add = onnx.helper.make_node(
func_add_name, ["i0", "i1"], ["t0"], domain=func_domain
)
node_add0 = onnx.helper.make_node("Add", ["i1", "i2"], ["t2"])
node_add1 = onnx.helper.make_node("Add", ["t0", "t2"], ["o_func_add"])
node_func_identity = onnx.helper.make_node(
func_identity_name, ["i1"], ["t1"], domain=func_domain
)
node_identity = onnx.helper.make_node("Identity", ["i1"], ["t3"])
node_add2 = onnx.helper.make_node("Add", ["t3", "t2"], ["o_no_func"])
node_func_nested0 = onnx.helper.make_node(
func_nested_identity_add_name,
["t0", "t1"],
["o_all_func0"],
domain=func_domain,
)
node_func_nested1 = onnx.helper.make_node(
func_nested_identity_add_name,
["t3", "t2"],
["o_all_func1"],
domain=func_domain,
)
graph_name = "graph_with_imbedded_functions"
ir_version = 8
opset_imports = [
onnx.helper.make_opsetid("", 14),
onnx.helper.make_opsetid("local", 1),
]
tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=2, shape=[5])
graph = onnx.helper.make_graph(
[
node_func_add,
node_add0,
node_add1,
node_func_identity,
node_identity,
node_func_nested0,
node_func_nested1,
node_add2,
],
graph_name,
[
onnx.helper.make_value_info(name="i0", type_proto=tensor_type_proto),
onnx.helper.make_value_info(name="i1", type_proto=tensor_type_proto),
onnx.helper.make_value_info(name="i2", type_proto=tensor_type_proto),
],
[
onnx.helper.make_value_info(
name="o_no_func", type_proto=tensor_type_proto
),
onnx.helper.make_value_info(
name="o_func_add", type_proto=tensor_type_proto
),
onnx.helper.make_value_info(
name="o_all_func0", type_proto=tensor_type_proto
),
onnx.helper.make_value_info(
name="o_all_func1", type_proto=tensor_type_proto
),
],
)
meta = {
"ir_version": ir_version,
"opset_imports": opset_imports,
"producer_name": "test_extract_model_with_local_function",
"functions": [func_identity, func_add, func_nested_identity_add],
}
model = onnx.helper.make_model(graph, **meta)
checker.check_model(model)
extracted_with_no_funcion = utils.Extractor(model).extract_model(
["i0", "i1", "i2"], ["o_no_func"]
)
self._verify_function_set(extracted_with_no_funcion, {}, func_domain)
extracted_with_add_funcion = utils.Extractor(model).extract_model(
["i0", "i1", "i2"], ["o_func_add"]
)
self._verify_function_set(
extracted_with_add_funcion, {func_add_name}, func_domain
)
extracted_with_o_all_funcion0 = utils.Extractor(model).extract_model(
["i0", "i1", "i2"], ["o_all_func0"]
)
self._verify_function_set(
extracted_with_o_all_funcion0,
{func_add_name, func_identity_name, func_nested_identity_add_name},
func_domain,
)
extracted_with_o_all_funcion1 = utils.Extractor(model).extract_model(
["i0", "i1", "i2"], ["o_all_func1"]
)
self._verify_function_set(
extracted_with_o_all_funcion1,
{func_add_name, func_identity_name, func_nested_identity_add_name},
func_domain,
)
extracted_with_o_all_funcion2 = utils.Extractor(model).extract_model(
["i0", "i1", "i2"],
["o_no_func", "o_func_add", "o_all_func0", "o_all_func1"],
)
self._verify_function_set(
extracted_with_o_all_funcion2,
{func_add_name, func_identity_name, func_nested_identity_add_name},
func_domain,
)
if __name__ == "__main__":
unittest.main()