codelion commited on
Commit
a83f370
·
verified ·
1 Parent(s): 527fd08
Files changed (1) hide show
  1. app.py +82 -46
app.py CHANGED
@@ -7,6 +7,7 @@ import base64
7
  import math
8
  import ast
9
  import logging
 
10
 
11
  # Set up logging
12
  logging.basicConfig(level=logging.DEBUG)
@@ -55,7 +56,7 @@ def ensure_float(value):
55
  return float(value)
56
  return None
57
 
58
- # Function to process and visualize log probs
59
  def visualize_logprobs(json_input):
60
  try:
61
  # Parse the input (handles both JSON and Python dictionaries)
@@ -69,30 +70,82 @@ def visualize_logprobs(json_input):
69
  else:
70
  raise ValueError("Input must be a list or dictionary with 'content' key")
71
 
72
- # Extract tokens and log probs, skipping None or non-finite values
73
  tokens = []
74
  logprobs = []
 
75
  for entry in content:
76
  logprob = ensure_float(entry.get("logprob", None))
77
  if logprob is not None and math.isfinite(logprob):
78
  tokens.append(entry["token"])
79
  logprobs.append(logprob)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  else:
81
  logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
82
 
83
- # Prepare table data, handling None in top_logprobs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  table_data = []
85
- for entry in content:
86
  logprob = ensure_float(entry.get("logprob", None))
87
- # Only include entries with finite logprob and non-None top_logprobs
88
- if (
89
- logprob is not None
90
- and math.isfinite(logprob)
91
- and "top_logprobs" in entry
92
- and entry["top_logprobs"] is not None
93
- ):
94
  token = entry["token"]
95
- logger.debug("Processing token: %s, logprob: %s (type: %s)", token, logprob, type(logprob))
96
  top_logprobs = entry["top_logprobs"]
97
  # Ensure all values in top_logprobs are floats
98
  finite_top_logprobs = {}
@@ -100,44 +153,15 @@ def visualize_logprobs(json_input):
100
  float_value = ensure_float(value)
101
  if float_value is not None and math.isfinite(float_value):
102
  finite_top_logprobs[key] = float_value
103
-
104
  # Extract top 3 alternatives from top_logprobs
105
- top_3 = sorted(
106
- finite_top_logprobs.items(), key=lambda x: x[1], reverse=True
107
- )[:3]
108
  row = [token, f"{logprob:.4f}"]
109
  for alt_token, alt_logprob in top_3:
110
  row.append(f"{alt_token}: {alt_logprob:.4f}")
111
- # Pad with empty strings if fewer than 3 alternatives
112
  while len(row) < 5:
113
  row.append("")
114
  table_data.append(row)
115
 
116
- # Create the plot
117
- if logprobs:
118
- plt.figure(figsize=(10, 5))
119
- plt.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b")
120
- plt.title("Log Probabilities of Generated Tokens")
121
- plt.xlabel("Token Position")
122
- plt.ylabel("Log Probability")
123
- plt.grid(True)
124
- plt.xticks(range(len(logprobs)), tokens, rotation=45, ha="right")
125
- plt.tight_layout()
126
-
127
- # Save plot to a bytes buffer
128
- buf = io.BytesIO()
129
- plt.savefig(buf, format="png", bbox_inches="tight")
130
- buf.seek(0)
131
- plt.close()
132
-
133
- # Convert to base64 for Gradio
134
- img_bytes = buf.getvalue()
135
- img_base64 = base64.b64encode(img_bytes).decode("utf-8")
136
- img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
137
- else:
138
- img_html = "No finite log probabilities to plot."
139
-
140
- # Create DataFrame for the table
141
  df = (
142
  pd.DataFrame(
143
  table_data,
@@ -177,11 +201,22 @@ def visualize_logprobs(json_input):
177
  else:
178
  colored_text_html = "No finite log probabilities to display."
179
 
180
- return img_html, df, colored_text_html
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  except Exception as e:
183
  logger.error("Visualization failed: %s", str(e))
184
- return f"Error: {str(e)}", None, None
185
 
186
  # Gradio interface
187
  with gr.Blocks(title="Log Probability Visualizer") as app:
@@ -196,15 +231,16 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
196
  placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
197
  )
198
 
199
- plot_output = gr.HTML(label="Log Probability Plot")
200
  table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
201
  text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
 
202
 
203
  btn = gr.Button("Visualize")
204
  btn.click(
205
  fn=visualize_logprobs,
206
  inputs=json_input,
207
- outputs=[plot_output, table_output, text_output],
208
  )
209
 
210
  app.launch()
 
7
  import math
8
  import ast
9
  import logging
10
+ from matplotlib.widgets import Cursor
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.DEBUG)
 
56
  return float(value)
57
  return None
58
 
59
+ # Function to process and visualize log probs with hover and alternatives
60
  def visualize_logprobs(json_input):
61
  try:
62
  # Parse the input (handles both JSON and Python dictionaries)
 
70
  else:
71
  raise ValueError("Input must be a list or dictionary with 'content' key")
72
 
73
+ # Extract tokens, log probs, and top alternatives, skipping None or non-finite values
74
  tokens = []
75
  logprobs = []
76
+ top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
77
  for entry in content:
78
  logprob = ensure_float(entry.get("logprob", None))
