import gradio as gr import json import matplotlib.pyplot as plt import pandas as pd import io import base64 import math # Function to process and visualize log probs def visualize_logprobs(json_input): try: # Parse the JSON input data = json.loads(json_input) if isinstance(data, dict) and "content" in data: content = data["content"] elif isinstance(data, list): content = data else: raise ValueError("Input must be a list or dictionary with 'content' key") # Extract tokens and log probs, skipping None or non-finite values tokens = [] logprobs = [] for entry in content: if ( "logprob" in entry and entry["logprob"] is not None and math.isfinite(entry["logprob"]) ): tokens.append(entry["token"]) logprobs.append(entry["logprob"]) # Prepare table data, handling None in top_logprobs table_data = [] for entry in content: # Only include entries with finite logprob and non-None top_logprobs if ( "logprob" in entry and entry["logprob"] is not None and math.isfinite(entry["logprob"]) and "top_logprobs" in entry and entry["top_logprobs"] is not None ): token = entry["token"] logprob = entry["logprob"] top_logprobs = entry["top_logprobs"] # Extract top 3 alternatives from top_logprobs top_3 = sorted( top_logprobs.items(), key=lambda x: x[1], reverse=True )[:3] row = [token, f"{logprob:.4f}"] for alt_token, alt_logprob in top_3: row.append(f"{alt_token}: {alt_logprob:.4f}") # Pad with empty strings if fewer than 3 alternatives while len(row) < 5: row.append("") table_data.append(row) # Create the plot if logprobs: plt.figure(figsize=(10, 5)) plt.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b") plt.title("Log Probabilities of Generated Tokens") plt.xlabel("Token Position") plt.ylabel("Log Probability") plt.grid(True) plt.xticks(range(len(logprobs)), tokens, rotation=45, ha="right") plt.tight_layout() # Save plot to a bytes buffer buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight") buf.seek(0) plt.close() # Convert to base64 for Gradio img_bytes = buf.getvalue() img_base64 = base64.b64encode(img_bytes).decode("utf-8") img_html = f'' else: img_html = "No finite log probabilities to plot." # Create DataFrame for the table df = ( pd.DataFrame( table_data, columns=[ "Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative", ], ) if table_data else None ) # Generate colored text if logprobs: min_logprob = min(logprobs) max_logprob = max(logprobs) if max_logprob == min_logprob: normalized_probs = [0.5] * len(logprobs) else: normalized_probs = [ (lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs ] colored_text = "" for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)): r = int(255 * (1 - norm_prob)) # Red for low confidence g = int(255 * norm_prob) # Green for high confidence b = 0 color = f"rgb({r}, {g}, {b})" colored_text += f'{token}' if i < len(tokens) - 1: colored_text += " " colored_text_html = f"

{colored_text}

" else: colored_text_html = "No finite log probabilities to display." return img_html, df, colored_text_html except Exception as e: return f"Error: {str(e)}", None, None # Gradio interface with gr.Blocks(title="Log Probability Visualizer") as app: gr.Markdown("# Log Probability Visualizer") gr.Markdown( "Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities." ) json_input = gr.Textbox( label="JSON Input", lines=10, placeholder="Paste your JSON or Python dict here...", ) plot_output = gr.HTML(label="Log Probability Plot") table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives") text_output = gr.HTML(label="Colored Text (Confidence Visualization)") btn = gr.Button("Visualize") btn.click( fn=visualize_logprobs, inputs=json_input, outputs=[plot_output, table_output, text_output], ) app.launch()