from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import json model_name = "jla25/squareV3" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) class EndpointHandler: def __init__(self, model_dir): self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) self.model.eval() def preprocess(self, data): if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None: raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.") # Prompt personalizado para guiar al modelo input_text = f"Generate a valid JSON capturing data from this text:{data['inputs']}" print(f"Prompt generado para el modelo: {input_text}") input_text = input_text.encode("utf-8").decode("utf-8") tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024) return tokens def inference(self, tokens): generate_kwargs = { "max_length": 512, "num_beams": 5, "do_sample": False, "temperature": 0.3, "top_k": 50, "top_p": 0.8, "early_stopping": True, # Añadir explicitamente esta configuración "repetition_penalty": 2.5 } with torch.no_grad(): outputs = self.model.generate(**tokens, **generate_kwargs) return outputs def clean_output(self, output): try: start_index = output.index("{") end_index = output.rindex("}") + 1 return output[start_index:end_index] except ValueError: return output def postprocess(self, outputs): decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) cleaned_output = self.clean_output(decoded_output) # Imprimir siempre el texto generado para depuración print(f"Texto generado por el modelo: {decoded_output}") print(f"JSON limpiado: {cleaned_output}") return {"response": cleaned_output} def __call__(self, data): tokens = self.preprocess(data) outputs = self.inference(tokens) result = self.postprocess(outputs) return result # Crear una instancia del handler handler = EndpointHandler(model_name)