|
import torch |
|
import os |
|
import torch.nn as nn |
|
from pipeline import ONNXBaseConfig, ONNXBaseModel |
|
|
|
local_model_path = './custom_model' |
|
config = ONNXBaseConfig(model_path='model.onnx', |
|
id2label={0: 'label_0', 1: 'label_1'}, |
|
label2id={0: 'label_1', 1: 'label_0'}) |
|
model = ONNXBaseModel(config, base_path='./custom_mode') |
|
config.save_pretrained(local_model_path) |
|
|
|
import json |
|
config_path = local_model_path + '/config.json' |
|
with open(config_path, 'r') as f: |
|
config_data = json.load(f) |
|
config_data['model_type'] = 'onnx-base' |
|
del config_data['transformers_version'] |
|
with open(config_path, 'w') as f: |
|
json.dump(config_data, f, indent=2) |
|
|
|
|
|
dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32) |
|
onnx_file_path = './custom_model' + '/' + 'model.onnx' |
|
class ZeroModel(nn.Module): |
|
def __init__(self): |
|
super(ZeroModel, self).__init__() |
|
def forward(self, x): |
|
return torch.zeros_like(x) |
|
zero_model = ZeroModel() |
|
torch.onnx.export(zero_model, dummy_input, onnx_file_path, |
|
input_names=['input'], output_names=['output'], |
|
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}) |
|
|
|
|