from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline import torch def load_model(model_id): tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16, load_in_4bit=True ) return model, tokenizer class EndpointHandler: def __init__(self, path=""): self.model, self.tokenizer = load_model(path) self.pipeline = TextGenerationPipeline( model=self.model, tokenizer=self.tokenizer ) def __call__(self, data): # Extraire le texte d'entrée if isinstance(data, dict): text = data.get("inputs", "") else: text = data # Paramètres de génération par défaut generation_kwargs = { "max_new_tokens": 512, "temperature": 0.7, "top_p": 0.95, "repetition_penalty": 1.15, "do_sample": True, "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, } # Mettre à jour avec les paramètres de la requête si fournis if isinstance(data, dict) and "parameters" in data: generation_kwargs.update(data["parameters"]) try: # Générer la réponse outputs = self.pipeline( text, **generation_kwargs ) # Formater la sortie en tableau comme requis par l'API if isinstance(outputs, list): return [{"generated_text": output["generated_text"]} for output in outputs] return [{"generated_text": outputs["generated_text"]}] except Exception as e: return [{"error": str(e)}]