Spaces:
Sleeping
Sleeping
Update app.py (#1)
Browse files- Update app.py (b7159a07f2b7b4eb2ba22113df0d4c04de90a6bc)
app.py
CHANGED
@@ -7,6 +7,7 @@ import base64
|
|
7 |
import math
|
8 |
import ast
|
9 |
import logging
|
|
|
10 |
|
11 |
# Set up logging
|
12 |
logging.basicConfig(level=logging.DEBUG)
|
@@ -55,7 +56,7 @@ def ensure_float(value):
|
|
55 |
return float(value)
|
56 |
return None
|
57 |
|
58 |
-
# Function to process and visualize log probs
|
59 |
def visualize_logprobs(json_input):
|
60 |
try:
|
61 |
# Parse the input (handles both JSON and Python dictionaries)
|
@@ -69,30 +70,82 @@ def visualize_logprobs(json_input):
|
|
69 |
else:
|
70 |
raise ValueError("Input must be a list or dictionary with 'content' key")
|
71 |
|
72 |
-
# Extract tokens
|
73 |
tokens = []
|
74 |
logprobs = []
|
|
|
75 |
for entry in content:
|
76 |
logprob = ensure_float(entry.get("logprob", None))
|
77 |
if logprob is not None and math.isfinite(logprob):
|
78 |
tokens.append(entry["token"])
|
79 |
logprobs.append(logprob)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
else:
|
81 |
logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
|
82 |
|
83 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
table_data = []
|
85 |
-
for entry in content:
|
86 |
logprob = ensure_float(entry.get("logprob", None))
|
87 |
-
|
88 |
-
if (
|
89 |
-
logprob is not None
|
90 |
-
and math.isfinite(logprob)
|
91 |
-
and "top_logprobs" in entry
|
92 |
-
and entry["top_logprobs"] is not None
|
93 |
-
):
|
94 |
token = entry["token"]
|
95 |
-
logger.debug("Processing token: %s, logprob: %s (type: %s)", token, logprob, type(logprob))
|
96 |
top_logprobs = entry["top_logprobs"]
|
97 |
# Ensure all values in top_logprobs are floats
|
98 |
finite_top_logprobs = {}
|
@@ -100,44 +153,15 @@ def visualize_logprobs(json_input):
|
|
100 |
float_value = ensure_float(value)
|
101 |
if float_value is not None and math.isfinite(float_value):
|
102 |
finite_top_logprobs[key] = float_value
|
103 |
-
|
104 |
# Extract top 3 alternatives from top_logprobs
|
105 |
-
top_3 = sorted(
|
106 |
-
finite_top_logprobs.items(), key=lambda x: x[1], reverse=True
|
107 |
-
)[:3]
|
108 |
row = [token, f"{logprob:.4f}"]
|
109 |
for alt_token, alt_logprob in top_3:
|
110 |
row.append(f"{alt_token}: {alt_logprob:.4f}")
|
111 |
-
# Pad with empty strings if fewer than 3 alternatives
|
112 |
while len(row) < 5:
|
113 |
row.append("")
|
114 |
table_data.append(row)
|
115 |
|
116 |
-
# Create the plot
|
117 |
-
if logprobs:
|
118 |
-
plt.figure(figsize=(10, 5))
|
119 |
-
plt.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b")
|
120 |
-
plt.title("Log Probabilities of Generated Tokens")
|
121 |
-
plt.xlabel("Token Position")
|
122 |
-
plt.ylabel("Log Probability")
|
123 |
-
plt.grid(True)
|
124 |
-
plt.xticks(range(len(logprobs)), tokens, rotation=45, ha="right")
|
125 |
-
plt.tight_layout()
|
126 |
-
|
127 |
-
# Save plot to a bytes buffer
|
128 |
-
buf = io.BytesIO()
|
129 |
-
plt.savefig(buf, format="png", bbox_inches="tight")
|
130 |
-
buf.seek(0)
|
131 |
-
plt.close()
|
132 |
-
|
133 |
-
# Convert to base64 for Gradio
|
134 |
-
img_bytes = buf.getvalue()
|
135 |
-
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
|
136 |
-
img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
|
137 |
-
else:
|
138 |
-
img_html = "No finite log probabilities to plot."
|
139 |
-
|
140 |
-
# Create DataFrame for the table
|
141 |
df = (
|
142 |
pd.DataFrame(
|
143 |
table_data,
|
@@ -177,11 +201,22 @@ def visualize_logprobs(json_input):
|
|
177 |
else:
|
178 |
colored_text_html = "No finite log probabilities to display."
|
179 |
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
except Exception as e:
|
183 |
logger.error("Visualization failed: %s", str(e))
|
184 |
-
return f"Error: {str(e)}", None, None
|
185 |
|
186 |
# Gradio interface
|
187 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
@@ -196,15 +231,16 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
|
|
196 |
placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
|
197 |
)
|
198 |
|
199 |
-
plot_output = gr.HTML(label="Log Probability Plot")
|
200 |
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
|
201 |
text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
|
|
|
202 |
|
203 |
btn = gr.Button("Visualize")
|
204 |
btn.click(
|
205 |
fn=visualize_logprobs,
|
206 |
inputs=json_input,
|
207 |
-
outputs=[plot_output, table_output, text_output],
|
208 |
)
|
209 |
|
210 |
app.launch()
|
|
|
7 |
import math
|
8 |
import ast
|
9 |
import logging
|
10 |
+
from matplotlib.widgets import Cursor
|
11 |
|
12 |
# Set up logging
|
13 |
logging.basicConfig(level=logging.DEBUG)
|
|
|
56 |
return float(value)
|
57 |
return None
|
58 |
|
59 |
+
# Function to process and visualize log probs with hover and alternatives
|
60 |
def visualize_logprobs(json_input):
|
61 |
try:
|
62 |
# Parse the input (handles both JSON and Python dictionaries)
|
|
|
70 |
else:
|
71 |
raise ValueError("Input must be a list or dictionary with 'content' key")
|
72 |
|
73 |
+
# Extract tokens, log probs, and top alternatives, skipping None or non-finite values
|
74 |
tokens = []
|
75 |
logprobs = []
|
76 |
+
top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
|
77 |
for entry in content:
|
78 |
logprob = ensure_float(entry.get("logprob", None))
|
79 |
if logprob is not None and math.isfinite(logprob):
|
80 |
tokens.append(entry["token"])
|
81 |
logprobs.append(logprob)
|
82 |
+
# Get top_logprobs, default to empty dict if None
|
83 |
+
top_probs = entry.get("top_logprobs", {})
|
84 |
+
# Ensure all values in top_logprobs are floats
|
85 |
+
finite_top_probs = {}
|
86 |
+
for key, value in top_probs.items():
|
87 |
+
float_value = ensure_float(value)
|
88 |
+
if float_value is not None and math.isfinite(float_value):
|
89 |
+
finite_top_probs[key] = float_value
|
90 |
+
# Get the top 3 log probs (including the selected token)
|
91 |
+
all_probs = {entry["token"]: logprob} # Add the selected token's logprob
|
92 |
+
all_probs.update(finite_top_probs) # Add alternatives
|
93 |
+
sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
|
94 |
+
top_3 = sorted_probs[:3] # Top 3 log probs (highest to lowest)
|
95 |
+
top_alternatives.append(top_3)
|
96 |
else:
|
97 |
logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
|
98 |
|
99 |
+
# Create the plot with hover functionality
|
100 |
+
if logprobs:
|
101 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
102 |
+
scatter = ax.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Selected Token")[0]
|
103 |
+
ax.set_title("Log Probabilities of Generated Tokens")
|
104 |
+
ax.set_xlabel("Token Position")
|
105 |
+
ax.set_ylabel("Log Probability")
|
106 |
+
ax.grid(True)
|
107 |
+
ax.set_xticks([]) # Hide X-axis labels by default
|
108 |
+
|
109 |
+
# Add hover functionality using Matplotlib's Cursor for tooltips
|
110 |
+
cursor = Cursor(ax, useblit=True, color='red', linewidth=1)
|
111 |
+
token_annotations = []
|
112 |
+
for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
|
113 |
+
annotation = ax.annotate('', (x, y), xytext=(10, 10), textcoords='offset points', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), visible=False)
|
114 |
+
token_annotations.append(annotation)
|
115 |
+
|
116 |
+
def on_hover(event):
|
117 |
+
if event.inaxes == ax:
|
118 |
+
for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
|
119 |
+
contains, _ = scatter.contains(event)
|
120 |
+
if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5:
|
121 |
+
token_annotations[i].set_text(tokens[i])
|
122 |
+
token_annotations[i].set_visible(True)
|
123 |
+
fig.canvas.draw_idle()
|
124 |
+
else:
|
125 |
+
token_annotations[i].set_visible(False)
|
126 |
+
fig.canvas.draw_idle()
|
127 |
+
|
128 |
+
fig.canvas.mpl_connect('motion_notify_event', on_hover)
|
129 |
+
|
130 |
+
# Save plot to a bytes buffer
|
131 |
+
buf = io.BytesIO()
|
132 |
+
plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
|
133 |
+
buf.seek(0)
|
134 |
+
plt.close()
|
135 |
+
|
136 |
+
# Convert to base64 for Gradio
|
137 |
+
img_bytes = buf.getvalue()
|
138 |
+
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
|
139 |
+
img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
|
140 |
+
else:
|
141 |
+
img_html = "No finite log probabilities to plot."
|
142 |
+
|
143 |
+
# Create DataFrame for the table
|
144 |
table_data = []
|
145 |
+
for i, entry in enumerate(content):
|
146 |
logprob = ensure_float(entry.get("logprob", None))
|
147 |
+
if logprob is not None and math.isfinite(logprob) and "top_logprobs" in entry and entry["top_logprobs"] is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
token = entry["token"]
|
|
|
149 |
top_logprobs = entry["top_logprobs"]
|
150 |
# Ensure all values in top_logprobs are floats
|
151 |
finite_top_logprobs = {}
|
|
|
153 |
float_value = ensure_float(value)
|
154 |
if float_value is not None and math.isfinite(float_value):
|
155 |
finite_top_logprobs[key] = float_value
|
|
|
156 |
# Extract top 3 alternatives from top_logprobs
|
157 |
+
top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
|
|
|
|
|
158 |
row = [token, f"{logprob:.4f}"]
|
159 |
for alt_token, alt_logprob in top_3:
|
160 |
row.append(f"{alt_token}: {alt_logprob:.4f}")
|
|
|
161 |
while len(row) < 5:
|
162 |
row.append("")
|
163 |
table_data.append(row)
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
df = (
|
166 |
pd.DataFrame(
|
167 |
table_data,
|
|
|
201 |
else:
|
202 |
colored_text_html = "No finite log probabilities to display."
|
203 |
|
204 |
+
# Create an alternative visualization for top 3 tokens
|
205 |
+
alt_viz_html = ""
|
206 |
+
if logprobs and top_alternatives:
|
207 |
+
alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>"
|
208 |
+
for i, (token, probs) in enumerate(zip(tokens, top_alternatives)):
|
209 |
+
alt_viz_html += f"<li>Position {i} (Token: {token}):<br>"
|
210 |
+
for tok, prob in probs:
|
211 |
+
alt_viz_html += f"{tok}: {prob:.4f}<br>"
|
212 |
+
alt_viz_html += "</li>"
|
213 |
+
alt_viz_html += "</ul>"
|
214 |
+
|
215 |
+
return img_html, df, colored_text_html, alt_viz_html
|
216 |
|
217 |
except Exception as e:
|
218 |
logger.error("Visualization failed: %s", str(e))
|
219 |
+
return f"Error: {str(e)}", None, None, None
|
220 |
|
221 |
# Gradio interface
|
222 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
|
|
231 |
placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
|
232 |
)
|
233 |
|
234 |
+
plot_output = gr.HTML(label="Log Probability Plot (Hover for Tokens)")
|
235 |
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
|
236 |
text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
|
237 |
+
alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities")
|
238 |
|
239 |
btn = gr.Button("Visualize")
|
240 |
btn.click(
|
241 |
fn=visualize_logprobs,
|
242 |
inputs=json_input,
|
243 |
+
outputs=[plot_output, table_output, text_output, alt_viz_output],
|
244 |
)
|
245 |
|
246 |
app.launch()
|