codelion commited on
Commit
cbaf223
·
verified ·
1 Parent(s): ccde0a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -597
app.py CHANGED
@@ -8,13 +8,9 @@ import math
8
  import ast
9
  import logging
10
  import numpy as np
11
- from sklearn.cluster import KMeans
12
- from sklearn.manifold import TSNE
13
  from scipy import stats
14
- from scipy.stats import entropy
15
- from scipy.signal import correlate
16
- import networkx as nx
17
- from matplotlib.widgets import Cursor
18
 
19
  # Set up logging
20
  logging.basicConfig(level=logging.DEBUG)
@@ -63,8 +59,8 @@ def ensure_float(value):
63
  return float(value)
64
  return None
65
 
66
- # Function to process and visualize log probs with multiple analyses
67
- def visualize_logprobs(json_input, prob_filter=-1e9):
68
  try:
69
  # Parse the input (handles both JSON and Python dictionaries)
70
  data = parse_input(json_input)
@@ -81,18 +77,11 @@ def visualize_logprobs(json_input, prob_filter=-1e9):
81
  tokens = []
82
  logprobs = []
83
  top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
84
- token_types = [] # Simplified token type categorization
85
  for entry in content:
86
  logprob = ensure_float(entry.get("logprob", None))
87
  if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter:
88
  tokens.append(entry["token"])
89
  logprobs.append(logprob)
90
- # Categorize token type (simple heuristic)
91
- token = entry["token"].lower().strip()
92
- if token in ["the", "a", "an"]: token_types.append("article")
93
- elif token in ["is", "are", "was", "were"]: token_types.append("verb")
94
- elif token in ["top", "so", "need", "figure"]: token_types.append("noun")
95
- else: token_types.append("other")
96
  # Get top_logprobs, default to empty dict if None
97
  top_probs = entry.get("top_logprobs", {})
98
  # Ensure all values in top_logprobs are floats
@@ -112,505 +101,76 @@ def visualize_logprobs(json_input, prob_filter=-1e9):
112
 
113
  # Check if there's valid data after filtering
114
  if not logprobs or not tokens:
115
- return ("No finite log probabilities or tokens to visualize after filtering.", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
116
-
117
- # 1. Main Log Probability Plot (with click for tokens)
118
- def create_main_plot():
119
- fig_main, ax_main = plt.subplots(figsize=(10, 5))
120
- if not logprobs or not tokens:
121
- raise ValueError("No data for main plot")
122
- scatter = ax_main.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Selected Token")[0]
123
- ax_main.set_title("Log Probabilities of Generated Tokens")
124
- ax_main.set_xlabel("Token Position")
125
- ax_main.set_ylabel("Log Probability")
126
- ax_main.grid(True)
127
- ax_main.set_xticks([]) # Hide X-axis labels by default
128
-
129
- # Add click functionality to show token
130
- token_annotations = []
131
- for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
132
- annotation = ax_main.annotate('', (x, y), xytext=(10, 10), textcoords='offset points', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), visible=False)
133
- token_annotations.append(annotation)
134
-
135
- def on_click(event):
136
- if event.inaxes == ax_main:
137
- for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
138
- contains, _ = scatter.contains(event)
139
- if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5:
140
- token_annotations[i].set_text(tokens[i])
141
- token_annotations[i].set_visible(True)
142
- fig_main.canvas.draw_idle()
143
- else:
144
- token_annotations[i].set_visible(False)
145
- fig_main.canvas.draw_idle()
146
-
147
- fig_main.canvas.mpl_connect('button_press_event', on_click)
148
-
149
- buf_main = io.BytesIO()
150
- plt.savefig(buf_main, format="png", bbox_inches="tight", dpi=100)
151
- buf_main.seek(0)
152
- plt.close(fig_main)
153
- return buf_main
154
-
155
- # 2. K-Means Clustering of Log Probabilities
156
- def create_cluster_plot():
157
- if not logprobs:
158
- raise ValueError("No data for clustering plot")
159
- kmeans = KMeans(n_clusters=3, random_state=42)
160
- cluster_labels = kmeans.fit_predict(np.array(logprobs).reshape(-1, 1))
161
- fig_cluster, ax_cluster = plt.subplots(figsize=(10, 5))
162
- scatter = ax_cluster.scatter(range(len(logprobs)), logprobs, c=cluster_labels, cmap='viridis')
163
- ax_cluster.set_title("K-Means Clustering of Log Probabilities")
164
- ax_cluster.set_xlabel("Token Position")
165
- ax_cluster.set_ylabel("Log Probability")
166
- ax_cluster.grid(True)
167
- plt.colorbar(scatter, ax=ax_cluster, label="Cluster")
168
- buf_cluster = io.BytesIO()
169
- plt.savefig(buf_cluster, format="png", bbox_inches="tight", dpi=100)
170
- buf_cluster.seek(0)
171
- plt.close(fig_cluster)
172
- return buf_cluster
173
-
174
- # 3. Probability Drop Analysis
175
- def create_drops_plot():
176
- if not logprobs or len(logprobs) < 2:
177
- raise ValueError("Insufficient data for probability drops")
178
- drops = [logprobs[i+1] - logprobs[i] if i < len(logprobs)-1 else 0 for i in range(len(logprobs))]
179
- fig_drops, ax_drops = plt.subplots(figsize=(10, 5))
180
- ax_drops.bar(range(len(drops)), drops, color='red', alpha=0.5)
181
- ax_drops.set_title("Significant Probability Drops")
182
- ax_drops.set_xlabel("Token Position")
183
- ax_drops.set_ylabel("Log Probability Drop")
184
- ax_drops.grid(True)
185
- buf_drops = io.BytesIO()
186
- plt.savefig(buf_drops, format="png", bbox_inches="tight", dpi=100)
187
- buf_drops.seek(0)
188
- plt.close(fig_drops)
189
- return buf_drops
190
-
191
- # 4. N-Gram Analysis (Bigrams for simplicity)
192
- def create_ngram_plot():
193
- if not logprobs or len(logprobs) < 2:
194
- raise ValueError("Insufficient data for N-gram analysis")
195
- bigrams = [(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)]
196
- bigram_probs = [logprobs[i] + logprobs[i+1] for i in range(len(tokens)-1)]
197
- fig_ngram, ax_ngram = plt.subplots(figsize=(10, 5))
198
- ax_ngram.bar(range(len(bigrams)), bigram_probs, color='green')
199
- ax_ngram.set_title("N-Gram (Bigrams) Probability Sum")
200
- ax_ngram.set_xlabel("Bigram Position")
201
- ax_ngram.set_ylabel("Sum of Log Probabilities")
202
- ax_ngram.set_xticks(range(len(bigrams)))
203
- ax_ngram.set_xticklabels([f"{b[0]}->{b[1]}" for b in bigrams], rotation=45, ha="right")
204
- ax_ngram.grid(True)
205
- buf_ngram = io.BytesIO()
206
- plt.savefig(buf_ngram, format="png", bbox_inches="tight", dpi=100)
207
- buf_ngram.seek(0)
208
- plt.close(fig_ngram)
209
- return buf_ngram
210
 
