File size: 2,281 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0

import os
import shutil
import tempfile
import unittest

import onnx
from onnx import TensorProto, helper


class TestUtilityFunctions(unittest.TestCase):
    def test_extract_model(self) -> None:
        def create_tensor(name):  # type: ignore
            return helper.make_tensor_value_info(name, TensorProto.FLOAT, [1, 2])

        A0 = create_tensor("A0")
        A1 = create_tensor("A1")
        B0 = create_tensor("B0")
        B1 = create_tensor("B1")
        B2 = create_tensor("B2")
        C0 = create_tensor("C0")
        C1 = create_tensor("C1")
        D0 = create_tensor("D0")
        L0_0 = helper.make_node("Add", ["A0", "A1"], ["B0"])
        L0_1 = helper.make_node("Sub", ["A0", "A1"], ["B1"])
        L0_2 = helper.make_node("Mul", ["A0", "A1"], ["B2"])
        L1_0 = helper.make_node("Add", ["B0", "B1"], ["C0"])
        L1_1 = helper.make_node("Sub", ["B1", "B2"], ["C1"])
        L2_0 = helper.make_node("Mul", ["C0", "C1"], ["D0"])

        g0 = helper.make_graph(
            [L0_0, L0_1, L0_2, L1_0, L1_1, L2_0], "test", [A0, A1], [D0]
        )
        m0 = helper.make_model(g0, producer_name="test")
        tdir = tempfile.mkdtemp()
        p0 = os.path.join(tdir, "original.onnx")
        onnx.save(m0, p0)

        p1 = os.path.join(tdir, "extracted.onnx")
        input_names = ["B0", "B1", "B2"]
        output_names = ["C0", "C1"]
        onnx.utils.extract_model(p0, p1, input_names, output_names)

        m1 = onnx.load(p1)
        self.assertEqual(m1.producer_name, "onnx.utils.extract_model")
        self.assertEqual(m1.ir_version, m0.ir_version)
        self.assertEqual(m1.opset_import, m0.opset_import)
        self.assertEqual(len(m1.graph.node), 2)
        self.assertEqual(len(m1.graph.input), 3)
        self.assertEqual(len(m1.graph.output), 2)
        self.assertEqual(m1.graph.input[0], B0)
        self.assertEqual(m1.graph.input[1], B1)
        self.assertEqual(m1.graph.input[2], B2)
        self.assertEqual(m1.graph.output[0], C0)
        self.assertEqual(m1.graph.output[1], C1)
        shutil.rmtree(tdir, ignore_errors=True)


if __name__ == "__main__":
    unittest.main()