codelion commited on
Commit
0e1182d
·
verified ·
1 Parent(s): a83f370

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +403 -60
app.py CHANGED
@@ -7,6 +7,14 @@ import base64
7
  import math
8
  import ast
9
  import logging
 
 
 
 
 
 
 
 
10
  from matplotlib.widgets import Cursor
11
 
12
  # Set up logging
@@ -56,8 +64,8 @@ def ensure_float(value):
56
  return float(value)
57
  return None
58
 
59
- # Function to process and visualize log probs with hover and alternatives
60
- def visualize_logprobs(json_input):
61
  try:
62
  # Parse the input (handles both JSON and Python dictionaries)
63
  data = parse_input(json_input)
@@ -74,11 +82,18 @@ def visualize_logprobs(json_input):
74
  tokens = []
75
  logprobs = []
76
  top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
 
77
  for entry in content:
78
  logprob = ensure_float(entry.get("logprob", None))
79
- if logprob is not None and math.isfinite(logprob):
80
  tokens.append(entry["token"])
81
  logprobs.append(logprob)
 
 
 
 
 
 
82
  # Get top_logprobs, default to empty dict if None
83
  top_probs = entry.get("top_logprobs", {})
84
  # Ensure all values in top_logprobs are floats
@@ -96,55 +111,340 @@ def visualize_logprobs(json_input):
96
  else:
97
  logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
98
 
99
- # Create the plot with hover functionality
100
- if logprobs:
101
- fig, ax = plt.subplots(figsize=(10, 5))
102
- scatter = ax.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Selected Token")[0]
103
- ax.set_title("Log Probabilities of Generated Tokens")
104
- ax.set_xlabel("Token Position")
105
- ax.set_ylabel("Log Probability")
106
- ax.grid(True)
107
- ax.set_xticks([]) # Hide X-axis labels by default
108
-
109
- # Add hover functionality using Matplotlib's Cursor for tooltips
110
- cursor = Cursor(ax, useblit=True, color='red', linewidth=1)
111
- token_annotations = []
112
- for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
113
- annotation = ax.annotate('', (x, y), xytext=(10, 10), textcoords='offset points', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), visible=False)
114
- token_annotations.append(annotation)
115
-
116
- def on_hover(event):
117
- if event.inaxes == ax:
118
- for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
119
- contains, _ = scatter.contains(event)
120
- if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5:
121
- token_annotations[i].set_text(tokens[i])
122
- token_annotations[i].set_visible(True)
123
- fig.canvas.draw_idle()
124
- else:
125
- token_annotations[i].set_visible(False)
126
- fig.canvas.draw_idle()
127
-
128
- fig.canvas.mpl_connect('motion_notify_event', on_hover)
129
-
130
- # Save plot to a bytes buffer
131
- buf = io.BytesIO()
132
- plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
133
- buf.seek(0)
134
- plt.close()
135
-
136
- # Convert to base64 for Gradio
137
- img_bytes = buf.getvalue()
138
- img_base64 = base64.b64encode(img_bytes).decode("utf-8")
139
- img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
140
- else:
141
- img_html = "No finite log probabilities to plot."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  # Create DataFrame for the table
144
  table_data = []
145
  for i, entry in enumerate(content):
146
  logprob = ensure_float(entry.get("logprob", None))
147
- if logprob is not None and math.isfinite(logprob) and "top_logprobs" in entry and entry["top_logprobs"] is not None:
148
  token = entry["token"]
149
  top_logprobs = entry["top_logprobs"]
150
  # Ensure all values in top_logprobs are floats
@@ -201,7 +501,7 @@ def visualize_logprobs(json_input):
201
  else:
202
  colored_text_html = "No finite log probabilities to display."
203
 
204
- # Create an alternative visualization for top 3 tokens
205
  alt_viz_html = ""
206
  if logprobs and top_alternatives:
207
  alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>"
@@ -212,26 +512,63 @@ def visualize_logprobs(json_input):
212
  alt_viz_html += "</li>"
213
  alt_viz_html += "</ul>"
214
 
215
- return img_html, df, colored_text_html, alt_viz_html
 
 
 
216
 
217
  except Exception as e:
218
  logger.error("Visualization failed: %s", str(e))
