File size: 5,514 Bytes
0d16e19 c5a16f8 0d16e19 a87db35 0d16e19 05da434 c5a16f8 d584a44 50adeb4 463dbb5 63701ed 463dbb5 4722f73 0d16e19 05da434 0d16e19 05da434 0d16e19 8708518 d584a44 8708518 d584a44 c5a16f8 d584a44 c5a16f8 d584a44 c5a16f8 0d16e19 d584a44 8708518 d584a44 8708518 c5a16f8 8708518 c5a16f8 d584a44 c5a16f8 0d16e19 a87db35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json
class EndpointHandler:
def __init__(self, model_dir):
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
self.model.eval()
def preprocess(self, data):
if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
# Prompt personalizado para guiar al modelo
input_text = (
f"""
Por favor, genera un JSON v谩lido basado en las siguientes especificaciones:
Formato esperado:
{{
"values": [
{{
"id": "firstName",
"value": "STRING"
}},
{{
"id": "lastName",
"value": "STRING"
}},
{{
"id": "jobTitle",
"value": "STRING"
}},
{{
"id": "adress",
"value": [
{{
"id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
"value": "STRING"
}}
]
}},
{{
"id": "email",
"value": [
{{
"id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
"value": "STRING"
}}
]
}},
{{
"id": "phone",
"value": [
{{
"id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
"value": "STRING (ONLY NUMBERS)"
}}
]
}},
{{
"id": "notes",
"value": "STRING"
}},
{{
"id": "roleFunction",
"value": "[BUYER-SELLER-SUPPLIER-PARTNER-COLLABORATOR-PROVIDER-CUSTOMER]"
}}
]
}}
Ejemplo de salida:
Para el texto de entrada: "Habl茅 con Ana L贸pez, CEO de Innovatech. Su n煤mero es 654 321 987 y su correo es [email protected]."
La salida ser铆a:
{{
"values": [
{{"id": "firstName", "value": "Ana"}},
{{"id": "lastName", "value": "L贸pez"}},
{{"id": "jobTitle", "value": "CEO"}},
{{"id": "phone", "value": [{{"id": "MOBILE", "value": "654321987"}}]}},
{{"id": "email", "value": [{{"id": "WORK", "value": "[email protected]"}}]}}
]
}}
Solo incluye los campos detectados en el texto de entrada.
Procesa el siguiente texto: "{data['inputs']}"
""")
# Imprimir el texto generado para el prompt
print(f"Prompt generado para el modelo: {input_text}")
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
return tokens
def inference(self, tokens):
generate_kwargs = {
"max_length": 1000,
"num_beams": 5,
"do_sample": True,
"temperature": 0.3,
"top_k": 50,
"top_p": 0.9,
"repetition_penalty": 2.5
}
with torch.no_grad():
outputs = self.model.generate(**tokens, **generate_kwargs)
return outputs
def clean_output(self, output):
# Extraer el JSON dentro del texto generado
try:
start_index = output.index("{")
end_index = output.rindex("}") + 1
return output[start_index:end_index]
except ValueError:
# Si no hay un JSON v谩lido en el texto
return output
def validate_json(self, json_text):
# Validar el JSON generado
try:
json_data = json.loads(json_text)
if "values" in json_data and isinstance(json_data["values"], list):
return {"is_valid": True, "json_data": json_data}
else:
return {"is_valid": False, "error": "El JSON no contiene el formato esperado."}
except json.JSONDecodeError as e:
return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"}
def postprocess(self, outputs):
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
cleaned_output = self.clean_output(decoded_output)
# Imprimir siempre el texto generado para depuraci贸n
print(f"Texto generado: {decoded_output}")
print(f"JSON limpiado: {cleaned_output}")
# Validar el JSON generado
validation_result = self.validate_json(cleaned_output)
if not validation_result["is_valid"]:
print(f"Error en la validaci贸n: {validation_result['error']}")
raise ValueError(f"JSON inv谩lido: {validation_result['error']}")
return {"response": validation_result["json_data"]}
def __call__(self, data):
tokens = self.preprocess(data)
outputs = self.inference(tokens)
result = self.postprocess(outputs)
return result
|