from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel from transformers.pipelines import PIPELINE_REGISTRY from huggingface_hub import hf_hub_download import torch import os # 1. create auto config class DummyConfig(PretrainedConfig): model_type = 'dummy' # 2. create model class DummyModel(PreTrainedModel): config_class = DummyConfig def __init__(self, model: str, model_path: str): is_local = os.path.isdir(model) if is_local: base_path = model model_path = os.path.join(base_path, model_path) else: model_path = hf_hub_download(repo_id=model, filename=model_path) base_path = os.path.dirname(model_path) config = DummyConfig(base_path=base_path, model_path=model_path) super().__init__(config) def forward(self, input=None, **kwargs): return {} @property def device(self): device = 'cuda' if torch.cuda.is_available() else 'cpu' return torch.device(device) # 2. register Pipeline from transformers.pipelines import Pipeline class DummyPipeline(Pipeline): def __init__(self, model, **kwargs): super().__init__(model=model, **kwargs) self.device_id = kwargs['device'] self.model_path = self.model.config.model_path self.base_path = self.model.config.base_path def __call__( self, inputs: str, **kwargs, ): inputs = {"inputs": inputs} return super().__call__(inputs, **kwargs) def _sanitize_parameters(self, **kwargs): return {}, {}, {} def preprocess(self, input): return {'input': input} def _forward(self, model_input): return {'data': 'dummy', 'device_id': self.device_id, 'base_path': self.base_path, 'model_path': self.model_path } def postprocess(self, model_outputs): return model_outputs PIPELINE_REGISTRY.register_pipeline( task='dummy-task', pipeline_class=DummyPipeline, pt_model=DummyModel, ) # 4. show how to use from transformers import pipeline pipe = pipeline( model=DummyModel("m3/onnx-base", 'model.onnx'), task='dummy-task', batch_size=10, device='cuda', ) dummy_input = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32) input_data = dummy_input.numpy() result = pipe( inputs=input_data, device='cuda', ) print(result)