Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,91 +4,51 @@ import matplotlib.pyplot as plt
|
|
4 |
import pandas as pd
|
5 |
import io
|
6 |
import base64
|
7 |
-
import ast
|
8 |
import math
|
9 |
|
10 |
-
# Function to safely convert string representations of infinity
|
11 |
-
def parse_infinity(value):
|
12 |
-
if isinstance(value, str):
|
13 |
-
if value.lower() == '-infinity' or value.lower() == '-inf':
|
14 |
-
return float('-inf')
|
15 |
-
elif value.lower() == 'infinity' or value.lower() == 'inf':
|
16 |
-
return float('inf')
|
17 |
-
return value
|
18 |
-
|
19 |
# Function to process and visualize log probs
|
20 |
def visualize_logprobs(json_input):
|
21 |
try:
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
import re
|
27 |
-
return re.sub(r'-inf', '"-Infinity"', re.sub(r'inf', '"Infinity"', s))
|
28 |
-
|
29 |
-
data = json.loads(replace_inf(json_input))
|
30 |
-
# Convert string "Infinity" or "-Infinity" back to float if needed
|
31 |
-
if isinstance(data, dict) and 'content' in data:
|
32 |
-
for entry in data['content']:
|
33 |
-
if 'logprob' in entry:
|
34 |
-
entry['logprob'] = parse_infinity(entry['logprob'])
|
35 |
-
if 'top_logprobs' in entry:
|
36 |
-
entry['top_logprobs'] = {k: parse_infinity(v) for k, v in entry['top_logprobs'].items()}
|
37 |
-
elif isinstance(data, list):
|
38 |
-
for entry in data:
|
39 |
-
if 'logprob' in entry:
|
40 |
-
entry['logprob'] = parse_infinity(entry['logprob'])
|
41 |
-
if 'top_logprobs' in entry:
|
42 |
-
entry['top_logprobs'] = {k: parse_infinity(v) for k, v in entry['top_logprobs'].items()}
|
43 |
-
|
44 |
-
except json.JSONDecodeError:
|
45 |
-
# If JSON fails, try to parse as Python literal (e.g., with single quotes)
|
46 |
-
try:
|
47 |
-
data = ast.literal_eval(json_input)
|
48 |
-
# Ensure -inf is handled as float('-inf')
|
49 |
-
if isinstance(data, dict) and 'content' in data:
|
50 |
-
for entry in data['content']:
|
51 |
-
if 'logprob' in entry and isinstance(entry['logprob'], str):
|
52 |
-
entry['logprob'] = parse_infinity(entry['logprob'])
|
53 |
-
if 'top_logprobs' in entry:
|
54 |
-
entry['top_logprobs'] = {k: parse_infinity(v) for k, v in entry['top_logprobs'].items()}
|
55 |
-
elif isinstance(data, list):
|
56 |
-
for entry in data:
|
57 |
-
if 'logprob' in entry and isinstance(entry['logprob'], str):
|
58 |
-
entry['logprob'] = parse_infinity(entry['logprob'])
|
59 |
-
if 'top_logprobs' in entry:
|
60 |
-
entry['top_logprobs'] = {k: parse_infinity(v) for k, v in entry['top_logprobs'].items()}
|
61 |
-
except (SyntaxError, ValueError) as e:
|
62 |
-
raise ValueError(f"Malformed input: {str(e)}")
|
63 |
-
|
64 |
-
# Ensure data is a list or dictionary with 'content'
|
65 |
-
if isinstance(data, dict) and 'content' in data:
|
66 |
-
content = data['content']
|
67 |
elif isinstance(data, list):
|
68 |
content = data
|
69 |
else:
|
70 |
raise ValueError("Input must be a list or dictionary with 'content' key")
|
71 |
-
|
72 |
-
# Extract tokens and log probs, skipping None
|
73 |
tokens = []
|
74 |
logprobs = []
|
75 |
for entry in content:
|
76 |
-
if
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
table_data = []
|
82 |
for entry in content:
|
83 |
-
|
84 |
-
|
85 |
-
logprob
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
row = [token, f"{logprob:.4f}"]
|
93 |
for alt_token, alt_logprob in top_3:
|
94 |
row.append(f"{alt_token}: {alt_logprob:.4f}")
|
@@ -96,88 +56,98 @@ def visualize_logprobs(json_input):
|
|
96 |
while len(row) < 5:
|
97 |
row.append("")
|
98 |
table_data.append(row)
|
99 |
-
|
100 |
-
# Create the plot
|
101 |
if logprobs:
|
102 |
plt.figure(figsize=(10, 5))
|
103 |
-
plt.plot(range(len(logprobs)), logprobs, marker=
|
104 |
plt.title("Log Probabilities of Generated Tokens")
|
105 |
plt.xlabel("Token Position")
|
106 |
plt.ylabel("Log Probability")
|
107 |
plt.grid(True)
|
108 |
-
plt.xticks(range(len(logprobs)), tokens, rotation=45, ha=
|
109 |
plt.tight_layout()
|
110 |
-
|
111 |
# Save plot to a bytes buffer
|
112 |
buf = io.BytesIO()
|
113 |
-
plt.savefig(buf, format=
|
114 |
buf.seek(0)
|
115 |
plt.close()
|
116 |
-
|
117 |
-
# Convert
|
118 |
img_bytes = buf.getvalue()
|
119 |
-
img_base64 = base64.b64encode(img_bytes).decode(
|
120 |
img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
|
121 |
else:
|
122 |
img_html = "No finite log probabilities to plot."
|
123 |
-
|
124 |
-
# Create
|
125 |
-
df =
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
if logprobs:
|
132 |
-
# Normalize log probs to [0, 1] for color scaling (0 = most uncertain, 1 = most confident)
|
133 |
min_logprob = min(logprobs)
|
134 |
max_logprob = max(logprobs)
|
135 |
if max_logprob == min_logprob:
|
136 |
-
normalized_probs = [0.5] * len(logprobs)
|
137 |
else:
|
138 |
-
normalized_probs = [
|
139 |
-
|
140 |
-
|
|
|
141 |
colored_text = ""
|
142 |
for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
color = f'rgb({r}, {g}, {b})'
|
148 |
colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
|
149 |
if i < len(tokens) - 1:
|
150 |
-
colored_text += " "
|
151 |
-
|
152 |
-
colored_text_html = f'<p>{colored_text}</p>'
|
153 |
else:
|
154 |
colored_text_html = "No finite log probabilities to display."
|
155 |
-
|
156 |
return img_html, df, colored_text_html
|
157 |
-
|
158 |
except Exception as e:
|
159 |
return f"Error: {str(e)}", None, None
|
160 |
|
161 |
# Gradio interface
|
162 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
163 |
gr.Markdown("# Log Probability Visualizer")
|
164 |
-
gr.Markdown(
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
170 |
plot_output = gr.HTML(label="Log Probability Plot")
|
171 |
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
|
172 |
text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
|
173 |
-
|
174 |
-
# Button to trigger visualization
|
175 |
btn = gr.Button("Visualize")
|
176 |
btn.click(
|
177 |
fn=visualize_logprobs,
|
178 |
inputs=json_input,
|
179 |
-
outputs=[plot_output, table_output, text_output]
|
180 |
)
|
181 |
|
182 |
-
# Launch the app
|
183 |
app.launch()
|
|
|
4 |
import pandas as pd
|
5 |
import io
|
6 |
import base64
|
|
|
7 |
import math
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
# Function to process and visualize log probs
|
10 |
def visualize_logprobs(json_input):
|
11 |
try:
|
12 |
+
# Parse the JSON input
|
13 |
+
data = json.loads(json_input)
|
14 |
+
if isinstance(data, dict) and "content" in data:
|
15 |
+
content = data["content"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
elif isinstance(data, list):
|
17 |
content = data
|
18 |
else:
|
19 |
raise ValueError("Input must be a list or dictionary with 'content' key")
|
20 |
+
|
21 |
+
# Extract tokens and log probs, skipping None or non-finite values
|
22 |
tokens = []
|
23 |
logprobs = []
|
24 |
for entry in content:
|
25 |
+
if (
|
26 |
+
"logprob" in entry
|
27 |
+
and entry["logprob"] is not None
|
28 |
+
and math.isfinite(entry["logprob"])
|
29 |
+
):
|
30 |
+
tokens.append(entry["token"])
|
31 |
+
logprobs.append(entry["logprob"])
|
32 |
+
|
33 |
+
# Prepare table data, handling None in top_logprobs
|
34 |
table_data = []
|
35 |
for entry in content:
|
36 |
+
# Only include entries with finite logprob and non-None top_logprobs
|
37 |
+
if (
|
38 |
+
"logprob" in entry
|
39 |
+
and entry["logprob"] is not None
|
40 |
+
and math.isfinite(entry["logprob"])
|
41 |
+
and "top_logprobs" in entry
|
42 |
+
and entry["top_logprobs"] is not None
|
43 |
+
):
|
44 |
+
token = entry["token"]
|
45 |
+
logprob = entry["logprob"]
|
46 |
+
top_logprobs = entry["top_logprobs"]
|
47 |
+
|
48 |
+
# Extract top 3 alternatives from top_logprobs
|
49 |
+
top_3 = sorted(
|
50 |
+
top_logprobs.items(), key=lambda x: x[1], reverse=True
|
51 |
+
)[:3]
|
52 |
row = [token, f"{logprob:.4f}"]
|
53 |
for alt_token, alt_logprob in top_3:
|
54 |
row.append(f"{alt_token}: {alt_logprob:.4f}")
|
|
|
56 |
while len(row) < 5:
|
57 |
row.append("")
|
58 |
table_data.append(row)
|
59 |
+
|
60 |
+
# Create the plot
|
61 |
if logprobs:
|
62 |
plt.figure(figsize=(10, 5))
|
63 |
+
plt.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b")
|
64 |
plt.title("Log Probabilities of Generated Tokens")
|
65 |
plt.xlabel("Token Position")
|
66 |
plt.ylabel("Log Probability")
|
67 |
plt.grid(True)
|
68 |
+
plt.xticks(range(len(logprobs)), tokens, rotation=45, ha="right")
|
69 |
plt.tight_layout()
|
70 |
+
|
71 |
# Save plot to a bytes buffer
|
72 |
buf = io.BytesIO()
|
73 |
+
plt.savefig(buf, format="png", bbox_inches="tight")
|
74 |
buf.seek(0)
|
75 |
plt.close()
|
76 |
+
|
77 |
+
# Convert to base64 for Gradio
|
78 |
img_bytes = buf.getvalue()
|
79 |
+
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
|
80 |
img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
|
81 |
else:
|
82 |
img_html = "No finite log probabilities to plot."
|
83 |
+
|
84 |
+
# Create DataFrame for the table
|
85 |
+
df = (
|
86 |
+
pd.DataFrame(
|
87 |
+
table_data,
|
88 |
+
columns=[
|
89 |
+
"Token",
|
90 |
+
"Log Prob",
|
91 |
+
"Top 1 Alternative",
|
92 |
+
"Top 2 Alternative",
|
93 |
+
"Top 3 Alternative",
|
94 |
+
],
|
95 |
+
)
|
96 |
+
if table_data
|
97 |
+
else None
|
98 |
+
)
|
99 |
+
|
100 |
+
# Generate colored text
|
101 |
if logprobs:
|
|
|
102 |
min_logprob = min(logprobs)
|
103 |
max_logprob = max(logprobs)
|
104 |
if max_logprob == min_logprob:
|
105 |
+
normalized_probs = [0.5] * len(logprobs)
|
106 |
else:
|
107 |
+
normalized_probs = [
|
108 |
+
(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs
|
109 |
+
]
|
110 |
+
|
111 |
colored_text = ""
|
112 |
for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
|
113 |
+
r = int(255 * (1 - norm_prob)) # Red for low confidence
|
114 |
+
g = int(255 * norm_prob) # Green for high confidence
|
115 |
+
b = 0
|
116 |
+
color = f"rgb({r}, {g}, {b})"
|
|
|
117 |
colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
|
118 |
if i < len(tokens) - 1:
|
119 |
+
colored_text += " "
|
120 |
+
colored_text_html = f"<p>{colored_text}</p>"
|
|
|
121 |
else:
|
122 |
colored_text_html = "No finite log probabilities to display."
|
123 |
+
|
124 |
return img_html, df, colored_text_html
|
125 |
+
|
126 |
except Exception as e:
|
127 |
return f"Error: {str(e)}", None, None
|
128 |
|
129 |
# Gradio interface
|
130 |
with gr.Blocks(title="Log Probability Visualizer") as app:
|
131 |
gr.Markdown("# Log Probability Visualizer")
|
132 |
+
gr.Markdown(
|
133 |
+
"Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities."
|
134 |
+
)
|
135 |
+
|
136 |
+
json_input = gr.Textbox(
|
137 |
+
label="JSON Input",
|
138 |
+
lines=10,
|
139 |
+
placeholder="Paste your JSON or Python dict here...",
|
140 |
+
)
|
141 |
+
|
142 |
plot_output = gr.HTML(label="Log Probability Plot")
|
143 |
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
|
144 |
text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
|
145 |
+
|
|
|
146 |
btn = gr.Button("Visualize")
|
147 |
btn.click(
|
148 |
fn=visualize_logprobs,
|
149 |
inputs=json_input,
|
150 |
+
outputs=[plot_output, table_output, text_output],
|
151 |
)
|
152 |
|
|
|
153 |
app.launch()
|