# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 import os import shutil import tempfile import unittest import onnx from onnx import TensorProto, helper class TestUtilityFunctions(unittest.TestCase): def test_extract_model(self) -> None: def create_tensor(name): # type: ignore return helper.make_tensor_value_info(name, TensorProto.FLOAT, [1, 2]) A0 = create_tensor("A0") A1 = create_tensor("A1") B0 = create_tensor("B0") B1 = create_tensor("B1") B2 = create_tensor("B2") C0 = create_tensor("C0") C1 = create_tensor("C1") D0 = create_tensor("D0") L0_0 = helper.make_node("Add", ["A0", "A1"], ["B0"]) L0_1 = helper.make_node("Sub", ["A0", "A1"], ["B1"]) L0_2 = helper.make_node("Mul", ["A0", "A1"], ["B2"]) L1_0 = helper.make_node("Add", ["B0", "B1"], ["C0"]) L1_1 = helper.make_node("Sub", ["B1", "B2"], ["C1"]) L2_0 = helper.make_node("Mul", ["C0", "C1"], ["D0"]) g0 = helper.make_graph( [L0_0, L0_1, L0_2, L1_0, L1_1, L2_0], "test", [A0, A1], [D0] ) m0 = helper.make_model(g0, producer_name="test") tdir = tempfile.mkdtemp() p0 = os.path.join(tdir, "original.onnx") onnx.save(m0, p0) p1 = os.path.join(tdir, "extracted.onnx") input_names = ["B0", "B1", "B2"] output_names = ["C0", "C1"] onnx.utils.extract_model(p0, p1, input_names, output_names) m1 = onnx.load(p1) self.assertEqual(m1.producer_name, "onnx.utils.extract_model") self.assertEqual(m1.ir_version, m0.ir_version) self.assertEqual(m1.opset_import, m0.opset_import) self.assertEqual(len(m1.graph.node), 2) self.assertEqual(len(m1.graph.input), 3) self.assertEqual(len(m1.graph.output), 2) self.assertEqual(m1.graph.input[0], B0) self.assertEqual(m1.graph.input[1], B1) self.assertEqual(m1.graph.input[2], B2) self.assertEqual(m1.graph.output[0], C0) self.assertEqual(m1.graph.output[1], C1) shutil.rmtree(tdir, ignore_errors=True) if __name__ == "__main__": unittest.main()