m3 commited on
Commit
aad22ec
·
1 Parent(s): 05297dc

chore: splite the file

Browse files
Files changed (4) hide show
  1. src/demo.py +2 -106
  2. src/init_model.py +26 -23
  3. src/init_onnx.py +31 -0
  4. src/pipeline.py +76 -0
src/demo.py CHANGED
@@ -1,111 +1,6 @@
1
- from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel, modeling_utils
2
- from transformers.pipelines import PIPELINE_REGISTRY
3
- from huggingface_hub import hf_hub_download
4
-
5
- import onnxruntime as ort
6
  import torch
7
- import os
8
- import torch.nn as nn
9
-
10
- # 1. register AutoConfig
11
- class ONNXBaseConfig(PretrainedConfig):
12
- model_type = 'onnx-base'
13
-
14
- AutoConfig.register('onnx-base', ONNXBaseConfig)
15
-
16
- # 2. register AutoModel
17
- class ONNXBaseModel(PreTrainedModel):
18
- config_class = ONNXBaseConfig
19
- def __init__(self, config, base_path=None):
20
- super().__init__(config)
21
- if base_path:
22
- model_path = base_path + '/' + config.model_path
23
- if os.path.exists(model_path):
24
- self.session = ort.InferenceSession(model_path)
25
-
26
- def forward(self, input=None, **kwargs):
27
- outs = self.session.run(None, {'input': input})
28
- return outs
29
-
30
- @classmethod
31
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
32
- config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
33
- is_local = os.path.isdir(pretrained_model_name_or_path)
34
- if is_local:
35
- base_path = pretrained_model_name_or_path
36
- else:
37
- config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename='config.json')
38
- base_path = os.path.dirname(config_path)
39
- hf_hub_download(repo_id=pretrained_model_name_or_path, filename=config.model_path)
40
- return cls(config, base_path=base_path)
41
-
42
- @property
43
- def device(self):
44
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
- return torch.device(device)
46
-
47
- AutoModel.register(ONNXBaseConfig, ONNXBaseModel)
48
-
49
- # option: save config to path
50
- local_model_path = './custom_model'
51
- config = ONNXBaseConfig(model_path='model.onnx',
52
- id2label={0: 'label_0', 1: 'label_1'},
53
- label2id={0: 'label_1', 1: 'label_0'})
54
- model = ONNXBaseModel(config, base_path='./custom_mode')
55
- config.save_pretrained(local_model_path)
56
- # make sure have model_type
57
- import json
58
- config_path = local_model_path + '/config.json'
59
- with open(config_path, 'r') as f:
60
- config_data = json.load(f)
61
- config_data['model_type'] = 'onnx-base'
62
- del config_data['transformers_version']
63
- with open(config_path, 'w') as f:
64
- json.dump(config_data, f, indent=2)
65
-
66
- # save onnx
67
- dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
68
- onnx_file_path = './custom_model' + '/' + 'model.onnx'
69
- class ZeroModel(nn.Module):
70
- def __init__(self):
71
- super(ZeroModel, self).__init__()
72
- def forward(self, x):
73
- return torch.zeros_like(x)
74
- zero_model = ZeroModel()
75
- torch.onnx.export(zero_model, dummy_input, onnx_file_path,
76
- input_names=['input'], output_names=['output'],
77
- dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
78
 
79
-
80
- # 2. register Pipeline
81
- from transformers.pipelines import Pipeline
82
-
83
- class ONNXBasePipeline(Pipeline):
84
- def __init__(self, model, **kwargs):
85
- self.device_id = kwargs['device']
86
- super().__init__(model=model, **kwargs)
87
-
88
- def _sanitize_parameters(self, **kwargs):
89
- return {}, {}, {}
90
-
91
- def preprocess(self, input):
92
- return {'input': input}
93
-
94
- def _forward(self, model_input):
95
- with torch.no_grad():
96
- outputs = self.model(**model_input)
97
- return outputs
98
-
99
- def postprocess(self, model_outputs):
100
- return model_outputs
101
-
102
- PIPELINE_REGISTRY.register_pipeline(
103
- task='onnx-base',
104
- pipeline_class=ONNXBasePipeline,
105
- pt_model=ONNXBaseModel
106
- )
107
-
108
- # 4. show how to use
109
  from transformers import pipeline
110
 
111
  pipe = pipeline(
@@ -120,5 +15,6 @@ input_data = dummy_input.numpy()
120
  result = pipe(
121
  inputs=input_data, device='cuda',
122
  )
 
123
  print(result)
124
 
 
 
 
 
 
 
1
  import torch
2
+ import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from transformers import pipeline
5
 
6
  pipe = pipeline(
 
15
  result = pipe(
16
  inputs=input_data, device='cuda',
17
  )
18
+
19
  print(result)
20
 
src/init_model.py CHANGED
@@ -1,31 +1,34 @@
1
  import torch
 
2
  import torch.nn as nn
3
- import torch.onnx
4
- class BaseModel(nn.Module):
5
- def __init__(self):
6
- super(BaseModel, self).__init__()
7
 
8
- def forward(self, x):
9
- return torch.zeros_like(x)
10
-
11
- # create a model
12
- model = BaseModel()
 
 
 
 
 
 
 
 
 
 
13
 
 
14
  dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
15
-
16
- onnx_file_path = "model.onnx"
17
- torch.onnx.export(model, dummy_input, onnx_file_path,
 
 
 
 
 
18
  input_names=['input'], output_names=['output'],
19
  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
20
 
21
- print(f"Model has been exported to {onnx_file_path}")
22
-
23
- import onnx
24
- import onnxruntime as ort
25
- onnx_model = onnx.load(onnx_file_path)
26
- onnx.checker.check_model(onnx_model)
27
- ort_session = ort.InferenceSession(onnx_file_path)
28
- input_data = dummy_input.numpy()
29
- outputs = ort_session.run(None, {'input': input_data})
30
- print("Model output:", outputs)
31
-
 
1
  import torch
2
+ import os
3
  import torch.nn as nn
4
+ from pipeline import ONNXBaseConfig, ONNXBaseModel
 
 
 
5
 
6
+ local_model_path = './custom_model'
7
+ config = ONNXBaseConfig(model_path='model.onnx',
8
+ id2label={0: 'label_0', 1: 'label_1'},
9
+ label2id={0: 'label_1', 1: 'label_0'})
10
+ model = ONNXBaseModel(config, base_path='./custom_mode')
11
+ config.save_pretrained(local_model_path)
12
+ # make sure have model_type
13
+ import json
14
+ config_path = local_model_path + '/config.json'
15
+ with open(config_path, 'r') as f:
16
+ config_data = json.load(f)
17
+ config_data['model_type'] = 'onnx-base'
18
+ del config_data['transformers_version']
19
+ with open(config_path, 'w') as f:
20
+ json.dump(config_data, f, indent=2)
21
 
22
+ # save onnx
23
  dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
24
+ onnx_file_path = './custom_model' + '/' + 'model.onnx'
25
+ class ZeroModel(nn.Module):
26
+ def __init__(self):
27
+ super(ZeroModel, self).__init__()
28
+ def forward(self, x):
29
+ return torch.zeros_like(x)
30
+ zero_model = ZeroModel()
31
+ torch.onnx.export(zero_model, dummy_input, onnx_file_path,
32
  input_names=['input'], output_names=['output'],
33
  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
34
 
 
 
 
 
 
 
 
 
 
 
 
src/init_onnx.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.onnx
4
+ class BaseModel(nn.Module):
5
+ def __init__(self):
6
+ super(BaseModel, self).__init__()
7
+
8
+ def forward(self, x):
9
+ return torch.zeros_like(x)
10
+
11
+ # create a model
12
+ model = BaseModel()
13
+
14
+ dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
15
+
16
+ onnx_file_path = "model.onnx"
17
+ torch.onnx.export(model, dummy_input, onnx_file_path,
18
+ input_names=['input'], output_names=['output'],
19
+ dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
20
+
21
+ print(f"Model has been exported to {onnx_file_path}")
22
+
23
+ import onnx
24
+ import onnxruntime as ort
25
+ onnx_model = onnx.load(onnx_file_path)
26
+ onnx.checker.check_model(onnx_model)
27
+ ort_session = ort.InferenceSession(onnx_file_path)
28
+ input_data = dummy_input.numpy()
29
+ outputs = ort_session.run(None, {'input': input_data})
30
+ print("Model output:", outputs)
31
+
src/pipeline.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel
2
+ from transformers.pipelines import PIPELINE_REGISTRY
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ import onnxruntime as ort
6
+ import torch
7
+ import os
8
+
9
+ # 1. register AutoConfig
10
+ class ONNXBaseConfig(PretrainedConfig):
11
+ model_type = 'onnx-base'
12
+
13
+ AutoConfig.register('onnx-base', ONNXBaseConfig)
14
+
15
+ # 2. register AutoModel
16
+ class ONNXBaseModel(PreTrainedModel):
17
+ config_class = ONNXBaseConfig
18
+ def __init__(self, config, base_path=None):
19
+ super().__init__(config)
20
+ if base_path:
21
+ model_path = base_path + '/' + config.model_path
22
+ if os.path.exists(model_path):
23
+ self.session = ort.InferenceSession(model_path)
24
+
25
+ def forward(self, input=None, **kwargs):
26
+ outs = self.session.run(None, {'input': input})
27
+ return outs
28
+
29
+ @classmethod
30
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
31
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
32
+ if config.model_path is None:
33
+ config.model_path = 'model.onnx'
34
+ is_local = os.path.isdir(pretrained_model_name_or_path)
35
+ if is_local:
36
+ base_path = pretrained_model_name_or_path
37
+ else:
38
+ config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename='config.json')
39
+ base_path = os.path.dirname(config_path)
40
+ hf_hub_download(repo_id=pretrained_model_name_or_path, filename=config.model_path)
41
+ return cls(config, base_path=base_path)
42
+
43
+ @property
44
+ def device(self):
45
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
46
+ return torch.device(device)
47
+
48
+ AutoModel.register(ONNXBaseConfig, ONNXBaseModel)
49
+
50
+ # 2. register Pipeline
51
+ from transformers.pipelines import Pipeline
52
+
53
+ class ONNXBasePipeline(Pipeline):
54
+ def __init__(self, model, **kwargs):
55
+ self.device_id = kwargs['device']
56
+ super().__init__(model=model, **kwargs)
57
+
58
+ def _sanitize_parameters(self, **kwargs):
59
+ return {}, {}, {}
60
+
61
+ def preprocess(self, input):
62
+ return {'input': input}
63
+
64
+ def _forward(self, model_input):
65
+ with torch.no_grad():
66
+ outputs = self.model(**model_input)
67
+ return outputs
68
+
69
+ def postprocess(self, model_outputs):
70
+ return model_outputs
71
+
72
+ PIPELINE_REGISTRY.register_pipeline(
73
+ task='onnx-base',
74
+ pipeline_class=ONNXBasePipeline,
75
+ pt_model=ONNXBaseModel
76
+ )