219
- return f"Error: {str(e)}", None, None, None
220
 
221
- # Gradio interface
222
  with gr.Blocks(title="Log Probability Visualizer") as app:
223
  gr.Markdown("# Log Probability Visualizer")
224
  gr.Markdown(
225
- "Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Ensure property names are in double quotes (e.g., \"content\") for JSON, or use correct Python dictionary format."
226
  )
227
 
228
- json_input = gr.Textbox(
229
- label="JSON Input",
230
- lines=10,
231
- placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
232
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- plot_output = gr.HTML(label="Log Probability Plot (Hover for Tokens)")
235
  table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
236
  text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
237
  alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities")
@@ -239,8 +576,14 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
239
  btn = gr.Button("Visualize")
240
  btn.click(
241
  fn=visualize_logprobs,
242
- inputs=json_input,
243
- outputs=[plot_output, table_output, text_output, alt_viz_output],
 
 
 
 
 
 
244
  )
245
 
246
  app.launch()
 
7
  import math
8
  import ast
9
  import logging
10
+ import numpy as np
11
+ from sklearn.cluster import KMeans
12
+ from sklearn.decomposition import PCA
13
+ from sklearn.manifold import TSNE
14
+ from scipy import stats
15
+ from scipy.stats import entropy
16
+ from scipy.signal import correlate
17
+ import networkx as nx
18
  from matplotlib.widgets import Cursor
19
 
20
  # Set up logging
 
64
  return float(value)
65
  return None
66
 
67
+ # Function to process and visualize log probs with multiple analyses
68
+ def visualize_logprobs(json_input, prob_filter=-float('inf')):
69
  try:
70
  # Parse the input (handles both JSON and Python dictionaries)
71
  data = parse_input(json_input)
 
82
  tokens = []
83
  logprobs = []
84
  top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
85
+ token_types = [] # Simplified token type categorization
86
  for entry in content:
87
  logprob = ensure_float(entry.get("logprob", None))
88
+ if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter:
89
  tokens.append(entry["token"])
90
  logprobs.append(logprob)
91
+ # Categorize token type (simple heuristic)
92
+ token = entry["token"].lower().strip()
93
+ if token in ["the", "a", "an"]: token_types.append("article")
94
+ elif token in ["is", "are", "was", "were"]: token_types.append("verb")
95
+ elif token in ["top", "so", "need", "figure"]: token_types.append("noun")
96
+ else: token_types.append("other")
97
  # Get top_logprobs, default to empty dict if None
98
  top_probs = entry.get("top_logprobs", {})
99
  # Ensure all values in top_logprobs are floats
 
111
  else:
112
  logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
113
 
114
+ # If no valid data after filtering, return error messages
115
+ if not logprobs:
116
+ return "No finite log probabilities to visualize after filtering.", None, None, None, None, None, None, None, None, None, None
117
+
118
+ # 1. Main Log Probability Plot (with click for tokens)
119
+ fig_main, ax_main = plt.subplots(figsize=(10, 5))
120
+ scatter = ax_main.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Selected Token")[0]
121
+ ax_main.set_title("Log Probabilities of Generated Tokens")
122
+ ax_main.set_xlabel("Token Position")
123
+ ax_main.set_ylabel("Log Probability")
124
+ ax_main.grid(True)
125
+ ax_main.set_xticks([]) # Hide X-axis labels by default
126
+
127
+ # Add click functionality to show token
128
+ token_annotations = []
129
+ for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
130
+ annotation = ax_main.annotate('', (x, y), xytext=(10, 10), textcoords='offset points', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), visible=False)
131
+ token_annotations.append(annotation)
132
+
133
+ def on_click(event):
134
+ if event.inaxes == ax_main:
135
+ for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
136
+ contains, _ = scatter.contains(event)
137
+ if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5:
138
+ token_annotations[i].set_text(tokens[i])
139
+ token_annotations[i].set_visible(True)
140
+ fig_main.canvas.draw_idle()
141
+ else:
142
+ token_annotations[i].set_visible(False)
143
+ fig_main.canvas.draw_idle()
144
+
145
+ fig_main.canvas.mpl_connect('button_press_event', on_click)
146
+
147
+ # Save main plot
148
+ buf_main = io.BytesIO()
149
+ plt.savefig(buf_main, format="png", bbox_inches="tight", dpi=100)
150
+ buf_main.seek(0)
151
+ plt.close(fig_main)
152
+ img_main_bytes = buf_main.getvalue()
153
+ img_main_base64 = base64.b64encode(img_main_bytes).decode("utf-8")
154
+ img_main_html = f'<img src="data:image/png;base64,{img_main_base64}" style="max-width: 100%; height: auto;">'
155
+
156
+ # 2. K-Means Clustering of Log Probabilities
157
+ kmeans = KMeans(n_clusters=3, random_state=42)
158
+ cluster_labels = kmeans.fit_predict(np.array(logprobs).reshape(-1, 1))
159
+ fig_cluster, ax_cluster = plt.subplots(figsize=(10, 5))
160
+ scatter = ax_cluster.scatter(range(len(logprobs)), logprobs, c=cluster_labels, cmap='viridis')
161
+ ax_cluster.set_title("K-Means Clustering of Log Probabilities")
162
+ ax_cluster.set_xlabel("Token Position")
163
+ ax_cluster.set_ylabel("Log Probability")
164
+ ax_cluster.grid(True)
165
+ plt.colorbar(scatter, ax=ax_cluster, label="Cluster")
166
+ buf_cluster = io.BytesIO()
167
+ plt.savefig(buf_cluster, format="png", bbox_inches="tight", dpi=100)
168
+ buf_cluster.seek(0)
169
+ plt.close(fig_cluster)
170
+ img_cluster_bytes = buf_cluster.getvalue()
171
+ img_cluster_base64 = base64.b64encode(img_cluster_bytes).decode("utf-8")
172
+ img_cluster_html = f'<img src="data:image/png;base64,{img_cluster_base64}" style="max-width: 100%; height: auto;">'
173
+
174
+ # 3. Probability Drop Analysis
175
+ drops = [logprobs[i+1] - logprobs[i] if i < len(logprobs)-1 else 0 for i in range(len(logprobs))]
176
+ fig_drops, ax_drops = plt.subplots(figsize=(10, 5))
177
+ ax_drops.bar(range(len(drops)), drops, color='red', alpha=0.5)
178
+ ax_drops.set_title("Significant Probability Drops")
179
+ ax_drops.set_xlabel("Token Position")
180
+ ax_drops.set_ylabel("Log Probability Drop")
181
+ ax_drops.grid(True)
182
+ buf_drops = io.BytesIO()
183
+ plt.savefig(buf_drops, format="png", bbox_inches="tight", dpi=100)
184
+ buf_drops.seek(0)
185
+ plt.close(fig_drops)
186
+ img_drops_bytes = buf_drops.getvalue()
187
+ img_drops_base64 = base64.b64encode(img_drops_bytes).decode("utf-8")
188
+ img_drops_html = f'<img src="data:image/png;base64,{img_drops_base64}" style="max-width: 100%; height: auto;">'
189
+
190
+ # 4. N-Gram Analysis (Bigrams for simplicity)
191
+ bigrams = [(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)]
192
+ bigram_probs = [logprobs[i] + logprobs[i+1] for i in range(len(tokens)-1)]
193
+ fig_ngram, ax_ngram = plt.subplots(figsize=(10, 5))
194
+ ax_ngram.bar(range(len(bigrams)), bigram_probs, color='green')
195
+ ax_ngram.set_title("N-Gram (Bigrams) Probability Sum")
196
+ ax_ngram.set_xlabel("Bigram Position")
197
+ ax_ngram.set_ylabel("Sum of Log Probabilities")
198
+ ax_ngram.set_xticks(range(len(bigrams)))
199
+ ax_ngram.set_xticklabels([f"{b[0]}->{b[1]}" for b in bigrams], rotation=45, ha="right")
200
+ ax_ngram.grid(True)
201
+ buf_ngram = io.BytesIO()
202
+ plt.savefig(buf_ngram, format="png", bbox_inches="tight", dpi=100)
203
+ buf_ngram.seek(0)
204
+ plt.close(fig_ngram)
205
+ img_ngram_bytes = buf_ngram.getvalue()
206
+ img_ngram_base64 = base64.b64encode(img_ngram_bytes).decode("utf-8")
207
+ img_ngram_html = f'<img src="data:image/png;base64,{img_ngram_base64}" style="max-width: 100%; height: auto;">'
208
+
209
+ # 5. Markov Chain Modeling (Simple Graph)
210
+ G = nx.DiGraph()
211
+ for i in range(len(tokens)-1):
212
+ G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i])
213
+ fig_markov, ax_markov = plt.subplots(figsize=(10, 5))
214
+ pos = nx.spring_layout(G)
215
+ nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_markov)
216
+ ax_markov.set_title("Markov Chain of Token Transitions")
217
+ buf_markov = io.BytesIO()
218
+ plt.savefig(buf_markov, format="png", bbox_inches="tight", dpi=100)
219
+ buf_markov.seek(0)
220
+ plt.close(fig_markov)
221
+ img_markov_bytes = buf_markov.getvalue()
222
+ img_markov_base64 = base64.b64encode(img_markov_bytes).decode("utf-8")
223
+ img_markov_html = f'<img src="data:image/png;base64,{img_markov_base64}" style="max-width: 100%; height: auto;">'
224
+
225
+ # 6. Anomaly Detection (Outlier Detection with Z-Score)
226
+ z_scores = np.abs(stats.zscore(logprobs))
227
+ outliers = z_scores > 2 # Threshold for outliers
228
+ fig_anomaly, ax_anomaly = plt.subplots(figsize=(10, 5))
229
+ ax_anomaly.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b")
230
+ ax_anomaly.plot(np.where(outliers)[0], [logprobs[i] for i in np.where(outliers)[0]], "ro", label="Outliers")
231
+ ax_anomaly.set_title("Log Probabilities with Outliers")
232
+ ax_anomaly.set_xlabel("Token Position")
233
+ ax_anomaly.set_ylabel("Log Probability")
234
+ ax_anomaly.grid(True)
235
+ ax_anomaly.legend()
236
+ ax_anomaly.set_xticks([]) # Hide X-axis labels
237
+ buf_anomaly = io.BytesIO()
238
+ plt.savefig(buf_anomaly, format="png", bbox_inches="tight", dpi=100)
239
+ buf_anomaly.seek(0)
240
+ plt.close(fig_anomaly)
241
+ img_anomaly_bytes = buf_anomaly.getvalue()
242
+ img_anomaly_base64 = base64.b64encode(img_anomaly_bytes).decode("utf-8")
243
+ img_anomaly_html = f'<img src="data:image/png;base64,{img_anomaly_base64}" style="max-width: 100%; height: auto;">'
244
+
245
+ # 7. Autocorrelation
246
+ autocorr = correlate(logprobs, logprobs, mode='full')
247
+ autocorr = autocorr[len(autocorr)//2:] / len(logprobs) # Normalize
248
+ fig_autocorr, ax_autocorr = plt.subplots(figsize=(10, 5))
249
+ ax_autocorr.plot(range(len(autocorr)), autocorr, color='purple')
250
+ ax_autocorr.set_title("Autocorrelation of Log Probabilities")
251
+ ax_autocorr.set_xlabel("Lag")
252
+ ax_autocorr.set_ylabel("Autocorrelation")
253
+ ax_autocorr.grid(True)
254
+ buf_autocorr = io.BytesIO()
255
+ plt.savefig(buf_autocorr, format="png", bbox_inches="tight", dpi=100)
256
+ buf_autocorr.seek(0)
257
+ plt.close(fig_autocorr)
258
+ img_autocorr_bytes = buf_autocorr.getvalue()
259
+ img_autocorr_base64 = base64.b64encode(img_autocorr_bytes).decode("utf-8")
260
+ img_autocorr_html = f'<img src="data:image/png;base64,{img_autocorr_base64}" style="max-width: 100%; height: auto;">'
261
+
262
+ # 8. Smoothing (Moving Average)
263
+ window_size = 3
264
+ moving_avg = np.convolve(logprobs, np.ones(window_size)/window_size, mode='valid')
265
+ fig_smoothing, ax_smoothing = plt.subplots(figsize=(10, 5))
266
+ ax_smoothing.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Original")
267
+ ax_smoothing.plot(range(window_size-1, len(logprobs)), moving_avg, color="orange", label="Moving Average")
268
+ ax_smoothing.set_title("Log Probabilities with Moving Average")
269
+ ax_smoothing.set_xlabel("Token Position")
270
+ ax_smoothing.set_ylabel("Log Probability")
271
+ ax_smoothing.grid(True)
272
+ ax_smoothing.legend()
273
+ ax_smoothing.set_xticks([]) # Hide X-axis labels
274
+ buf_smoothing = io.BytesIO()
275
+ plt.savefig(buf_smoothing, format="png", bbox_inches="tight", dpi=100)
276
+ buf_smoothing.seek(0)
277
+ plt.close(fig_smoothing)
278
+ img_smoothing_bytes = buf_smoothing.getvalue()
279
+ img_smoothing_base64 = base64.b64encode(img_smoothing_bytes).decode("utf-8")
280
+ img_smoothing_html = f'<img src="data:image/png;base64,{img_smoothing_base64}" style="max-width: 100%; height: auto;">'
281
+
282
+ # 9. Uncertainty Propagation (Variance of Top Logprobs)
283
+ variances = []
284
+ for probs in top_alternatives:
285
+ if len(probs) > 1:
286
+ values = [p[1] for p in probs]
287
+ variances.append(np.var(values))
288
+ else:
289
+ variances.append(0)
290
+ fig_uncertainty, ax_uncertainty = plt.subplots(figsize=(10, 5))
291
+ ax_uncertainty.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Log Prob")
292
+ ax_uncertainty.fill_between(range(len(logprobs)), [lp - v for lp, v in zip(logprobs, variances)],
293
+ [lp + v for lp, v in zip(logprobs, variances)], color='gray', alpha=0.3, label="Uncertainty")
294
+ ax_uncertainty.set_title("Log Probabilities with Uncertainty Propagation")
295
+ ax_uncertainty.set_xlabel("Token Position")
296
+ ax_uncertainty.set_ylabel("Log Probability")
297
+ ax_uncertainty.grid(True)
298
+ ax_uncertainty.legend()
299
+ ax_uncertainty.set_xticks([]) # Hide X-axis labels
300
+ buf_uncertainty = io.BytesIO()
301
+ plt.savefig(buf_uncertainty, format="png", bbox_inches="tight", dpi=100)
302
+ buf_uncertainty.seek(0)
303
+ plt.close(fig_uncertainty)
304
+ img_uncertainty_bytes = buf_uncertainty.getvalue()
305
+ img_uncertainty_base64 = base64.b64encode(img_uncertainty_bytes).decode("utf-8")
306
+ img_uncertainty_html = f'<img src="data:image/png;base64,{img_uncertainty_base64}" style="max-width: 100%; height: auto;">'
307
+
308
+ # 10. Correlation Heatmap
309
+ corr_matrix = np.corrcoef(logprobs, rowvar=False)
310
+ fig_corr, ax_corr = plt.subplots(figsize=(10, 5))
311
+ im = ax_corr.imshow(corr_matrix, cmap='coolwarm', interpolation='nearest')
312
+ ax_corr.set_title("Correlation of Log Probabilities Across Positions")
313
+ ax_corr.set_xlabel("Token Position")
314
+ ax_corr.set_ylabel("Token Position")
315
+ plt.colorbar(im, ax=ax_corr, label="Correlation")
316
+ buf_corr = io.BytesIO()
317
+ plt.savefig(buf_corr, format="png", bbox_inches="tight", dpi=100)
318
+ buf_corr.seek(0)
319
+ plt.close(fig_corr)
320
+ img_corr_bytes = buf_corr.getvalue()
321
+ img_corr_base64 = base64.b64encode(img_corr_bytes).decode("utf-8")
322
+ img_corr_html = f'<img src="data:image/png;base64,{img_corr_base64}" style="max-width: 100%; height: auto;">'
323
+
324
+ # 11. Token Type Correlation
325
+ type_probs = {t: [] for t in set(token_types)}
326
+ for t, p in zip(token_types, logprobs):
327
+ type_probs[t].append(p)
328
+ fig_type, ax_type = plt.subplots(figsize=(10, 5))
329
+ for t in type_probs:
330
+ ax_type.bar(t, np.mean(type_probs[t]), yerr=np.std(type_probs[t]), capsize=5, label=t)
331
+ ax_type.set_title("Average Log Probability by Token Type")
332
+ ax_type.set_xlabel("Token Type")
333
+ ax_type.set_ylabel("Average Log Probability")
334
+ ax_type.grid(True)
335
+ ax_type.legend()
336
+ buf_type = io.BytesIO()
337
+ plt.savefig(buf_type, format="png", bbox_inches="tight", dpi=100)
338
+ buf_type.seek(0)
339
+ plt.close(fig_type)
340
+ img_type_bytes = buf_type.getvalue()
341
+ img_type_base64 = base64.b64encode(img_type_bytes).decode("utf-8")
342
+ img_type_html = f'<img src="data:image/png;base64,{img_type_base64}" style="max-width: 100%; height: auto;">'
343
+
344
+ # 12. Token Embedding Similarity vs. Probability (Simulated)
345
+ # Simulate embedding distances (e.g., cosine similarity) as random values for demonstration
346
+ simulated_embeddings = np.random.rand(len(tokens), 2) # 2D embeddings
347
+ fig_embed, ax_embed = plt.subplots(figsize=(10, 5))
348
+ ax_embed.scatter(simulated_embeddings[:, 0], simulated_embeddings[:, 1], c=logprobs, cmap='viridis')
349
+ ax_embed.set_title("Token Embedding Similarity vs. Log Probability")
350
+ ax_embed.set_xlabel("Embedding Dimension 1")
351
+ ax_embed.set_ylabel("Embedding Dimension 2")
352
+ plt.colorbar(ax_embed.collections[0], ax=ax_embed, label="Log Probability")
353
+ buf_embed = io.BytesIO()
354
+ plt.savefig(buf_embed, format="png", bbox_inches="tight", dpi=100)
355
+ buf_embed.seek(0)
356
+ plt.close(fig_embed)
357
+ img_embed_bytes = buf_embed.getvalue()
358
+ img_embed_base64 = base64.b64encode(img_embed_bytes).decode("utf-8")
359
+ img_embed_html = f'<img src="data:image/png;base64,{img_embed_base64}" style="max-width: 100%; height: auto;">'
360
+
361
+ # 13. Bayesian Inference (Simplified as Inferred Probabilities)
362
+ # Simulate inferred probabilities based on top_logprobs entropy
363
+ entropies = [entropy([p[1] for p in probs], base=2) for probs in top_alternatives if len(probs) > 1]
364
+ fig_bayesian, ax_bayesian = plt.subplots(figsize=(10, 5))
365
+ ax_bayesian.bar(range(len(entropies)), entropies, color='orange')
366
+ ax_bayesian.set_title("Bayesian Inferred Uncertainty (Entropy)")
367
+ ax_bayesian.set_xlabel("Token Position")
368
+ ax_bayesian.set_ylabel("Entropy")
369
+ ax_bayesian.grid(True)
370
+ buf_bayesian = io.BytesIO()
371
+ plt.savefig(buf_bayesian, format="png", bbox_inches="tight", dpi=100)
372
+ buf_bayesian.seek(0)
373
+ plt.close(fig_bayesian)
374
+ img_bayesian_bytes = buf_bayesian.getvalue()
375
+ img_bayesian_base64 = base64.b64encode(img_bayesian_bytes).decode("utf-8")
376
+ img_bayesian_html = f'<img src="data:image/png;base64,{img_bayesian_base64}" style="max-width: 100%; height: auto;">'
377
+
378
+ # 14. Graph-Based Analysis
379
+ G = nx.DiGraph()
380
+ for i in range(len(tokens)-1):
381
+ G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i])
382
+ fig_graph, ax_graph = plt.subplots(figsize=(10, 5))
383
+ pos = nx.spring_layout(G)
384
+ nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_graph)
385
+ ax_graph.set_title("Graph of Token Transitions")
386
+ buf_graph = io.BytesIO()
387
+ plt.savefig(buf_graph, format="png", bbox_inches="tight", dpi=100)
388
+ buf_graph.seek(0)
389
+ plt.close(fig_graph)
390
+ img_graph_bytes = buf_graph.getvalue()
391
+ img_graph_base64 = base64.b64encode(img_graph_bytes).decode("utf-8")
392
+ img_graph_html = f'<img src="data:image/png;base64,{img_graph_base64}" style="max-width: 100%; height: auto;">'
393
+
394
+ # 15. Dimensionality Reduction (t-SNE)
395
+ features = np.array([logprobs + [p[1] for p in alts[:2]] for logprobs, alts in zip([logprobs], top_alternatives)])
396
+ tsne = TSNE(n_components=2, random_state=42)
397
+ tsne_result = tsne.fit_transform(features.T)
398
+ fig_tsne, ax_tsne = plt.subplots(figsize=(10, 5))
399
+ scatter = ax_tsne.scatter(tsne_result[:, 0], tsne_result[:, 1], c=logprobs, cmap='viridis')
400
+ ax_tsne.set_title("t-SNE of Log Probabilities and Top Alternatives")
401
+ ax_tsne.set_xlabel("t-SNE Dimension 1")
402
+ ax_tsne.set_ylabel("t-SNE Dimension 2")
403
+ plt.colorbar(scatter, ax=ax_tsne, label="Log Probability")
404
+ buf_tsne = io.BytesIO()
405
+ plt.savefig(buf_tsne, format="png", bbox_inches="tight", dpi=100)
406
+ buf_tsne.seek(0)
407
+ plt.close(fig_tsne)
408
+ img_tsne_bytes = buf_tsne.getvalue()
409
+ img_tsne_base64 = base64.b64encode(img_tsne_bytes).decode("utf-8")
410
+ img_tsne_html = f'<img src="data:image/png;base64,{img_tsne_base64}" style="max-width: 100%; height: auto;">'
411
+
412
+ # 16. Interactive Heatmap
413
+ fig_heatmap, ax_heatmap = plt.subplots(figsize=(10, 5))
414
+ im = ax_heatmap.imshow([logprobs], cmap='viridis', aspect='auto')
415
+ ax_heatmap.set_title("Interactive Heatmap of Log Probabilities")
416
+ ax_heatmap.set_xlabel("Token Position")
417
+ ax_heatmap.set_ylabel("Probability Level")
418
+ plt.colorbar(im, ax=ax_heatmap, label="Log Probability")
419
+ buf_heatmap = io.BytesIO()
420
+ plt.savefig(buf_heatmap, format="png", bbox_inches="tight", dpi=100)
421
+ buf_heatmap.seek(0)
422
+ plt.close(fig_heatmap)
423
+ img_heatmap_bytes = buf_heatmap.getvalue()
424
+ img_heatmap_base64 = base64.b64encode(img_heatmap_bytes).decode("utf-8")
425
+ img_heatmap_html = f'<img src="data:image/png;base64,{img_heatmap_base64}" style="max-width: 100%; height: auto;">'
426
+
427
+ # 17. Probability Distribution Plots (Box Plots for Top Logprobs)
428
+ all_top_probs = [p[1] for alts in top_alternatives for p in alts]
429
+ fig_dist, ax_dist = plt.subplots(figsize=(10, 5))
430
+ ax_dist.boxplot([logprobs] + [p[1] for alts in top_alternatives for p in alts[:2]], labels=["Selected"] + ["Alt1", "Alt2"])
431
+ ax_dist.set_title("Probability Distribution of Top Tokens")
432
+ ax_dist.set_xlabel("Token Type")
433
+ ax_dist.set_ylabel("Log Probability")
434
+ ax_dist.grid(True)
435
+ buf_dist = io.BytesIO()
436
+ plt.savefig(buf_dist, format="png", bbox_inches="tight", dpi=100)
437
+ buf_dist.seek(0)
438
+ plt.close(fig_dist)
439
+ img_dist_bytes = buf_dist.getvalue()
440
+ img_dist_base64 = base64.b64encode(img_dist_bytes).decode("utf-8")
441
+ img_dist_html = f'<img src="data:image/png;base64,{img_dist_base64}" style="max-width: 100%; height: auto;">'
442
 
