modele-test / handler.py
Aktraiser's picture
Update handler.py
37b22c3 verified
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
import torch
def load_model(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
load_in_4bit=True
)
return model, tokenizer
class EndpointHandler:
def __init__(self, path=""):
self.model, self.tokenizer = load_model(path)
self.pipeline = TextGenerationPipeline(
model=self.model,
tokenizer=self.tokenizer
)
def __call__(self, data):
# Extraire le texte d'entrée
if isinstance(data, dict):
text = data.get("inputs", "")
else:
text = data
# Paramètres de génération par défaut
generation_kwargs = {
"max_new_tokens": 512,
"temperature": 0.7,
"top_p": 0.95,
"repetition_penalty": 1.15,
"do_sample": True,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Mettre à jour avec les paramètres de la requête si fournis
if isinstance(data, dict) and "parameters" in data:
generation_kwargs.update(data["parameters"])
try:
# Générer la réponse
outputs = self.pipeline(
text,
**generation_kwargs
)
# Formater la sortie en tableau comme requis par l'API
if isinstance(outputs, list):
return [{"generated_text": output["generated_text"]} for output in outputs]
return [{"generated_text": outputs["generated_text"]}]
except Exception as e:
return [{"error": str(e)}]