yabramuvdi commited on
Commit
bb934b2
·
verified ·
1 Parent(s): 373e477

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -88
app.py CHANGED
@@ -26,142 +26,131 @@ from huggingface_hub import login
26
  HF_TOKEN = os.getenv('HF_TOKEN')
27
  login(token=HF_TOKEN)
28
 
29
- # Available models
30
  AVAILABLE_MODELS = {
31
- "bloomz-560m": "bigscience/bloomz-560m",
32
- "bloomz-7B1": "bigscience/bloomz-7b1",
33
- "GPT-J-6B": "EleutherAI/gpt-j-6b",
34
- "mT5-XL": "google/mt5-xl",
35
  }
36
 
37
- # Initialize model and tokenizer
38
  current_model = None
39
  current_tokenizer = None
40
  current_model_name = None
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
 
43
- def load_model(model_name):
44
- """Load the selected model and tokenizer."""
45
  global current_model, current_tokenizer, current_model_name
46
- if current_model_name != model_name:
47
- current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name]).to(device)
48
- current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
49
- current_model_name = model_name
50
 
51
- # Load the default model at startup
52
- load_model("bloomz-560m")
53
 
54
  @spaces.GPU()
55
- def get_next_token_predictions(text, model_name, top_k=10):
56
- """Generate the next token predictions with their probabilities."""
57
  global current_model, current_tokenizer
58
 
59
- # Load the model if it has changed
60
- if current_model_name != model_name:
61
- load_model(model_name)
62
 
63
- inputs = current_tokenizer(text, return_tensors="pt").to(device)
64
 
65
  with torch.no_grad():
66
- outputs = current_model(**inputs)
67
- logits = outputs.logits[0, -1, :]
68
- probs = torch.nn.functional.softmax(logits, dim=-1)
69
 
70
- top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
71
  top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
72
 
73
- return top_k_tokens, top_k_probs.cpu().tolist()
74
-
75
- def plot_probabilities(tokens, probs):
76
- """Generate a horizontal bar chart for token probabilities."""
77
- fig, ax = plt.subplots(figsize=(8, 5))
78
- ax.barh(tokens[::-1], probs[::-1], color="skyblue")
79
- ax.set_xlabel("Probability")
80
- ax.set_title("Next Token Predictions")
81
- plt.tight_layout()
82
-
83
- # Save plot as an image and return the file path
84
- plot_path = "token_probabilities.png"
85
- plt.savefig(plot_path)
86
- plt.close(fig)
87
-
88
- return plot_path
89
 
90
- def predict_next_token(model_name, text, top_k, custom_token=""):
91
- """Get predictions and update the UI with text and a chart."""
92
- if custom_token:
93
- text += custom_token
94
 
95
- tokens, probs = get_next_token_predictions(text, model_name, top_k)
96
 
97
- # Generate bar chart
98
- plot_path = plot_probabilities(tokens, probs)
99
 
100
- return gr.update(choices=[f"'{t}'" for t in tokens]), plot_path
101
 
102
- def append_selected_token(text, selected_token):
103
- """Append selected token from dropdown to the text input."""
104
- if selected_token:
105
- clean_token = selected_token.strip("'")
106
- text += f" {clean_token}"
107
- return text
108
 
109
- # Create the UI
110
  with gr.Blocks() as demo:
111
- gr.Markdown("# 🔥 Interactive Text Prediction with Transformers")
112
  gr.Markdown(
113
- "This application lets you interactively generate text using multiple transformer models. "
114
- "Choose a model, type your text, and explore token predictions."
115
  )
116
 
117
  with gr.Row():
118
- model_dropdown = gr.Dropdown(
119
  choices=list(AVAILABLE_MODELS.keys()),
120
- value="distilgpt2",
121
- label="Select Model"
 
 
 
 
 
 
122
  )
123
 
124
  with gr.Row():
125
- text_input = gr.Textbox(
126
  lines=5,
127
- label="Input Text",
128
- placeholder="Type your text here...",
129
- value="The quick brown fox"
130
  )
131
 
132
  with gr.Row():
133
- top_k_slider = gr.Slider(
134
- minimum=1,
135
- maximum=20,
136
- value=10,
137
- step=1,
138
- label="Top-k Predictions"
139
- )
140
 
141
  with gr.Row():
142
- predict_button = gr.Button("Predict")
143
-
144
- with gr.Row():
145
- token_dropdown = gr.Dropdown(
146
- label="Predicted Tokens",
147
  choices=[]
148
  )
149
- append_button = gr.Button("Append Token")
150
 
151
  with gr.Row():
152
- chart_output = gr.Image(label="Token Probability Chart")
 
 
 
 
 
153
 
154
- # Button click events
155
- predict_button.click(
156
- predict_next_token,
157
- inputs=[model_dropdown, text_input, top_k_slider],
158
- outputs=[token_dropdown, chart_output]
159
  )
160
 
161
- append_button.click(
162
- append_selected_token,
163
- inputs=[text_input, token_dropdown],
164
- outputs=text_input
165
  )
166
 
167
  demo.queue().launch()
 
26
  HF_TOKEN = os.getenv('HF_TOKEN')
27
  login(token=HF_TOKEN)
28
 