211
- # 5. Markov Chain Modeling (Simple Graph)
212
- def create_markov_plot():
213
- if not tokens or len(tokens) < 2:
214
- raise ValueError("Insufficient data for Markov chain")
215
- G = nx.DiGraph()
216
- for i in range(len(tokens)-1):
217
- G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i])
218
- fig_markov, ax_markov = plt.subplots(figsize=(10, 5))
219
- pos = nx.spring_layout(G)
220
- nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_markov)
221
- ax_markov.set_title("Markov Chain of Token Transitions")
222
- buf_markov = io.BytesIO()
223
- plt.savefig(buf_markov, format="png", bbox_inches="tight", dpi=100)
224
- buf_markov.seek(0)
225
- plt.close(fig_markov)
226
- return buf_markov
 
 
 
227
 
228
- # 6. Anomaly Detection (Outlier Detection with Z-Score)
229
- def create_anomaly_plot():
230
- if not logprobs:
231
- raise ValueError("No data for anomaly detection")
232
- z_scores = np.abs(stats.zscore(logprobs))
 
233
  outliers = z_scores > 2 # Threshold for outliers
234
- fig_anomaly, ax_anomaly = plt.subplots(figsize=(10, 5))
235
- ax_anomaly.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b")
236
- ax_anomaly.plot(np.where(outliers)[0], [logprobs[i] for i in np.where(outliers)[0]], "ro", label="Outliers")
237
- ax_anomaly.set_title("Log Probabilities with Outliers")
238
- ax_anomaly.set_xlabel("Token Position")
239
- ax_anomaly.set_ylabel("Log Probability")
240
- ax_anomaly.grid(True)
241
- ax_anomaly.legend()
242
- ax_anomaly.set_xticks([]) # Hide X-axis labels
243
- buf_anomaly = io.BytesIO()
244
- plt.savefig(buf_anomaly, format="png", bbox_inches="tight", dpi=100)
245
- buf_anomaly.seek(0)
246
- plt.close(fig_anomaly)
247
- return buf_anomaly
248
-
249
- # 7. Autocorrelation
250
- def create_autocorr_plot():
251
- if not logprobs:
252
- raise ValueError("No data for autocorrelation")
253
- autocorr = correlate(logprobs, logprobs, mode='full')
254
- autocorr = autocorr[len(autocorr)//2:] / len(logprobs) # Normalize
255
- fig_autocorr, ax_autocorr = plt.subplots(figsize=(10, 5))
256
- ax_autocorr.plot(range(len(autocorr)), autocorr, color='purple')
257
- ax_autocorr.set_title("Autocorrelation of Log Probabilities")
258
- ax_autocorr.set_xlabel("Lag")
259
- ax_autocorr.set_ylabel("Autocorrelation")
260
- ax_autocorr.grid(True)
261
- buf_autocorr = io.BytesIO()
262
- plt.savefig(buf_autocorr, format="png", bbox_inches="tight", dpi=100)
263
- buf_autocorr.seek(0)
264
- plt.close(fig_autocorr)
265
- return buf_autocorr
266
-
267
- # 8. Smoothing (Moving Average)
268
- def create_smoothing_plot():
269
- if not logprobs:
270
- raise ValueError("No data for smoothing")
271
- window_size = 3
272
- moving_avg = np.convolve(logprobs, np.ones(window_size)/window_size, mode='valid')
273
- fig_smoothing, ax_smoothing = plt.subplots(figsize=(10, 5))
274
- ax_smoothing.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Original")
275
- ax_smoothing.plot(range(window_size-1, len(logprobs)), moving_avg, color="orange", label="Moving Average")
276
- ax_smoothing.set_title("Log Probabilities with Moving Average")
277
- ax_smoothing.set_xlabel("Token Position")
278
- ax_smoothing.set_ylabel("Log Probability")
279
- ax_smoothing.grid(True)
280
- ax_smoothing.legend()
281
- ax_smoothing.set_xticks([]) # Hide X-axis labels
282
- buf_smoothing = io.BytesIO()
283
- plt.savefig(buf_smoothing, format="png", bbox_inches="tight", dpi=100)
284
- buf_smoothing.seek(0)
285
- plt.close(fig_smoothing)
286
- return buf_smoothing
287
-
288
- # 9. Uncertainty Propagation (Variance of Top Logprobs)
289
- def create_uncertainty_plot():
290
- if not logprobs or not top_alternatives:
291
- raise ValueError("No data for uncertainty propagation")
292
- variances = []
293
- for probs in top_alternatives:
294
- if len(probs) > 1:
295
- values = [p[1] for p in probs]
296
- variances.append(np.var(values))
297
- else:
298
- variances.append(0)
299
- fig_uncertainty, ax_uncertainty = plt.subplots(figsize=(10, 5))
300
- ax_uncertainty.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Log Prob")
301
- ax_uncertainty.fill_between(range(len(logprobs)), [lp - v for lp, v in zip(logprobs, variances)],
302
- [lp + v for lp, v in zip(logprobs, variances)], color='gray', alpha=0.3, label="Uncertainty")
303
- ax_uncertainty.set_title("Log Probabilities with Uncertainty Propagation")
304
- ax_uncertainty.set_xlabel("Token Position")
305
- ax_uncertainty.set_ylabel("Log Probability")
306
- ax_uncertainty.grid(True)
307
- ax_uncertainty.legend()
308
- ax_uncertainty.set_xticks([]) # Hide X-axis labels
309
- buf_uncertainty = io.BytesIO()
310
- plt.savefig(buf_uncertainty, format="png", bbox_inches="tight", dpi=100)
311
- buf_uncertainty.seek(0)
312
- plt.close(fig_uncertainty)
313
- return buf_uncertainty
314
-
315
- # 10. Correlation Heatmap
316
- def create_corr_plot():
317
- if not logprobs or len(logprobs) < 2:
318
- raise ValueError("Insufficient data for correlation heatmap")
319
- corr_matrix = np.corrcoef(logprobs, rowvar=False)
320
- fig_corr, ax_corr = plt.subplots(figsize=(10, 5))
321
- im = ax_corr.imshow(corr_matrix, cmap='coolwarm', interpolation='nearest')
322
- ax_corr.set_title("Correlation of Log Probabilities Across Positions")
323
- ax_corr.set_xlabel("Token Position")
324
- ax_corr.set_ylabel("Token Position")
325
- plt.colorbar(im, ax=ax_corr, label="Correlation")
326
- buf_corr = io.BytesIO()
327
- plt.savefig(buf_corr, format="png", bbox_inches="tight", dpi=100)
328
- buf_corr.seek(0)
329
- plt.close(fig_corr)
330
- return buf_corr
331
-
332
- # 11. Token Type Correlation
333
- def create_type_plot():
334
- if not logprobs or not token_types:
335
- raise ValueError("No data for token type correlation")
336
- type_probs = {t: [] for t in set(token_types)}
337
- for t, p in zip(token_types, logprobs):
338
- type_probs[t].append(p)
339
- fig_type, ax_type = plt.subplots(figsize=(10, 5))
340
- for t in type_probs:
341
- ax_type.bar(t, np.mean(type_probs[t]), yerr=np.std(type_probs[t]), capsize=5, label=t)
342
- ax_type.set_title("Average Log Probability by Token Type")
343
- ax_type.set_xlabel("Token Type")
344
- ax_type.set_ylabel("Average Log Probability")
345
- ax_type.grid(True)
346
- ax_type.legend()
347
- buf_type = io.BytesIO()
348
- plt.savefig(buf_type, format="png", bbox_inches="tight", dpi=100)
349
- buf_type.seek(0)
350
- plt.close(fig_type)
351
- return buf_type
352
-
353
- # 12. Token Embedding Similarity vs. Probability (Simulated)
354
- def create_embed_plot():
355
- if not logprobs or not tokens:
356
- raise ValueError("No data for embedding similarity")
357
- simulated_embeddings = np.random.rand(len(tokens), 2) # 2D embeddings
358
- fig_embed, ax_embed = plt.subplots(figsize=(10, 5))
359
- ax_embed.scatter(simulated_embeddings[:, 0], simulated_embeddings[:, 1], c=logprobs, cmap='viridis')
360
- ax_embed.set_title("Token Embedding Similarity vs. Log Probability")
361
- ax_embed.set_xlabel("Embedding Dimension 1")
362
- ax_embed.set_ylabel("Embedding Dimension 2")
363
- plt.colorbar(ax_embed.collections[0], ax=ax_embed, label="Log Probability")
364
- buf_embed = io.BytesIO()
365
- plt.savefig(buf_embed, format="png", bbox_inches="tight", dpi=100)
366
- buf_embed.seek(0)
367
- plt.close(fig_embed)
368
- return buf_embed
369
-
370
- # 13. Bayesian Inference (Simplified as Inferred Probabilities)
371
- def create_bayesian_plot():
372
- if not top_alternatives:
373
- raise ValueError("No data for Bayesian inference")
374
- entropies = [entropy([p[1] for p in probs], base=2) for probs in top_alternatives if len(probs) > 1]
375
- fig_bayesian, ax_bayesian = plt.subplots(figsize=(10, 5))
376
- ax_bayesian.bar(range(len(entropies)), entropies, color='orange')
377
- ax_bayesian.set_title("Bayesian Inferred Uncertainty (Entropy)")
378
- ax_bayesian.set_xlabel("Token Position")
379
- ax_bayesian.set_ylabel("Entropy")
380
- ax_bayesian.grid(True)
381
- buf_bayesian = io.BytesIO()
382
- plt.savefig(buf_bayesian, format="png", bbox_inches="tight", dpi=100)
383
- buf_bayesian.seek(0)
384
- plt.close(fig_bayesian)
385
- return buf_bayesian
386
-
387
- # 14. Graph-Based Analysis
388
- def create_graph_plot():
389
- if not tokens or len(tokens) < 2:
390
- raise ValueError("Insufficient data for graph analysis")
391
- G = nx.DiGraph()
392
- for i in range(len(tokens)-1):
393
- G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i])
394
- fig_graph, ax_graph = plt.subplots(figsize=(10, 5))
395
- pos = nx.spring_layout(G)
396
- nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_graph)
397
- ax_graph.set_title("Graph of Token Transitions")
398
- buf_graph = io.BytesIO()
399
- plt.savefig(buf_graph, format="png", bbox_inches="tight", dpi=100)
400
- buf_graph.seek(0)
401
- plt.close(fig_graph)
402
- return buf_graph
403
-
404
- # 15. Dimensionality Reduction (t-SNE)
405
- def create_tsne_plot():
406
- if not logprobs or not top_alternatives:
407
- raise ValueError("No data for t-SNE")
408
- features = np.array([logprobs + [p[1] for p in alts[:2]] for logprobs, alts in zip([logprobs], top_alternatives)])
409
- tsne = TSNE(n_components=2, random_state=42)
410
- tsne_result = tsne.fit_transform(features.T)
411
- fig_tsne, ax_tsne = plt.subplots(figsize=(10, 5))
412
- scatter = ax_tsne.scatter(tsne_result[:, 0], tsne_result[:, 1], c=logprobs, cmap='viridis')
413
- ax_tsne.set_title("t-SNE of Log Probabilities and Top Alternatives")
414
- ax_tsne.set_xlabel("t-SNE Dimension 1")
415
- ax_tsne.set_ylabel("t-SNE Dimension 2")
416
- plt.colorbar(scatter, ax=ax_tsne, label="Log Probability")
417
- buf_tsne = io.BytesIO()
418
- plt.savefig(buf_tsne, format="png", bbox_inches="tight", dpi=100)
419
- buf_tsne.seek(0)
420
- plt.close(fig_tsne)
421
- return buf_tsne
422
-
423
- # 16. Interactive Heatmap
424
- def create_heatmap_plot():
425
- if not logprobs:
426
- raise ValueError("No data for heatmap")
427
- fig_heatmap, ax_heatmap = plt.subplots(figsize=(10, 5))
428
- im = ax_heatmap.imshow([logprobs], cmap='viridis', aspect='auto')
429
- ax_heatmap.set_title("Interactive Heatmap of Log Probabilities")
430
- ax_heatmap.set_xlabel("Token Position")
431
- ax_heatmap.set_ylabel("Probability Level")
432
- plt.colorbar(im, ax=ax_heatmap, label="Log Probability")
433
- buf_heatmap = io.BytesIO()
434
- plt.savefig(buf_heatmap, format="png", bbox_inches="tight", dpi=100)
435
- buf_heatmap.seek(0)
436
- plt.close(fig_heatmap)
437
- return buf_heatmap
438
-
439
- # 17. Probability Distribution Plots (Box Plots for Top Logprobs)
440
- def create_dist_plot():
441
- if not logprobs or not top_alternatives:
442
- raise ValueError("No data for probability distribution")
443
- all_top_probs = [p[1] for alts in top_alternatives for p in alts]
444
- fig_dist, ax_dist = plt.subplots(figsize=(10, 5))
445
- ax_dist.boxplot([logprobs] + [p[1] for alts in top_alternatives for p in alts[:2]], labels=["Selected"] + ["Alt1", "Alt2"])
446
- ax_dist.set_title("Probability Distribution of Top Tokens")
447
- ax_dist.set_xlabel("Token Type")
448
- ax_dist.set_ylabel("Log Probability")
449
- ax_dist.grid(True)
450
- buf_dist = io.BytesIO()
451
- plt.savefig(buf_dist, format="png", bbox_inches="tight", dpi=100)
452
- buf_dist.seek(0)
453
- plt.close(fig_dist)
454
- return buf_dist
455
-
456
- # Create all plots safely
457
- img_main_html = "Placeholder for Log Probability Plot"
458
- img_cluster_html = "Placeholder for K-Means Clustering"
459
- img_drops_html = "Placeholder for Probability Drops"
460
- img_ngram_html = "Placeholder for N-Gram Analysis"
461
- img_markov_html = "Placeholder for Markov Chain"
462
- img_anomaly_html = "Placeholder for Anomaly Detection"
463
- img_autocorr_html = "Placeholder for Autocorrelation"
464
- img_smoothing_html = "Placeholder for Smoothing (Moving Average)"
465
- img_uncertainty_html = "Placeholder for Uncertainty Propagation"
466
- img_corr_html = "Placeholder for Correlation Heatmap"
467
- img_type_html = "Placeholder for Token Type Correlation"
468
- img_embed_html = "Placeholder for Embedding Similarity vs. Probability"
469
- img_bayesian_html = "Placeholder for Bayesian Inference (Entropy)"
470
- img_graph_html = "Placeholder for Graph of Token Transitions"
471
- img_tsne_html = "Placeholder for t-SNE of Log Probabilities"
472
- img_heatmap_html = "Placeholder for Interactive Heatmap"
473
- img_dist_html = "Placeholder for Probability Distribution"
474
-
475
- try:
476
- buf_main = create_main_plot()
477
- img_main_bytes = buf_main.getvalue()
478
- img_main_base64 = base64.b64encode(img_main_bytes).decode("utf-8")
479
- img_main_html = f'<img src="data:image/png;base64,{img_main_base64}" style="max-width: 100%; height: auto;">'
480
- except Exception as e:
481
- logger.error("Failed to create main plot: %s", str(e))
482
-
483
- try:
484
- buf_cluster = create_cluster_plot()
485
- img_cluster_bytes = buf_cluster.getvalue()
486
- img_cluster_base64 = base64.b64encode(img_cluster_bytes).decode("utf-8")
487
- img_cluster_html = f'<img src="data:image/png;base64,{img_cluster_base64}" style="max-width: 100%; height: auto;">'
488
- except Exception as e:
489
- logger.error("Failed to create cluster plot: %s", str(e))
490
-
491
- try:
492
- buf_drops = create_drops_plot()
493
- img_drops_bytes = buf_drops.getvalue()
494
- img_drops_base64 = base64.b64encode(img_drops_bytes).decode("utf-8")
495
- img_drops_html = f'<img src="data:image/png;base64,{img_drops_base64}" style="max-width: 100%; height: auto;">'
496
- except Exception as e:
497
- logger.error("Failed to create drops plot: %s", str(e))
498
-
499
- try:
500
- buf_ngram = create_ngram_plot()
501
- img_ngram_bytes = buf_ngram.getvalue()
502
- img_ngram_base64 = base64.b64encode(img_ngram_bytes).decode("utf-8")
503
- img_ngram_html = f'<img src="data:image/png;base64,{img_ngram_base64}" style="max-width: 100%; height: auto;">'
504
- except Exception as e:
505
- logger.error("Failed to create ngram plot: %s", str(e))
506
-
507
- try:
508
- buf_markov = create_markov_plot()
509
- img_markov_bytes = buf_markov.getvalue()
510
- img_markov_base64 = base64.b64encode(img_markov_bytes).decode("utf-8")
511
- img_markov_html = f'<img src="data:image/png;base64,{img_markov_base64}" style="max-width: 100%; height: auto;">'
512
- except Exception as e:
513
- logger.error("Failed to create markov plot: %s", str(e))
514
-
515
- try:
516
- buf_anomaly = create_anomaly_plot()
517
- img_anomaly_bytes = buf_anomaly.getvalue()
518
- img_anomaly_base64 = base64.b64encode(img_anomaly_bytes).decode("utf-8")
519
- img_anomaly_html = f'<img src="data:image/png;base64,{img_anomaly_base64}" style="max-width: 100%; height: auto;">'
520
- except Exception as e:
521
- logger.error("Failed to create anomaly plot: %s", str(e))
522
-
523
- try:
524
- buf_autocorr = create_autocorr_plot()
525
- img_autocorr_bytes = buf_autocorr.getvalue()
526
- img_autocorr_base64 = base64.b64encode(img_autocorr_bytes).decode("utf-8")
527
- img_autocorr_html = f'<img src="data:image/png;base64,{img_autocorr_base64}" style="max-width: 100%; height: auto;">'
528
- except Exception as e:
529
- logger.error("Failed to create autocorr plot: %s", str(e))
530
-
531
- try:
532
- buf_smoothing = create_smoothing_plot()
533
- img_smoothing_bytes = buf_smoothing.getvalue()
534
- img_smoothing_base64 = base64.b64encode(img_smoothing_bytes).decode("utf-8")
535
- img_smoothing_html = f'<img src="data:image/png;base64,{img_smoothing_base64}" style="max-width: 100%; height: auto;">'
536
- except Exception as e:
537
- logger.error("Failed to create smoothing plot: %s", str(e))
538
-
539
- try:
540
- buf_uncertainty = create_uncertainty_plot()
541
- img_uncertainty_bytes = buf_uncertainty.getvalue()
542
- img_uncertainty_base64 = base64.b64encode(img_uncertainty_bytes).decode("utf-8")
543
- img_uncertainty_html = f'<img src="data:image/png;base64,{img_uncertainty_base64}" style="max-width: 100%; height: auto;">'
544
- except Exception as e:
545
- logger.error("Failed to create uncertainty plot: %s", str(e))
546
-
547
- try:
548
- buf_corr = create_corr_plot()
549
- img_corr_bytes = buf_corr.getvalue()
550
- img_corr_base64 = base64.b64encode(img_corr_bytes).decode("utf-8")
551
- img_corr_html = f'<img src="data:image/png;base64,{img_corr_base64}" style="max-width: 100%; height: auto;">'
552
- except Exception as e:
553
- logger.error("Failed to create correlation plot: %s", str(e))
554
-
555
- try:
556
- buf_type = create_type_plot()
557
- img_type_bytes = buf_type.getvalue()
558
- img_type_base64 = base64.b64encode(img_type_bytes).decode("utf-8")
559
- img_type_html = f'<img src="data:image/png;base64,{img_type_base64}" style="max-width: 100%; height: auto;">'
560
- except Exception as e:
561
- logger.error("Failed to create type plot: %s", str(e))
562
-
563
- try:
564
- buf_embed = create_embed_plot()
565
- img_embed_bytes = buf_embed.getvalue()
566
- img_embed_base64 = base64.b64encode(img_embed_bytes).decode("utf-8")
567
- img_embed_html = f'<img src="data:image/png;base64,{img_embed_base64}" style="max-width: 100%; height: auto;">'
568
- except Exception as e:
569
- logger.error("Failed to create embed plot: %s", str(e))
570
-
571
- try:
572
- buf_bayesian = create_bayesian_plot()
573
- img_bayesian_bytes = buf_bayesian.getvalue()
574
- img_bayesian_base64 = base64.b64encode(img_bayesian_bytes).decode("utf-8")
575
- img_bayesian_html = f'<img src="data:image/png;base64,{img_bayesian_base64}" style="max-width: 100%; height: auto;">'
576
- except Exception as e:
577
- logger.error("Failed to create bayesian plot: %s", str(e))
578
-
579
- try:
580
- buf_graph = create_graph_plot()
581
- img_graph_bytes = buf_graph.getvalue()
582
- img_graph_base64 = base64.b64encode(img_graph_bytes).decode("utf-8")
583
- img_graph_html = f'<img src="data:image/png;base64,{img_graph_base64}" style="max-width: 100%; height: auto;">'
584
- except Exception as e:
585
- logger.error("Failed to create graph plot: %s", str(e))
586
-
587
- try:
588
- buf_tsne = create_tsne_plot()
589
- img_tsne_bytes = buf_tsne.getvalue()
590
- img_tsne_base64 = base64.b64encode(img_tsne_bytes).decode("utf-8")
591
- img_tsne_html = f'<img src="data:image/png;base64,{img_tsne_base64}" style="max-width: 100%; height: auto;">'
592
- except Exception as e:
593
- logger.error("Failed to create tsne plot: %s", str(e))
594
-
595
- try:
596
- buf_heatmap = create_heatmap_plot()
597
- img_heatmap_bytes = buf_heatmap.getvalue()
598
- img_heatmap_base64 = base64.b64encode(img_heatmap_bytes).decode("utf-8")
599
- img_heatmap_html = f'<img src="data:image/png;base64,{img_heatmap_base64}" style="max-width: 100%; height: auto;">'
600
- except Exception as e:
601
- logger.error("Failed to create heatmap plot: %s", str(e))
602
-
603
- try:
604
- buf_dist = create_dist_plot()
605
- img_dist_bytes = buf_dist.getvalue()
606
- img_dist_base64 = base64.b64encode(img_dist_bytes).decode("utf-8")
607
- img_dist_html = f'<img src="data:image/png;base64,{img_dist_base64}" style="max-width: 100%; height: auto;">'
608
- except Exception as e:
609
- logger.error("Failed to create distribution plot: %s", str(e))
610
 
