File size: 2,384 Bytes
0d16e19
 
c5a16f8
0d16e19
e29f84e
05cc433
 
c1975d1
 
05cc433
 
0d16e19
 
c1975d1
 
a87db35
0d16e19
 
05da434
 
c5a16f8
d584a44
c1975d1
463dbb5
c1975d1
5888de7
0d16e19
 
 
05da434
5888de7
 
66a0eb1
5888de7
 
66a0eb1
5888de7
05da434
0d16e19
05da434
0d16e19
 
8708518
d584a44
 
 
 
 
 
8708518
0d16e19
 
d584a44
 
 
c1975d1
d584a44
 
c1975d1
0d16e19
 
 
 
 
a87db35
c1975d1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json


model_name = "jla25/squareV3"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=True)


class EndpointHandler:
    def __init__(self, model_dir):
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_auth_token=True)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, use_auth_token=True)
        self.model.eval()

    def preprocess(self, data):
        if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
            raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")

        # Prompt personalizado para guiar al modelo
        input_text = f"{data['inputs']}"
        print(f"Prompt generado para el modelo: {input_text}")
        
        tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
        return tokens

    def inference(self, tokens):
        generate_kwargs = {
            "max_length": 1024,
            "num_beams": 5,
            "do_sample": False,
            "temperature": 0.3,
            "top_k": 50,
            "top_p": 0.7,
            "repetition_penalty": 2.5
        }
        with torch.no_grad():
            outputs = self.model.generate(**tokens, **generate_kwargs)
        return outputs

    def clean_output(self, output):
        try:
            start_index = output.index("{")
            end_index = output.rindex("}") + 1
            return output[start_index:end_index]
        except ValueError:
            return output

    def postprocess(self, outputs):
        decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        cleaned_output = self.clean_output(decoded_output)

        # Imprimir siempre el texto generado para depuración
        print(f"Texto generado por el modelo: {decoded_output}")
        print(f"JSON limpiado: {cleaned_output}")

        return {"response": cleaned_output}

    def __call__(self, data):
        tokens = self.preprocess(data)
        outputs = self.inference(tokens)
        result = self.postprocess(outputs)
        return result


# Crear una instancia del handler
handler = EndpointHandler(model_name)