from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import json class EndpointHandler: def __init__(self, model_dir): self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) self.model.eval() def preprocess(self, data): # Validar entrada if not data or not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None: raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido") input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"] if not input_text.strip(): raise ValueError("El texto de entrada no puede estar vacío") # Tokenización de la entrada tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) return tokens def inference(self, tokens): with torch.no_grad(): outputs = self.model.generate(**tokens) return outputs def postprocess(self, outputs): decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": decoded_output} def __call__(self, data): tokens = self.preprocess(data) outputs = self.inference(tokens) result = self.postprocess(outputs) return result