611
- # Create DataFrame for the table
612
  table_data = []
613
- for i, entry in enumerate(content):
614
  logprob = ensure_float(entry.get("logprob", None))
615
  if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter and "top_logprobs" in entry and entry["top_logprobs"] is not None:
616
  token = entry["token"]
@@ -645,75 +205,52 @@ def visualize_logprobs(json_input, prob_filter=-1e9):
645
  else None
646
  )
647
 
648
- # Generate colored text
649
- if logprobs:
650
- min_logprob = min(logprobs)
651
- max_logprob = max(logprobs)
652
  if max_logprob == min_logprob:
653
- normalized_probs = [0.5] * len(logprobs)
654
  else:
655
  normalized_probs = [
656
- (lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs
657
  ]
658
 
659
  colored_text = ""
660
- for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
661
  r = int(255 * (1 - norm_prob)) # Red for low confidence
662
  g = int(255 * norm_prob) # Green for high confidence
663
  b = 0
664
  color = f"rgb({r}, {g}, {b})"
665
  colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
666
- if i < len(tokens) - 1:
667
  colored_text += " "
668
  colored_text_html = f"<p>{colored_text}</p>"
669
  else:
670
  colored_text_html = "No finite log probabilities to display."
671
 
672
- # Top 3 Token Log Probabilities
673
  alt_viz_html = ""
674
- if logprobs and top_alternatives:
675
- alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>"
676
- for i, (token, probs) in enumerate(zip(tokens, top_alternatives)):
677
- alt_viz_html += f"<li>Position {i} (Token: {token}):<br>"
678
  for tok, prob in probs:
679
  alt_viz_html += f"{tok}: {prob:.4f}<br>"
680
  alt_viz_html += "</li>"
681
  alt_viz_html += "</ul>"
682
 
683
- # Convert buffers to HTML for Gradio
684
- def buffer_to_html(buf):
685
- if isinstance(buf, str): # If it's an error message
686
- return buf
687
- img_bytes = buf.getvalue()
688
- img_base64 = base64.b64encode(img_bytes).decode("utf-8")
689
- return f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
690
-
691
- return (
692
- buffer_to_html(img_main_html), df, colored_text_html, alt_viz_html,
693
- buffer_to_html(img_cluster_html), buffer_to_html(img_drops_html), buffer_to_html(img_ngram_html),
694
- buffer_to_html(img_markov_html), buffer_to_html(img_anomaly_html), buffer_to_html(img_autocorr_html),
695
- buffer_to_html(img_smoothing_html), buffer_to_html(img_uncertainty_html), buffer_to_html(img_corr_html),
696
- buffer_to_html(img_type_html), buffer_to_html(img_embed_html), buffer_to_html(img_bayesian_html),
697
- buffer_to_html(img_graph_html), buffer_to_html(img_tsne_html), buffer_to_html(img_heatmap_html),
698
- buffer_to_html(img_dist_html)
699
- )
700
 
701
  except Exception as e:
702
  logger.error("Visualization failed: %s", str(e))
703
- return (
704
- f"Error: {str(e)}", None, None, None, "Placeholder for K-Means Clustering", "Placeholder for Probability Drops",
705
- "Placeholder for N-Gram Analysis", "Placeholder for Markov Chain", "Placeholder for Anomaly Detection",
706
- "Placeholder for Autocorrelation", "Placeholder for Smoothing (Moving Average)", "Placeholder for Uncertainty Propagation",
707
- "Placeholder for Correlation Heatmap", "Placeholder for Token Type Correlation", "Placeholder for Embedding Similarity vs. Probability",
708
- "Placeholder for Bayesian Inference (Entropy)", "Placeholder for Graph of Token Transitions", "Placeholder for t-SNE of Log Probabilities",
709
- "Placeholder for Interactive Heatmap", "Placeholder for Probability Distribution"
710
- )
711
 
712
- # Gradio interface with improved layout and placeholders
713
  with gr.Blocks(title="Log Probability Visualizer") as app:
714
  gr.Markdown("# Log Probability Visualizer")
715
  gr.Markdown(
716
- "Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use the filter to focus on specific log probability ranges."
717
  )
718
 
719
  with gr.Row():
@@ -725,61 +262,54 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
725
  )
