Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,13 +8,9 @@ import math
|
|
8 |
import ast
|
9 |
import logging
|
10 |
import numpy as np
|
11 |
-
|
12 |
-
from
|
13 |
from scipy import stats
|
14 |
-
from scipy.stats import entropy
|
15 |
-
from scipy.signal import correlate
|
16 |
-
import networkx as nx
|
17 |
-
from matplotlib.widgets import Cursor
|
18 |
|
19 |
# Set up logging
|
20 |
logging.basicConfig(level=logging.DEBUG)
|
@@ -63,8 +59,8 @@ def ensure_float(value):
|
|
63 |
return float(value)
|
64 |
return None
|
65 |
|
66 |
-
# Function to process and visualize log probs with
|
67 |
-
def visualize_logprobs(json_input, prob_filter=-1e9):
|
68 |
try:
|
69 |
# Parse the input (handles both JSON and Python dictionaries)
|
70 |
data = parse_input(json_input)
|
@@ -81,18 +77,11 @@ def visualize_logprobs(json_input, prob_filter=-1e9):
|
|
81 |
tokens = []
|
82 |
logprobs = []
|
83 |
top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
|
84 |
-
token_types = [] # Simplified token type categorization
|
85 |
for entry in content:
|
86 |
logprob = ensure_float(entry.get("logprob", None))
|
87 |
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter:
|
88 |
tokens.append(entry["token"])
|
89 |
logprobs.append(logprob)
|
90 |
-
# Categorize token type (simple heuristic)
|
91 |
-
token = entry["token"].lower().strip()
|
92 |
-
if token in ["the", "a", "an"]: token_types.append("article")
|
93 |
-
elif token in ["is", "are", "was", "were"]: token_types.append("verb")
|
94 |
-
elif token in ["top", "so", "need", "figure"]: token_types.append("noun")
|
95 |
-
else: token_types.append("other")
|
96 |
# Get top_logprobs, default to empty dict if None
|
97 |
top_probs = entry.get("top_logprobs", {})
|
98 |
# Ensure all values in top_logprobs are floats
|
@@ -112,505 +101,76 @@ def visualize_logprobs(json_input, prob_filter=-1e9):
|
|
112 |
|
113 |
# Check if there's valid data after filtering
|
114 |
if not logprobs or not tokens:
|
115 |
-
return ("No finite log probabilities or tokens to visualize after filtering
|
116 |
-
|
117 |
-
#
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5:
|
140 |
-
token_annotations[i].set_text(tokens[i])
|
141 |
-
token_annotations[i].set_visible(True)
|
142 |
-
fig_main.canvas.draw_idle()
|
143 |
-
else:
|
144 |
-
token_annotations[i].set_visible(False)
|
145 |
-
fig_main.canvas.draw_idle()
|
146 |
-
|
147 |
-
fig_main.canvas.mpl_connect('button_press_event', on_click)
|
148 |
-
|
149 |
-
buf_main = io.BytesIO()
|
150 |
-
plt.savefig(buf_main, format="png", bbox_inches="tight", dpi=100)
|
151 |
-
buf_main.seek(0)
|
152 |
-
plt.close(fig_main)
|
153 |
-
return buf_main
|
154 |
-
|
155 |
-
# 2. K-Means Clustering of Log Probabilities
|
156 |
-
def create_cluster_plot():
|
157 |
-
if not logprobs:
|
158 |
-
raise ValueError("No data for clustering plot")
|
159 |
-
kmeans = KMeans(n_clusters=3, random_state=42)
|
160 |
-
cluster_labels = kmeans.fit_predict(np.array(logprobs).reshape(-1, 1))
|
161 |
-
fig_cluster, ax_cluster = plt.subplots(figsize=(10, 5))
|
162 |
-
scatter = ax_cluster.scatter(range(len(logprobs)), logprobs, c=cluster_labels, cmap='viridis')
|
163 |
-
ax_cluster.set_title("K-Means Clustering of Log Probabilities")
|
164 |
-
ax_cluster.set_xlabel("Token Position")
|
165 |
-
ax_cluster.set_ylabel("Log Probability")
|
166 |
-
ax_cluster.grid(True)
|
167 |
-
plt.colorbar(scatter, ax=ax_cluster, label="Cluster")
|
168 |
-
buf_cluster = io.BytesIO()
|
169 |
-
plt.savefig(buf_cluster, format="png", bbox_inches="tight", dpi=100)
|
170 |
-
buf_cluster.seek(0)
|
171 |
-
plt.close(fig_cluster)
|
172 |
-
return buf_cluster
|
173 |
-
|
174 |
-
# 3. Probability Drop Analysis
|
175 |
-
def create_drops_plot():
|
176 |
-
if not logprobs or len(logprobs) < 2:
|
177 |
-
raise ValueError("Insufficient data for probability drops")
|
178 |
-
drops = [logprobs[i+1] - logprobs[i] if i < len(logprobs)-1 else 0 for i in range(len(logprobs))]
|
179 |
-
fig_drops, ax_drops = plt.subplots(figsize=(10, 5))
|
180 |
-
ax_drops.bar(range(len(drops)), drops, color='red', alpha=0.5)
|
181 |
-
ax_drops.set_title("Significant Probability Drops")
|
182 |
-
ax_drops.set_xlabel("Token Position")
|
183 |
-
ax_drops.set_ylabel("Log Probability Drop")
|
184 |
-
ax_drops.grid(True)
|
185 |
-
buf_drops = io.BytesIO()
|
186 |
-
plt.savefig(buf_drops, format="png", bbox_inches="tight", dpi=100)
|
187 |
-
buf_drops.seek(0)
|
188 |
-
plt.close(fig_drops)
|
189 |
-
return buf_drops
|
190 |
-
|
191 |
-
# 4. N-Gram Analysis (Bigrams for simplicity)
|
192 |
-
def create_ngram_plot():
|
193 |
-
if not logprobs or len(logprobs) < 2:
|
194 |
-
raise ValueError("Insufficient data for N-gram analysis")
|
195 |
-
bigrams = [(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)]
|
196 |
-
bigram_probs = [logprobs[i] + logprobs[i+1] for i in range(len(tokens)-1)]
|
197 |
-
fig_ngram, ax_ngram = plt.subplots(figsize=(10, 5))
|
198 |
-
ax_ngram.bar(range(len(bigrams)), bigram_probs, color='green')
|
199 |
-
ax_ngram.set_title("N-Gram (Bigrams) Probability Sum")
|
200 |
-
ax_ngram.set_xlabel("Bigram Position")
|
201 |
-
ax_ngram.set_ylabel("Sum of Log Probabilities")
|
202 |
-
ax_ngram.set_xticks(range(len(bigrams)))
|
203 |
-
ax_ngram.set_xticklabels([f"{b[0]}->{b[1]}" for b in bigrams], rotation=45, ha="right")
|
204 |
-
ax_ngram.grid(True)
|
205 |
-
buf_ngram = io.BytesIO()
|
206 |
-
plt.savefig(buf_ngram, format="png", bbox_inches="tight", dpi=100)
|
207 |
-
buf_ngram.seek(0)
|
208 |
-
plt.close(fig_ngram)
|
209 |
-
return buf_ngram
|
210 |
|
211 |
-
#
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
for i in range(len(
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
227 |
|
228 |
-
#
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
233 |
outliers = z_scores > 2 # Threshold for outliers
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
# 7. Autocorrelation
|
250 |
-
def create_autocorr_plot():
|
251 |
-
if not logprobs:
|
252 |
-
raise ValueError("No data for autocorrelation")
|
253 |
-
autocorr = correlate(logprobs, logprobs, mode='full')
|
254 |
-
autocorr = autocorr[len(autocorr)//2:] / len(logprobs) # Normalize
|
255 |
-
fig_autocorr, ax_autocorr = plt.subplots(figsize=(10, 5))
|
256 |
-
ax_autocorr.plot(range(len(autocorr)), autocorr, color='purple')
|
257 |
-
ax_autocorr.set_title("Autocorrelation of Log Probabilities")
|
258 |
-
ax_autocorr.set_xlabel("Lag")
|
259 |
-
ax_autocorr.set_ylabel("Autocorrelation")
|
260 |
-
ax_autocorr.grid(True)
|
261 |
-
buf_autocorr = io.BytesIO()
|
262 |
-
plt.savefig(buf_autocorr, format="png", bbox_inches="tight", dpi=100)
|
263 |
-
buf_autocorr.seek(0)
|
264 |
-
plt.close(fig_autocorr)
|
265 |
-
return buf_autocorr
|
266 |
-
|
267 |
-
# 8. Smoothing (Moving Average)
|
268 |
-
def create_smoothing_plot():
|
269 |
-
if not logprobs:
|
270 |
-
raise ValueError("No data for smoothing")
|
271 |
-
window_size = 3
|
272 |
-
moving_avg = np.convolve(logprobs, np.ones(window_size)/window_size, mode='valid')
|
273 |
-
fig_smoothing, ax_smoothing = plt.subplots(figsize=(10, 5))
|
274 |
-
ax_smoothing.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Original")
|
275 |
-
ax_smoothing.plot(range(window_size-1, len(logprobs)), moving_avg, color="orange", label="Moving Average")
|
276 |
-
ax_smoothing.set_title("Log Probabilities with Moving Average")
|
277 |
-
ax_smoothing.set_xlabel("Token Position")
|
278 |
-
ax_smoothing.set_ylabel("Log Probability")
|
279 |
-
ax_smoothing.grid(True)
|
280 |
-
ax_smoothing.legend()
|
281 |
-
ax_smoothing.set_xticks([]) # Hide X-axis labels
|
282 |
-
buf_smoothing = io.BytesIO()
|
283 |
-
plt.savefig(buf_smoothing, format="png", bbox_inches="tight", dpi=100)
|
284 |
-
buf_smoothing.seek(0)
|
285 |
-
plt.close(fig_smoothing)
|
286 |
-
return buf_smoothing
|
287 |
-
|
288 |
-
# 9. Uncertainty Propagation (Variance of Top Logprobs)
|
289 |
-
def create_uncertainty_plot():
|
290 |
-
if not logprobs or not top_alternatives:
|
291 |
-
raise ValueError("No data for uncertainty propagation")
|
292 |
-
variances = []
|
293 |
-
for probs in top_alternatives:
|
294 |
-
if len(probs) > 1:
|
295 |
-
values = [p[1] for p in probs]
|
296 |
-
variances.append(np.var(values))
|
297 |
-
else:
|
298 |
-
variances.append(0)
|
299 |
-
fig_uncertainty, ax_uncertainty = plt.subplots(figsize=(10, 5))
|
300 |
-
ax_uncertainty.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Log Prob")
|
301 |
-
ax_uncertainty.fill_between(range(len(logprobs)), [lp - v for lp, v in zip(logprobs, variances)],
|
302 |
-
[lp + v for lp, v in zip(logprobs, variances)], color='gray', alpha=0.3, label="Uncertainty")
|
303 |
-
ax_uncertainty.set_title("Log Probabilities with Uncertainty Propagation")
|
304 |
-
ax_uncertainty.set_xlabel("Token Position")
|
305 |
-
ax_uncertainty.set_ylabel("Log Probability")
|
306 |
-
ax_uncertainty.grid(True)
|
307 |
-
ax_uncertainty.legend()
|
308 |
-
ax_uncertainty.set_xticks([]) # Hide X-axis labels
|
309 |
-
buf_uncertainty = io.BytesIO()
|
310 |
-
plt.savefig(buf_uncertainty, format="png", bbox_inches="tight", dpi=100)
|
311 |
-
buf_uncertainty.seek(0)
|
312 |
-
plt.close(fig_uncertainty)
|
313 |
-
return buf_uncertainty
|
314 |
-
|
315 |
-
# 10. Correlation Heatmap
|
316 |
-
def create_corr_plot():
|
317 |
-
if not logprobs or len(logprobs) < 2:
|
318 |
-
raise ValueError("Insufficient data for correlation heatmap")
|
319 |
-
corr_matrix = np.corrcoef(logprobs, rowvar=False)
|
320 |
-
fig_corr, ax_corr = plt.subplots(figsize=(10, 5))
|
321 |
-
im = ax_corr.imshow(corr_matrix, cmap='coolwarm', interpolation='nearest')
|
322 |
-
ax_corr.set_title("Correlation of Log Probabilities Across Positions")
|
323 |
-
ax_corr.set_xlabel("Token Position")
|
324 |
-
ax_corr.set_ylabel("Token Position")
|
325 |
-
plt.colorbar(im, ax=ax_corr, label="Correlation")
|
326 |
-
buf_corr = io.BytesIO()
|
327 |
-
plt.savefig(buf_corr, format="png", bbox_inches="tight", dpi=100)
|
328 |
-
buf_corr.seek(0)
|
329 |
-
plt.close(fig_corr)
|
330 |
-
return buf_corr
|
331 |
-
|
332 |
-
# 11. Token Type Correlation
|
333 |
-
def create_type_plot():
|
334 |
-
if not logprobs or not token_types:
|
335 |
-
raise ValueError("No data for token type correlation")
|
336 |
-
type_probs = {t: [] for t in set(token_types)}
|
337 |
-
for t, p in zip(token_types, logprobs):
|
338 |
-
type_probs[t].append(p)
|
339 |
-
fig_type, ax_type = plt.subplots(figsize=(10, 5))
|
340 |
-
for t in type_probs:
|
341 |
-
ax_type.bar(t, np.mean(type_probs[t]), yerr=np.std(type_probs[t]), capsize=5, label=t)
|
342 |
-
ax_type.set_title("Average Log Probability by Token Type")
|
343 |
-
ax_type.set_xlabel("Token Type")
|
344 |
-
ax_type.set_ylabel("Average Log Probability")
|
345 |
-
ax_type.grid(True)
|
346 |
-
ax_type.legend()
|
347 |
-
buf_type = io.BytesIO()
|
348 |
-
plt.savefig(buf_type, format="png", bbox_inches="tight", dpi=100)
|
349 |
-
buf_type.seek(0)
|
350 |
-
plt.close(fig_type)
|
351 |
-
return buf_type
|
352 |
-
|
353 |
-
# 12. Token Embedding Similarity vs. Probability (Simulated)
|
354 |
-
def create_embed_plot():
|
355 |
-
if not logprobs or not tokens:
|
356 |
-
raise ValueError("No data for embedding similarity")
|
357 |
-
simulated_embeddings = np.random.rand(len(tokens), 2) # 2D embeddings
|
358 |
-
fig_embed, ax_embed = plt.subplots(figsize=(10, 5))
|
359 |
-
ax_embed.scatter(simulated_embeddings[:, 0], simulated_embeddings[:, 1], c=logprobs, cmap='viridis')
|
360 |
-
ax_embed.set_title("Token Embedding Similarity vs. Log Probability")
|
361 |
-
ax_embed.set_xlabel("Embedding Dimension 1")
|
362 |
-
ax_embed.set_ylabel("Embedding Dimension 2")
|
363 |
-
plt.colorbar(ax_embed.collections[0], ax=ax_embed, label="Log Probability")
|
364 |
-
buf_embed = io.BytesIO()
|
365 |
-
plt.savefig(buf_embed, format="png", bbox_inches="tight", dpi=100)
|
366 |
-
buf_embed.seek(0)
|
367 |
-
plt.close(fig_embed)
|
368 |
-
return buf_embed
|
369 |
-
|
370 |
-
# 13. Bayesian Inference (Simplified as Inferred Probabilities)
|
371 |
-
def create_bayesian_plot():
|
372 |
-
if not top_alternatives:
|
373 |
-
raise ValueError("No data for Bayesian inference")
|
374 |
-
entropies = [entropy([p[1] for p in probs], base=2) for probs in top_alternatives if len(probs) > 1]
|
375 |
-
fig_bayesian, ax_bayesian = plt.subplots(figsize=(10, 5))
|
376 |
-
ax_bayesian.bar(range(len(entropies)), entropies, color='orange')
|
377 |
-
ax_bayesian.set_title("Bayesian Inferred Uncertainty (Entropy)")
|
378 |
-
ax_bayesian.set_xlabel("Token Position")
|
379 |
-
ax_bayesian.set_ylabel("Entropy")
|
380 |
-
ax_bayesian.grid(True)
|
381 |
-
buf_bayesian = io.BytesIO()
|
382 |
-
plt.savefig(buf_bayesian, format="png", bbox_inches="tight", dpi=100)
|
383 |
-
buf_bayesian.seek(0)
|
384 |
-
plt.close(fig_bayesian)
|
385 |
-
return buf_bayesian
|
386 |
-
|
387 |
-
# 14. Graph-Based Analysis
|
388 |
-
def create_graph_plot():
|
389 |
-
if not tokens or len(tokens) < 2:
|
390 |
-
raise ValueError("Insufficient data for graph analysis")
|
391 |
-
G = nx.DiGraph()
|
392 |
-
for i in range(len(tokens)-1):
|
393 |
-
G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i])
|
394 |
-
fig_graph, ax_graph = plt.subplots(figsize=(10, 5))
|
395 |
-
pos = nx.spring_layout(G)
|
396 |
-
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_graph)
|
397 |
-
ax_graph.set_title("Graph of Token Transitions")
|
398 |
-
buf_graph = io.BytesIO()
|
399 |
-
plt.savefig(buf_graph, format="png", bbox_inches="tight", dpi=100)
|
400 |
-
buf_graph.seek(0)
|
401 |
-
plt.close(fig_graph)
|
402 |
-
return buf_graph
|
403 |
-
|
404 |
-
# 15. Dimensionality Reduction (t-SNE)
|
405 |
-
def create_tsne_plot():
|
406 |
-
if not logprobs or not top_alternatives:
|
407 |
-
raise ValueError("No data for t-SNE")
|
408 |
-
features = np.array([logprobs + [p[1] for p in alts[:2]] for logprobs, alts in zip([logprobs], top_alternatives)])
|
409 |
-
tsne = TSNE(n_components=2, random_state=42)
|
410 |
-
tsne_result = tsne.fit_transform(features.T)
|
411 |
-
fig_tsne, ax_tsne = plt.subplots(figsize=(10, 5))
|
412 |
-
scatter = ax_tsne.scatter(tsne_result[:, 0], tsne_result[:, 1], c=logprobs, cmap='viridis')
|
413 |
-
ax_tsne.set_title("t-SNE of Log Probabilities and Top Alternatives")
|
414 |
-
ax_tsne.set_xlabel("t-SNE Dimension 1")
|
415 |
-
ax_tsne.set_ylabel("t-SNE Dimension 2")
|
416 |
-
plt.colorbar(scatter, ax=ax_tsne, label="Log Probability")
|
417 |
-
buf_tsne = io.BytesIO()
|
418 |
-
plt.savefig(buf_tsne, format="png", bbox_inches="tight", dpi=100)
|
419 |
-
buf_tsne.seek(0)
|
420 |
-
plt.close(fig_tsne)
|
421 |
-
return buf_tsne
|
422 |
-
|
423 |
-
# 16. Interactive Heatmap
|
424 |
-
def create_heatmap_plot():
|
425 |
-
if not logprobs:
|
426 |
-
raise ValueError("No data for heatmap")
|
427 |
-
fig_heatmap, ax_heatmap = plt.subplots(figsize=(10, 5))
|
428 |
-
im = ax_heatmap.imshow([logprobs], cmap='viridis', aspect='auto')
|
429 |
-
ax_heatmap.set_title("Interactive Heatmap of Log Probabilities")
|
430 |
-
ax_heatmap.set_xlabel("Token Position")
|
431 |
-
ax_heatmap.set_ylabel("Probability Level")
|
432 |
-
plt.colorbar(im, ax=ax_heatmap, label="Log Probability")
|
433 |
-
buf_heatmap = io.BytesIO()
|
434 |
-
plt.savefig(buf_heatmap, format="png", bbox_inches="tight", dpi=100)
|
435 |
-
buf_heatmap.seek(0)
|
436 |
-
plt.close(fig_heatmap)
|
437 |
-
return buf_heatmap
|
438 |
-
|
439 |
-
# 17. Probability Distribution Plots (Box Plots for Top Logprobs)
|
440 |
-
def create_dist_plot():
|
441 |
-
if not logprobs or not top_alternatives:
|
442 |
-
raise ValueError("No data for probability distribution")
|
443 |
-
all_top_probs = [p[1] for alts in top_alternatives for p in alts]
|
444 |
-
fig_dist, ax_dist = plt.subplots(figsize=(10, 5))
|
445 |
-
ax_dist.boxplot([logprobs] + [p[1] for alts in top_alternatives for p in alts[:2]], labels=["Selected"] + ["Alt1", "Alt2"])
|
446 |
-
ax_dist.set_title("Probability Distribution of Top Tokens")
|
447 |
-
ax_dist.set_xlabel("Token Type")
|
448 |
-
ax_dist.set_ylabel("Log Probability")
|
449 |
-
ax_dist.grid(True)
|
450 |
-
buf_dist = io.BytesIO()
|
451 |
-
plt.savefig(buf_dist, format="png", bbox_inches="tight", dpi=100)
|
452 |
-
buf_dist.seek(0)
|
453 |
-
plt.close(fig_dist)
|
454 |
-
return buf_dist
|
455 |
-
|
456 |
-
# Create all plots safely
|
457 |
-
img_main_html = "Placeholder for Log Probability Plot"
|
458 |
-
img_cluster_html = "Placeholder for K-Means Clustering"
|
459 |
-
img_drops_html = "Placeholder for Probability Drops"
|
460 |
-
img_ngram_html = "Placeholder for N-Gram Analysis"
|
461 |
-
img_markov_html = "Placeholder for Markov Chain"
|
462 |
-
img_anomaly_html = "Placeholder for Anomaly Detection"
|
463 |
-
img_autocorr_html = "Placeholder for Autocorrelation"
|
464 |
-
img_smoothing_html = "Placeholder for Smoothing (Moving Average)"
|
465 |
-
img_uncertainty_html = "Placeholder for Uncertainty Propagation"
|
466 |
-
img_corr_html = "Placeholder for Correlation Heatmap"
|
467 |
-
img_type_html = "Placeholder for Token Type Correlation"
|
468 |
-
img_embed_html = "Placeholder for Embedding Similarity vs. Probability"
|
469 |
-
img_bayesian_html = "Placeholder for Bayesian Inference (Entropy)"
|
470 |
-
img_graph_html = "Placeholder for Graph of Token Transitions"
|
471 |
-
img_tsne_html = "Placeholder for t-SNE of Log Probabilities"
|
472 |
-
img_heatmap_html = "Placeholder for Interactive Heatmap"
|
473 |
-
img_dist_html = "Placeholder for Probability Distribution"
|
474 |
-
|
475 |
-
try:
|
476 |
-
buf_main = create_main_plot()
|
477 |
-
img_main_bytes = buf_main.getvalue()
|
478 |
-
img_main_base64 = base64.b64encode(img_main_bytes).decode("utf-8")
|
479 |
-
img_main_html = f'<img src="data:image/png;base64,{img_main_base64}" style="max-width: 100%; height: auto;">'
|
480 |
-
except Exception as e:
|
481 |
-
logger.error("Failed to create main plot: %s", str(e))
|
482 |
-
|
483 |
-
try:
|
484 |
-
buf_cluster = create_cluster_plot()
|
485 |
-
img_cluster_bytes = buf_cluster.getvalue()
|
486 |
-
img_cluster_base64 = base64.b64encode(img_cluster_bytes).decode("utf-8")
|
487 |
-
img_cluster_html = f'<img src="data:image/png;base64,{img_cluster_base64}" style="max-width: 100%; height: auto;">'
|
488 |
-
except Exception as e:
|
489 |
-
logger.error("Failed to create cluster plot: %s", str(e))
|
490 |
-
|
491 |
-
try:
|
492 |
-
buf_drops = create_drops_plot()
|
493 |
-
img_drops_bytes = buf_drops.getvalue()
|
494 |
-
img_drops_base64 = base64.b64encode(img_drops_bytes).decode("utf-8")
|
495 |
-
img_drops_html = f'<img src="data:image/png;base64,{img_drops_base64}" style="max-width: 100%; height: auto;">'
|
496 |
-
except Exception as e:
|
497 |
-
logger.error("Failed to create drops plot: %s", str(e))
|
498 |
-
|
499 |
-
try:
|
500 |
-
buf_ngram = create_ngram_plot()
|
501 |
-
img_ngram_bytes = buf_ngram.getvalue()
|
502 |
-
img_ngram_base64 = base64.b64encode(img_ngram_bytes).decode("utf-8")
|
503 |
-
img_ngram_html = f'<img src="data:image/png;base64,{img_ngram_base64}" style="max-width: 100%; height: auto;">'
|
504 |
-
except Exception as e:
|
505 |
-
logger.error("Failed to create ngram plot: %s", str(e))
|
506 |
-
|
507 |
-
try:
|
508 |
-
buf_markov = create_markov_plot()
|
509 |
-
img_markov_bytes = buf_markov.getvalue()
|
510 |
-
img_markov_base64 = base64.b64encode(img_markov_bytes).decode("utf-8")
|
511 |
-
img_markov_html = f'<img src="data:image/png;base64,{img_markov_base64}" style="max-width: 100%; height: auto;">'
|
512 |
-
except Exception as e:
|
513 |
-
logger.error("Failed to create markov plot: %s", str(e))
|
514 |
-
|
515 |
-
try:
|
516 |
-
buf_anomaly = create_anomaly_plot()
|
517 |
-
img_anomaly_bytes = buf_anomaly.getvalue()
|
518 |
-
img_anomaly_base64 = base64.b64encode(img_anomaly_bytes).decode("utf-8")
|
519 |
-
img_anomaly_html = f'<img src="data:image/png;base64,{img_anomaly_base64}" style="max-width: 100%; height: auto;">'
|
520 |
-
except Exception as e:
|
521 |
-
logger.error("Failed to create anomaly plot: %s", str(e))
|
522 |
-
|
523 |
-
try:
|
524 |
-
buf_autocorr = create_autocorr_plot()
|
525 |
-
img_autocorr_bytes = buf_autocorr.getvalue()
|
526 |
-
img_autocorr_base64 = base64.b64encode(img_autocorr_bytes).decode("utf-8")
|
527 |
-
img_autocorr_html = f'<img src="data:image/png;base64,{img_autocorr_base64}" style="max-width: 100%; height: auto;">'
|
528 |
-
except Exception as e:
|
529 |
-
logger.error("Failed to create autocorr plot: %s", str(e))
|
530 |
-
|
531 |
-
try:
|
532 |
-
buf_smoothing = create_smoothing_plot()
|
533 |
-
img_smoothing_bytes = buf_smoothing.getvalue()
|
534 |
-
img_smoothing_base64 = base64.b64encode(img_smoothing_bytes).decode("utf-8")
|
535 |
-
img_smoothing_html = f'<img src="data:image/png;base64,{img_smoothing_base64}" style="max-width: 100%; height: auto;">'
|
536 |
-
except Exception as e:
|
537 |
-
logger.error("Failed to create smoothing plot: %s", str(e))
|
538 |
-
|
539 |
-
try:
|
540 |
-
buf_uncertainty = create_uncertainty_plot()
|
541 |
-
img_uncertainty_bytes = buf_uncertainty.getvalue()
|
542 |
-
img_uncertainty_base64 = base64.b64encode(img_uncertainty_bytes).decode("utf-8")
|
543 |
-
img_uncertainty_html = f'<img src="data:image/png;base64,{img_uncertainty_base64}" style="max-width: 100%; height: auto;">'
|
544 |
-
except Exception as e:
|
545 |
-
logger.error("Failed to create uncertainty plot: %s", str(e))
|
546 |
-
|
547 |
-
try:
|
548 |
-
buf_corr = create_corr_plot()
|
549 |
-
img_corr_bytes = buf_corr.getvalue()
|
550 |
-
img_corr_base64 = base64.b64encode(img_corr_bytes).decode("utf-8")
|
551 |
-
img_corr_html = f'<img src="data:image/png;base64,{img_corr_base64}" style="max-width: 100%; height: auto;">'
|
552 |
-
except Exception as e:
|
553 |
-
logger.error("Failed to create correlation plot: %s", str(e))
|
554 |
-
|
555 |
-
try:
|
556 |
-
buf_type = create_type_plot()
|
557 |
-
img_type_bytes = buf_type.getvalue()
|
558 |
-
img_type_base64 = base64.b64encode(img_type_bytes).decode("utf-8")
|
559 |
-
img_type_html = f'<img src="data:image/png;base64,{img_type_base64}" style="max-width: 100%; height: auto;">'
|
560 |
-
except Exception as e:
|
561 |
-
logger.error("Failed to create type plot: %s", str(e))
|
562 |
-
|
563 |
-
try:
|
564 |
-
buf_embed = create_embed_plot()
|
565 |
-
img_embed_bytes = buf_embed.getvalue()
|
566 |
-
img_embed_base64 = base64.b64encode(img_embed_bytes).decode("utf-8")
|
567 |
-
img_embed_html = f'<img src="data:image/png;base64,{img_embed_base64}" style="max-width: 100%; height: auto;">'
|
568 |
-
except Exception as e:
|
569 |
-
logger.error("Failed to create embed plot: %s", str(e))
|
570 |
-
|
571 |
-
try:
|
572 |
-
buf_bayesian = create_bayesian_plot()
|
573 |
-
img_bayesian_bytes = buf_bayesian.getvalue()
|
574 |
-
img_bayesian_base64 = base64.b64encode(img_bayesian_bytes).decode("utf-8")
|
575 |
-
img_bayesian_html = f'<img src="data:image/png;base64,{img_bayesian_base64}" style="max-width: 100%; height: auto;">'
|
576 |
-
except Exception as e:
|
577 |
-
logger.error("Failed to create bayesian plot: %s", str(e))
|
578 |
-
|
579 |
-
try:
|
580 |
-
buf_graph = create_graph_plot()
|
581 |
-
img_graph_bytes = buf_graph.getvalue()
|
582 |
-
img_graph_base64 = base64.b64encode(img_graph_bytes).decode("utf-8")
|
583 |
-
img_graph_html = f'<img src="data:image/png;base64,{img_graph_base64}" style="max-width: 100%; height: auto;">'
|
584 |
-
except Exception as e:
|
585 |
-
logger.error("Failed to create graph plot: %s", str(e))
|
586 |
-
|
587 |
-
try:
|
588 |
-
buf_tsne = create_tsne_plot()
|
589 |
-
img_tsne_bytes = buf_tsne.getvalue()
|
590 |
-
img_tsne_base64 = base64.b64encode(img_tsne_bytes).decode("utf-8")
|
591 |
-
img_tsne_html = f'<img src="data:image/png;base64,{img_tsne_base64}" style="max-width: 100%; height: auto;">'
|
592 |
-
except Exception as e:
|
593 |
-
logger.error("Failed to create tsne plot: %s", str(e))
|
594 |
-
|
595 |
-
try:
|
596 |
-
buf_heatmap = create_heatmap_plot()
|
597 |
-
img_heatmap_bytes = buf_heatmap.getvalue()
|
598 |
-
img_heatmap_base64 = base64.b64encode(img_heatmap_bytes).decode("utf-8")
|
599 |
-
img_heatmap_html = f'<img src="data:image/png;base64,{img_heatmap_base64}" style="max-width: 100%; height: auto;">'
|
600 |
-
except Exception as e:
|
601 |
-
logger.error("Failed to create heatmap plot: %s", str(e))
|
602 |
-
|
603 |
-
try:
|
604 |
-
buf_dist = create_dist_plot()
|
605 |
-
img_dist_bytes = buf_dist.getvalue()
|
606 |
-
img_dist_base64 = base64.b64encode(img_dist_bytes).decode("utf-8")
|
607 |
-
img_dist_html = f'<img src="data:image/png;base64,{img_dist_base64}" style="max-width: 100%; height: auto;">'
|
608 |
-
except Exception as e:
|
609 |
-
logger.error("Failed to create distribution plot: %s", str(e))
|
610 |
|
611 |
-
# Create DataFrame for the table
|
612 |
table_data = []
|
613 |
-
for i, entry in enumerate(content):
|
614 |
logprob = ensure_float(entry.get("logprob", None))
|
615 |
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter and "top_logprobs" in entry and entry["top_logprobs"] is not None:
|
616 |
token = entry["token"]
|
@@ -645,75 +205,52 @@ def visualize_logprobs(json_input, prob_filter=-1e9):
|
|
645 |
else None
|
646 |
)
|
647 |
|
648 |
-
# Generate colored text
|
649 |
-
if
|
650 |
-
min_logprob = min(
|
651 |
-
max_logprob = max(
|
652 |
if max_logprob == min_logprob:
|
653 |
-
normalized_probs = [0.5] * len(
|
654 |
else:
|
655 |
normalized_probs = [
|
656 |
-
(lp - min_logprob) / (max_logprob - min_logprob) for lp in
|
657 |
]
|
658 |
|
659 |
colored_text = ""
|
660 |
-
for i, (token, norm_prob) in enumerate(zip(
|
661 |
r = int(255 * (1 - norm_prob)) # Red for low confidence
|
662 |
g = int(255 * norm_prob) # Green for high confidence
|
663 |
b = 0
|
664 |
color = f"rgb({r}, {g}, {b})"
|
665 |
colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
|
666 |
-
if i < len(
|
667 |
colored_text += " "
|
668 |
colored_text_html = f"<p>{colored_text}</p>"
|
669 |
else:
|
670 |
colored_text_html = "No finite log probabilities to display."
|
671 |
|
672 |
-
# Top 3 Token Log Probabilities
|
673 |
alt_viz_html = ""
|
674 |
-
if
|
675 |
-
alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>"
|
676 |
-
for i, (token, probs) in enumerate(zip(
|
677 |
-
alt_viz_html += f"<li>Position {i} (Token: {token}):<br>"
|
678 |
for tok, prob in probs:
|
679 |
alt_viz_html += f"{tok}: {prob:.4f}<br>"
|
680 |
alt_viz_html += "</li>"
|
681 |
alt_viz_html += "</ul>"
|
682 |
|
683 |
-
|
684 |
-
def buffer_to_html(buf):
|
685 |
-
if isinstance(buf, str): # If it's an error message
|
686 |
-
return buf
|
687 |
-
img_bytes = buf.getvalue()
|
688 |
-
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
|
689 |
-
return f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
|
690 |
-
|
691 |
-
return (
|
692 |
-
buffer_to_html(img_main_html), df, colored_text_html, alt_viz_html,
|
693 |
-
buffer_to_html(img_cluster_html), buffer_to_html(img_drops_html), buffer_to_html(img_ngram_html),
|
694 |
-
buffer_to_html(img_markov_html), buffer_to_html(img_anomaly_html), buffer_to_html(img_autocorr_html),
|
695 |
-
buffer_to_html(img_smoothing_html), buffer_to_html(img_uncertainty_html), buffer_to_html(img_corr_html),
|
696 |
-
buffer_to_html(img_type_html), buffer_to_html(img_embed_html), buffer_to_html(img_bayesian_html),
|
697 |
-
buffer_to_html(img_graph_html), buffer_to_html(img_tsne_html), buffer_to_html(img_heatmap_html),
|
698 |
-
buffer_to_html(img_dist_html)
|
699 |
-
)
|
700 |
|
701 |
except Exception as e:
|
702 |
logger.error("Visualization failed: %s", str(e))
|
703 |
-
return (
|
704 |
-
f"Error: {str(e)}", None, None, None, "Placeholder for K-Means Clustering", "Placeholder for Probability Drops",
|
705 |
-
"Placeholder for N-Gram Analysis", "Placeholder for Markov Chain", "Placeholder for Anomaly Detection",
|
706 |
-
"Placeholder for Autocorrelation", "Placeholder for Smoothing (Moving Average)", "Placeholder for Uncertainty Propagation",
|
707 |
-
"Placeholder for Correlation Heatmap", "Placeholder for Token Type Correlation", "Placeholder for Embedding Similarity vs. Probability",
|
708 |
-
"Placeholder for Bayesian Inference (Entropy)", "Placeholder for Graph of Token Transitions", "Placeholder for t-SNE of Log Probabilities",
|
709 |
-
"Placeholder for Interactive Heatmap", "Placeholder for Probability Distribution"
|
710 |
-
)
|
711 |
|
712 |
-
# Gradio interface with
|
713 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
714 |
gr.Markdown("# Log Probability Visualizer")
|
715 |
gr.Markdown(
|
716 |
-
"Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use the filter
|
717 |
)
|
718 |
|
719 |
with gr.Row():
|
@@ -725,61 +262,54 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
|
|
725 |
)
|
726 |
with gr.Column(scale=1):
|
727 |
prob_filter = gr.Slider(minimum=-1e9, maximum=0, value=-1e9, label="Log Probability Filter (≥)")
|
|
|
|
|
728 |
|
729 |
-
with gr.
|
730 |
-
|
731 |
-
|
732 |
-
plot_output = gr.HTML(label="Log Probability Plot (Click for Tokens)", value="Placeholder for Log Probability Plot")
|
733 |
-
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives", value=None)
|
734 |
-
with gr.Row():
|
735 |
-
text_output = gr.HTML(label="Colored Text (Confidence Visualization)", value="Placeholder for Colored Text (Confidence Visualization)")
|
736 |
-
alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities", value="Placeholder for Top 3 Token Log Probabilities")
|
737 |
-
|
738 |
-
with gr.Tab("Clustering & Patterns"):
|
739 |
-
with gr.Row():
|
740 |
-
cluster_output = gr.HTML(label="K-Means Clustering", value="Placeholder for K-Means Clustering")
|
741 |
-
drops_output = gr.HTML(label="Probability Drops", value="Placeholder for Probability Drops")
|
742 |
-
with gr.Row():
|
743 |
-
ngram_output = gr.HTML(label="N-Gram Analysis", value="Placeholder for N-Gram Analysis")
|
744 |
-
markov_output = gr.HTML(label="Markov Chain", value="Placeholder for Markov Chain")
|
745 |
-
|
746 |
-
with gr.Tab("Time Series & Anomalies"):
|
747 |
-
with gr.Row():
|
748 |
-
anomaly_output = gr.HTML(label="Anomaly Detection", value="Placeholder for Anomaly Detection")
|
749 |
-
autocorr_output = gr.HTML(label="Autocorrelation", value="Placeholder for Autocorrelation")
|
750 |
-
with gr.Row():
|
751 |
-
smoothing_output = gr.HTML(label="Smoothing (Moving Average)", value="Placeholder for Smoothing (Moving Average)")
|
752 |
-
uncertainty_output = gr.HTML(label="Uncertainty Propagation", value="Placeholder for Uncertainty Propagation")
|
753 |
-
|
754 |
-
with gr.Tab("Correlation & Types"):
|
755 |
-
with gr.Row():
|
756 |
-
corr_output = gr.HTML(label="Correlation Heatmap", value="Placeholder for Correlation Heatmap")
|
757 |
-
type_output = gr.HTML(label="Token Type Correlation", value="Placeholder for Token Type Correlation")
|
758 |
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
bayesian_output = gr.HTML(label="Bayesian Inference (Entropy)", value="Placeholder for Bayesian Inference (Entropy)")
|
763 |
-
with gr.Row():
|
764 |
-
graph_output = gr.HTML(label="Graph of Token Transitions", value="Placeholder for Graph of Token Transitions")
|
765 |
-
tsne_output = gr.HTML(label="t-SNE of Log Probabilities", value="Placeholder for t-SNE of Log Probabilities")
|
766 |
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
dist_output = gr.HTML(label="Probability Distribution", value="Placeholder for Probability Distribution")
|
771 |
|
772 |
btn = gr.Button("Visualize")
|
773 |
btn.click(
|
774 |
fn=visualize_logprobs,
|
775 |
-
inputs=[json_input, prob_filter],
|
776 |
-
outputs=[
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
783 |
)
|
784 |
|
785 |
app.launch()
|
|
|
8 |
import ast
|
9 |
import logging
|
10 |
import numpy as np
|
11 |
+
import plotly.graph_objects as go
|
12 |
+
from plotly.subplots import make_subplots
|
13 |
from scipy import stats
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# Set up logging
|
16 |
logging.basicConfig(level=logging.DEBUG)
|
|
|
59 |
return float(value)
|
60 |
return None
|
61 |
|
62 |
+
# Function to process and visualize log probs with interactive Plotly plots
|
63 |
+
def visualize_logprobs(json_input, prob_filter=-1e9, page_size=50, page=0):
|
64 |
try:
|
65 |
# Parse the input (handles both JSON and Python dictionaries)
|
66 |
data = parse_input(json_input)
|
|
|
77 |
tokens = []
|
78 |
logprobs = []
|
79 |
top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
|
|
|
80 |
for entry in content:
|
81 |
logprob = ensure_float(entry.get("logprob", None))
|
82 |
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter:
|
83 |
tokens.append(entry["token"])
|
84 |
logprobs.append(logprob)
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
# Get top_logprobs, default to empty dict if None
|
86 |
top_probs = entry.get("top_logprobs", {})
|
87 |
# Ensure all values in top_logprobs are floats
|
|
|
101 |
|
102 |
# Check if there's valid data after filtering
|
103 |
if not logprobs or not tokens:
|
104 |
+
return (gr.update(value="No finite log probabilities or tokens to visualize after filtering"), None, None, None, 1, 0)
|
105 |
+
|
106 |
+
# Paginate data for large inputs
|
107 |
+
total_pages = max(1, (len(logprobs) + page_size - 1) // page_size)
|
108 |
+
start_idx = page * page_size
|
109 |
+
end_idx = min((page + 1) * page_size, len(logprobs))
|
110 |
+
paginated_tokens = tokens[start_idx:end_idx]
|
111 |
+
paginated_logprobs = logprobs[start_idx:end_idx]
|
112 |
+
paginated_alternatives = top_alternatives[start_idx:end_idx] if top_alternatives else []
|
113 |
+
|
114 |
+
# 1. Main Log Probability Plot (Interactive Plotly)
|
115 |
+
main_fig = go.Figure()
|
116 |
+
main_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
|
117 |
+
main_fig.update_layout(
|
118 |
+
title="Log Probabilities of Generated Tokens",
|
119 |
+
xaxis_title="Token Position",
|
120 |
+
yaxis_title="Log Probability",
|
121 |
+
hovermode="closest",
|
122 |
+
clickmode='event+select'
|
123 |
+
)
|
124 |
+
main_fig.update_traces(
|
125 |
+
customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, prob) in enumerate(zip(paginated_tokens, paginated_logprobs))],
|
126 |
+
hovertemplate='<b>%{customdata}</b><extra></extra>'
|
127 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
+
# 2. Probability Drop Analysis (Interactive Plotly)
|
130 |
+
if len(paginated_logprobs) < 2:
|
131 |
+
drops_fig = go.Figure()
|
132 |
+
drops_fig.add_trace(go.Bar(x=list(range(len(paginated_logprobs)-1)), y=[0], name='Drop', marker_color='red'))
|
133 |
+
else:
|
134 |
+
drops = [paginated_logprobs[i+1] - paginated_logprobs[i] for i in range(len(paginated_logprobs)-1)]
|
135 |
+
drops_fig = go.Figure()
|
136 |
+
drops_fig.add_trace(go.Bar(x=list(range(len(drops))), y=drops, name='Drop', marker_color='red'))
|
137 |
+
drops_fig.update_layout(
|
138 |
+
title="Significant Probability Drops",
|
139 |
+
xaxis_title="Token Position",
|
140 |
+
yaxis_title="Log Probability Drop",
|
141 |
+
hovermode="closest",
|
142 |
+
clickmode='event+select'
|
143 |
+
)
|
144 |
+
drops_fig.update_traces(
|
145 |
+
customdata=[f"Drop: {drop:.4f}, From: {paginated_tokens[i]} to {paginated_tokens[i+1]}, Position: {i+start_idx}" for i, drop in enumerate(drops)],
|
146 |
+
hovertemplate='<b>%{customdata}</b><extra></extra>'
|
147 |
+
)
|
148 |
|
149 |
+
# 3. Anomaly Detection (Interactive Plotly)
|
150 |
+
if not paginated_logprobs:
|
151 |
+
anomaly_fig = go.Figure()
|
152 |
+
anomaly_fig.add_trace(go.Scatter(x=[], y=[], mode='markers+lines', name='Log Prob', marker_color='blue'))
|
153 |
+
else:
|
154 |
+
z_scores = np.abs(stats.zscore(paginated_logprobs))
|
155 |
outliers = z_scores > 2 # Threshold for outliers
|
156 |
+
anomaly_fig = go.Figure()
|
157 |
+
anomaly_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker_color='blue'))
|
158 |
+
anomaly_fig.add_trace(go.Scatter(x=np.where(outliers)[0], y=[paginated_logprobs[i] for i in np.where(outliers)[0]], mode='markers', name='Outliers', marker_color='red'))
|
159 |
+
anomaly_fig.update_layout(
|
160 |
+
title="Log Probabilities with Outliers",
|
161 |
+
xaxis_title="Token Position",
|
162 |
+
yaxis_title="Log Probability",
|
163 |
+
hovermode="closest",
|
164 |
+
clickmode='event+select'
|
165 |
+
)
|
166 |
+
anomaly_fig.update_traces(
|
167 |
+
customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}, Outlier: {out}" for i, (tok, prob, out) in enumerate(zip(paginated_tokens, paginated_logprobs, outliers))],
|
168 |
+
hovertemplate='<b>%{customdata}</b><extra></extra>'
|
169 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
+
# Create DataFrame for the table (paginated)
|
172 |
table_data = []
|
173 |
+
for i, entry in enumerate(content[start_idx:end_idx]):
|
174 |
logprob = ensure_float(entry.get("logprob", None))
|
175 |
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter and "top_logprobs" in entry and entry["top_logprobs"] is not None:
|
176 |
token = entry["token"]
|
|
|
205 |
else None
|
206 |
)
|
207 |
|
208 |
+
# Generate colored text (paginated)
|
209 |
+
if paginated_logprobs:
|
210 |
+
min_logprob = min(paginated_logprobs)
|
211 |
+
max_logprob = max(paginated_logprobs)
|
212 |
if max_logprob == min_logprob:
|
213 |
+
normalized_probs = [0.5] * len(paginated_logprobs)
|
214 |
else:
|
215 |
normalized_probs = [
|
216 |
+
(lp - min_logprob) / (max_logprob - min_logprob) for lp in paginated_logprobs
|
217 |
]
|
218 |
|
219 |
colored_text = ""
|
220 |
+
for i, (token, norm_prob) in enumerate(zip(paginated_tokens, normalized_probs)):
|
221 |
r = int(255 * (1 - norm_prob)) # Red for low confidence
|
222 |
g = int(255 * norm_prob) # Green for high confidence
|
223 |
b = 0
|
224 |
color = f"rgb({r}, {g}, {b})"
|
225 |
colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
|
226 |
+
if i < len(paginated_tokens) - 1:
|
227 |
colored_text += " "
|
228 |
colored_text_html = f"<p>{colored_text}</p>"
|
229 |
else:
|
230 |
colored_text_html = "No finite log probabilities to display."
|
231 |
|
232 |
+
# Top 3 Token Log Probabilities (paginated)
|
233 |
alt_viz_html = ""
|
234 |
+
if paginated_logprobs and paginated_alternatives:
|
235 |
+
alt_viz_html = "<h3>Top 3 Token Log Probabilities (Paginated)</h3><ul>"
|
236 |
+
for i, (token, probs) in enumerate(zip(paginated_tokens, paginated_alternatives)):
|
237 |
+
alt_viz_html += f"<li>Position {i+start_idx} (Token: {token}):<br>"
|
238 |
for tok, prob in probs:
|
239 |
alt_viz_html += f"{tok}: {prob:.4f}<br>"
|
240 |
alt_viz_html += "</li>"
|
241 |
alt_viz_html += "</ul>"
|
242 |
|
243 |
+
return (main_fig, df, colored_text_html, alt_viz_html, drops_fig, anomaly_fig, total_pages, page)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
except Exception as e:
|
246 |
logger.error("Visualization failed: %s", str(e))
|
247 |
+
return (gr.update(value=f"Error: {str(e)}"), None, "No finite log probabilities to display.", None, gr.update(value="No data for probability drops."), gr.update(value="No data for anomalies."), 1, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
+
# Gradio interface with interactive layout and pagination
|
250 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
251 |
gr.Markdown("# Log Probability Visualizer")
|
252 |
gr.Markdown(
|
253 |
+
"Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use the filter and pagination to navigate large inputs."
|
254 |
)
|
255 |
|
256 |
with gr.Row():
|
|
|
262 |
)
|
263 |
with gr.Column(scale=1):
|
264 |
prob_filter = gr.Slider(minimum=-1e9, maximum=0, value=-1e9, label="Log Probability Filter (≥)")
|
265 |
+
page_size = gr.Number(value=50, label="Page Size", precision=0, minimum=10, maximum=1000)
|
266 |
+
page = gr.Number(value=0, label="Page Number", precision=0, minimum=0)
|
267 |
|
268 |
+
with gr.Row():
|
269 |
+
plot_output = gr.Plot(label="Log Probability Plot (Click for Tokens)")
|
270 |
+
drops_output = gr.Plot(label="Probability Drops (Click for Details)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
+
with gr.Row():
|
273 |
+
anomaly_output = gr.Plot(label="Anomaly Detection (Click for Details)")
|
274 |
+
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
|
|
|
|
|
|
|
|
|
275 |
|
276 |
+
with gr.Row():
|
277 |
+
text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
|
278 |
+
alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities")
|
|
|
279 |
|
280 |
btn = gr.Button("Visualize")
|
281 |
btn.click(
|
282 |
fn=visualize_logprobs,
|
283 |
+
inputs=[json_input, prob_filter, page_size, page],
|
284 |
+
outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, anomaly_output, gr.State(visible=False), gr.State(visible=False)],
|
285 |
+
)
|
286 |
+
|
287 |
+
# Pagination controls
|
288 |
+
with gr.Row():
|
289 |
+
prev_btn = gr.Button("Previous Page")
|
290 |
+
next_btn = gr.Button("Next Page")
|
291 |
+
total_pages_output = gr.Number(label="Total Pages", interactive=False, visible=False)
|
292 |
+
current_page_output = gr.Number(label="Current Page", interactive=False, visible=False)
|
293 |
+
|
294 |
+
def update_page(json_input, prob_filter, page_size, current_page, action):
|
295 |
+
if action == "prev" and current_page > 0:
|
296 |
+
current_page -= 1
|
297 |
+
elif action == "next":
|
298 |
+
total_pages = visualize_logprobs(json_input, prob_filter, page_size, 0)[6] # Get total pages
|
299 |
+
if current_page < total_pages - 1:
|
300 |
+
current_page += 1
|
301 |
+
return gr.update(value=current_page), gr.update(value=total_pages)
|
302 |
+
|
303 |
+
prev_btn.click(
|
304 |
+
fn=lambda *args: update_page(*args, "prev"),
|
305 |
+
inputs=[json_input, prob_filter, page_size, page, gr.State()],
|
306 |
+
outputs=[page, total_pages_output]
|
307 |
+
)
|
308 |
+
|
309 |
+
next_btn.click(
|
310 |
+
fn=lambda *args: update_page(*args, "next"),
|
311 |
+
inputs=[json_input, prob_filter, page_size, page, gr.State()],
|
312 |
+
outputs=[page, total_pages_output]
|
313 |
)
|
314 |
|
315 |
app.launch()
|