from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import json token = os.getenv("HUGGINGFACE_TOKEN") model_name = "jla25/squareV3" tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token) model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=token) class EndpointHandler: def __init__(self, model_dir): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) self.model.eval() def preprocess(self, data): if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None: raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.") # Prompt personalizado para guiar al modelo input_text = ({data['inputs']}) # Imprimir el texto generado para el prompt print(f"Prompt generado para el modelo: {input_text}") tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024) return tokens def inference(self, tokens): generate_kwargs = { "max_length": 1024, "num_beams": 5, "do_sample": False, "temperature": 0.3, "top_k": 50, "top_p": 0.7, "repetition_penalty": 2.5 } with torch.no_grad(): outputs = self.model.generate(**tokens, **generate_kwargs) return outputs def clean_output(self, output): # Extraer el JSON dentro del texto generado try: start_index = output.index("{") end_index = output.rindex("}") + 1 return output[start_index:end_index] except ValueError: # Si no hay un JSON válido en el texto return output def validate_json(self, json_text): # Validar el JSON generado try: json_data = json.loads(json_text) if "values" in json_data and isinstance(json_data["values"], list): return {"is_valid": True, "json_data": json_data} else: return {"is_valid": False, "error": "El JSON no contiene el formato esperado."} except json.JSONDecodeError as e: return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"} def postprocess(self, outputs): decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) cleaned_output = self.clean_output(decoded_output) # Imprimir siempre el texto generado para depuración print(f"Texto generado: {decoded_output}") print(f"JSON limpiado: {cleaned_output}") # Validar el JSON generado validation_result = self.validate_json(cleaned_output) if not validation_result["is_valid"]: print(f"Error en la validación: {validation_result['error']}") raise ValueError(f"JSON inválido: {validation_result['error']}") return {"response": validation_result["json_data"]} def __call__(self, data): tokens = self.preprocess(data) outputs = self.inference(tokens) result = self.postprocess(outputs) return result