726
  with gr.Column(scale=1):
727
  prob_filter = gr.Slider(minimum=-1e9, maximum=0, value=-1e9, label="Log Probability Filter (≥)")
 
 
728
 
729
- with gr.Tabs():
730
- with gr.Tab("Core Visualizations"):
731
- with gr.Row():
732
- plot_output = gr.HTML(label="Log Probability Plot (Click for Tokens)", value="Placeholder for Log Probability Plot")
733
- table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives", value=None)
734
- with gr.Row():
735
- text_output = gr.HTML(label="Colored Text (Confidence Visualization)", value="Placeholder for Colored Text (Confidence Visualization)")
736
- alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities", value="Placeholder for Top 3 Token Log Probabilities")
737
-
738
- with gr.Tab("Clustering & Patterns"):
739
- with gr.Row():
740
- cluster_output = gr.HTML(label="K-Means Clustering", value="Placeholder for K-Means Clustering")
741
- drops_output = gr.HTML(label="Probability Drops", value="Placeholder for Probability Drops")
742
- with gr.Row():
743
- ngram_output = gr.HTML(label="N-Gram Analysis", value="Placeholder for N-Gram Analysis")
744
- markov_output = gr.HTML(label="Markov Chain", value="Placeholder for Markov Chain")
745
-
746
- with gr.Tab("Time Series & Anomalies"):
747
- with gr.Row():
748
- anomaly_output = gr.HTML(label="Anomaly Detection", value="Placeholder for Anomaly Detection")
749
- autocorr_output = gr.HTML(label="Autocorrelation", value="Placeholder for Autocorrelation")
750
- with gr.Row():
751
- smoothing_output = gr.HTML(label="Smoothing (Moving Average)", value="Placeholder for Smoothing (Moving Average)")
752
- uncertainty_output = gr.HTML(label="Uncertainty Propagation", value="Placeholder for Uncertainty Propagation")
753
-
754
- with gr.Tab("Correlation & Types"):
755
- with gr.Row():
756
- corr_output = gr.HTML(label="Correlation Heatmap", value="Placeholder for Correlation Heatmap")
757
- type_output = gr.HTML(label="Token Type Correlation", value="Placeholder for Token Type Correlation")
758
 
