yabramuvdi commited on
Commit
93cbacc
·
verified ·
1 Parent(s): 5169734

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -17,6 +17,7 @@ def fake_gpu():
17
  import numpy as np
18
  import torch
19
  import gradio as gr
 
20
  from transformers import AutoModelForCausalLM, AutoTokenizer
21
  import spaces
22
 
@@ -67,16 +68,32 @@ def get_next_token_predictions(text, model_name, top_k=10):
67
 
68
  return top_k_tokens, top_k_probs.cpu().tolist()
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def predict_next_token(model_name, text, top_k, custom_token=""):
71
- """Get predictions and update the UI."""
72
  if custom_token:
73
  text += custom_token
74
 
75
  tokens, probs = get_next_token_predictions(text, model_name, top_k)
76
 
77
- predictions = "\n".join([f"'{token}': {prob:.4f}" for token, prob in zip(tokens, probs)])
78
-
79
- return gr.update(choices=[f"'{t}'" for t in tokens]), predictions
 
80
 
81
  def append_selected_token(text, selected_token):
82
  """Append selected token from dropdown to the text input."""
@@ -128,16 +145,13 @@ with gr.Blocks() as demo:
128
  append_button = gr.Button("Append Token")
129
 
130
  with gr.Row():
131
- predictions_output = gr.Textbox(
132
- lines=10,
133
- label="Token Probabilities"
134
- )
135
 
136
  # Button click events
137
  predict_button.click(
138
  predict_next_token,
139
  inputs=[model_dropdown, text_input, top_k_slider],
140
- outputs=[token_dropdown, predictions_output]
141
  )
142
 
143
  append_button.click(
 
17
  import numpy as np
18
  import torch
19
  import gradio as gr
20
+ import matplotlib.pyplot as plt
21
  from transformers import AutoModelForCausalLM, AutoTokenizer
22
  import spaces
23
 
 
68
 
69
  return top_k_tokens, top_k_probs.cpu().tolist()
70
 
71
+ def plot_probabilities(tokens, probs):
72
+ """Generate a horizontal bar chart for token probabilities."""
73
+ fig, ax = plt.subplots(figsize=(8, 5))
74
+ ax.barh(tokens[::-1], probs[::-1], color="skyblue")
75
+ ax.set_xlabel("Probability")
76
+ ax.set_title("Next Token Predictions")
77
+ plt.tight_layout()
78
+
79
+ # Save plot as an image and return the file path
80
+ plot_path = "token_probabilities.png"
81
+ plt.savefig(plot_path)
82
+ plt.close(fig)
83
+
84
+ return plot_path
85
+
86
  def predict_next_token(model_name, text, top_k, custom_token=""):
87
+ """Get predictions and update the UI with text and a chart."""
88
  if custom_token:
89
  text += custom_token
90
 
91
  tokens, probs = get_next_token_predictions(text, model_name, top_k)
92
 
93
+ # Generate bar chart
94
+ plot_path = plot_probabilities(tokens, probs)
95
+
96
+ return gr.update(choices=[f"'{t}'" for t in tokens]), plot_path
97
 
98
  def append_selected_token(text, selected_token):
99
  """Append selected token from dropdown to the text input."""
 
145
  append_button = gr.Button("Append Token")
146
 
147
  with gr.Row():
148
+ chart_output = gr.Image(label="Token Probability Chart")
 
 
 
149
 
150
  # Button click events
151
  predict_button.click(
152
  predict_next_token,
153
  inputs=[model_dropdown, text_input, top_k_slider],
154
+ outputs=[token_dropdown, chart_output]
155
  )
156
 
157
  append_button.click(