import gradio as gr import json import matplotlib.pyplot as plt import pandas as pd import io import base64 import ast # For safely evaluating Python literals # Function to process and visualize log probs def visualize_logprobs(json_input): try: # Try to parse as JSON first try: data = json.loads(json_input) except json.JSONDecodeError: # If JSON fails, try to parse as Python literal (e.g., with single quotes) data = ast.literal_eval(json_input) # Extract tokens and log probs, skipping None values tokens = [entry['token'] for entry in data['content'] if entry['logprob'] is not None] logprobs = [entry['logprob'] for entry in data['content'] if entry['logprob'] is not None] # Prepare data for the table table_data = [] for entry in data['content']: if entry['logprob'] is not None: token = entry['token'] logprob = entry['logprob'] top_logprobs = entry['top_logprobs'] # Extract top 3 alternatives, sorted by log prob (most probable first) top_3 = sorted(top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3] row = [token, f"{logprob:.4f}"] for alt_token, alt_logprob in top_3: row.append(f"{alt_token}: {alt_logprob:.4f}") # Pad with empty strings if fewer than 3 alternatives while len(row) < 5: row.append("") table_data.append(row) # Create the plot plt.figure(figsize=(10, 5)) plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b') plt.title("Log Probabilities of Generated Tokens") plt.xlabel("Token Position") plt.ylabel("Log Probability") plt.grid(True) plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right') plt.tight_layout() # Save plot to a bytes buffer buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) plt.close() # Convert buffer to base64 for Gradio img_bytes = buf.getvalue() img_base64 = base64.b64encode(img_bytes).decode('utf-8') img_html = f'' # Create a DataFrame for the table df = pd.DataFrame( table_data, columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"] ) return img_html, df except Exception as e: return f"Error: {str(e)}", None # Gradio interface with gr.Blocks(title="Log Probability Visualizer") as app: gr.Markdown("# Log Probability Visualizer") gr.Markdown("Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities.") # Input json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON or Python dict here...") # Outputs plot_output = gr.HTML(label="Log Probability Plot") table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives") # Button to trigger visualization btn = gr.Button("Visualize") btn.click( fn=visualize_logprobs, inputs=json_input, outputs=[plot_output, table_output] ) # Launch the app app.launch()