759
- with gr.Tab("Advanced Analyses"):
760
- with gr.Row():
761
- embed_output = gr.HTML(label="Embedding Similarity vs. Probability", value="Placeholder for Embedding Similarity vs. Probability")
762
- bayesian_output = gr.HTML(label="Bayesian Inference (Entropy)", value="Placeholder for Bayesian Inference (Entropy)")
763
- with gr.Row():
764
- graph_output = gr.HTML(label="Graph of Token Transitions", value="Placeholder for Graph of Token Transitions")
765
- tsne_output = gr.HTML(label="t-SNE of Log Probabilities", value="Placeholder for t-SNE of Log Probabilities")
766
 
767
- with gr.Tab("Enhanced Visualizations"):
768
- with gr.Row():
769
- heatmap_output = gr.HTML(label="Interactive Heatmap", value="Placeholder for Interactive Heatmap")
770
- dist_output = gr.HTML(label="Probability Distribution", value="Placeholder for Probability Distribution")
771
 
772
  btn = gr.Button("Visualize")
773
  btn.click(
774
  fn=visualize_logprobs,
775
- inputs=[json_input, prob_filter],
776
- outputs=[
777
- plot_output, table_output, text_output, alt_viz_output,
778
- cluster_output, drops_output, ngram_output, markov_output,
779
- anomaly_output, autocorr_output, smoothing_output, uncertainty_output,
780
- corr_output, type_output, embed_output, bayesian_output,
781
- graph_output, tsne_output, heatmap_output, dist_output
782
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
783
  )
784
 
785
  app.launch()
 
8
  import ast
9
  import logging
10
  import numpy as np
11
+ import plotly.graph_objects as go
12
+ from plotly.subplots import make_subplots
13
  from scipy import stats
 
 
 
 
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.DEBUG)
 