443
  # Create DataFrame for the table
444
  table_data = []
445
  for i, entry in enumerate(content):
446
  logprob = ensure_float(entry.get("logprob", None))
447
+ if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter and "top_logprobs" in entry and entry["top_logprobs"] is not None:
448
  token = entry["token"]
449
  top_logprobs = entry["top_logprobs"]
450
  # Ensure all values in top_logprobs are floats
 
501
  else:
502
  colored_text_html = "No finite log probabilities to display."
503
 
504
+ # Top 3 Token Log Probabilities
505
  alt_viz_html = ""
506
  if logprobs and top_alternatives:
507
  alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>"
 
512
  alt_viz_html += "</li>"
513
  alt_viz_html += "</ul>"
514
 
515
+ return (img_main_html, df, colored_text_html, alt_viz_html, img_cluster_html, img_drops_html,
516
+ img_ngram_html, img_markov_html, img_anomaly_html, img_autocorr_html, img_smoothing_html,
517
+ img_uncertainty_html, img_corr_html, img_type_html, img_embed_html, img_bayesian_html,
518
+ img_graph_html, img_tsne_html, img_heatmap_html, img_dist_html)
519
 
520
  except Exception as e:
521
  logger.error("Visualization failed: %s", str(e))
522
+ return (f"Error: {str(e)}", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
523
 
524
+ # Gradio interface with dynamic filtering
525
  with gr.Blocks(title="Log Probability Visualizer") as app:
526
  gr.Markdown("# Log Probability Visualizer")
527
  gr.Markdown(
528
+ "Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use the filter to focus on specific log probability ranges."
529
  )
530
 
531
+ with gr.Row():
532
+ json_input = gr.Textbox(
533
+ label="JSON Input",
534
+ lines=10,
535
+ placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
536
+ )
537
+ prob_filter = gr.Slider(minimum=-float('inf'), maximum=0, value=-float('inf'), label="Log Probability Filter (≥)")
538
+
539
+ with gr.Row():
540
+ plot_output = gr.HTML(label="Log Probability Plot (Click for Tokens)")
541
+ cluster_output = gr.HTML(label="K-Means Clustering")
542
+ drops_output = gr.HTML(label="Probability Drops")
543
+
544
+ with gr.Row():
545
+ ngram_output = gr.HTML(label="N-Gram Analysis")
546
+ markov_output = gr.HTML(label="Markov Chain")
547
+
548
+ with gr.Row():
549
+ anomaly_output = gr.HTML(label="Anomaly Detection")
550
+ autocorr_output = gr.HTML(label="Autocorrelation")
551
+
552
+ with gr.Row():
553
+ smoothing_output = gr.HTML(label="Smoothing (Moving Average)")
554
+ uncertainty_output = gr.HTML(label="Uncertainty Propagation")
555
+
556
+ with gr.Row():
557
+ corr_output = gr.HTML(label="Correlation Heatmap")
558
+ type_output = gr.HTML(label="Token Type Correlation")
559
+
560
+ with gr.Row():
561
+ embed_output = gr.HTML(label="Embedding Similarity vs. Probability")
562
+ bayesian_output = gr.HTML(label="Bayesian Inference (Entropy)")
563
+
564
+ with gr.Row():
565
+ graph_output = gr.HTML(label="Graph of Token Transitions")
566
+ tsne_output = gr.HTML(label="t-SNE of Log Probabilities")
567
+
568
+ with gr.Row():
569
+ heatmap_output = gr.HTML(label="Interactive Heatmap")
570
+ dist_output = gr.HTML(label="Probability Distribution")
571
 
 
572
  table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
573
  text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
574
  alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities")
 
576
  btn = gr.Button("Visualize")
577
  btn.click(
578
  fn=visualize_logprobs,
579
+ inputs=[json_input, prob_filter],
580
+ outputs=[
581
+ plot_output, table_output, text_output, alt_viz_output,
582
+ cluster_output, drops_output, ngram_output, markov_output,
583
+ anomaly_output, autocorr_output, smoothing_output, uncertainty_output,
584
+ corr_output, type_output, embed_output, bayesian_output,
585
+ graph_output, tsne_output, heatmap_output, dist_output
586
+ ],
587
  )
588
 
589
  app.launch()