onnx-base / src /demo.py
m3's picture
feat: add onnx model
a5e4e8f
raw
history blame
3.99 kB
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel, modeling_utils
from transformers.pipelines import PIPELINE_REGISTRY
from huggingface_hub import hf_hub_download
import onnxruntime as ort
import torch
import os
import torch.nn as nn
# 1. register AutoConfig
class ONNXBaseConfig(PretrainedConfig):
model_type = 'onnx-base'
AutoConfig.register('onnx-base', ONNXBaseConfig)
# 2. register AutoModel
class ONNXBaseModel(PreTrainedModel):
config_class = ONNXBaseConfig
def __init__(self, config, base_path=None):
super().__init__(config)
if base_path:
model_path = base_path + '/' + config.model_path
if os.path.exists(model_path):
self.session = ort.InferenceSession(model_path)
def forward(self, input=None, **kwargs):
outs = self.session.run(None, {'input': input})
return outs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
base_path = pretrained_model_name_or_path
else:
config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename='config.json')
base_path = os.path.dirname(config_path)
hf_hub_download(repo_id=pretrained_model_name_or_path, filename=config.model_path)
return cls(config, base_path=base_path)
@property
def device(self):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
return torch.device(device)
AutoModel.register(ONNXBaseConfig, ONNXBaseModel)
# option: save config to path
local_model_path = './custom_model'
config = ONNXBaseConfig(model_path='model.onnx',
id2label={0: 'label_0', 1: 'label_1'},
label2id={0: 'label_1', 1: 'label_0'})
model = ONNXBaseModel(config, base_path='./custom_mode')
config.save_pretrained(local_model_path)
# make sure have model_type
import json
config_path = local_model_path + '/config.json'
with open(config_path, 'r') as f:
config_data = json.load(f)
config_data['model_type'] = 'onnx-base'
del config_data['transformers_version']
with open(config_path, 'w') as f:
json.dump(config_data, f, indent=2)
# save onnx
dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
onnx_file_path = './custom_model' + '/' + 'model.onnx'
class ZeroModel(nn.Module):
def __init__(self):
super(ZeroModel, self).__init__()
def forward(self, x):
return torch.zeros_like(x)
zero_model = ZeroModel()
torch.onnx.export(zero_model, dummy_input, onnx_file_path,
input_names=['input'], output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
# 2. register Pipeline
from transformers.pipelines import Pipeline
class ONNXBasePipeline(Pipeline):
def __init__(self, model, **kwargs):
self.device_id = kwargs['device']
super().__init__(model=model, **kwargs)
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, input):
return {'input': input}
def _forward(self, model_input):
with torch.no_grad():
outputs = self.model(**model_input)
return outputs
def postprocess(self, model_outputs):
return model_outputs
PIPELINE_REGISTRY.register_pipeline(
task='onnx-base',
pipeline_class=ONNXBasePipeline,
pt_model=ONNXBaseModel
)
# 4. show how to use
from transformers import pipeline
pipe = pipeline(
task='onnx-base',
model='m3/onnx-base',
batch_size=10,
device='cuda',
)
dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32)
input_data = dummy_input.numpy()
result = pipe(
inputs=input_data, device='cuda',
)
print(result)