from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel from transformers.pipelines import PIPELINE_REGISTRY from huggingface_hub import hf_hub_download import onnxruntime as ort import torch import os # 1. register AutoConfig class ONNXBaseConfig(PretrainedConfig): model_type = 'onnx-base' AutoConfig.register('onnx-base', ONNXBaseConfig) # 2. register AutoModel class ONNXBaseModel(PreTrainedModel): config_class = ONNXBaseConfig def __init__(self, config, base_path=None): super().__init__(config) if base_path: model_path = base_path + '/' + config.model_path if os.path.exists(model_path): self.session = ort.InferenceSession(model_path) def forward(self, input=None, **kwargs): outs = self.session.run(None, {'input': input}) return outs @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) if config.model_path is None: config.model_path = 'model.onnx' is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: base_path = pretrained_model_name_or_path else: config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename='config.json') base_path = os.path.dirname(config_path) hf_hub_download(repo_id=pretrained_model_name_or_path, filename=config.model_path) return cls(config, base_path=base_path) @property def device(self): device = 'cuda' if torch.cuda.is_available() else 'cpu' return torch.device(device) AutoModel.register(ONNXBaseConfig, ONNXBaseModel) # 2. register Pipeline from transformers.pipelines import Pipeline class ONNXBasePipeline(Pipeline): def __init__(self, model, **kwargs): self.device_id = kwargs['device'] super().__init__(model=model, **kwargs) def _sanitize_parameters(self, **kwargs): return {}, {}, {} def preprocess(self, input): return {'input': input} def _forward(self, model_input): with torch.no_grad(): outputs = self.model(**model_input) return outputs def postprocess(self, model_outputs): return model_outputs PIPELINE_REGISTRY.register_pipeline( task='onnx-base', pipeline_class=ONNXBasePipeline, pt_model=ONNXBaseModel )