File size: 2,488 Bytes
0d16e19
 
c5a16f8
0d16e19
e29f84e
62a752c
05cc433
25fd307
 
05cc433
 
0d16e19
 
25fd307
 
a87db35
0d16e19
 
05da434
 
c5a16f8
d584a44
d6ed5c7
463dbb5
4cb23e7
5888de7
0d16e19
 
 
05da434
e519cca
5888de7
1ebac2e
5888de7
 
d6ed5c7
e519cca
5888de7
05da434
0d16e19
05da434
0d16e19
 
8708518
d584a44
 
 
 
 
 
8708518
0d16e19
 
d584a44
 
 
c1975d1
d584a44
 
c1975d1
0d16e19
 
 
 
 
a87db35
c1975d1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json


model_name = "jla25/squareV3"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)


class EndpointHandler:
    def __init__(self, model_dir):
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
        self.model.eval()

    def preprocess(self, data):
        if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
            raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")

        # Prompt personalizado para guiar al modelo
        input_text = f"Generate a valid JSON capturing data from this text:{data['inputs']}"
        print(f"Prompt generado para el modelo: {input_text}")
        input_text = input_text.encode("utf-8").decode("utf-8")
        tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
        return tokens

    def inference(self, tokens):
        generate_kwargs = {
            "max_length": 512,
            "num_beams": 5,
            "do_sample": False,
            "temperature": 0.3,
            "top_k": 50,
            "top_p": 0.8,
            "early_stopping": True,  # Añadir explicitamente esta configuración
            "repetition_penalty": 2.5
        }
        with torch.no_grad():
            outputs = self.model.generate(**tokens, **generate_kwargs)
        return outputs

    def clean_output(self, output):
        try:
            start_index = output.index("{")
            end_index = output.rindex("}") + 1
            return output[start_index:end_index]
        except ValueError:
            return output

    def postprocess(self, outputs):
        decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        cleaned_output = self.clean_output(decoded_output)

        # Imprimir siempre el texto generado para depuración
        print(f"Texto generado por el modelo: {decoded_output}")
        print(f"JSON limpiado: {cleaned_output}")

        return {"response": cleaned_output}

    def __call__(self, data):
        tokens = self.preprocess(data)
        outputs = self.inference(tokens)
        result = self.postprocess(outputs)
        return result


# Crear una instancia del handler
handler = EndpointHandler(model_name)