File size: 3,200 Bytes
0d16e19
 
c5a16f8
0d16e19
e29f84e
05cc433
 
e29f84e
 
05cc433
 
0d16e19
 
05cc433
 
a87db35
0d16e19
 
05da434
 
c5a16f8
d584a44
2433496
463dbb5
 
5888de7
0d16e19
 
 
05da434
5888de7
 
66a0eb1
5888de7
 
66a0eb1
5888de7
05da434
0d16e19
05da434
0d16e19
 
8708518
d584a44
 
 
 
 
 
 
 
8708518
d584a44
 
c5a16f8
d584a44
 
 
 
 
c5a16f8
d584a44
c5a16f8
0d16e19
 
d584a44
 
 
8708518
d584a44
 
 
 
8708518
c5a16f8
8708518
c5a16f8
d584a44
c5a16f8
0d16e19
 
 
 
 
a87db35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json


model_name = "jla25/squareV3"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)


class EndpointHandler:
    def __init__(self, model_dir):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self.model.eval()

    def preprocess(self, data):
        if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
            raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")

        # Prompt personalizado para guiar al modelo
        input_text = ({data['inputs']})
        # Imprimir el texto generado para el prompt
        print(f"Prompt generado para el modelo: {input_text}")
        tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
        return tokens

    def inference(self, tokens):
        generate_kwargs = {
            "max_length": 1024,
            "num_beams": 5,
            "do_sample": False,
            "temperature": 0.3,
            "top_k": 50,
            "top_p": 0.7,
            "repetition_penalty": 2.5
        }
        with torch.no_grad():
            outputs = self.model.generate(**tokens, **generate_kwargs)
        return outputs

    def clean_output(self, output):
        # Extraer el JSON dentro del texto generado
        try:
            start_index = output.index("{")
            end_index = output.rindex("}") + 1
            return output[start_index:end_index]
        except ValueError:
            # Si no hay un JSON válido en el texto
            return output

    def validate_json(self, json_text):
        # Validar el JSON generado
        try:
            json_data = json.loads(json_text)
            if "values" in json_data and isinstance(json_data["values"], list):
                return {"is_valid": True, "json_data": json_data}
            else:
                return {"is_valid": False, "error": "El JSON no contiene el formato esperado."}
        except json.JSONDecodeError as e:
            return {"is_valid": False, "error": f"Error decodificando JSON: {str(e)}"}

    def postprocess(self, outputs):
        decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        cleaned_output = self.clean_output(decoded_output)

        # Imprimir siempre el texto generado para depuración
        print(f"Texto generado: {decoded_output}")
        print(f"JSON limpiado: {cleaned_output}")

        # Validar el JSON generado
        validation_result = self.validate_json(cleaned_output)

        if not validation_result["is_valid"]:
            print(f"Error en la validación: {validation_result['error']}")
            raise ValueError(f"JSON inválido: {validation_result['error']}")

        return {"response": validation_result["json_data"]}

    def __call__(self, data):
        tokens = self.preprocess(data)
        outputs = self.inference(tokens)
        result = self.postprocess(outputs)
        return result