from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import json import jsonschema class EndpointHandler: def __init__(self, model_dir): self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) self.model.eval() # Esquema de validación del JSON self.json_schema = { "type": "object", "properties": { "values": { "type": "array", "items": { "type": "object", "properties": { "id": {"type": "string"}, "value": {"type": ["string", "array"]} }, "required": ["id", "value"] }, }, }, "required": ["values"], } def preprocess(self, data): # Validar la entrada if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None: raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.") # Construir el prompt con el formato especificado input_text = f""" Por favor, genera un JSON válido basado en las siguientes especificaciones: Formato esperado: {{ "values": [ {{ "id": "firstName", "value": "STRING" }}, {{ "id": "lastName", "value": "STRING" }}, {{ "id": "jobTitle", "value": "STRING" }}, {{ "id": "adress", "value": [ {{ "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]", "value": "STRING" }} ] }}, {{ "id": "email", "value": [ {{ "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]", "value": "STRING" }} ] }}, {{ "id": "phone", "value": [ {{ "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]", "value": "STRING (ONLY NUMBERS)" }} ] }}, {{ "id": "notes", "value": "STRING" }}, {{ "id": "roleFunction", "value": "[BUYER-SELLER-SUPPLIER-PARTNER-COLLABORATOR-PROVIDER-CUSTOMER]" }} ] }} Solo incluye los campos detectados en el texto de entrada. Procesa el siguiente texto: "{data['inputs']}" """ # Tokenizar el texto de entrada tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000) return tokens def inference(self, tokens): # Parámetros de generación generate_kwargs = { "max_length": 1000, "num_beams": 5, "do_sample": True, "temperature": 0.3, "top_k": 50, "top_p": 0.9, "repetition_penalty": 2.5 } # Generar salida con el modelo with torch.no_grad(): outputs = self.model.generate(**tokens, **generate_kwargs) return outputs def validate_json(self, decoded_output): # Validar el JSON generado con el esquema try: json_data = json.loads(decoded_output) jsonschema.validate(instance=json_data, schema=self.json_schema) return {"is_valid": True, "json_data": json_data} except json.JSONDecodeError as e: return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"} except jsonschema.exceptions.ValidationError as e: return {"is_valid": False, "error": f"Error validando JSON: {str(e)}"} def postprocess(self, outputs): # Decodificar la salida generada decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Validar el JSON generado validation_result = self.validate_json(decoded_output) if not validation_result["is_valid"]: raise ValueError(f"JSON inválido: {validation_result['error']}") return {"response": validation_result["json_data"]} def __call__(self, data): tokens = self.preprocess(data) outputs = self.inference(tokens) result = self.postprocess(outputs) return result