XVERSE-13B-256K / ms_wrapper.py
pom
update files
c760666
raw
history blame
3.62 kB
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from modelscope.models.base import TorchModel
from modelscope.preprocessors.base import Preprocessor
from modelscope.pipelines.base import Model, Pipeline
from modelscope.utils.config import Config
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.models.builder import MODELS
from transformers import AutoTokenizer, AutoModelForCausalLM
@MODELS.register_module('text-generation', module_name='XVERSE-13B')
class XVERSE13BTextGeneration(TorchModel):
def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
self.model = self.model.eval()
def forward(self, inputs, **forward_params):
inputs = self.tokenizer(inputs, return_tensors='pt').input_ids
inputs = inputs.cuda()
generated_ids = self.model.generate(inputs, eos_token_id=self.tokenizer.eos_token_id, **forward_params)
return self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
@PIPELINES.register_module('text-generation', module_name='XVERSE-13B-pipeline')
class XVERSE13BTextGenerationPipeline(Pipeline):
""" Give simple introduction to this pipeline.
Examples:
>>> from modelscope.pipelines import pipeline
>>> input = "Hello, ModelScope!"
>>> my_pipeline = pipeline('text-generation', 'xverse/XVERSE-13B')
>>> result = my_pipeline(input)
"""
def __init__(self, model, **kwargs):
"""
use `model` and `preprocessor` to create a custom pipeline for prediction
Args:
model: model id on modelscope hub.
preprocessor: the class of method be init_preprocessor
"""
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or Model'
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError
super().__init__(model=pipe_model, **kwargs)
def _sanitize_parameters(self, **pipeline_parameters):
"""
this method should sanitize the keyword args to preprocessor params,
forward params and postprocess params on '__call__' or '_process_single' method
considered to be a normal classmethod with default implementation / output
Default Returns:
Dict[str, str]: preprocess_params = {}
Dict[str, str]: forward_params = {}
Dict[str, str]: postprocess_params = pipeline_parameters
"""
return {}, pipeline_parameters, {}
def preprocess(self, inputs, **preprocess_params):
return inputs
def forward(self, inputs, **forward_params):
""" Provide default implementation using self.model and user can reimplement it
"""
output = super().forward(inputs, **forward_params)
return {'text': output}
def postprocess(self, inputs):
""" If current pipeline support model reuse, common postprocess
code should be write here.
Args:
inputs: input data
Return:
dict of results: a dict containing outputs of model, each
output should have the standard output name.
"""
return inputs