m3 commited on
Commit
a5e4e8f
·
1 Parent(s): 8aa52b9

feat: add onnx model

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. src/demo.py +93 -27
  3. src/init_model.py +24 -11
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ ./custom_model
src/demo.py CHANGED
@@ -1,58 +1,124 @@
1
- # 1. 首先,你需要定义一个 ONNX 模型配置类,并注册它
2
- from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, AutoModel
3
  from transformers.pipelines import PIPELINE_REGISTRY
 
4
 
5
- class ONNXBaseConfig(PretrainedConfig):
6
- model_type = "onnx-base"
 
 
7
 
8
- # 注册配置类
9
- AutoConfig.register("onnx-base", ONNXBaseConfig)
 
10
 
11
- # 注册模型类
12
- class ONNXBaseModel(AutoModel):
13
- config_class = ONNXBaseConfig
14
 
 
15
  class ONNXBaseModel(PreTrainedModel):
16
  config_class = ONNXBaseConfig
17
-
18
- def __init__(self, config):
19
  super().__init__(config)
20
-
21
- def forward(self, *args, **kwargs):
22
- return self.dummy_param
23
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  AutoModel.register(ONNXBaseConfig, ONNXBaseModel)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  from transformers.pipelines import Pipeline
28
 
29
  class ONNXBasePipeline(Pipeline):
30
  def __init__(self, model, **kwargs):
 
31
  super().__init__(model=model, **kwargs)
32
 
33
  def _sanitize_parameters(self, **kwargs):
34
  return {}, {}, {}
35
 
36
- def preprocess(self, inputs):
37
- return inputs
38
 
39
- def _forward(self, model_inputs):
40
- return self.model(**model_inputs)
 
 
41
 
42
  def postprocess(self, model_outputs):
43
  return model_outputs
44
 
45
  PIPELINE_REGISTRY.register_pipeline(
46
- task="onnx-base",
47
- pipeline_class=ONNXBasePipeline
 
48
  )
49
 
50
-
51
  from transformers import pipeline
52
 
53
- # 使用自定义的 pipeline 任务
54
- onnx_pipeline = pipeline(task="onnx-base", model="m3/onnx-base")
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # 使用 pipeline
57
- result = onnx_pipeline("Your input data here")
58
- print(result)
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel, modeling_utils
 
2
  from transformers.pipelines import PIPELINE_REGISTRY
3
+ from huggingface_hub import hf_hub_download
4
 
5
+ import onnxruntime as ort
6
+ import torch
7
+ import os
8
+ import torch.nn as nn
9
 
10
+ # 1. register AutoConfig
11
+ class ONNXBaseConfig(PretrainedConfig):
12
+ model_type = 'onnx-base'
13
 
14
+ AutoConfig.register('onnx-base', ONNXBaseConfig)
 
 
15
 
16
+ # 2. register AutoModel
17
  class ONNXBaseModel(PreTrainedModel):
18
  config_class = ONNXBaseConfig
19
+ def __init__(self, config, base_path=None):
 
20
  super().__init__(config)
21
+ if base_path:
22
+ model_path = base_path + '/' + config.model_path
23
+ if os.path.exists(model_path):
24
+ self.session = ort.InferenceSession(model_path)
25
+
26
+ def forward(self, input=None, **kwargs):
27
+ outs = self.session.run(None, {'input': input})
28
+ return outs
29
+
30
+ @classmethod
31
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
32
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
33
+ is_local = os.path.isdir(pretrained_model_name_or_path)
34
+ if is_local:
35
+ base_path = pretrained_model_name_or_path
36
+ else:
37
+ config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename='config.json')
38
+ base_path = os.path.dirname(config_path)
39
+ hf_hub_download(repo_id=pretrained_model_name_or_path, filename=config.model_path)
40
+ return cls(config, base_path=base_path)
41
+
42
+ @property
43
+ def device(self):
44
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
+ return torch.device(device)
46
 
47
  AutoModel.register(ONNXBaseConfig, ONNXBaseModel)
48
 
