import gradio as gr import json import matplotlib.pyplot as plt import pandas as pd from io import StringIO # Function to process and visualize log probs def visualize_logprobs(json_input): try: # Parse the JSON input data = json.loads(json_input) # Extract tokens and log probs, skipping None values tokens = [entry['token'] for entry in data['content'] if entry['logprob'] is not None] logprobs = [entry['logprob'] for entry in data['content'] if entry['logprob'] is not None] # Prepare data for the table table_data = [] for entry in data['content']: if entry['logprob'] is not None: token = entry['token'] logprob = entry['logprob'] top_logprobs = entry['top_logprobs'] # Extract top 3 alternatives top_3 = sorted(top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3] row = [token, f"{logprob:.4f}"] for alt_token, alt_logprob in top_3: row.append(f"{alt_token}: {alt_logprob:.4f}") # Pad with empty strings if fewer than 3 alternatives while len(row) < 5: row.append("") table_data.append(row) # Create the plot plt.figure(figsize=(10, 5)) plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b') plt.title("Log Probabilities of Generated Tokens") plt.xlabel("Token Position") plt.ylabel("Log Probability") plt.grid(True) plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right') plt.tight_layout() # Save plot to a buffer img_buffer = StringIO() plt.savefig(img_buffer, format='png', bbox_inches='tight') img_buffer.seek(0) plt.close() # Create a DataFrame for the table df = pd.DataFrame( table_data, columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"] ) return img_buffer, df except Exception as e: return f"Error: {str(e)}", None # Gradio interface with gr.Blocks(title="Log Probability Visualizer") as app: gr.Markdown("# Log Probability Visualizer") gr.Markdown("Paste your JSON log prob data below to visualize the tokens and their probabilities.") # Input json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON here...") # Outputs plot_output = gr.Image(label="Log Probability Plot") table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives") # Button to trigger visualization btn = gr.Button("Visualize") btn.click( fn=visualize_logprobs, inputs=json_input, outputs=[plot_output, table_output] ) # Launch the app app.launch()