File size: 5,007 Bytes
0d16e19 c5a16f8 0d16e19 a87db35 0d16e19 c5a16f8 0d16e19 333ea63 05da434 c5a16f8 333ea63 4722f73 0d16e19 333ea63 05da434 333ea63 0d16e19 05da434 0d16e19 c5a16f8 0d16e19 333ea63 0d16e19 c5a16f8 0d16e19 a87db35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json
import jsonschema
class EndpointHandler:
def __init__(self, model_dir):
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
self.model.eval()
# Esquema de validación del JSON
self.json_schema = {
"type": "object",
"properties": {
"values": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"value": {"type": ["string", "array"]}
},
"required": ["id", "value"]
},
},
},
"required": ["values"],
}
def preprocess(self, data):
# Validar la entrada
if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")
# Construir el prompt con el formato especificado
input_text = f"""
Por favor, genera un JSON válido basado en las siguientes especificaciones:
Formato esperado:
{{
"values": [
{{
"id": "firstName",
"value": "STRING"
}},
{{
"id": "lastName",
"value": "STRING"
}},
{{
"id": "jobTitle",
"value": "STRING"
}},
{{
"id": "adress",
"value": [
{{
"id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
"value": "STRING"
}}
]
}},
{{
"id": "email",
"value": [
{{
"id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
"value": "STRING"
}}
]
}},
{{
"id": "phone",
"value": [
{{
"id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
"value": "STRING (ONLY NUMBERS)"
}}
]
}},
{{
"id": "notes",
"value": "STRING"
}},
{{
"id": "roleFunction",
"value": "[BUYER-SELLER-SUPPLIER-PARTNER-COLLABORATOR-PROVIDER-CUSTOMER]"
}}
]
}}
Solo incluye los campos detectados en el texto de entrada.
Procesa el siguiente texto: "{data['inputs']}"
"""
# Tokenizar el texto de entrada
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
return tokens
def inference(self, tokens):
# Parámetros de generación
generate_kwargs = {
"max_length": 1000,
"num_beams": 5,
"do_sample": True,
"temperature": 0.3,
"top_k": 50,
"top_p": 0.9,
"repetition_penalty": 2.5
}
# Generar salida con el modelo
with torch.no_grad():
outputs = self.model.generate(**tokens, **generate_kwargs)
return outputs
def validate_json(self, decoded_output):
# Validar el JSON generado con el esquema
try:
json_data = json.loads(decoded_output)
jsonschema.validate(instance=json_data, schema=self.json_schema)
return {"is_valid": True, "json_data": json_data}
except json.JSONDecodeError as e:
return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"}
except jsonschema.exceptions.ValidationError as e:
return {"is_valid": False, "error": f"Error validando JSON: {str(e)}"}
def postprocess(self, outputs):
# Decodificar la salida generada
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Validar el JSON generado
validation_result = self.validate_json(decoded_output)
if not validation_result["is_valid"]:
raise ValueError(f"JSON inválido: {validation_result['error']}")
return {"response": validation_result["json_data"]}
def __call__(self, data):
tokens = self.preprocess(data)
outputs = self.inference(tokens)
result = self.postprocess(outputs)
return result
|