|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import json |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
|
self.model.eval() |
|
|
|
def preprocess(self, data): |
|
if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None: |
|
raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.") |
|
|
|
|
|
input_text = ( |
|
f""" |
|
Por favor, genera un JSON válido basado en las siguientes especificaciones: |
|
|
|
Formato esperado: |
|
{{ |
|
"values": [ |
|
{{ |
|
"id": "firstName", |
|
"value": "STRING" |
|
}}, |
|
{{ |
|
"id": "lastName", |
|
"value": "STRING" |
|
}}, |
|
{{ |
|
"id": "jobTitle", |
|
"value": "STRING" |
|
}}, |
|
{{ |
|
"id": "adress", |
|
"value": [ |
|
{{ |
|
"id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]", |
|
"value": "STRING" |
|
}} |
|
] |
|
}}, |
|
{{ |
|
"id": "email", |
|
"value": [ |
|
{{ |
|
"id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]", |
|
"value": "STRING" |
|
}} |
|
] |
|
}}, |
|
{{ |
|
"id": "phone", |
|
"value": [ |
|
{{ |
|
"id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]", |
|
"value": "STRING (ONLY NUMBERS)" |
|
}} |
|
] |
|
}}, |
|
{{ |
|
"id": "notes", |
|
"value": "STRING" |
|
}}, |
|
{{ |
|
"id": "roleFunction", |
|
"value": "[BUYER-SELLER-SUPPLIER-PARTNER-COLLABORATOR-PROVIDER-CUSTOMER]" |
|
}} |
|
] |
|
}} |
|
Ejemplo de salida: |
|
Para el texto de entrada: "Hablé con Ana López, CEO de Innovatech. Su número es 654 321 987 y su correo es [email protected]." |
|
La salida sería: |
|
{{ |
|
"values": [ |
|
{{"id": "firstName", "value": "Ana"}}, |
|
{{"id": "lastName", "value": "López"}}, |
|
{{"id": "jobTitle", "value": "CEO"}}, |
|
{{"id": "phone", "value": [{{"id": "MOBILE", "value": "654321987"}}]}}, |
|
{{"id": "email", "value": [{{"id": "WORK", "value": "[email protected]"}}]}} |
|
] |
|
}} |
|
Solo incluye los campos detectados en el texto de entrada. |
|
Procesa el siguiente texto: "{data['inputs']}" |
|
""") |
|
|
|
print(f"Prompt generado para el modelo: {input_text}") |
|
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000) |
|
return tokens |
|
|
|
def inference(self, tokens): |
|
generate_kwargs = { |
|
"max_length": 1000, |
|
"num_beams": 5, |
|
"do_sample": True, |
|
"temperature": 0.3, |
|
"top_k": 50, |
|
"top_p": 0.9, |
|
"repetition_penalty": 2.5 |
|
} |
|
with torch.no_grad(): |
|
outputs = self.model.generate(**tokens, **generate_kwargs) |
|
return outputs |
|
|
|
def clean_output(self, output): |
|
|
|
try: |
|
start_index = output.index("{") |
|
end_index = output.rindex("}") + 1 |
|
return output[start_index:end_index] |
|
except ValueError: |
|
|
|
return output |
|
|
|
def validate_json(self, json_text): |
|
|
|
try: |
|
json_data = json.loads(json_text) |
|
if "values" in json_data and isinstance(json_data["values"], list): |
|
return {"is_valid": True, "json_data": json_data} |
|
else: |
|
return {"is_valid": False, "error": "El JSON no contiene el formato esperado."} |
|
except json.JSONDecodeError as e: |
|
return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"} |
|
|
|
def postprocess(self, outputs): |
|
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
cleaned_output = self.clean_output(decoded_output) |
|
|
|
|
|
print(f"Texto generado: {decoded_output}") |
|
print(f"JSON limpiado: {cleaned_output}") |
|
|
|
|
|
validation_result = self.validate_json(cleaned_output) |
|
|
|
if not validation_result["is_valid"]: |
|
print(f"Error en la validación: {validation_result['error']}") |
|
raise ValueError(f"JSON inválido: {validation_result['error']}") |
|
|
|
return {"response": validation_result["json_data"]} |
|
|
|
def __call__(self, data): |
|
tokens = self.preprocess(data) |
|
outputs = self.inference(tokens) |
|
result = self.postprocess(outputs) |
|
return result |
|
|