File size: 918 Bytes
aad22ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
import torch.onnx
class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel, self).__init__()

    def forward(self, x):
        return torch.zeros_like(x)

# create a model
model = BaseModel()

dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)

onnx_file_path = "model.onnx"
torch.onnx.export(model, dummy_input, onnx_file_path,
                  input_names=['input'], output_names=['output'],
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

print(f"Model has been exported to {onnx_file_path}")

import onnx
import onnxruntime as ort
onnx_model = onnx.load(onnx_file_path)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(onnx_file_path)
input_data = dummy_input.numpy()
outputs = ort_session.run(None, {'input': input_data})
print("Model output:", outputs)