|
import torch |
|
import torch.nn as nn |
|
import torch.onnx |
|
class BaseModel(nn.Module): |
|
def __init__(self): |
|
super(BaseModel, self).__init__() |
|
|
|
def forward(self, x): |
|
return torch.zeros_like(x) |
|
|
|
|
|
model = BaseModel() |
|
|
|
dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32) |
|
|
|
onnx_file_path = "model.onnx" |
|
torch.onnx.export(model, dummy_input, onnx_file_path, |
|
input_names=['input'], output_names=['output'], |
|
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}) |
|
|
|
print(f"Model has been exported to {onnx_file_path}") |
|
|
|
import onnx |
|
import onnxruntime as ort |
|
onnx_model = onnx.load(onnx_file_path) |
|
onnx.checker.check_model(onnx_model) |
|
ort_session = ort.InferenceSession(onnx_file_path) |
|
input_data = dummy_input.numpy() |
|
outputs = ort_session.run(None, {'input': input_data}) |
|
print("Model output:", outputs) |
|
|
|
|