from typing import Dict, List, Any | |
import torch | |
from transformers import pipeline | |
class EndpointHandler: | |
def __init__(self, path=""): | |
self.pipeline = pipeline( | |
task="text-generation", | |
model="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
device_map='auto', | |
#trust_remote_code=True, | |
model_kwargs={ | |
"load_in_4bit": True | |
}, | |
# batch_size=1, | |
) | |
# model.generation_config | |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
""" | |
data args: | |
inputs (:obj: `str`) | |
parameters (:obj: `Dict`) | |
Return: | |
A :obj:`list` | `dict`: will be serialized and returned | |
""" | |
inputs = data.pop("inputs", "") | |
params = data.pop("parameters", ()) | |
if not params: | |
params = dict() | |
# run normal prediction | |
generation = self.pipeline(inputs, **params) | |
# **generate_kwargs https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation, | |
# https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation | |
return generation | |
# Returns | |
# A list or a list of list of dict | |
# Returns one of the following dictionaries (cannot return a combination of both generated_text and generated_token_ids): | |
# generated_text (str, present when return_text=True) — The generated text. | |
# generated_token_ids (torch.Tensor or tf.Tensor, present when return_tensors=True) — The token ids of the generated text. |