Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
2.28 kB
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import os
import shutil
import tempfile
import unittest
import onnx
from onnx import TensorProto, helper
class TestUtilityFunctions(unittest.TestCase):
def test_extract_model(self) -> None:
def create_tensor(name): # type: ignore
return helper.make_tensor_value_info(name, TensorProto.FLOAT, [1, 2])
A0 = create_tensor("A0")
A1 = create_tensor("A1")
B0 = create_tensor("B0")
B1 = create_tensor("B1")
B2 = create_tensor("B2")
C0 = create_tensor("C0")
C1 = create_tensor("C1")
D0 = create_tensor("D0")
L0_0 = helper.make_node("Add", ["A0", "A1"], ["B0"])
L0_1 = helper.make_node("Sub", ["A0", "A1"], ["B1"])
L0_2 = helper.make_node("Mul", ["A0", "A1"], ["B2"])
L1_0 = helper.make_node("Add", ["B0", "B1"], ["C0"])
L1_1 = helper.make_node("Sub", ["B1", "B2"], ["C1"])
L2_0 = helper.make_node("Mul", ["C0", "C1"], ["D0"])
g0 = helper.make_graph(
[L0_0, L0_1, L0_2, L1_0, L1_1, L2_0], "test", [A0, A1], [D0]
)
m0 = helper.make_model(g0, producer_name="test")
tdir = tempfile.mkdtemp()
p0 = os.path.join(tdir, "original.onnx")
onnx.save(m0, p0)
p1 = os.path.join(tdir, "extracted.onnx")
input_names = ["B0", "B1", "B2"]
output_names = ["C0", "C1"]
onnx.utils.extract_model(p0, p1, input_names, output_names)
m1 = onnx.load(p1)
self.assertEqual(m1.producer_name, "onnx.utils.extract_model")
self.assertEqual(m1.ir_version, m0.ir_version)
self.assertEqual(m1.opset_import, m0.opset_import)
self.assertEqual(len(m1.graph.node), 2)
self.assertEqual(len(m1.graph.input), 3)
self.assertEqual(len(m1.graph.output), 2)
self.assertEqual(m1.graph.input[0], B0)
self.assertEqual(m1.graph.input[1], B1)
self.assertEqual(m1.graph.input[2], B2)
self.assertEqual(m1.graph.output[0], C0)
self.assertEqual(m1.graph.output[1], C1)
shutil.rmtree(tdir, ignore_errors=True)
if __name__ == "__main__":
unittest.main()