Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,14 @@ import base64
|
|
7 |
import math
|
8 |
import ast
|
9 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from matplotlib.widgets import Cursor
|
11 |
|
12 |
# Set up logging
|
@@ -56,8 +64,8 @@ def ensure_float(value):
|
|
56 |
return float(value)
|
57 |
return None
|
58 |
|
59 |
-
# Function to process and visualize log probs with
|
60 |
-
def visualize_logprobs(json_input):
|
61 |
try:
|
62 |
# Parse the input (handles both JSON and Python dictionaries)
|
63 |
data = parse_input(json_input)
|
@@ -74,11 +82,18 @@ def visualize_logprobs(json_input):
|
|
74 |
tokens = []
|
75 |
logprobs = []
|
76 |
top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
|
|
|
77 |
for entry in content:
|
78 |
logprob = ensure_float(entry.get("logprob", None))
|
79 |
-
if logprob is not None and math.isfinite(logprob):
|
80 |
tokens.append(entry["token"])
|
81 |
logprobs.append(logprob)
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
# Get top_logprobs, default to empty dict if None
|
83 |
top_probs = entry.get("top_logprobs", {})
|
84 |
# Ensure all values in top_logprobs are floats
|
@@ -96,55 +111,340 @@ def visualize_logprobs(json_input):
|
|
96 |
else:
|
97 |
logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
|
98 |
|
99 |
-
#
|
100 |
-
if logprobs:
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
# Create DataFrame for the table
|
144 |
table_data = []
|
145 |
for i, entry in enumerate(content):
|
146 |
logprob = ensure_float(entry.get("logprob", None))
|
147 |
-
if logprob is not None and math.isfinite(logprob) and "top_logprobs" in entry and entry["top_logprobs"] is not None:
|
148 |
token = entry["token"]
|
149 |
top_logprobs = entry["top_logprobs"]
|
150 |
# Ensure all values in top_logprobs are floats
|
@@ -201,7 +501,7 @@ def visualize_logprobs(json_input):
|
|
201 |
else:
|
202 |
colored_text_html = "No finite log probabilities to display."
|
203 |
|
204 |
-
#
|
205 |
alt_viz_html = ""
|
206 |
if logprobs and top_alternatives:
|
207 |
alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>"
|
@@ -212,26 +512,63 @@ def visualize_logprobs(json_input):
|
|
212 |
alt_viz_html += "</li>"
|
213 |
alt_viz_html += "</ul>"
|
214 |
|
215 |
-
return
|
|
|
|
|
|
|
216 |
|
217 |
except Exception as e:
|
218 |
logger.error("Visualization failed: %s", str(e))
|
219 |
-
return f"Error: {str(e)}", None, None, None
|
220 |
|
221 |
-
# Gradio interface
|
222 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
223 |
gr.Markdown("# Log Probability Visualizer")
|
224 |
gr.Markdown(
|
225 |
-
"Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities.
|
226 |
)
|
227 |
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
-
plot_output = gr.HTML(label="Log Probability Plot (Hover for Tokens)")
|
235 |
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
|
236 |
text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
|
237 |
alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities")
|
@@ -239,8 +576,14 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
|
|
239 |
btn = gr.Button("Visualize")
|
240 |
btn.click(
|
241 |
fn=visualize_logprobs,
|
242 |
-
inputs=json_input,
|
243 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
)
|
245 |
|
246 |
app.launch()
|
|
|
7 |
import math
|
8 |
import ast
|
9 |
import logging
|
10 |
+
import numpy as np
|
11 |
+
from sklearn.cluster import KMeans
|
12 |
+
from sklearn.decomposition import PCA
|
13 |
+
from sklearn.manifold import TSNE
|
14 |
+
from scipy import stats
|
15 |
+
from scipy.stats import entropy
|
16 |
+
from scipy.signal import correlate
|
17 |
+
import networkx as nx
|
18 |
from matplotlib.widgets import Cursor
|
19 |
|
20 |
# Set up logging
|
|
|
64 |
return float(value)
|
65 |
return None
|
66 |
|
67 |
+
# Function to process and visualize log probs with multiple analyses
|
68 |
+
def visualize_logprobs(json_input, prob_filter=-float('inf')):
|
69 |
try:
|
70 |
# Parse the input (handles both JSON and Python dictionaries)
|
71 |
data = parse_input(json_input)
|
|
|
82 |
tokens = []
|
83 |
logprobs = []
|
84 |
top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
|
85 |
+
token_types = [] # Simplified token type categorization
|
86 |
for entry in content:
|
87 |
logprob = ensure_float(entry.get("logprob", None))
|
88 |
+
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter:
|
89 |
tokens.append(entry["token"])
|
90 |
logprobs.append(logprob)
|
91 |
+
# Categorize token type (simple heuristic)
|
92 |
+
token = entry["token"].lower().strip()
|
93 |
+
if token in ["the", "a", "an"]: token_types.append("article")
|
94 |
+
elif token in ["is", "are", "was", "were"]: token_types.append("verb")
|
95 |
+
elif token in ["top", "so", "need", "figure"]: token_types.append("noun")
|
96 |
+
else: token_types.append("other")
|
97 |
# Get top_logprobs, default to empty dict if None
|
98 |
top_probs = entry.get("top_logprobs", {})
|
99 |
# Ensure all values in top_logprobs are floats
|
|
|
111 |
else:
|
112 |
logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
|
113 |
|
114 |
+
# If no valid data after filtering, return error messages
|
115 |
+
if not logprobs:
|
116 |
+
return "No finite log probabilities to visualize after filtering.", None, None, None, None, None, None, None, None, None, None
|
117 |
+
|
118 |
+
# 1. Main Log Probability Plot (with click for tokens)
|
119 |
+
fig_main, ax_main = plt.subplots(figsize=(10, 5))
|
120 |
+
scatter = ax_main.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Selected Token")[0]
|
121 |
+
ax_main.set_title("Log Probabilities of Generated Tokens")
|
122 |
+
ax_main.set_xlabel("Token Position")
|
123 |
+
ax_main.set_ylabel("Log Probability")
|
124 |
+
ax_main.grid(True)
|
125 |
+
ax_main.set_xticks([]) # Hide X-axis labels by default
|
126 |
+
|
127 |
+
# Add click functionality to show token
|
128 |
+
token_annotations = []
|
129 |
+
for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
|
130 |
+
annotation = ax_main.annotate('', (x, y), xytext=(10, 10), textcoords='offset points', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), visible=False)
|
131 |
+
token_annotations.append(annotation)
|
132 |
+
|
133 |
+
def on_click(event):
|
134 |
+
if event.inaxes == ax_main:
|
135 |
+
for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
|
136 |
+
contains, _ = scatter.contains(event)
|
137 |
+
if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5:
|
138 |
+
token_annotations[i].set_text(tokens[i])
|
139 |
+
token_annotations[i].set_visible(True)
|
140 |
+
fig_main.canvas.draw_idle()
|
141 |
+
else:
|
142 |
+
token_annotations[i].set_visible(False)
|
143 |
+
fig_main.canvas.draw_idle()
|
144 |
+
|
145 |
+
fig_main.canvas.mpl_connect('button_press_event', on_click)
|
146 |
+
|
147 |
+
# Save main plot
|
148 |
+
buf_main = io.BytesIO()
|
149 |
+
plt.savefig(buf_main, format="png", bbox_inches="tight", dpi=100)
|
150 |
+
buf_main.seek(0)
|
151 |
+
plt.close(fig_main)
|
152 |
+
img_main_bytes = buf_main.getvalue()
|
153 |
+
img_main_base64 = base64.b64encode(img_main_bytes).decode("utf-8")
|
154 |
+
img_main_html = f'<img src="data:image/png;base64,{img_main_base64}" style="max-width: 100%; height: auto;">'
|
155 |
+
|
156 |
+
# 2. K-Means Clustering of Log Probabilities
|
157 |
+
kmeans = KMeans(n_clusters=3, random_state=42)
|
158 |
+
cluster_labels = kmeans.fit_predict(np.array(logprobs).reshape(-1, 1))
|
159 |
+
fig_cluster, ax_cluster = plt.subplots(figsize=(10, 5))
|
160 |
+
scatter = ax_cluster.scatter(range(len(logprobs)), logprobs, c=cluster_labels, cmap='viridis')
|
161 |
+
ax_cluster.set_title("K-Means Clustering of Log Probabilities")
|
162 |
+
ax_cluster.set_xlabel("Token Position")
|
163 |
+
ax_cluster.set_ylabel("Log Probability")
|
164 |
+
ax_cluster.grid(True)
|
165 |
+
plt.colorbar(scatter, ax=ax_cluster, label="Cluster")
|
166 |
+
buf_cluster = io.BytesIO()
|
167 |
+
plt.savefig(buf_cluster, format="png", bbox_inches="tight", dpi=100)
|
168 |
+
buf_cluster.seek(0)
|
169 |
+
plt.close(fig_cluster)
|
170 |
+
img_cluster_bytes = buf_cluster.getvalue()
|
171 |
+
img_cluster_base64 = base64.b64encode(img_cluster_bytes).decode("utf-8")
|
172 |
+
img_cluster_html = f'<img src="data:image/png;base64,{img_cluster_base64}" style="max-width: 100%; height: auto;">'
|
173 |
+
|
174 |
+
# 3. Probability Drop Analysis
|
175 |
+
drops = [logprobs[i+1] - logprobs[i] if i < len(logprobs)-1 else 0 for i in range(len(logprobs))]
|
176 |
+
fig_drops, ax_drops = plt.subplots(figsize=(10, 5))
|
177 |
+
ax_drops.bar(range(len(drops)), drops, color='red', alpha=0.5)
|
178 |
+
ax_drops.set_title("Significant Probability Drops")
|
179 |
+
ax_drops.set_xlabel("Token Position")
|
180 |
+
ax_drops.set_ylabel("Log Probability Drop")
|
181 |
+
ax_drops.grid(True)
|
182 |
+
buf_drops = io.BytesIO()
|
183 |
+
plt.savefig(buf_drops, format="png", bbox_inches="tight", dpi=100)
|
184 |
+
buf_drops.seek(0)
|
185 |
+
plt.close(fig_drops)
|
186 |
+
img_drops_bytes = buf_drops.getvalue()
|
187 |
+
img_drops_base64 = base64.b64encode(img_drops_bytes).decode("utf-8")
|
188 |
+
img_drops_html = f'<img src="data:image/png;base64,{img_drops_base64}" style="max-width: 100%; height: auto;">'
|
189 |
+
|
190 |
+
# 4. N-Gram Analysis (Bigrams for simplicity)
|
191 |
+
bigrams = [(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)]
|
192 |
+
bigram_probs = [logprobs[i] + logprobs[i+1] for i in range(len(tokens)-1)]
|
193 |
+
fig_ngram, ax_ngram = plt.subplots(figsize=(10, 5))
|
194 |
+
ax_ngram.bar(range(len(bigrams)), bigram_probs, color='green')
|
195 |
+
ax_ngram.set_title("N-Gram (Bigrams) Probability Sum")
|
196 |
+
ax_ngram.set_xlabel("Bigram Position")
|
197 |
+
ax_ngram.set_ylabel("Sum of Log Probabilities")
|
198 |
+
ax_ngram.set_xticks(range(len(bigrams)))
|
199 |
+
ax_ngram.set_xticklabels([f"{b[0]}->{b[1]}" for b in bigrams], rotation=45, ha="right")
|
200 |
+
ax_ngram.grid(True)
|
201 |
+
buf_ngram = io.BytesIO()
|
202 |
+
plt.savefig(buf_ngram, format="png", bbox_inches="tight", dpi=100)
|
203 |
+
buf_ngram.seek(0)
|
204 |
+
plt.close(fig_ngram)
|
205 |
+
img_ngram_bytes = buf_ngram.getvalue()
|
206 |
+
img_ngram_base64 = base64.b64encode(img_ngram_bytes).decode("utf-8")
|
207 |
+
img_ngram_html = f'<img src="data:image/png;base64,{img_ngram_base64}" style="max-width: 100%; height: auto;">'
|
208 |
+
|
209 |
+
# 5. Markov Chain Modeling (Simple Graph)
|
210 |
+
G = nx.DiGraph()
|
211 |
+
for i in range(len(tokens)-1):
|
212 |
+
G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i])
|
213 |
+
fig_markov, ax_markov = plt.subplots(figsize=(10, 5))
|
214 |
+
pos = nx.spring_layout(G)
|
215 |
+
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_markov)
|
216 |
+
ax_markov.set_title("Markov Chain of Token Transitions")
|
217 |
+
buf_markov = io.BytesIO()
|
218 |
+
plt.savefig(buf_markov, format="png", bbox_inches="tight", dpi=100)
|
219 |
+
buf_markov.seek(0)
|
220 |
+
plt.close(fig_markov)
|
221 |
+
img_markov_bytes = buf_markov.getvalue()
|
222 |
+
img_markov_base64 = base64.b64encode(img_markov_bytes).decode("utf-8")
|
223 |
+
img_markov_html = f'<img src="data:image/png;base64,{img_markov_base64}" style="max-width: 100%; height: auto;">'
|
224 |
+
|
225 |
+
# 6. Anomaly Detection (Outlier Detection with Z-Score)
|
226 |
+
z_scores = np.abs(stats.zscore(logprobs))
|
227 |
+
outliers = z_scores > 2 # Threshold for outliers
|
228 |
+
fig_anomaly, ax_anomaly = plt.subplots(figsize=(10, 5))
|
229 |
+
ax_anomaly.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b")
|
230 |
+
ax_anomaly.plot(np.where(outliers)[0], [logprobs[i] for i in np.where(outliers)[0]], "ro", label="Outliers")
|
231 |
+
ax_anomaly.set_title("Log Probabilities with Outliers")
|
232 |
+
ax_anomaly.set_xlabel("Token Position")
|
233 |
+
ax_anomaly.set_ylabel("Log Probability")
|
234 |
+
ax_anomaly.grid(True)
|
235 |
+
ax_anomaly.legend()
|
236 |
+
ax_anomaly.set_xticks([]) # Hide X-axis labels
|
237 |
+
buf_anomaly = io.BytesIO()
|
238 |
+
plt.savefig(buf_anomaly, format="png", bbox_inches="tight", dpi=100)
|
239 |
+
buf_anomaly.seek(0)
|
240 |
+
plt.close(fig_anomaly)
|
241 |
+
img_anomaly_bytes = buf_anomaly.getvalue()
|
242 |
+
img_anomaly_base64 = base64.b64encode(img_anomaly_bytes).decode("utf-8")
|
243 |
+
img_anomaly_html = f'<img src="data:image/png;base64,{img_anomaly_base64}" style="max-width: 100%; height: auto;">'
|
244 |
+
|
245 |
+
# 7. Autocorrelation
|
246 |
+
autocorr = correlate(logprobs, logprobs, mode='full')
|
247 |
+
autocorr = autocorr[len(autocorr)//2:] / len(logprobs) # Normalize
|
248 |
+
fig_autocorr, ax_autocorr = plt.subplots(figsize=(10, 5))
|
249 |
+
ax_autocorr.plot(range(len(autocorr)), autocorr, color='purple')
|
250 |
+
ax_autocorr.set_title("Autocorrelation of Log Probabilities")
|
251 |
+
ax_autocorr.set_xlabel("Lag")
|
252 |
+
ax_autocorr.set_ylabel("Autocorrelation")
|
253 |
+
ax_autocorr.grid(True)
|
254 |
+
buf_autocorr = io.BytesIO()
|
255 |
+
plt.savefig(buf_autocorr, format="png", bbox_inches="tight", dpi=100)
|
256 |
+
buf_autocorr.seek(0)
|
257 |
+
plt.close(fig_autocorr)
|
258 |
+
img_autocorr_bytes = buf_autocorr.getvalue()
|
259 |
+
img_autocorr_base64 = base64.b64encode(img_autocorr_bytes).decode("utf-8")
|
260 |
+
img_autocorr_html = f'<img src="data:image/png;base64,{img_autocorr_base64}" style="max-width: 100%; height: auto;">'
|
261 |
+
|
262 |
+
# 8. Smoothing (Moving Average)
|
263 |
+
window_size = 3
|
264 |
+
moving_avg = np.convolve(logprobs, np.ones(window_size)/window_size, mode='valid')
|
265 |
+
fig_smoothing, ax_smoothing = plt.subplots(figsize=(10, 5))
|
266 |
+
ax_smoothing.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Original")
|
267 |
+
ax_smoothing.plot(range(window_size-1, len(logprobs)), moving_avg, color="orange", label="Moving Average")
|
268 |
+
ax_smoothing.set_title("Log Probabilities with Moving Average")
|
269 |
+
ax_smoothing.set_xlabel("Token Position")
|
270 |
+
ax_smoothing.set_ylabel("Log Probability")
|
271 |
+
ax_smoothing.grid(True)
|
272 |
+
ax_smoothing.legend()
|
273 |
+
ax_smoothing.set_xticks([]) # Hide X-axis labels
|
274 |
+
buf_smoothing = io.BytesIO()
|
275 |
+
plt.savefig(buf_smoothing, format="png", bbox_inches="tight", dpi=100)
|
276 |
+
buf_smoothing.seek(0)
|
277 |
+
plt.close(fig_smoothing)
|
278 |
+
img_smoothing_bytes = buf_smoothing.getvalue()
|
279 |
+
img_smoothing_base64 = base64.b64encode(img_smoothing_bytes).decode("utf-8")
|
280 |
+
img_smoothing_html = f'<img src="data:image/png;base64,{img_smoothing_base64}" style="max-width: 100%; height: auto;">'
|
281 |
+
|
282 |
+
# 9. Uncertainty Propagation (Variance of Top Logprobs)
|
283 |
+
variances = []
|
284 |
+
for probs in top_alternatives:
|
285 |
+
if len(probs) > 1:
|
286 |
+
values = [p[1] for p in probs]
|
287 |
+
variances.append(np.var(values))
|
288 |
+
else:
|
289 |
+
variances.append(0)
|
290 |
+
fig_uncertainty, ax_uncertainty = plt.subplots(figsize=(10, 5))
|
291 |
+
ax_uncertainty.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Log Prob")
|
292 |
+
ax_uncertainty.fill_between(range(len(logprobs)), [lp - v for lp, v in zip(logprobs, variances)],
|
293 |
+
[lp + v for lp, v in zip(logprobs, variances)], color='gray', alpha=0.3, label="Uncertainty")
|
294 |
+
ax_uncertainty.set_title("Log Probabilities with Uncertainty Propagation")
|
295 |
+
ax_uncertainty.set_xlabel("Token Position")
|
296 |
+
ax_uncertainty.set_ylabel("Log Probability")
|
297 |
+
ax_uncertainty.grid(True)
|
298 |
+
ax_uncertainty.legend()
|
299 |
+
ax_uncertainty.set_xticks([]) # Hide X-axis labels
|
300 |
+
buf_uncertainty = io.BytesIO()
|
301 |
+
plt.savefig(buf_uncertainty, format="png", bbox_inches="tight", dpi=100)
|
302 |
+
buf_uncertainty.seek(0)
|
303 |
+
plt.close(fig_uncertainty)
|
304 |
+
img_uncertainty_bytes = buf_uncertainty.getvalue()
|
305 |
+
img_uncertainty_base64 = base64.b64encode(img_uncertainty_bytes).decode("utf-8")
|
306 |
+
img_uncertainty_html = f'<img src="data:image/png;base64,{img_uncertainty_base64}" style="max-width: 100%; height: auto;">'
|
307 |
+
|
308 |
+
# 10. Correlation Heatmap
|
309 |
+
corr_matrix = np.corrcoef(logprobs, rowvar=False)
|
310 |
+
fig_corr, ax_corr = plt.subplots(figsize=(10, 5))
|
311 |
+
im = ax_corr.imshow(corr_matrix, cmap='coolwarm', interpolation='nearest')
|
312 |
+
ax_corr.set_title("Correlation of Log Probabilities Across Positions")
|
313 |
+
ax_corr.set_xlabel("Token Position")
|
314 |
+
ax_corr.set_ylabel("Token Position")
|
315 |
+
plt.colorbar(im, ax=ax_corr, label="Correlation")
|
316 |
+
buf_corr = io.BytesIO()
|
317 |
+
plt.savefig(buf_corr, format="png", bbox_inches="tight", dpi=100)
|
318 |
+
buf_corr.seek(0)
|
319 |
+
plt.close(fig_corr)
|
320 |
+
img_corr_bytes = buf_corr.getvalue()
|
321 |
+
img_corr_base64 = base64.b64encode(img_corr_bytes).decode("utf-8")
|
322 |
+
img_corr_html = f'<img src="data:image/png;base64,{img_corr_base64}" style="max-width: 100%; height: auto;">'
|
323 |
+
|
324 |
+
# 11. Token Type Correlation
|
325 |
+
type_probs = {t: [] for t in set(token_types)}
|
326 |
+
for t, p in zip(token_types, logprobs):
|
327 |
+
type_probs[t].append(p)
|
328 |
+
fig_type, ax_type = plt.subplots(figsize=(10, 5))
|
329 |
+
for t in type_probs:
|
330 |
+
ax_type.bar(t, np.mean(type_probs[t]), yerr=np.std(type_probs[t]), capsize=5, label=t)
|
331 |
+
ax_type.set_title("Average Log Probability by Token Type")
|
332 |
+
ax_type.set_xlabel("Token Type")
|
333 |
+
ax_type.set_ylabel("Average Log Probability")
|
334 |
+
ax_type.grid(True)
|
335 |
+
ax_type.legend()
|
336 |
+
buf_type = io.BytesIO()
|
337 |
+
plt.savefig(buf_type, format="png", bbox_inches="tight", dpi=100)
|
338 |
+
buf_type.seek(0)
|
339 |
+
plt.close(fig_type)
|
340 |
+
img_type_bytes = buf_type.getvalue()
|
341 |
+
img_type_base64 = base64.b64encode(img_type_bytes).decode("utf-8")
|
342 |
+
img_type_html = f'<img src="data:image/png;base64,{img_type_base64}" style="max-width: 100%; height: auto;">'
|
343 |
+
|
344 |
+
# 12. Token Embedding Similarity vs. Probability (Simulated)
|
345 |
+
# Simulate embedding distances (e.g., cosine similarity) as random values for demonstration
|
346 |
+
simulated_embeddings = np.random.rand(len(tokens), 2) # 2D embeddings
|
347 |
+
fig_embed, ax_embed = plt.subplots(figsize=(10, 5))
|
348 |
+
ax_embed.scatter(simulated_embeddings[:, 0], simulated_embeddings[:, 1], c=logprobs, cmap='viridis')
|
349 |
+
ax_embed.set_title("Token Embedding Similarity vs. Log Probability")
|
350 |
+
ax_embed.set_xlabel("Embedding Dimension 1")
|
351 |
+
ax_embed.set_ylabel("Embedding Dimension 2")
|
352 |
+
plt.colorbar(ax_embed.collections[0], ax=ax_embed, label="Log Probability")
|
353 |
+
buf_embed = io.BytesIO()
|
354 |
+
plt.savefig(buf_embed, format="png", bbox_inches="tight", dpi=100)
|
355 |
+
buf_embed.seek(0)
|
356 |
+
plt.close(fig_embed)
|
357 |
+
img_embed_bytes = buf_embed.getvalue()
|
358 |
+
img_embed_base64 = base64.b64encode(img_embed_bytes).decode("utf-8")
|
359 |
+
img_embed_html = f'<img src="data:image/png;base64,{img_embed_base64}" style="max-width: 100%; height: auto;">'
|
360 |
+
|
361 |
+
# 13. Bayesian Inference (Simplified as Inferred Probabilities)
|
362 |
+
# Simulate inferred probabilities based on top_logprobs entropy
|
363 |
+
entropies = [entropy([p[1] for p in probs], base=2) for probs in top_alternatives if len(probs) > 1]
|
364 |
+
fig_bayesian, ax_bayesian = plt.subplots(figsize=(10, 5))
|
365 |
+
ax_bayesian.bar(range(len(entropies)), entropies, color='orange')
|
366 |
+
ax_bayesian.set_title("Bayesian Inferred Uncertainty (Entropy)")
|
367 |
+
ax_bayesian.set_xlabel("Token Position")
|
368 |
+
ax_bayesian.set_ylabel("Entropy")
|
369 |
+
ax_bayesian.grid(True)
|
370 |
+
buf_bayesian = io.BytesIO()
|
371 |
+
plt.savefig(buf_bayesian, format="png", bbox_inches="tight", dpi=100)
|
372 |
+
buf_bayesian.seek(0)
|
373 |
+
plt.close(fig_bayesian)
|
374 |
+
img_bayesian_bytes = buf_bayesian.getvalue()
|
375 |
+
img_bayesian_base64 = base64.b64encode(img_bayesian_bytes).decode("utf-8")
|
376 |
+
img_bayesian_html = f'<img src="data:image/png;base64,{img_bayesian_base64}" style="max-width: 100%; height: auto;">'
|
377 |
+
|
378 |
+
# 14. Graph-Based Analysis
|
379 |
+
G = nx.DiGraph()
|
380 |
+
for i in range(len(tokens)-1):
|
381 |
+
G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i])
|
382 |
+
fig_graph, ax_graph = plt.subplots(figsize=(10, 5))
|
383 |
+
pos = nx.spring_layout(G)
|
384 |
+
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_graph)
|
385 |
+
ax_graph.set_title("Graph of Token Transitions")
|
386 |
+
buf_graph = io.BytesIO()
|
387 |
+
plt.savefig(buf_graph, format="png", bbox_inches="tight", dpi=100)
|
388 |
+
buf_graph.seek(0)
|
389 |
+
plt.close(fig_graph)
|
390 |
+
img_graph_bytes = buf_graph.getvalue()
|
391 |
+
img_graph_base64 = base64.b64encode(img_graph_bytes).decode("utf-8")
|
392 |
+
img_graph_html = f'<img src="data:image/png;base64,{img_graph_base64}" style="max-width: 100%; height: auto;">'
|
393 |
+
|
394 |
+
# 15. Dimensionality Reduction (t-SNE)
|
395 |
+
features = np.array([logprobs + [p[1] for p in alts[:2]] for logprobs, alts in zip([logprobs], top_alternatives)])
|
396 |
+
tsne = TSNE(n_components=2, random_state=42)
|
397 |
+
tsne_result = tsne.fit_transform(features.T)
|
398 |
+
fig_tsne, ax_tsne = plt.subplots(figsize=(10, 5))
|
399 |
+
scatter = ax_tsne.scatter(tsne_result[:, 0], tsne_result[:, 1], c=logprobs, cmap='viridis')
|
400 |
+
ax_tsne.set_title("t-SNE of Log Probabilities and Top Alternatives")
|
401 |
+
ax_tsne.set_xlabel("t-SNE Dimension 1")
|
402 |
+
ax_tsne.set_ylabel("t-SNE Dimension 2")
|
403 |
+
plt.colorbar(scatter, ax=ax_tsne, label="Log Probability")
|
404 |
+
buf_tsne = io.BytesIO()
|
405 |
+
plt.savefig(buf_tsne, format="png", bbox_inches="tight", dpi=100)
|
406 |
+
buf_tsne.seek(0)
|
407 |
+
plt.close(fig_tsne)
|
408 |
+
img_tsne_bytes = buf_tsne.getvalue()
|
409 |
+
img_tsne_base64 = base64.b64encode(img_tsne_bytes).decode("utf-8")
|
410 |
+
img_tsne_html = f'<img src="data:image/png;base64,{img_tsne_base64}" style="max-width: 100%; height: auto;">'
|
411 |
+
|
412 |
+
# 16. Interactive Heatmap
|
413 |
+
fig_heatmap, ax_heatmap = plt.subplots(figsize=(10, 5))
|
414 |
+
im = ax_heatmap.imshow([logprobs], cmap='viridis', aspect='auto')
|
415 |
+
ax_heatmap.set_title("Interactive Heatmap of Log Probabilities")
|
416 |
+
ax_heatmap.set_xlabel("Token Position")
|
417 |
+
ax_heatmap.set_ylabel("Probability Level")
|
418 |
+
plt.colorbar(im, ax=ax_heatmap, label="Log Probability")
|
419 |
+
buf_heatmap = io.BytesIO()
|
420 |
+
plt.savefig(buf_heatmap, format="png", bbox_inches="tight", dpi=100)
|
421 |
+
buf_heatmap.seek(0)
|
422 |
+
plt.close(fig_heatmap)
|
423 |
+
img_heatmap_bytes = buf_heatmap.getvalue()
|
424 |
+
img_heatmap_base64 = base64.b64encode(img_heatmap_bytes).decode("utf-8")
|
425 |
+
img_heatmap_html = f'<img src="data:image/png;base64,{img_heatmap_base64}" style="max-width: 100%; height: auto;">'
|
426 |
+
|
427 |
+
# 17. Probability Distribution Plots (Box Plots for Top Logprobs)
|
428 |
+
all_top_probs = [p[1] for alts in top_alternatives for p in alts]
|
429 |
+
fig_dist, ax_dist = plt.subplots(figsize=(10, 5))
|
430 |
+
ax_dist.boxplot([logprobs] + [p[1] for alts in top_alternatives for p in alts[:2]], labels=["Selected"] + ["Alt1", "Alt2"])
|
431 |
+
ax_dist.set_title("Probability Distribution of Top Tokens")
|
432 |
+
ax_dist.set_xlabel("Token Type")
|
433 |
+
ax_dist.set_ylabel("Log Probability")
|
434 |
+
ax_dist.grid(True)
|
435 |
+
buf_dist = io.BytesIO()
|
436 |
+
plt.savefig(buf_dist, format="png", bbox_inches="tight", dpi=100)
|
437 |
+
buf_dist.seek(0)
|
438 |
+
plt.close(fig_dist)
|
439 |
+
img_dist_bytes = buf_dist.getvalue()
|
440 |
+
img_dist_base64 = base64.b64encode(img_dist_bytes).decode("utf-8")
|
441 |
+
img_dist_html = f'<img src="data:image/png;base64,{img_dist_base64}" style="max-width: 100%; height: auto;">'
|
442 |
|
443 |
# Create DataFrame for the table
|
444 |
table_data = []
|
445 |
for i, entry in enumerate(content):
|
446 |
logprob = ensure_float(entry.get("logprob", None))
|
447 |
+
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter and "top_logprobs" in entry and entry["top_logprobs"] is not None:
|
448 |
token = entry["token"]
|
449 |
top_logprobs = entry["top_logprobs"]
|
450 |
# Ensure all values in top_logprobs are floats
|
|
|
501 |
else:
|
502 |
colored_text_html = "No finite log probabilities to display."
|
503 |
|
504 |
+
# Top 3 Token Log Probabilities
|
505 |
alt_viz_html = ""
|
506 |
if logprobs and top_alternatives:
|
507 |
alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>"
|
|
|
512 |
alt_viz_html += "</li>"
|
513 |
alt_viz_html += "</ul>"
|
514 |
|
515 |
+
return (img_main_html, df, colored_text_html, alt_viz_html, img_cluster_html, img_drops_html,
|
516 |
+
img_ngram_html, img_markov_html, img_anomaly_html, img_autocorr_html, img_smoothing_html,
|
517 |
+
img_uncertainty_html, img_corr_html, img_type_html, img_embed_html, img_bayesian_html,
|
518 |
+
img_graph_html, img_tsne_html, img_heatmap_html, img_dist_html)
|
519 |
|
520 |
except Exception as e:
|
521 |
logger.error("Visualization failed: %s", str(e))
|
522 |
+
return (f"Error: {str(e)}", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
|
523 |
|
524 |
+
# Gradio interface with dynamic filtering
|
525 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
526 |
gr.Markdown("# Log Probability Visualizer")
|
527 |
gr.Markdown(
|
528 |
+
"Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use the filter to focus on specific log probability ranges."
|
529 |
)
|
530 |
|
531 |
+
with gr.Row():
|
532 |
+
json_input = gr.Textbox(
|
533 |
+
label="JSON Input",
|
534 |
+
lines=10,
|
535 |
+
placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
|
536 |
+
)
|
537 |
+
prob_filter = gr.Slider(minimum=-float('inf'), maximum=0, value=-float('inf'), label="Log Probability Filter (≥)")
|
538 |
+
|
539 |
+
with gr.Row():
|
540 |
+
plot_output = gr.HTML(label="Log Probability Plot (Click for Tokens)")
|
541 |
+
cluster_output = gr.HTML(label="K-Means Clustering")
|
542 |
+
drops_output = gr.HTML(label="Probability Drops")
|
543 |
+
|
544 |
+
with gr.Row():
|
545 |
+
ngram_output = gr.HTML(label="N-Gram Analysis")
|
546 |
+
markov_output = gr.HTML(label="Markov Chain")
|
547 |
+
|
548 |
+
with gr.Row():
|
549 |
+
anomaly_output = gr.HTML(label="Anomaly Detection")
|
550 |
+
autocorr_output = gr.HTML(label="Autocorrelation")
|
551 |
+
|
552 |
+
with gr.Row():
|
553 |
+
smoothing_output = gr.HTML(label="Smoothing (Moving Average)")
|
554 |
+
uncertainty_output = gr.HTML(label="Uncertainty Propagation")
|
555 |
+
|
556 |
+
with gr.Row():
|
557 |
+
corr_output = gr.HTML(label="Correlation Heatmap")
|
558 |
+
type_output = gr.HTML(label="Token Type Correlation")
|
559 |
+
|
560 |
+
with gr.Row():
|
561 |
+
embed_output = gr.HTML(label="Embedding Similarity vs. Probability")
|
562 |
+
bayesian_output = gr.HTML(label="Bayesian Inference (Entropy)")
|
563 |
+
|
564 |
+
with gr.Row():
|
565 |
+
graph_output = gr.HTML(label="Graph of Token Transitions")
|
566 |
+
tsne_output = gr.HTML(label="t-SNE of Log Probabilities")
|
567 |
+
|
568 |
+
with gr.Row():
|
569 |
+
heatmap_output = gr.HTML(label="Interactive Heatmap")
|
570 |
+
dist_output = gr.HTML(label="Probability Distribution")
|
571 |
|
|
|
572 |
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
|
573 |
text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
|
574 |
alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities")
|
|
|
576 |
btn = gr.Button("Visualize")
|
577 |
btn.click(
|
578 |
fn=visualize_logprobs,
|
579 |
+
inputs=[json_input, prob_filter],
|
580 |
+
outputs=[
|
581 |
+
plot_output, table_output, text_output, alt_viz_output,
|
582 |
+
cluster_output, drops_output, ngram_output, markov_output,
|
583 |
+
anomaly_output, autocorr_output, smoothing_output, uncertainty_output,
|
584 |
+
corr_output, type_output, embed_output, bayesian_output,
|
585 |
+
graph_output, tsne_output, heatmap_output, dist_output
|
586 |
+
],
|
587 |
)
|
588 |
|
589 |
app.launch()
|