camilaaeromoca commited on
Commit
dcf41c9
·
verified ·
1 Parent(s): 71892ba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ modelo_llm = AutoModelForCausalLM.from_pretrained("modelos/modelo_final")
2
+
3
+ # Definindo uma classe chamada NumberTokenizer, que é usada para tokenizar os números
4
+ class DSATokenizer:
5
+
6
+ # Método construtor da classe, que é executado quando um objeto dessa classe é criado
7
+ def __init__(self, numbers_qty = 10):
8
+
9
+ # Lista de tokens possíveis que o tokenizador pode encontrar
10
+ vocab = ['+', '=', '-1', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
11
+
12
+ # Definindo a quantidade de números que o tokenizador pode lidar
13
+ self.numbers_qty = numbers_qty
14
+
15
+ # Definindo o token de preenchimento (padding)
16
+ self.pad_token = '-1'
17
+
18
+ # Criando um dicionário que mapeia cada token para um índice único
19
+ self.encoder = {str(v):i for i,v in enumerate(vocab)}
20
+
21
+ # Criando um dicionário que mapeia cada índice único de volta ao token correspondente
22
+ self.decoder = {i:str(v) for i,v in enumerate(vocab)}
23
+
24
+ # Obtendo o índice do token de preenchimento no encoder
25
+ self.pad_token_id = self.encoder[self.pad_token]
26
+
27
+ # Método para decodificar uma lista de IDs de token de volta para uma string
28
+ def decode(self, token_ids):
29
+ return ' '.join(self.decoder[t] for t in token_ids)
30
+
31
+ # Método que é chamado quando o objeto da classe é invocado como uma função
32
+ def __call__(self, text):
33
+ # Dividindo o texto em tokens individuais e retornando uma lista dos IDs correspondentes
34
+ return [self.encoder[t] for t in text.split()]
35
+
36
+ # Cria o objeto
37
+ tokenizer = DSATokenizer(13)
38
+
39
+ # Definindo a função gera_solution com três parâmetros: input, solution_length e model
40
+ def faz_previsao(entrada, solution_length = 6, model = modelo_llm):
41
+
42
+ # Colocando o modelo em modo de avaliação.
43
+ model.eval()
44
+
45
+ # Convertendo a entrada (string) em tensor utilizando o tokenizer.
46
+ # O tensor é uma estrutura de dados que o modelo de aprendizado de máquina pode processar.
47
+ entrada = torch.tensor(tokenizer(entrada))
48
+
49
+ # Iniciando uma lista vazia para armazenar a solução
50
+ solution = []
51
+
52
+ # Loop que gera a solução de comprimento solution_length
53
+ for i in range(solution_length):
54
+
55
+ # Alimentando a entrada atual ao modelo e obtendo a saída
56
+ saida = model(entrada)
57
+
58
+ # Pegando o índice do maior valor no último conjunto de logits (log-odds) da saída,
59
+ # que é a previsão do modelo para o próximo token
60
+ predicted = saida.logits[-1].argmax()
61
+
62
+ # Concatenando a previsão atual com a entrada atual.
63
+ # Isso servirá como a nova entrada para a próxima iteração.
64
+ entrada = torch.cat((entrada, predicted.unsqueeze(0)), dim = 0)
65
+
66
+ # Adicionando a previsão atual à lista de soluções e convertendo o tensor em um número Python padrão
67
+ solution.append(predicted.cpu().item())
68
+
69
+ # Decodificando a lista de soluções para obter a string de saída e retornando-a
70
+ return tokenizer.decode(solution)
71
+
72
+ # Testa a função
73
+ faz_previsao('3 + 5 =', solution_length = 2)
74
+
75
+ # Função para retornar a função que faz a previsão
76
+ def funcsolve(entrada):
77
+ return faz_previsao(entrada, solution_length = 2)
78
+
79
+
80
+ # Cria a web app
81
+ webapp = gr.Interface(fn = funcsolve,
82
+ inputs = [gr.Textbox(label = "Dados de Entrada",
83
+ lines = 1,
84
+ info = "Os dados devem estar na forma: '1 + 2 =' com um único espaço entre cada caractere e apenas números de um dígito são permitidos.")],
85
+ outputs = [gr.Textbox(label = "Resultado (Previsão do Modelo)", lines = 1)],
86
+ title = "Deploy de LLM Após o Fine-Tuning",
87
+ description = "Digite os dados de entrada e clique no botão Submit para o modelo fazer a previsão.",
88
+ examples = ["5 + 3 =", "2 + 9 ="])
89
+
90
+
91
+ webapp.launch()