File size: 5,358 Bytes
ac72c21
d7f8dad
43130a6
 
 
 
 
 
 
 
 
 
 
 
93a3f9a
d7f8dad
 
7c2299f
d7f8dad
 
93cbacc
d7f8dad
 
e5cdcee
 
 
 
 
43130a6
bb934b2
617bd81
bb934b2
617bd81
 
bb934b2
ac72c21
 
 
310d018
6c99f7c
bb934b2
 
ac72c21
bb934b2
 
 
 
6c99f7c
bb934b2
 
310d018
d7f8dad
bb934b2
 
ac72c21
617bd81
bb934b2
 
 
6c99f7c
22ac25d
310d018
6c99f7c
bb934b2
 
 
310d018
bb934b2
22ac25d
6c99f7c
bb934b2
 
 
fb423d0
69a7fd1
fb423d0
22ac25d
93cbacc
bb934b2
 
 
22ac25d
310d018
bb934b2
310d018
fb423d0
bb934b2
93cbacc
bb934b2
ac72c21
bb934b2
22ac25d
bb934b2
22ac25d
 
bb934b2
310d018
bb934b2
9a72c69
fb423d0
310d018
fb423d0
bb934b2
310d018
9a72c69
 
bb934b2
9a72c69
bb934b2
fb423d0
bb934b2
 
 
 
 
 
9a72c69
93a3f9a
9a72c69
bb934b2
93a3f9a
bb934b2
 
899bbf4
9a72c69
93a3f9a
310d018
bb934b2
310d018
93a3f9a
bb934b2
 
9a72c69
 
bb934b2
310d018
9a72c69
bb934b2
69a7fd1
22ac25d
 
 
 
bb934b2
310d018
bb934b2
 
 
 
 
9a72c69
175fea5
bb934b2
 
 
 
310d018
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
# Handle Spaces GPU
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:
    class spaces:
        @staticmethod
        def GPU(func):
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            return wrapper

@spaces.GPU
def fake_gpu():
    pass
    
import numpy as np
import pandas as pd
import torch
import gradio as gr
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
from huggingface_hub import login

# Authenticate
HF_TOKEN = os.getenv('HF_TOKEN')
login(token=HF_TOKEN)

# Modelos disponibles
AVAILABLE_MODELS = {
    "BLOOMZ-560M": "bigscience/bloomz-560m"
}

# Inicializar modelo y tokenizer
current_model = None
current_tokenizer = None
current_model_name = None
device = "cuda" if torch.cuda.is_available() else "cpu"

def cargar_modelo(nombre_modelo):
    """Carga el modelo y el tokenizer seleccionado."""
    global current_model, current_tokenizer, current_model_name
    if current_model_name != nombre_modelo:
        current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[nombre_modelo]).to(device)
        current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[nombre_modelo])
        current_model_name = nombre_modelo

# Cargar el modelo por defecto
cargar_modelo("BLOOMZ-560M")

@spaces.GPU()
def obtener_predicciones(texto, nombre_modelo, top_k=10):
    """Genera las predicciones de las siguientes palabras con sus probabilidades."""
    global current_model, current_tokenizer
    
    # Cargar modelo si ha cambiado
    if current_model_name != nombre_modelo:
        cargar_modelo(nombre_modelo)
    
    entradas = current_tokenizer(texto.strip(), return_tensors="pt").to(device)

    with torch.no_grad():
        salidas = current_model(**entradas)
        logits = salidas.logits[0, -1, :]
        probabilidades = torch.nn.functional.softmax(logits, dim=-1)
    
    top_k_prob, top_k_indices = torch.topk(probabilidades, k=top_k)
    top_k_tokens = [current_tokenizer.decode([idx.item()]).strip() for idx in top_k_indices]  # ✅ Strip spaces
    
    return top_k_tokens, top_k_prob.cpu().tolist()

def generar_barplot(tokens, probabilidades):
    """Convierte los datos en un DataFrame y lo ordena de mayor a menor probabilidad."""
    df = pd.DataFrame({"Palabra": tokens, "Probabilidad": probabilidades})
    df = df.sort_values(by="Probabilidad", ascending=False)  # ✅ Sort by probability (highest first)
    return df

def predecir_siguiente_palabra(nombre_modelo, texto, top_k, token_custom=""):
    """Obtiene predicciones y actualiza la UI."""
    if token_custom:
        texto = texto.rstrip() + " " + token_custom.strip()  # ✅ Prevents extra whitespaces

    tokens, probabilidades = obtener_predicciones(texto, nombre_modelo, int(top_k))

    # Generar gráfico con Gradio BarPlot (ahora ordenado)
    barplot_data = generar_barplot(tokens, probabilidades)

    return gr.update(choices=[f"'{t}'" for t in tokens]), barplot_data

def agregar_token_seleccionado(texto, token_seleccionado):
    """Agrega el token seleccionado al texto de entrada sin espacios extra."""
    if token_seleccionado:
        token_limpio = token_seleccionado.strip("'").strip()  # ✅ Removes unwanted spaces
        texto = texto.rstrip() + " " + token_limpio  # ✅ Ensures no double spaces
    return texto

# Crear la interfaz en español
with gr.Blocks() as demo:
    gr.Markdown("# 🔥 Predicción de Texto con Modelos Transformadores")
    gr.Markdown(
        "Esta aplicación permite generar palabras utilizando un modelo de lenguaje. "
        "Selecciona un modelo, introduce un texto y explora las palabras más probables a continuación."
    )
    
    with gr.Row():
        dropdown_modelo = gr.Dropdown(
            choices=list(AVAILABLE_MODELS.keys()),
            value="BLOOMZ-560M",
            label="📌 Modelo de lenguaje"
        )

        dropdown_top_k = gr.Dropdown(
            choices=["5", "10", "15", "20"],
            value="10",
            label="🔢 Número de palabras a mostrar"
        )
    
    with gr.Row():
        texto_entrada = gr.Textbox(
            lines=5,
            label="📝 Texto de entrada",
            placeholder="Escribe aquí...",
            value=""
        )
    
    with gr.Row():
        boton_predecir = gr.Button("🔮 Predecir")

    with gr.Row():
        dropdown_tokens = gr.Dropdown(
            label="🔠 Palabras predichas",
            choices=[]
        )
        boton_agregar = gr.Button("➕ Agregar palabra")

    with gr.Row():
        barplot_resultados = gr.BarPlot(
            value=pd.DataFrame(columns=["Palabra", "Probabilidad"]),  # ✅ Empty DataFrame to initialize
            x="Probabilidad",  # ✅ Swap axes to make it horizontal
            y="Palabra",
            title="📊 Predicciones del modelo",
            orientation="h"  # ✅ Makes the barplot horizontal
        )

    # Acciones de botones
    boton_predecir.click(
        predecir_siguiente_palabra,
        inputs=[dropdown_modelo, texto_entrada, dropdown_top_k],
        outputs=[dropdown_tokens, barplot_resultados]
    )

    boton_agregar.click(
        agregar_token_seleccionado,
        inputs=[texto_entrada, dropdown_tokens],
        outputs=texto_entrada
    )

demo.queue().launch()