import gradio as gr
import json
import matplotlib.pyplot as plt
import pandas as pd
import io
import base64
import ast
import math
# Function to process and visualize log probs
def visualize_logprobs(json_input):
try:
# Try to parse as JSON first
try:
data = json.loads(json_input)
except json.JSONDecodeError:
# If JSON fails, try to parse as Python literal (e.g., with single quotes)
data = ast.literal_eval(json_input)
# Ensure data is a list or dictionary with 'content'
if isinstance(data, dict) and 'content' in data:
content = data['content']
elif isinstance(data, list):
content = data
else:
raise ValueError("Input must be a list or dictionary with 'content' key")
# Extract tokens and log probs, skipping None values and handling non-finite values
tokens = []
logprobs = []
for entry in content:
if entry['logprob'] is not None and math.isfinite(entry['logprob']):
tokens.append(entry['token'])
logprobs.append(entry['logprob'])
# Prepare data for the table
table_data = []
for entry in content:
if entry['logprob'] is not None and math.isfinite(entry['logprob']):
token = entry['token']
logprob = entry['logprob']
top_logprobs = entry['top_logprobs'] or {}
# Filter out non-finite (e.g., -inf, inf, nan) log probs from top_logprobs
finite_top_logprobs = {k: v for k, v in top_logprobs.items() if math.isfinite(v)}
# Extract top 3 finite alternatives, sorted by log prob (most probable first)
top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
row = [token, f"{logprob:.4f}"]
for alt_token, alt_logprob in top_3:
row.append(f"{alt_token}: {alt_logprob:.4f}")
# Pad with empty strings if fewer than 3 alternatives
while len(row) < 5:
row.append("")
table_data.append(row)
# Create the plot (only for finite log probs)
if logprobs:
plt.figure(figsize=(10, 5))
plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b')
plt.title("Log Probabilities of Generated Tokens")
plt.xlabel("Token Position")
plt.ylabel("Log Probability")
plt.grid(True)
plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right')
plt.tight_layout()
# Save plot to a bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
plt.close()
# Convert buffer to base64 for Gradio
img_bytes = buf.getvalue()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
img_html = f''
else:
img_html = "No finite log probabilities to plot."
# Create a DataFrame for the table
df = pd.DataFrame(
table_data,
columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
) if table_data else None
# Generate colored text based on log probabilities
if logprobs:
# Normalize log probs to [0, 1] for color scaling (0 = most uncertain, 1 = most confident)
min_logprob = min(logprobs)
max_logprob = max(logprobs)
if max_logprob == min_logprob:
normalized_probs = [0.5] * len(logprobs) # Avoid division by zero
else:
normalized_probs = [(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs]
# Create HTML for colored text
colored_text = ""
for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
# Map normalized probability to RGB color (green for high confidence, red for low)
r = int(255 * (1 - norm_prob)) # Red increases as uncertainty increases
g = int(255 * norm_prob) # Green decreases as uncertainty increases
b = 0 # Blue stays 0 for simplicity
color = f'rgb({r}, {g}, {b})'
colored_text += f'{token}'
if i < len(tokens) - 1:
colored_text += " " # Add space between tokens
colored_text_html = f'
{colored_text}
' else: colored_text_html = "No finite log probabilities to display." return img_html, df, colored_text_html except Exception as e: return f"Error: {str(e)}", None, None # Gradio interface with gr.Blocks(title="Log Probability Visualizer") as app: gr.Markdown("# Log Probability Visualizer") gr.Markdown("Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities.") # Input json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON or Python dict here...") # Outputs plot_output = gr.HTML(label="Log Probability Plot") table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives") text_output = gr.HTML(label="Colored Text (Confidence Visualization)") # Button to trigger visualization btn = gr.Button("Visualize") btn.click( fn=visualize_logprobs, inputs=json_input, outputs=[plot_output, table_output, text_output] ) # Launch the app app.launch()