59
  return float(value)
60
  return None
61
 
62
+ # Function to process and visualize log probs with interactive Plotly plots
63
+ def visualize_logprobs(json_input, prob_filter=-1e9, page_size=50, page=0):
64
  try:
65
  # Parse the input (handles both JSON and Python dictionaries)
66
  data = parse_input(json_input)
 
77
  tokens = []
78
  logprobs = []
79
  top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
 
80
  for entry in content:
81
  logprob = ensure_float(entry.get("logprob", None))
82
  if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter:
83
  tokens.append(entry["token"])
84
  logprobs.append(logprob)
 
 
 
 
 
 
85
  # Get top_logprobs, default to empty dict if None
86
  top_probs = entry.get("top_logprobs", {})
87
  # Ensure all values in top_logprobs are floats
 
101
 
102
  # Check if there's valid data after filtering
103
  if not logprobs or not tokens:
104
+ return (gr.update(value="No finite log probabilities or tokens to visualize after filtering"), None, None, None, 1, 0)
105
+
106
+ # Paginate data for large inputs
107
+ total_pages = max(1, (len(logprobs) + page_size - 1) // page_size)
108
+ start_idx = page * page_size
109
+ end_idx = min((page + 1) * page_size, len(logprobs))
110
+ paginated_tokens = tokens[start_idx:end_idx]
111
+ paginated_logprobs = logprobs[start_idx:end_idx]
112
+ paginated_alternatives = top_alternatives[start_idx:end_idx] if top_alternatives else []
113
+
114
+ # 1. Main Log Probability Plot (Interactive Plotly)
115
+ main_fig = go.Figure()
116
+ main_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
117
+ main_fig.update_layout(
118
+ title="Log Probabilities of Generated Tokens",
119
+ xaxis_title="Token Position",
120
+ yaxis_title="Log Probability",
121
+ hovermode="closest",
122
+ clickmode='event+select'
123
+ )
124
+ main_fig.update_traces(
125
+ customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, prob) in enumerate(zip(paginated_tokens, paginated_logprobs))],
126
+ hovertemplate='<b>%{customdata}</b><extra></extra>'
127
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ # 2. Probability Drop Analysis (Interactive Plotly)
130
+ if len(paginated_logprobs) < 2:
131
+ drops_fig = go.Figure()
132
+ drops_fig.add_trace(go.Bar(x=list(range(len(paginated_logprobs)-1)), y=[0], name='Drop', marker_color='red'))
133
+ else:
134
+ drops = [paginated_logprobs[i+1] - paginated_logprobs[i] for i in range(len(paginated_logprobs)-1)]
135
+ drops_fig = go.Figure()
136
+ drops_fig.add_trace(go.Bar(x=list(range(len(drops))), y=drops, name='Drop', marker_color='red'))
137
+ drops_fig.update_layout(
138
+ title="Significant Probability Drops",
139
+ xaxis_title="Token Position",
140
+ yaxis_title="Log Probability Drop",
141
+ hovermode="closest",
142
+ clickmode='event+select'
143
+ )
144
+ drops_fig.update_traces(
145
+ customdata=[f"Drop: {drop:.4f}, From: {paginated_tokens[i]} to {paginated_tokens[i+1]}, Position: {i+start_idx}" for i, drop in enumerate(drops)],
146
+ hovertemplate='<b>%{customdata}</b><extra></extra>'
147
+ )
148
 
