File size: 767 Bytes
d7a991a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

import torch.nn as nn

from tools.deployment.pytorch2onnx import _convert_batchnorm, pytorch2onnx


class DummyModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv = nn.Conv3d(1, 2, 1)
        self.bn = nn.SyncBatchNorm(2)

    def forward(self, x):
        return self.bn(self.conv(x))

    def forward_dummy(self, x):
        return (self.forward(x), )


def test_onnx_exporting():
    with tempfile.TemporaryDirectory() as tmpdir:
        out_file = osp.join(tmpdir, 'tmp.onnx')
        model = DummyModel()
        model = _convert_batchnorm(model)
        # test exporting
        pytorch2onnx(model, (1, 1, 1, 1, 1), output_file=out_file)