File size: 3,641 Bytes
0d16e19
 
c5a16f8
 
8708518
0d16e19
 
 
 
 
a87db35
0d16e19
c5a16f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d16e19
05da434
 
c5a16f8
 
 
8708518
c5a16f8
 
4722f73
0d16e19
 
 
05da434
 
 
 
 
 
 
 
 
0d16e19
05da434
0d16e19
 
8708518
 
 
 
 
 
c5a16f8
8708518
c5a16f8
8708518
c5a16f8
 
 
8708518
c5a16f8
8708518
c5a16f8
0d16e19
 
c5a16f8
 
8708518
 
 
 
c5a16f8
8708518
 
c5a16f8
 
 
0d16e19
 
 
 
 
a87db35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json
import jsonschema
import re

class EndpointHandler:
    def __init__(self, model_dir):
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
        self.model.eval()

        # Esquema de validación del JSON
        self.json_schema = {
            "type": "object",
            "properties": {
                "values": {
                    "type": "array",
                    "items": {
                        "type": "object",
                        "properties": {
                            "id": {"type": "string"},
                            "value": {"type": ["string", "array"]}
                        },
                        "required": ["id", "value"]
                    },
                },
            },
            "required": ["values"],
        }

    def preprocess(self, data):
        if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
            raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")

        input_text = f"""
        Por favor, genera un JSON válido basado en las siguientes especificaciones:
        ... (Especificaciones del formato JSON omitidas por brevedad)
        Procesa el siguiente texto: "{data['inputs']}"
        """
        tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1000)
        return tokens

    def inference(self, tokens):
        generate_kwargs = {
            "max_length": 1000,
            "num_beams": 5,
            "do_sample": True,
            "temperature": 0.3,
            "top_k": 50,
            "top_p": 0.9,
            "repetition_penalty": 2.5
        }
        with torch.no_grad():
            outputs = self.model.generate(**tokens, **generate_kwargs)
        return outputs

    def clean_output(self, output):
        json_match = re.search(r"{.*}", output, re.DOTALL)
        if json_match:
            return json_match.group(0)
        return output

    def validate_json(self, decoded_output):
        cleaned_output = self.clean_output(decoded_output)
        try:
            json_data = json.loads(cleaned_output)
            jsonschema.validate(instance=json_data, schema=self.json_schema)
            return {"is_valid": True, "json_data": json_data}
        except json.JSONDecodeError as e:
            return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}", "raw_output": cleaned_output}
        except jsonschema.exceptions.ValidationError as e:
            return {"is_valid": False, "error": f"Error validando JSON: {str(e)}", "raw_output": cleaned_output}

    def postprocess(self, outputs):
        decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        validation_result = self.validate_json(decoded_output)
        
        # Siempre imprimir la salida generada
        print(f"Texto generado: {decoded_output}")

        if not validation_result["is_valid"]:
            print(f"Error en la validación: {validation_result['error']}")
            print(f"Salida sin procesar: {validation_result.get('raw_output', 'No disponible')}")
            raise ValueError(f"JSON inválido: {validation_result['error']}")
        
        return {"response": validation_result["json_data"]}

    def __call__(self, data):
        tokens = self.preprocess(data)
        outputs = self.inference(tokens)
        result = self.postprocess(outputs)
        return result