|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import json |
|
import jsonschema |
|
import re |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
|
self.model.eval() |
|
|
|
|
|
self.json_schema = { |
|
"type": "object", |
|
"properties": { |
|
"values": { |
|
"type": "array", |
|
"items": { |
|
"type": "object", |
|
"properties": { |
|
"id": {"type": "string"}, |
|
"value": {"type": ["string", "array"]} |
|
}, |
|
"required": ["id", "value"] |
|
}, |
|
}, |
|
}, |
|
"required": ["values"], |
|
} |
|
|
|
def preprocess(self, data): |
|
if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None: |
|
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.") |
|
|
|
input_text = f""" |
|
Por favor, genera un JSON v谩lido basado en las siguientes especificaciones: |
|
... (Especificaciones del formato JSON omitidas por brevedad) |
|
Procesa el siguiente texto: "{data['inputs']}" |
|
""" |
|
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000) |
|
return tokens |
|
|
|
def inference(self, tokens): |
|
generate_kwargs = { |
|
"max_length": 1000, |
|
"num_beams": 5, |
|
"do_sample": True, |
|
"temperature": 0.3, |
|
"top_k": 50, |
|
"top_p": 0.9, |
|
"repetition_penalty": 2.5 |
|
} |
|
with torch.no_grad(): |
|
outputs = self.model.generate(**tokens, **generate_kwargs) |
|
return outputs |
|
|
|
def clean_output(self, output): |
|
json_match = re.search(r"{.*}", output, re.DOTALL) |
|
if json_match: |
|
return json_match.group(0) |
|
return output |
|
|
|
def validate_json(self, decoded_output): |
|
cleaned_output = self.clean_output(decoded_output) |
|
try: |
|
json_data = json.loads(cleaned_output) |
|
jsonschema.validate(instance=json_data, schema=self.json_schema) |
|
return {"is_valid": True, "json_data": json_data} |
|
except json.JSONDecodeError as e: |
|
return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}", "raw_output": cleaned_output} |
|
except jsonschema.exceptions.ValidationError as e: |
|
return {"is_valid": False, "error": f"Error validando JSON: {str(e)}", "raw_output": cleaned_output} |
|
|
|
def postprocess(self, outputs): |
|
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
validation_result = self.validate_json(decoded_output) |
|
|
|
|
|
print(f"Texto generado: {decoded_output}") |
|
|
|
if not validation_result["is_valid"]: |
|
print(f"Error en la validaci贸n: {validation_result['error']}") |
|
print(f"Salida sin procesar: {validation_result.get('raw_output', 'No disponible')}") |
|
raise ValueError(f"JSON inv谩lido: {validation_result['error']}") |
|
|
|
return {"response": validation_result["json_data"]} |
|
|
|
def __call__(self, data): |
|
tokens = self.preprocess(data) |
|
outputs = self.inference(tokens) |
|
result = self.postprocess(outputs) |
|
return result |
|
|