File size: 5,514 Bytes
0d16e19
 
c5a16f8
0d16e19
 
 
 
 
a87db35
0d16e19
 
05da434
 
c5a16f8
d584a44
50adeb4
463dbb5
63701ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463dbb5
 
4722f73
0d16e19
 
 
05da434
 
 
 
 
 
 
 
 
0d16e19
05da434
0d16e19
 
8708518
d584a44
 
 
 
 
 
 
 
8708518
d584a44
 
c5a16f8
d584a44
 
 
 
 
c5a16f8
d584a44
c5a16f8
0d16e19
 
d584a44
 
 
8708518
d584a44
 
 
 
8708518
c5a16f8
8708518
c5a16f8
d584a44
c5a16f8
0d16e19
 
 
 
 
a87db35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json

class EndpointHandler:
    def __init__(self, model_dir):
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
        self.model.eval()

    def preprocess(self, data):
        if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
            raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")

        # Prompt personalizado para guiar al modelo
        input_text = (
            f"""
            Por favor, genera un JSON v谩lido basado en las siguientes especificaciones:

        Formato esperado:
        {{
            "values": [
                {{
                    "id": "firstName",
                    "value": "STRING"
                }},
                {{
                    "id": "lastName",
                    "value": "STRING"
                }},
                {{
                    "id": "jobTitle",
                    "value": "STRING"
                }},
                {{
                    "id": "adress",
                    "value": [
                        {{
                            "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
                            "value": "STRING"
                        }}
                    ]
                }},
                {{
                    "id": "email",
                    "value": [
                        {{
                            "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
                            "value": "STRING"
                        }}
                    ]
                }},
                {{
                    "id": "phone",
                    "value": [
                        {{
                            "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
                            "value": "STRING (ONLY NUMBERS)"
                        }}
                    ]
                }},
                {{
                    "id": "notes",
                    "value": "STRING"
                }},
                {{
                    "id": "roleFunction",
                    "value": "[BUYER-SELLER-SUPPLIER-PARTNER-COLLABORATOR-PROVIDER-CUSTOMER]"
                }}
            ]
        }}
        Ejemplo de salida:
        Para el texto de entrada: "Habl茅 con Ana L贸pez, CEO de Innovatech. Su n煤mero es 654 321 987 y su correo es [email protected]."
        La salida ser铆a:
        {{
            "values": [
                {{"id": "firstName", "value": "Ana"}},
                {{"id": "lastName", "value": "L贸pez"}},
                {{"id": "jobTitle", "value": "CEO"}},
                {{"id": "phone", "value": [{{"id": "MOBILE", "value": "654321987"}}]}},
                {{"id": "email", "value": [{{"id": "WORK", "value": "[email protected]"}}]}}
            ]
        }}
        Solo incluye los campos detectados en el texto de entrada.
        Procesa el siguiente texto: "{data['inputs']}"
            """)
        # Imprimir el texto generado para el prompt
        print(f"Prompt generado para el modelo: {input_text}")
        tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
        return tokens

    def inference(self, tokens):
        generate_kwargs = {
            "max_length": 1000,
            "num_beams": 5,
            "do_sample": True,
            "temperature": 0.3,
            "top_k": 50,
            "top_p": 0.9,
            "repetition_penalty": 2.5
        }
        with torch.no_grad():
            outputs = self.model.generate(**tokens, **generate_kwargs)
        return outputs

    def clean_output(self, output):
        # Extraer el JSON dentro del texto generado
        try:
            start_index = output.index("{")
            end_index = output.rindex("}") + 1
            return output[start_index:end_index]
        except ValueError:
            # Si no hay un JSON v谩lido en el texto
            return output

    def validate_json(self, json_text):
        # Validar el JSON generado
        try:
            json_data = json.loads(json_text)
            if "values" in json_data and isinstance(json_data["values"], list):
                return {"is_valid": True, "json_data": json_data}
            else:
                return {"is_valid": False, "error": "El JSON no contiene el formato esperado."}
        except json.JSONDecodeError as e:
            return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"}

    def postprocess(self, outputs):
        decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        cleaned_output = self.clean_output(decoded_output)

        # Imprimir siempre el texto generado para depuraci贸n
        print(f"Texto generado: {decoded_output}")
        print(f"JSON limpiado: {cleaned_output}")

        # Validar el JSON generado
        validation_result = self.validate_json(cleaned_output)

        if not validation_result["is_valid"]:
            print(f"Error en la validaci贸n: {validation_result['error']}")
            raise ValueError(f"JSON inv谩lido: {validation_result['error']}")

        return {"response": validation_result["json_data"]}

    def __call__(self, data):
        tokens = self.preprocess(data)
        outputs = self.inference(tokens)
        result = self.postprocess(outputs)
        return result