File size: 3,641 Bytes
0d16e19 c5a16f8 8708518 0d16e19 a87db35 0d16e19 c5a16f8 0d16e19 05da434 c5a16f8 8708518 c5a16f8 4722f73 0d16e19 05da434 0d16e19 05da434 0d16e19 8708518 c5a16f8 8708518 c5a16f8 8708518 c5a16f8 8708518 c5a16f8 8708518 c5a16f8 0d16e19 c5a16f8 8708518 c5a16f8 8708518 c5a16f8 0d16e19 a87db35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json
import jsonschema
import re
class EndpointHandler:
def __init__(self, model_dir):
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
self.model.eval()
# Esquema de validación del JSON
self.json_schema = {
"type": "object",
"properties": {
"values": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"value": {"type": ["string", "array"]}
},
"required": ["id", "value"]
},
},
},
"required": ["values"],
}
def preprocess(self, data):
if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")
input_text = f"""
Por favor, genera un JSON válido basado en las siguientes especificaciones:
... (Especificaciones del formato JSON omitidas por brevedad)
Procesa el siguiente texto: "{data['inputs']}"
"""
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
return tokens
def inference(self, tokens):
generate_kwargs = {
"max_length": 1000,
"num_beams": 5,
"do_sample": True,
"temperature": 0.3,
"top_k": 50,
"top_p": 0.9,
"repetition_penalty": 2.5
}
with torch.no_grad():
outputs = self.model.generate(**tokens, **generate_kwargs)
return outputs
def clean_output(self, output):
json_match = re.search(r"{.*}", output, re.DOTALL)
if json_match:
return json_match.group(0)
return output
def validate_json(self, decoded_output):
cleaned_output = self.clean_output(decoded_output)
try:
json_data = json.loads(cleaned_output)
jsonschema.validate(instance=json_data, schema=self.json_schema)
return {"is_valid": True, "json_data": json_data}
except json.JSONDecodeError as e:
return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}", "raw_output": cleaned_output}
except jsonschema.exceptions.ValidationError as e:
return {"is_valid": False, "error": f"Error validando JSON: {str(e)}", "raw_output": cleaned_output}
def postprocess(self, outputs):
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
validation_result = self.validate_json(decoded_output)
# Siempre imprimir la salida generada
print(f"Texto generado: {decoded_output}")
if not validation_result["is_valid"]:
print(f"Error en la validación: {validation_result['error']}")
print(f"Salida sin procesar: {validation_result.get('raw_output', 'No disponible')}")
raise ValueError(f"JSON inválido: {validation_result['error']}")
return {"response": validation_result["json_data"]}
def __call__(self, data):
tokens = self.preprocess(data)
outputs = self.inference(tokens)
result = self.postprocess(outputs)
return result
|