Spaces:
Running
Running
import gradio as gr | |
import json | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import io | |
import base64 | |
import ast | |
import math | |
# Function to process and visualize log probs | |
def visualize_logprobs(json_input): | |
try: | |
# Try to parse as JSON first | |
try: | |
data = json.loads(json_input) | |
except json.JSONDecodeError: | |
# If JSON fails, try to parse as Python literal (e.g., with single quotes) | |
data = ast.literal_eval(json_input) | |
# Ensure data is a list or dictionary with 'content' | |
if isinstance(data, dict) and 'content' in data: | |
content = data['content'] | |
elif isinstance(data, list): | |
content = data | |
else: | |
raise ValueError("Input must be a list or dictionary with 'content' key") | |
# Extract tokens and log probs, skipping None values and handling non-finite values | |
tokens = [] | |
logprobs = [] | |
for entry in content: | |
if entry['logprob'] is not None and math.isfinite(entry['logprob']): | |
tokens.append(entry['token']) | |
logprobs.append(entry['logprob']) | |
# Prepare data for the table | |
table_data = [] | |
for entry in content: | |
if entry['logprob'] is not None and math.isfinite(entry['logprob']): | |
token = entry['token'] | |
logprob = entry['logprob'] | |
top_logprobs = entry['top_logprobs'] or {} | |
# Filter out non-finite (e.g., -inf, inf, nan) log probs from top_logprobs | |
finite_top_logprobs = {k: v for k, v in top_logprobs.items() if math.isfinite(v)} | |
# Extract top 3 finite alternatives, sorted by log prob (most probable first) | |
top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3] | |
row = [token, f"{logprob:.4f}"] | |
for alt_token, alt_logprob in top_3: | |
row.append(f"{alt_token}: {alt_logprob:.4f}") | |
# Pad with empty strings if fewer than 3 alternatives | |
while len(row) < 5: | |
row.append("") | |
table_data.append(row) | |
# Create the plot (only for finite log probs) | |
if logprobs: | |
plt.figure(figsize=(10, 5)) | |
plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b') | |
plt.title("Log Probabilities of Generated Tokens") | |
plt.xlabel("Token Position") | |
plt.ylabel("Log Probability") | |
plt.grid(True) | |
plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right') | |
plt.tight_layout() | |
# Save plot to a bytes buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') | |
buf.seek(0) | |
plt.close() | |
# Convert buffer to base64 for Gradio | |
img_bytes = buf.getvalue() | |
img_base64 = base64.b64encode(img_bytes).decode('utf-8') | |
img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">' | |
else: | |
img_html = "No finite log probabilities to plot." | |
# Create a DataFrame for the table | |
df = pd.DataFrame( | |
table_data, | |
columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"] | |
) if table_data else None | |
# Generate colored text based on log probabilities | |
if logprobs: | |
# Normalize log probs to [0, 1] for color scaling (0 = most uncertain, 1 = most confident) | |
min_logprob = min(logprobs) | |
max_logprob = max(logprobs) | |
if max_logprob == min_logprob: | |
normalized_probs = [0.5] * len(logprobs) # Avoid division by zero | |
else: | |
normalized_probs = [(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs] | |
# Create HTML for colored text | |
colored_text = "" | |
for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)): | |
# Map normalized probability to RGB color (green for high confidence, red for low) | |
r = int(255 * (1 - norm_prob)) # Red increases as uncertainty increases | |
g = int(255 * norm_prob) # Green decreases as uncertainty increases | |
b = 0 # Blue stays 0 for simplicity | |
color = f'rgb({r}, {g}, {b})' | |
colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>' | |
if i < len(tokens) - 1: | |
colored_text += " " # Add space between tokens | |
colored_text_html = f'<p>{colored_text}</p>' | |
else: | |
colored_text_html = "No finite log probabilities to display." | |
return img_html, df, colored_text_html | |
except Exception as e: | |
return f"Error: {str(e)}", None, None | |
# Gradio interface | |
with gr.Blocks(title="Log Probability Visualizer") as app: | |
gr.Markdown("# Log Probability Visualizer") | |
gr.Markdown("Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities.") | |
# Input | |
json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON or Python dict here...") | |
# Outputs | |
plot_output = gr.HTML(label="Log Probability Plot") | |
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives") | |
text_output = gr.HTML(label="Colored Text (Confidence Visualization)") | |
# Button to trigger visualization | |
btn = gr.Button("Visualize") | |
btn.click( | |
fn=visualize_logprobs, | |
inputs=json_input, | |
outputs=[plot_output, table_output, text_output] | |
) | |
# Launch the app | |
app.launch() |