79
  if logprob is not None and math.isfinite(logprob):
80
  tokens.append(entry["token"])
81
  logprobs.append(logprob)
82
+ # Get top_logprobs, default to empty dict if None
83
+ top_probs = entry.get("top_logprobs", {})
84
+ # Ensure all values in top_logprobs are floats
85
+ finite_top_probs = {}
86
+ for key, value in top_probs.items():
87
+ float_value = ensure_float(value)
88
+ if float_value is not None and math.isfinite(float_value):
89
+ finite_top_probs[key] = float_value
90
+ # Get the top 3 log probs (including the selected token)
91
+ all_probs = {entry["token"]: logprob} # Add the selected token's logprob
92
+ all_probs.update(finite_top_probs) # Add alternatives
93
+ sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
94
+ top_3 = sorted_probs[:3] # Top 3 log probs (highest to lowest)
95
+ top_alternatives.append(top_3)
96
  else:
97
  logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
98
 
99
+ # Create the plot with hover functionality
100
+ if logprobs:
101
+ fig, ax = plt.subplots(figsize=(10, 5))
102
+ scatter = ax.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Selected Token")[0]
103
+ ax.set_title("Log Probabilities of Generated Tokens")
104
+ ax.set_xlabel("Token Position")
105
+ ax.set_ylabel("Log Probability")
106
+ ax.grid(True)
107
+ ax.set_xticks([]) # Hide X-axis labels by default
108
+
109
+ # Add hover functionality using Matplotlib's Cursor for tooltips
110
+ cursor = Cursor(ax, useblit=True, color='red', linewidth=1)
111
+ token_annotations = []
112
+ for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
113
+ annotation = ax.annotate('', (x, y), xytext=(10, 10), textcoords='offset points', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), visible=False)
114
+ token_annotations.append(annotation)
115
+
116
+ def on_hover(event):
117
+ if event.inaxes == ax:
118
+ for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
119
+ contains, _ = scatter.contains(event)
120
+ if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5:
121
+ token_annotations[i].set_text(tokens[i])
122
+ token_annotations[i].set_visible(True)
123
+ fig.canvas.draw_idle()
124
+ else:
125
+ token_annotations[i].set_visible(False)
126
+ fig.canvas.draw_idle()
127
+
128
+ fig.canvas.mpl_connect('motion_notify_event', on_hover)
129
+
130
+ # Save plot to a bytes buffer
131
+ buf = io.BytesIO()
132
+ plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
133
+ buf.seek(0)
134
+ plt.close()
135
+
136
+ # Convert to base64 for Gradio
137
+ img_bytes = buf.getvalue()
138
+ img_base64 = base64.b64encode(img_bytes).decode("utf-8")
139
+ img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
140
+ else:
141
+ img_html = "No finite log probabilities to plot."
142
+
143
+ # Create DataFrame for the table
144
  table_data = []
145
+ for i, entry in enumerate(content):
146
  logprob = ensure_float(entry.get("logprob", None))
147
+ if logprob is not None and math.isfinite(logprob) and "top_logprobs" in entry and entry["top_logprobs"] is not None:
 
 
 
 
 
 
148
  token = entry["token"]
 
149
  top_logprobs = entry["top_logprobs"]
150
  # Ensure all values in top_logprobs are floats
151
  finite_top_logprobs = {}
 
153
  float_value = ensure_float(value)
154
  if float_value is not None and math.isfinite(float_value):
155
  finite_top_logprobs[key] = float_value
 
156
  # Extract top 3 alternatives from top_logprobs
157
+ top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
 
 
158
  row = [token, f"{logprob:.4f}"]
159
  for alt_token, alt_logprob in top_3:
160
  row.append(f"{alt_token}: {alt_logprob:.4f}")
 
161
  while len(row) < 5:
162
  row.append("")
163
  table_data.append(row)
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  df = (
166
  pd.DataFrame(
167
  table_data,
 
201
  else:
202
  colored_text_html = "No finite log probabilities to display."
203
 
204
+ # Create an alternative visualization for top 3 tokens
205
+ alt_viz_html = ""
206
+ if logprobs and top_alternatives:
207
+ alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>"
208
+ for i, (token, probs) in enumerate(zip(tokens, top_alternatives)):
209
+ alt_viz_html += f"<li>Position {i} (Token: {token}):<br>"
210
+ for tok, prob in probs:
211
+ alt_viz_html += f"{tok}: {prob:.4f}<br>"
212
+ alt_viz_html += "</li>"
213
+ alt_viz_html += "</ul>"
214
+
215
+ return img_html, df, colored_text_html, alt_viz_html
216
 
217
  except Exception as e:
218
  logger.error("Visualization failed: %s", str(e))
219
+ return f"Error: {str(e)}", None, None, None
220
 
221
  # Gradio interface
222
  with gr.Blocks(title="Log Probability Visualizer") as app:
 
231
  placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
232
  )
233
 
234
+ plot_output = gr.HTML(label="Log Probability Plot (Hover for Tokens)")
235
  table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
236
  text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
237
+ alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities")
238
 
239
  btn = gr.Button("Visualize")
240
  btn.click(
241
  fn=visualize_logprobs,
242
  inputs=json_input,
243
+ outputs=[plot_output, table_output, text_output, alt_viz_output],
244
  )
245
 
246
  app.launch()