File size: 5,007 Bytes
0d16e19
 
c5a16f8
 
0d16e19
 
 
 
 
a87db35
0d16e19
c5a16f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d16e19
333ea63
05da434
 
c5a16f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333ea63
4722f73
0d16e19
 
 
333ea63
05da434
 
 
 
 
 
 
 
 
333ea63
0d16e19
05da434
0d16e19
 
c5a16f8
 
 
 
 
 
 
 
 
 
 
0d16e19
333ea63
0d16e19
c5a16f8
 
 
 
 
 
 
0d16e19
 
 
 
 
a87db35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json
import jsonschema

class EndpointHandler:
    def __init__(self, model_dir):
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
        self.model.eval()

        # Esquema de validación del JSON
        self.json_schema = {
            "type": "object",
            "properties": {
                "values": {
                    "type": "array",
                    "items": {
                        "type": "object",
                        "properties": {
                            "id": {"type": "string"},
                            "value": {"type": ["string", "array"]}
                        },
                        "required": ["id", "value"]
                    },
                },
            },
            "required": ["values"],
        }

    def preprocess(self, data):
        # Validar la entrada
        if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
            raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")

        # Construir el prompt con el formato especificado
        input_text = f"""
        Por favor, genera un JSON válido basado en las siguientes especificaciones:

        Formato esperado:
        {{
            "values": [
                {{
                    "id": "firstName",
                    "value": "STRING"
                }},
                {{
                    "id": "lastName",
                    "value": "STRING"
                }},
                {{
                    "id": "jobTitle",
                    "value": "STRING"
                }},
                {{
                    "id": "adress",
                    "value": [
                        {{
                            "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
                            "value": "STRING"
                        }}
                    ]
                }},
                {{
                    "id": "email",
                    "value": [
                        {{
                            "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
                            "value": "STRING"
                        }}
                    ]
                }},
                {{
                    "id": "phone",
                    "value": [
                        {{
                            "id": "[MOBILE-WORK-PERSONAL-MAIN-OTHER]",
                            "value": "STRING (ONLY NUMBERS)"
                        }}
                    ]
                }},
                {{
                    "id": "notes",
                    "value": "STRING"
                }},
                {{
                    "id": "roleFunction",
                    "value": "[BUYER-SELLER-SUPPLIER-PARTNER-COLLABORATOR-PROVIDER-CUSTOMER]"
                }}
            ]
        }}

        Solo incluye los campos detectados en el texto de entrada.
        Procesa el siguiente texto: "{data['inputs']}"
        """
        # Tokenizar el texto de entrada
        tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
        return tokens

    def inference(self, tokens):
        # Parámetros de generación
        generate_kwargs = {
            "max_length": 1000,
            "num_beams": 5,
            "do_sample": True,
            "temperature": 0.3,
            "top_k": 50,
            "top_p": 0.9,
            "repetition_penalty": 2.5
        }
        # Generar salida con el modelo
        with torch.no_grad():
            outputs = self.model.generate(**tokens, **generate_kwargs)
        return outputs

    def validate_json(self, decoded_output):
        # Validar el JSON generado con el esquema
        try:
            json_data = json.loads(decoded_output)
            jsonschema.validate(instance=json_data, schema=self.json_schema)
            return {"is_valid": True, "json_data": json_data}
        except json.JSONDecodeError as e:
            return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"}
        except jsonschema.exceptions.ValidationError as e:
            return {"is_valid": False, "error": f"Error validando JSON: {str(e)}"}

    def postprocess(self, outputs):
        # Decodificar la salida generada
        decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Validar el JSON generado
        validation_result = self.validate_json(decoded_output)
        if not validation_result["is_valid"]:
            raise ValueError(f"JSON inválido: {validation_result['error']}")
        
        return {"response": validation_result["json_data"]}

    def __call__(self, data):
        tokens = self.preprocess(data)
        outputs = self.inference(tokens)
        result = self.postprocess(outputs)
        return result