import torch | |
import torch.nn as nn | |
import torch.onnx | |
# Define a simple model | |
class SimpleModel(nn.Module): | |
def __init__(self): | |
super(SimpleModel, self).__init__() | |
self.fc = nn.Linear(10, 1) | |
def forward(self, x): | |
return self.fc(x) | |
# Instantiate and export the model | |
model = SimpleModel() | |
dummy_input = torch.randn(1, 10) | |
onnx_path = "../model.onnx" | |
torch.onnx.export(model, dummy_input, onnx_path, input_names=['input'], output_names=['output']) |