m3 commited on
Commit
5d4550d
·
1 Parent(s): 2cf78fb

chore: add architectures for config

Browse files
Files changed (3) hide show
  1. config.json +11 -2
  2. src/demo.py +58 -0
  3. src/init_model.py +1 -1
config.json CHANGED
@@ -1,4 +1,13 @@
1
  {
2
  "model_type": "onnx-base",
3
- "model_name_or_path": "model.onnx"
4
- }
 
 
 
 
 
 
 
 
 
 
1
  {
2
  "model_type": "onnx-base",
3
+ "model_path": "model.onnx",
4
+ "architectures": ["ONNXBaseModel"],
5
+ "id2label": {
6
+ "0": "LABEL_0",
7
+ "1": "LABEL_1"
8
+ },
9
+ "label2id": {
10
+ "LABEL_0": 0,
11
+ "LABEL_1": 1
12
+ }
13
+ }
src/demo.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. 首先,你需要定义一个 ONNX 模型配置类,并注册它
2
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, AutoModel
3
+ from transformers.pipelines import PIPELINE_REGISTRY
4
+
5
+ class ONNXBaseConfig(PretrainedConfig):
6
+ model_type = "onnx-base"
7
+
8
+ # 注册配置类
9
+ AutoConfig.register("onnx-base", ONNXBaseConfig)
10
+
11
+ # 注册模型类
12
+ class ONNXBaseModel(AutoModel):
13
+ config_class = ONNXBaseConfig
14
+
15
+ class ONNXBaseModel(PreTrainedModel):
16
+ config_class = ONNXBaseConfig
17
+
18
+ def __init__(self, config):
19
+ super().__init__(config)
20
+
21
+ def forward(self, *args, **kwargs):
22
+ return self.dummy_param
23
+
24
+
25
+ AutoModel.register(ONNXBaseConfig, ONNXBaseModel)
26
+
27
+ from transformers.pipelines import Pipeline
28
+
29
+ class ONNXBasePipeline(Pipeline):
30
+ def __init__(self, model, **kwargs):
31
+ super().__init__(model=model, **kwargs)
32
+
33
+ def _sanitize_parameters(self, **kwargs):
34
+ return {}, {}, {}
35
+
36
+ def preprocess(self, inputs):
37
+ return inputs
38
+
39
+ def _forward(self, model_inputs):
40
+ return self.model(**model_inputs)
41
+
42
+ def postprocess(self, model_outputs):
43
+ return model_outputs
44
+
45
+ PIPELINE_REGISTRY.register_pipeline(
46
+ task="onnx-base",
47
+ pipeline_class=ONNXBasePipeline
48
+ )
49
+
50
+
51
+ from transformers import pipeline
52
+
53
+ # 使用自定义的 pipeline 任务
54
+ onnx_pipeline = pipeline(task="onnx-base", model="m3/onnx-base")
55
+
56
+ # 使用 pipeline
57
+ result = onnx_pipeline("Your input data here")
58
+ print(result)
src/init_model.py CHANGED
@@ -6,7 +6,7 @@ import torch.onnx
6
  class SimpleModel(nn.Module):
7
  def __init__(self):
8
  super(SimpleModel, self).__init__()
9
- self.fc = nn.Linear(10, 1)
10
 
11
  def forward(self, x):
12
  return self.fc(x)
 
6
  class SimpleModel(nn.Module):
7
  def __init__(self):
8
  super(SimpleModel, self).__init__()
9
+ self.fc = nn.Linear(1, 1)
10
 
11
  def forward(self, x):
12
  return self.fc(x)