codelion's picture
Update app.py
181b7be verified
raw
history blame
5.47 kB
import gradio as gr
import json
import matplotlib.pyplot as plt
import pandas as pd
import io
import base64
import math
# Function to process and visualize log probs
def visualize_logprobs(json_input):
try:
# Parse the JSON input
data = json.loads(json_input)
if isinstance(data, dict) and "content" in data:
content = data["content"]
elif isinstance(data, list):
content = data
else:
raise ValueError("Input must be a list or dictionary with 'content' key")
# Extract tokens and log probs, skipping None or non-finite values
tokens = []
logprobs = []
for entry in content:
if (
"logprob" in entry
and entry["logprob"] is not None
and math.isfinite(entry["logprob"])
):
tokens.append(entry["token"])
logprobs.append(entry["logprob"])
# Prepare table data, handling None in top_logprobs
table_data = []
for entry in content:
# Only include entries with finite logprob and non-None top_logprobs
if (
"logprob" in entry
and entry["logprob"] is not None
and math.isfinite(entry["logprob"])
and "top_logprobs" in entry
and entry["top_logprobs"] is not None
):
token = entry["token"]
logprob = entry["logprob"]
top_logprobs = entry["top_logprobs"]
# Extract top 3 alternatives from top_logprobs
top_3 = sorted(
top_logprobs.items(), key=lambda x: x[1], reverse=True
)[:3]
row = [token, f"{logprob:.4f}"]
for alt_token, alt_logprob in top_3:
row.append(f"{alt_token}: {alt_logprob:.4f}")
# Pad with empty strings if fewer than 3 alternatives
while len(row) < 5:
row.append("")
table_data.append(row)
# Create the plot
if logprobs:
plt.figure(figsize=(10, 5))
plt.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b")
plt.title("Log Probabilities of Generated Tokens")
plt.xlabel("Token Position")
plt.ylabel("Log Probability")
plt.grid(True)
plt.xticks(range(len(logprobs)), tokens, rotation=45, ha="right")
plt.tight_layout()
# Save plot to a bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
plt.close()
# Convert to base64 for Gradio
img_bytes = buf.getvalue()
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
else:
img_html = "No finite log probabilities to plot."
# Create DataFrame for the table
df = (
pd.DataFrame(
table_data,
columns=[
"Token",
"Log Prob",
"Top 1 Alternative",
"Top 2 Alternative",
"Top 3 Alternative",
],
)
if table_data
else None
)
# Generate colored text
if logprobs:
min_logprob = min(logprobs)
max_logprob = max(logprobs)
if max_logprob == min_logprob:
normalized_probs = [0.5] * len(logprobs)
else:
normalized_probs = [
(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs
]
colored_text = ""
for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
r = int(255 * (1 - norm_prob)) # Red for low confidence
g = int(255 * norm_prob) # Green for high confidence
b = 0
color = f"rgb({r}, {g}, {b})"
colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
if i < len(tokens) - 1:
colored_text += " "
colored_text_html = f"<p>{colored_text}</p>"
else:
colored_text_html = "No finite log probabilities to display."
return img_html, df, colored_text_html
except Exception as e:
return f"Error: {str(e)}", None, None
# Gradio interface
with gr.Blocks(title="Log Probability Visualizer") as app:
gr.Markdown("# Log Probability Visualizer")
gr.Markdown(
"Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities."
)
json_input = gr.Textbox(
label="JSON Input",
lines=10,
placeholder="Paste your JSON or Python dict here...",
)
plot_output = gr.HTML(label="Log Probability Plot")
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
btn = gr.Button("Visualize")
btn.click(
fn=visualize_logprobs,
inputs=json_input,
outputs=[plot_output, table_output, text_output],
)
app.launch()