149
+ # 3. Anomaly Detection (Interactive Plotly)
150
+ if not paginated_logprobs:
151
+ anomaly_fig = go.Figure()
152
+ anomaly_fig.add_trace(go.Scatter(x=[], y=[], mode='markers+lines', name='Log Prob', marker_color='blue'))
153
+ else:
154
+ z_scores = np.abs(stats.zscore(paginated_logprobs))
155
  outliers = z_scores > 2 # Threshold for outliers
156
+ anomaly_fig = go.Figure()
157
+ anomaly_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker_color='blue'))
158
+ anomaly_fig.add_trace(go.Scatter(x=np.where(outliers)[0], y=[paginated_logprobs[i] for i in np.where(outliers)[0]], mode='markers', name='Outliers', marker_color='red'))
159
+ anomaly_fig.update_layout(
160
+ title="Log Probabilities with Outliers",
161
+ xaxis_title="Token Position",
162
+ yaxis_title="Log Probability",
163
+ hovermode="closest",
164
+ clickmode='event+select'
165
+ )
166
+ anomaly_fig.update_traces(
167
+ customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}, Outlier: {out}" for i, (tok, prob, out) in enumerate(zip(paginated_tokens, paginated_logprobs, outliers))],
168
+ hovertemplate='<b>%{customdata}</b><extra></extra>'
169
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ # Create DataFrame for the table (paginated)
172
  table_data = []
