File size: 1,266 Bytes
2cf78fb
aad22ec
2cf78fb
aad22ec
2cf78fb
aad22ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5e4e8f
aad22ec
a5e4e8f
aad22ec
 
 
 
 
 
 
 
a5e4e8f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import os
import torch.nn as nn
from pipeline import ONNXBaseConfig, ONNXBaseModel

local_model_path = './custom_model'
config = ONNXBaseConfig(model_path='model.onnx',
                        id2label={0: 'label_0', 1: 'label_1'},
                        label2id={0: 'label_1', 1: 'label_0'})
model = ONNXBaseModel(config, base_path='./custom_mode')
config.save_pretrained(local_model_path)
# make sure have model_type
import json
config_path = local_model_path +  '/config.json'
with open(config_path, 'r') as f:
    config_data = json.load(f)
config_data['model_type'] = 'onnx-base'
del config_data['transformers_version']
with open(config_path, 'w') as f:
    json.dump(config_data, f, indent=2)

# save onnx
dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
onnx_file_path = './custom_model' + '/' +  'model.onnx'
class ZeroModel(nn.Module):
    def __init__(self):
        super(ZeroModel, self).__init__()
    def forward(self, x):
        return torch.zeros_like(x)
zero_model = ZeroModel()
torch.onnx.export(zero_model, dummy_input, onnx_file_path,
                  input_names=['input'], output_names=['output'],
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})