49
+ # option: save config to path
50
+ local_model_path = './custom_model'
51
+ config = ONNXBaseConfig(model_path='model.onnx',
52
+ id2label={0: 'label_0', 1: 'label_1'},
53
+ label2id={0: 'label_1', 1: 'label_0'})
54
+ model = ONNXBaseModel(config, base_path='./custom_mode')
55
+ config.save_pretrained(local_model_path)
56
+ # make sure have model_type
57
+ import json
58
+ config_path = local_model_path + '/config.json'
59
+ with open(config_path, 'r') as f:
60
+ config_data = json.load(f)
61
+ config_data['model_type'] = 'onnx-base'
62
+ del config_data['transformers_version']
63
+ with open(config_path, 'w') as f:
64
+ json.dump(config_data, f, indent=2)
65
+
66
+ # save onnx
67
+ dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
68
+ onnx_file_path = './custom_model' + '/' + 'model.onnx'
69
+ class ZeroModel(nn.Module):
70
+ def __init__(self):
71
+ super(ZeroModel, self).__init__()
72
+ def forward(self, x):
73
+ return torch.zeros_like(x)
74
+ zero_model = ZeroModel()
75
+ torch.onnx.export(zero_model, dummy_input, onnx_file_path,
76
+ input_names=['input'], output_names=['output'],
77
+ dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
78
+
79
+
80
+ # 2. register Pipeline
81
  from transformers.pipelines import Pipeline
82
 
83
  class ONNXBasePipeline(Pipeline):
84
  def __init__(self, model, **kwargs):
85
+ self.device_id = kwargs['device']
86
  super().__init__(model=model, **kwargs)
87
 
88
  def _sanitize_parameters(self, **kwargs):
89
  return {}, {}, {}
90
 
91
+ def preprocess(self, input):
92
+ return {'input': input}
93
 
94
+ def _forward(self, model_input):
95
+ with torch.no_grad():
96
+ outputs = self.model(**model_input)
97
+ return outputs
98
 
99
  def postprocess(self, model_outputs):
100
  return model_outputs
101
 
102
  PIPELINE_REGISTRY.register_pipeline(
103
+ task='onnx-base',
104
+ pipeline_class=ONNXBasePipeline,
105
+ pt_model=ONNXBaseModel
106
  )
107
 
108
+ # 4. show how to use
109
  from transformers import pipeline
110
 
111
+ pipe = pipeline(
112
+ task='onnx-base',
113
+ model='m3/onnx-base',
114
+ batch_size=10,
115
+ device='cuda',
116
+ )
117
+
118
+ dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
119
+ input_data = dummy_input.numpy()
120
+ result = pipe(
121
+ inputs=input_data, device='cuda',
122
+ )
123
+ print(result)
124
 
 
 
 
src/init_model.py CHANGED
@@ -1,18 +1,31 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.onnx
4
-
5
- # Define a simple model
6
- class SimpleModel(nn.Module):
7
  def __init__(self):
8
- super(SimpleModel, self).__init__()
9
- self.fc = nn.Linear(1, 1)
10
 
11
  def forward(self, x):
12
- return self.fc(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Instantiate and export the model
15
- model = SimpleModel()
16
- dummy_input = torch.randn(1, 10)
17
- onnx_path = "../model.onnx"
18
- torch.onnx.export(model, dummy_input, onnx_path, input_names=['input'], output_names=['output'])
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.onnx
4
+ class BaseModel(nn.Module):
 
 
5
  def __init__(self):
6
+ super(BaseModel, self).__init__()
 
7
 
8
  def forward(self, x):
9
+ return torch.zeros_like(x)
10
+
11
+ # create a model
12
+ model = BaseModel()
13
+
14
+ dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
15
+
16
+ onnx_file_path = "model.onnx"
17
+ torch.onnx.export(model, dummy_input, onnx_file_path,
18
+ input_names=['input'], output_names=['output'],
19
+ dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
20
+
21
+ print(f"Model has been exported to {onnx_file_path}")
22
+
23
+ import onnx
24
+ import onnxruntime as ort
25
+ onnx_model = onnx.load(onnx_file_path)
26
+ onnx.checker.check_model(onnx_model)
27
+ ort_session = ort.InferenceSession(onnx_file_path)
28
+ input_data = dummy_input.numpy()
29
+ outputs = ort_session.run(None, {'input': input_data})
30
+ print("Model output:", outputs)
31