|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
|
self.model.eval() |
|
|
|
def preprocess(self, data): |
|
if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None: |
|
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.") |
|
|
|
input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"] |
|
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) |
|
if not tokens or not tokens["input_ids"]: |
|
raise ValueError("Error al tokenizar el texto de entrada. Verifica el texto.") |
|
return tokens |
|
|
|
def inference(self, tokens): |
|
generate_kwargs = { |
|
"max_length": 1000, |
|
"num_beams": 5, |
|
"do_sample": True, |
|
"temperature": 0.3, |
|
"top_k": 50, |
|
"top_p": 0.9, |
|
"repetition_penalty": 2.5 |
|
} |
|
with torch.no_grad(): |
|
outputs = self.model.generate(**tokens, **generate_kwargs) |
|
if outputs is None or len(outputs) == 0: |
|
raise ValueError("El modelo no generó ninguna salida.") |
|
return outputs |
|
|
|
def postprocess(self, outputs): |
|
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return {"generated_text": decoded_output} |
|
|
|
def __call__(self, data): |
|
tokens = self.preprocess(data) |
|
outputs = self.inference(tokens) |
|
result = self.postprocess(outputs) |
|
return result |
|
|