Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
import tempfile | |
import torch.nn as nn | |
from tools.deployment.pytorch2onnx import _convert_batchnorm, pytorch2onnx | |
class DummyModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv = nn.Conv3d(1, 2, 1) | |
self.bn = nn.SyncBatchNorm(2) | |
def forward(self, x): | |
return self.bn(self.conv(x)) | |
def forward_dummy(self, x): | |
return (self.forward(x), ) | |
def test_onnx_exporting(): | |
with tempfile.TemporaryDirectory() as tmpdir: | |
out_file = osp.join(tmpdir, 'tmp.onnx') | |
model = DummyModel() | |
model = _convert_batchnorm(model) | |
# test exporting | |
pytorch2onnx(model, (1, 1, 1, 1, 1), output_file=out_file) | |