from typing import Dict, List, Any import torch from transformers import pipeline class EndpointHandler: def __init__(self, path=""): self.pipeline = pipeline( "text-generation", model="mistralai/Mixtral-8x7B-Instruct-v0.1", device_map='auto', #trust_remote_code=True, model_kwargs={ "load_in_4bit": True }, # batch_size=1, ) # model.generation_config def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) parameters (:obj: `Dict`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs inputs = data.pop("inputs", "") # get additional date field params = data.pop("parameters", ()) if not params: params = dict() # run normal prediction generation = self.pipeline(inputs, **params) # **generate_kwargs https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation, # https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation return generation # Returns # A list or a list of list of dict # Returns one of the following dictionaries (cannot return a combination of both generated_text and generated_token_ids): # generated_text (str, present when return_text=True) — The generated text. # generated_token_ids (torch.Tensor or tf.Tensor, present when return_tensors=True) — The token ids of the generated text.