173
+ for i, entry in enumerate(content[start_idx:end_idx]):
174
  logprob = ensure_float(entry.get("logprob", None))
175
  if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter and "top_logprobs" in entry and entry["top_logprobs"] is not None:
176
  token = entry["token"]
 
205
  else None
206
  )
207
 
208
+ # Generate colored text (paginated)
209
+ if paginated_logprobs:
210
+ min_logprob = min(paginated_logprobs)
211
+ max_logprob = max(paginated_logprobs)
212
  if max_logprob == min_logprob:
213
+ normalized_probs = [0.5] * len(paginated_logprobs)
214
  else:
215
  normalized_probs = [
216
+ (lp - min_logprob) / (max_logprob - min_logprob) for lp in paginated_logprobs
217
  ]
218
 
219
  colored_text = ""
220
+ for i, (token, norm_prob) in enumerate(zip(paginated_tokens, normalized_probs)):
221
  r = int(255 * (1 - norm_prob)) # Red for low confidence
222
  g = int(255 * norm_prob) # Green for high confidence
223
  b = 0
224
  color = f"rgb({r}, {g}, {b})"
225
  colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
226
+ if i < len(paginated_tokens) - 1:
227
  colored_text += " "
228
  colored_text_html = f"<p>{colored_text}</p>"
229
  else:
230
  colored_text_html = "No finite log probabilities to display."
231
 
232
+ # Top 3 Token Log Probabilities (paginated)
233
  alt_viz_html = ""
234
+ if paginated_logprobs and paginated_alternatives:
235
+ alt_viz_html = "<h3>Top 3 Token Log Probabilities (Paginated)</h3><ul>"
236
+ for i, (token, probs) in enumerate(zip(paginated_tokens, paginated_alternatives)):
237
+ alt_viz_html += f"<li>Position {i+start_idx} (Token: {token}):<br>"
238
  for tok, prob in probs:
239
  alt_viz_html += f"{tok}: {prob:.4f}<br>"
240
  alt_viz_html += "</li>"
241
  alt_viz_html += "</ul>"
242
 
243
+ return (main_fig, df, colored_text_html, alt_viz_html, drops_fig, anomaly_fig, total_pages, page)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  except Exception as e:
246
  logger.error("Visualization failed: %s", str(e))
247
+ return (gr.update(value=f"Error: {str(e)}"), None, "No finite log probabilities to display.", None, gr.update(value="No data for probability drops."), gr.update(value="No data for anomalies."), 1, 0)
 
 
 
 
 
 
 
248
 
249
+ # Gradio interface with interactive layout and pagination
250
  with gr.Blocks(title="Log Probability Visualizer") as app:
251
  gr.Markdown("# Log Probability Visualizer")
252
  gr.Markdown(
253
+ "Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use the filter and pagination to navigate large inputs."
254
  )
255
 
256
  with gr.Row():
 
262
  )
263
  with gr.Column(scale=1):
264
  prob_filter = gr.Slider(minimum=-1e9, maximum=0, value=-1e9, label="Log Probability Filter (≥)")
265
+ page_size = gr.Number(value=50, label="Page Size", precision=0, minimum=10, maximum=1000)
266
+ page = gr.Number(value=0, label="Page Number", precision=0, minimum=0)
267
 
268
+ with gr.Row():
269
+ plot_output = gr.Plot(label="Log Probability Plot (Click for Tokens)")
270
+ drops_output = gr.Plot(label="Probability Drops (Click for Details)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ with gr.Row():
273
+ anomaly_output = gr.Plot(label="Anomaly Detection (Click for Details)")
274
+ table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
 
 
 
 
275
 
276
+ with gr.Row():
277
+ text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
278
+ alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities")
 
279
 
280
  btn = gr.Button("Visualize")
281
  btn.click(
282
  fn=visualize_logprobs,
283
+ inputs=[json_input, prob_filter, page_size, page],
284
+ outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, anomaly_output, gr.State(visible=False), gr.State(visible=False)],
285
+ )
286
+
287
+ # Pagination controls
288
+ with gr.Row():
289
+ prev_btn = gr.Button("Previous Page")
290
+ next_btn = gr.Button("Next Page")
291
+ total_pages_output = gr.Number(label="Total Pages", interactive=False, visible=False)
292
+ current_page_output = gr.Number(label="Current Page", interactive=False, visible=False)
293
+
294
+ def update_page(json_input, prob_filter, page_size, current_page, action):
295
+ if action == "prev" and current_page > 0:
296
+ current_page -= 1
297
+ elif action == "next":
298
+ total_pages = visualize_logprobs(json_input, prob_filter, page_size, 0)[6] # Get total pages
299
+ if current_page < total_pages - 1:
300
+ current_page += 1
301
+ return gr.update(value=current_page), gr.update(value=total_pages)
302
+
303
+ prev_btn.click(
304
+ fn=lambda *args: update_page(*args, "prev"),
305
+ inputs=[json_input, prob_filter, page_size, page, gr.State()],
306
+ outputs=[page, total_pages_output]
307
+ )
308
+
309
+ next_btn.click(
310
+ fn=lambda *args: update_page(*args, "next"),
311
+ inputs=[json_input, prob_filter, page_size, page, gr.State()],
312
+ outputs=[page, total_pages_output]
313
  )
314
 
315
  app.launch()