Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
|
|
5 |
import io
|
6 |
import base64
|
7 |
import ast
|
|
|
8 |
|
9 |
# Function to process and visualize log probs
|
10 |
def visualize_logprobs(json_input):
|
@@ -16,19 +17,34 @@ def visualize_logprobs(json_input):
|
|
16 |
# If JSON fails, try to parse as Python literal (e.g., with single quotes)
|
17 |
data = ast.literal_eval(json_input)
|
18 |
|
19 |
-
#
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# Prepare data for the table
|
24 |
table_data = []
|
25 |
-
for entry in
|
26 |
-
if entry['logprob'] is not None:
|
27 |
token = entry['token']
|
28 |
logprob = entry['logprob']
|
29 |
-
top_logprobs = entry['top_logprobs']
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
32 |
row = [token, f"{logprob:.4f}"]
|
33 |
for alt_token, alt_logprob in top_3:
|
34 |
row.append(f"{alt_token}: {alt_logprob:.4f}")
|
@@ -37,56 +53,61 @@ def visualize_logprobs(json_input):
|
|
37 |
row.append("")
|
38 |
table_data.append(row)
|
39 |
|
40 |
-
# Create the plot
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
60 |
|
61 |
# Create a DataFrame for the table
|
62 |
df = pd.DataFrame(
|
63 |
table_data,
|
64 |
columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
|
65 |
-
)
|
66 |
|
67 |
# Generate colored text based on log probabilities
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
else:
|
74 |
-
|
75 |
-
|
76 |
-
# Create HTML for colored text
|
77 |
-
colored_text = ""
|
78 |
-
for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
|
79 |
-
# Map normalized probability to RGB color (green for high confidence, red for low)
|
80 |
-
# Use a simple linear interpolation: green (0, 255, 0) to red (255, 0, 0)
|
81 |
-
r = int(255 * (1 - norm_prob)) # Red increases as uncertainty increases
|
82 |
-
g = int(255 * norm_prob) # Green decreases as uncertainty increases
|
83 |
-
b = 0 # Blue stays 0 for simplicity
|
84 |
-
color = f'rgb({r}, {g}, {b})'
|
85 |
-
colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
|
86 |
-
if i < len(tokens) - 1:
|
87 |
-
colored_text += " " # Add space between tokens
|
88 |
-
|
89 |
-
colored_text_html = f'<p>{colored_text}</p>'
|
90 |
|
91 |
return img_html, df, colored_text_html
|
92 |
|
|
|
5 |
import io
|
6 |
import base64
|
7 |
import ast
|
8 |
+
import math
|
9 |
|
10 |
# Function to process and visualize log probs
|
11 |
def visualize_logprobs(json_input):
|
|
|
17 |
# If JSON fails, try to parse as Python literal (e.g., with single quotes)
|
18 |
data = ast.literal_eval(json_input)
|
19 |
|
20 |
+
# Ensure data is a list or dictionary with 'content'
|
21 |
+
if isinstance(data, dict) and 'content' in data:
|
22 |
+
content = data['content']
|
23 |
+
elif isinstance(data, list):
|
24 |
+
content = data
|
25 |
+
else:
|
26 |
+
raise ValueError("Input must be a list or dictionary with 'content' key")
|
27 |
+
|
28 |
+
# Extract tokens and log probs, skipping None values and handling non-finite values
|
29 |
+
tokens = []
|
30 |
+
logprobs = []
|
31 |
+
for entry in content:
|
32 |
+
if entry['logprob'] is not None and math.isfinite(entry['logprob']):
|
33 |
+
tokens.append(entry['token'])
|
34 |
+
logprobs.append(entry['logprob'])
|
35 |
|
36 |
# Prepare data for the table
|
37 |
table_data = []
|
38 |
+
for entry in content:
|
39 |
+
if entry['logprob'] is not None and math.isfinite(entry['logprob']):
|
40 |
token = entry['token']
|
41 |
logprob = entry['logprob']
|
42 |
+
top_logprobs = entry['top_logprobs'] or {}
|
43 |
+
|
44 |
+
# Filter out non-finite (e.g., -inf, inf, nan) log probs from top_logprobs
|
45 |
+
finite_top_logprobs = {k: v for k, v in top_logprobs.items() if math.isfinite(v)}
|
46 |
+
# Extract top 3 finite alternatives, sorted by log prob (most probable first)
|
47 |
+
top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
|
48 |
row = [token, f"{logprob:.4f}"]
|
49 |
for alt_token, alt_logprob in top_3:
|
50 |
row.append(f"{alt_token}: {alt_logprob:.4f}")
|
|
|
53 |
row.append("")
|
54 |
table_data.append(row)
|
55 |
|
56 |
+
# Create the plot (only for finite log probs)
|
57 |
+
if logprobs:
|
58 |
+
plt.figure(figsize=(10, 5))
|
59 |
+
plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b')
|
60 |
+
plt.title("Log Probabilities of Generated Tokens")
|
61 |
+
plt.xlabel("Token Position")
|
62 |
+
plt.ylabel("Log Probability")
|
63 |
+
plt.grid(True)
|
64 |
+
plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right')
|
65 |
+
plt.tight_layout()
|
66 |
+
|
67 |
+
# Save plot to a bytes buffer
|
68 |
+
buf = io.BytesIO()
|
69 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
70 |
+
buf.seek(0)
|
71 |
+
plt.close()
|
72 |
+
|
73 |
+
# Convert buffer to base64 for Gradio
|
74 |
+
img_bytes = buf.getvalue()
|
75 |
+
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
|
76 |
+
img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
|
77 |
+
else:
|
78 |
+
img_html = "No finite log probabilities to plot."
|
79 |
|
80 |
# Create a DataFrame for the table
|
81 |
df = pd.DataFrame(
|
82 |
table_data,
|
83 |
columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
|
84 |
+
) if table_data else None
|
85 |
|
86 |
# Generate colored text based on log probabilities
|
87 |
+
if logprobs:
|
88 |
+
# Normalize log probs to [0, 1] for color scaling (0 = most uncertain, 1 = most confident)
|
89 |
+
min_logprob = min(logprobs)
|
90 |
+
max_logprob = max(logprobs)
|
91 |
+
if max_logprob == min_logprob:
|
92 |
+
normalized_probs = [0.5] * len(logprobs) # Avoid division by zero
|
93 |
+
else:
|
94 |
+
normalized_probs = [(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs]
|
95 |
+
|
96 |
+
# Create HTML for colored text
|
97 |
+
colored_text = ""
|
98 |
+
for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
|
99 |
+
# Map normalized probability to RGB color (green for high confidence, red for low)
|
100 |
+
r = int(255 * (1 - norm_prob)) # Red increases as uncertainty increases
|
101 |
+
g = int(255 * norm_prob) # Green decreases as uncertainty increases
|
102 |
+
b = 0 # Blue stays 0 for simplicity
|
103 |
+
color = f'rgb({r}, {g}, {b})'
|
104 |
+
colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
|
105 |
+
if i < len(tokens) - 1:
|
106 |
+
colored_text += " " # Add space between tokens
|
107 |
+
|
108 |
+
colored_text_html = f'<p>{colored_text}</p>'
|
109 |
else:
|
110 |
+
colored_text_html = "No finite log probabilities to display."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
return img_html, df, colored_text_html
|
113 |
|