codelion's picture
Create app.py
0244d3c verified
raw
history blame
2.97 kB
import gradio as gr
import json
import matplotlib.pyplot as plt
import pandas as pd
from io import StringIO
# Function to process and visualize log probs
def visualize_logprobs(json_input):
try:
# Parse the JSON input
data = json.loads(json_input)
# Extract tokens and log probs, skipping None values
tokens = [entry['token'] for entry in data['content'] if entry['logprob'] is not None]
logprobs = [entry['logprob'] for entry in data['content'] if entry['logprob'] is not None]
# Prepare data for the table
table_data = []
for entry in data['content']:
if entry['logprob'] is not None:
token = entry['token']
logprob = entry['logprob']
top_logprobs = entry['top_logprobs']
# Extract top 3 alternatives
top_3 = sorted(top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
row = [token, f"{logprob:.4f}"]
for alt_token, alt_logprob in top_3:
row.append(f"{alt_token}: {alt_logprob:.4f}")
# Pad with empty strings if fewer than 3 alternatives
while len(row) < 5:
row.append("")
table_data.append(row)
# Create the plot
plt.figure(figsize=(10, 5))
plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b')
plt.title("Log Probabilities of Generated Tokens")
plt.xlabel("Token Position")
plt.ylabel("Log Probability")
plt.grid(True)
plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right')
plt.tight_layout()
# Save plot to a buffer
img_buffer = StringIO()
plt.savefig(img_buffer, format='png', bbox_inches='tight')
img_buffer.seek(0)
plt.close()
# Create a DataFrame for the table
df = pd.DataFrame(
table_data,
columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
)
return img_buffer, df
except Exception as e:
return f"Error: {str(e)}", None
# Gradio interface
with gr.Blocks(title="Log Probability Visualizer") as app:
gr.Markdown("# Log Probability Visualizer")
gr.Markdown("Paste your JSON log prob data below to visualize the tokens and their probabilities.")
# Input
json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON here...")
# Outputs
plot_output = gr.Image(label="Log Probability Plot")
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
# Button to trigger visualization
btn = gr.Button("Visualize")
btn.click(
fn=visualize_logprobs,
inputs=json_input,
outputs=[plot_output, table_output]
)
# Launch the app
app.launch()