29
+ # Modelos disponibles
30
  AVAILABLE_MODELS = {
31
+ "BLOOMZ-560M": "bigscience/bloomz-560m"
 
 
 
32
  }
33
 
34
+ # Inicializar modelo y tokenizer
35
  current_model = None
36
  current_tokenizer = None
37
  current_model_name = None
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
39
 
40
+ def cargar_modelo(nombre_modelo):
41
+ """Carga el modelo y el tokenizer seleccionado."""
42
  global current_model, current_tokenizer, current_model_name
43
+ if current_model_name != nombre_modelo:
44
+ current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[nombre_modelo]).to(device)
45
+ current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[nombre_modelo])
46
+ current_model_name = nombre_modelo
47
 
48
+ # Cargar el modelo por defecto
49
+ cargar_modelo("BLOOMZ-560M")
50
 
51
  @spaces.GPU()
52
+ def obtener_predicciones(texto, nombre_modelo, top_k=10):
53
+ """Genera las predicciones de las siguientes palabras con sus probabilidades."""
54
  global current_model, current_tokenizer
55
 
56
+ # Cargar modelo si ha cambiado
57
+ if current_model_name != nombre_modelo:
58
+ cargar_modelo(nombre_modelo)
59
 
60
+ entradas = current_tokenizer(texto, return_tensors="pt").to(device)
61
 
62
  with torch.no_grad():
63
+ salidas = current_model(**entradas)
64
+ logits = salidas.logits[0, -1, :]
65
+ probabilidades = torch.nn.functional.softmax(logits, dim=-1)
66
 
67
+ top_k_prob, top_k_indices = torch.topk(probabilidades, k=top_k)
68
  top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
69
 
70
+ return top_k_tokens, top_k_prob.cpu().tolist()
71
+
72
+ def generar_barplot(tokens, probabilidades):
73
+ """Genera una gráfica de barras de Gradio con las palabras más probables."""
74
+ datos = {"Palabra": tokens, "Probabilidad": probabilidades}
75
+ return datos
 
 
 
 
 
 
 
 
 
 
76
 
77
+ def predecir_siguiente_palabra(nombre_modelo, texto, top_k, token_custom=""):
78
+ """Obtiene predicciones y actualiza la UI."""
79
+ if token_custom:
80
+ texto += token_custom
81
 
82
+ tokens, probabilidades = obtener_predicciones(texto, nombre_modelo, int(top_k))
83
 
84
+ # Generar gráfico con Gradio BarPlot
85
+ barplot_data = generar_barplot(tokens, probabilidades)
86
 
87
+ return gr.update(choices=[f"'{t}'" for t in tokens]), barplot_data
88
 
89
+ def agregar_token_seleccionado(texto, token_seleccionado):
90
+ """Agrega el token seleccionado al texto de entrada."""
91
+ if token_seleccionado:
92
+ token_limpio = token_seleccionado.strip("'")
93
+ texto += f" {token_limpio}"
94
+ return texto
95
 
96
+ # Crear la interfaz en español
97
  with gr.Blocks() as demo:
98
+ gr.Markdown("# 🔥 Predicción de Texto con Modelos Transformadores")
99
  gr.Markdown(
100
+ "Esta aplicación permite generar palabras utilizando un modelo de lenguaje. "
101
+ "Selecciona un modelo, introduce un texto y explora las palabras más probables a continuación."
102
  )
103
 
104
  with gr.Row():
105
+ dropdown_modelo = gr.Dropdown(
106
  choices=list(AVAILABLE_MODELS.keys()),
107
+ value="BLOOMZ-560M",
108
+ label="📌 Modelo de lenguaje"
109
+ )
110
+
111
+ dropdown_top_k = gr.Dropdown(
112
+ choices=["5", "10", "15", "20"],
113
+ value="10",
114
+ label="🔢 Número de palabras a mostrar"
115
  )
116
 
117
  with gr.Row():
118
+ texto_entrada = gr.Textbox(
119
  lines=5,
120
+ label="📝 Texto de entrada",
121
+ placeholder="Escribe aquí...",
122
+ value="Mi abuela me dejó una gran"
123
  )
124
 
125
  with gr.Row():
126
+ boton_predecir = gr.Button("🔮 Predecir")
 
 
 
 
 
 
127
 
128
  with gr.Row():
129
+ dropdown_tokens = gr.Dropdown(
130
+ label="🔠 Palabras predichas",
 
 
 
131
  choices=[]
132
  )
133
+ boton_agregar = gr.Button(" Agregar palabra")
134
 
135
  with gr.Row():
136
+ barplot_resultados = gr.BarPlot(
137
+ label="📊 Palabras más probables",
138
+ x="Palabra",
139
+ y="Probabilidad",
140
+ title="Predicciones del modelo"
141
+ )
142
 
143
+ # Acciones de botones
144
+ boton_predecir.click(
145
+ predecir_siguiente_palabra,
146
+ inputs=[dropdown_modelo, texto_entrada, dropdown_top_k],
147
+ outputs=[dropdown_tokens, barplot_resultados]
148
  )
149
 
150
+ boton_agregar.click(
151
+ agregar_token_seleccionado,
152
+ inputs=[texto_entrada, dropdown_tokens],
153
+ outputs=texto_entrada
154
  )
155
 
156
  demo.queue().launch()