import gradio as gr import json import matplotlib.pyplot as plt import pandas as pd import io import base64 import ast import math # Function to process and visualize log probs def visualize_logprobs(json_input): try: # Try to parse as JSON first try: data = json.loads(json_input) except json.JSONDecodeError: # If JSON fails, try to parse as Python literal (e.g., with single quotes) data = ast.literal_eval(json_input) # Ensure data is a list or dictionary with 'content' if isinstance(data, dict) and 'content' in data: content = data['content'] elif isinstance(data, list): content = data else: raise ValueError("Input must be a list or dictionary with 'content' key") # Extract tokens and log probs, skipping None values and handling non-finite values tokens = [] logprobs = [] for entry in content: if entry['logprob'] is not None and math.isfinite(entry['logprob']): tokens.append(entry['token']) logprobs.append(entry['logprob']) # Prepare data for the table table_data = [] for entry in content: if entry['logprob'] is not None and math.isfinite(entry['logprob']): token = entry['token'] logprob = entry['logprob'] top_logprobs = entry['top_logprobs'] or {} # Filter out non-finite (e.g., -inf, inf, nan) log probs from top_logprobs finite_top_logprobs = {k: v for k, v in top_logprobs.items() if math.isfinite(v)} # Extract top 3 finite alternatives, sorted by log prob (most probable first) top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3] row = [token, f"{logprob:.4f}"] for alt_token, alt_logprob in top_3: row.append(f"{alt_token}: {alt_logprob:.4f}") # Pad with empty strings if fewer than 3 alternatives while len(row) < 5: row.append("") table_data.append(row) # Create the plot (only for finite log probs) if logprobs: plt.figure(figsize=(10, 5)) plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b') plt.title("Log Probabilities of Generated Tokens") plt.xlabel("Token Position") plt.ylabel("Log Probability") plt.grid(True) plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right') plt.tight_layout() # Save plot to a bytes buffer buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) plt.close() # Convert buffer to base64 for Gradio img_bytes = buf.getvalue() img_base64 = base64.b64encode(img_bytes).decode('utf-8') img_html = f'' else: img_html = "No finite log probabilities to plot." # Create a DataFrame for the table df = pd.DataFrame( table_data, columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"] ) if table_data else None # Generate colored text based on log probabilities if logprobs: # Normalize log probs to [0, 1] for color scaling (0 = most uncertain, 1 = most confident) min_logprob = min(logprobs) max_logprob = max(logprobs) if max_logprob == min_logprob: normalized_probs = [0.5] * len(logprobs) # Avoid division by zero else: normalized_probs = [(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs] # Create HTML for colored text colored_text = "" for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)): # Map normalized probability to RGB color (green for high confidence, red for low) r = int(255 * (1 - norm_prob)) # Red increases as uncertainty increases g = int(255 * norm_prob) # Green decreases as uncertainty increases b = 0 # Blue stays 0 for simplicity color = f'rgb({r}, {g}, {b})' colored_text += f'{token}' if i < len(tokens) - 1: colored_text += " " # Add space between tokens colored_text_html = f'

{colored_text}

' else: colored_text_html = "No finite log probabilities to display." return img_html, df, colored_text_html except Exception as e: return f"Error: {str(e)}", None, None # Gradio interface with gr.Blocks(title="Log Probability Visualizer") as app: gr.Markdown("# Log Probability Visualizer") gr.Markdown("Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities.") # Input json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON or Python dict here...") # Outputs plot_output = gr.HTML(label="Log Probability Plot") table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives") text_output = gr.HTML(label="Colored Text (Confidence Visualization)") # Button to trigger visualization btn = gr.Button("Visualize") btn.click( fn=visualize_logprobs, inputs=json_input, outputs=[plot_output, table_output, text_output] ) # Launch the app app.launch()