import gradio as gr
import json
import matplotlib.pyplot as plt
import pandas as pd
import io
import base64
import math
import ast
import logging
import numpy as np
import plotly.graph_objects as go
import asyncio
import anyio
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Function to safely parse JSON or Python dictionary input
def parse_input(json_input):
logger.debug("Attempting to parse input: %s", json_input)
try:
# Try to parse as JSON first
data = json.loads(json_input)
logger.debug("Successfully parsed as JSON")
return data
except json.JSONDecodeError as e:
logger.error("JSON parsing failed: %s (Input: %s)", str(e), json_input[:100] + "..." if len(json_input) > 100 else json_input)
raise ValueError(f"Malformed input: {str(e)}. Ensure property names are in double quotes (e.g., \"content\") and the format matches JSON (e.g., {{\"content\": [...]}}).")
# Function to ensure a value is a float, converting from string if necessary
def ensure_float(value):
if value is None:
logger.debug("Replacing None logprob with 0.0")
return 0.0 # Default to 0.0 for None to ensure visualization
if isinstance(value, str):
try:
return float(value)
except ValueError:
logger.error("Failed to convert string '%s' to float", value)
return 0.0 # Default to 0.0 for invalid strings
if isinstance(value, (int, float)):
return float(value)
return 0.0 # Default for any other type
# Function to get or generate a token value (default to "Unknown" if missing)
def get_token(entry):
token = entry.get("token", "Unknown")
if token == "Unknown":
logger.warning("Missing 'token' key for entry: %s, using 'Unknown'", entry)
return token
# Function to create an empty Plotly figure
def create_empty_figure(title):
return go.Figure().update_layout(title=title, xaxis_title="", yaxis_title="", showlegend=False)
# Precompute the next chunk asynchronously
async def precompute_chunk(json_input, chunk_size, current_chunk):
try:
data = parse_input(json_input)
content = data.get("content", []) if isinstance(data, dict) else data
if not isinstance(content, list):
raise ValueError("Content must be a list of entries")
tokens = []
logprobs = []
top_alternatives = []
for entry in content:
if not isinstance(entry, dict):
logger.warning("Skipping non-dictionary entry: %s", entry)
continue
logprob = ensure_float(entry.get("logprob", None))
if logprob >= -100000: # Include all entries with default 0.0
tokens.append(get_token(entry))
logprobs.append(logprob)
top_probs = entry.get("top_logprobs", {})
if top_probs is None:
logger.debug("top_logprobs is None for token: %s, using empty dict", get_token(entry))
top_probs = {}
finite_top_probs = []
for key, value in top_probs.items():
float_value = ensure_float(value)
if float_value is not None and math.isfinite(float_value):
finite_top_probs.append((key, float_value))
sorted_probs = sorted(finite_top_probs, key=lambda x: x[1], reverse=True)
top_alternatives.append(sorted_probs)
if not tokens or not logprobs:
return None, None, None
next_chunk = current_chunk + 1
start_idx = next_chunk * chunk_size
end_idx = min((next_chunk + 1) * chunk_size, len(tokens))
if start_idx >= len(tokens):
return None, None, None
paginated_tokens = tokens[start_idx:end_idx]
paginated_logprobs = logprobs[start_idx:end_idx]
paginated_alternatives = top_alternatives[start_idx:end_idx]
return paginated_tokens, paginated_logprobs, paginated_alternatives
except Exception as e:
logger.error("Precomputation failed for chunk %d: %s", current_chunk + 1, str(e))
return None, None, None
# Function to process and visualize a chunk of log probs with dynamic top_logprobs
def visualize_logprobs(json_input, chunk=0, chunk_size=1000):
try:
# Parse the input (handles JSON only)
data = parse_input(json_input)
# Ensure data is a dictionary with 'content' key containing a list
if isinstance(data, dict) and "content" in data:
content = data["content"]
if not isinstance(content, list):
raise ValueError("Content must be a list of entries")
elif isinstance(data, list):
content = data # Handle direct list input (though only JSON is expected)
else:
raise ValueError("Input must be a dictionary with 'content' key or a list of entries")
# Extract tokens, log probs, and top alternatives, skipping non-finite values with fixed filter of -100000
tokens = []
logprobs = []
top_alternatives = [] # List to store all top_logprobs (dynamic length)
for entry in content:
if not isinstance(entry, dict):
logger.warning("Skipping non-dictionary entry: %s", entry)
continue
logprob = ensure_float(entry.get("logprob", None))
if logprob >= -100000: # Include all entries with default 0.0
tokens.append(get_token(entry))
logprobs.append(logprob)
# Get top_logprobs, default to empty dict if None
top_probs = entry.get("top_logprobs", {})
if top_probs is None:
logger.debug("top_logprobs is None for token: %s, using empty dict", get_token(entry))
top_probs = {} # Default to empty dict for None
# Ensure all values in top_logprobs are floats and create a list of tuples
finite_top_probs = []
for key, value in top_probs.items():
float_value = ensure_float(value)
if float_value is not None and math.isfinite(float_value):
finite_top_probs.append((key, float_value))
# Sort by log probability (descending) to get all alternatives
sorted_probs = sorted(finite_top_probs, key=lambda x: x[1], reverse=True)
top_alternatives.append(sorted_probs) # Store all alternatives, dynamic length
else:
logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
# Check if there's valid data after filtering
if not logprobs or not tokens:
return (create_empty_figure("Log Probabilities of Generated Tokens"), None, "No tokens to display.", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Significant Probability Drops"), 1, 0)
# Paginate data for chunks of 1,000 tokens
total_chunks = max(1, (len(logprobs) + chunk_size - 1) // chunk_size)
start_idx = chunk * chunk_size
end_idx = min((chunk + 1) * chunk_size, len(logprobs))
paginated_tokens = tokens[start_idx:end_idx]
paginated_logprobs = logprobs[start_idx:end_idx]
paginated_alternatives = top_alternatives[start_idx:end_idx] if top_alternatives else []
# 1. Main Log Probability Plot (Interactive Plotly)
main_fig = go.Figure()
main_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
main_fig.update_layout(
title="Log Probabilities of Generated Tokens (Chunk %d)" % (chunk + 1),
xaxis_title="Token Position (within chunk)",
yaxis_title="Log Probability",
hovermode="closest",
clickmode='event+select'
)
main_fig.update_traces(
customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, prob) in enumerate(zip(paginated_tokens, paginated_logprobs))],
hovertemplate='%{customdata}
{colored_text}
" else: colored_text_html = "No tokens to display in this chunk." # Top Token Log Probabilities (Interactive Plotly, dynamic length, for the current chunk) alt_viz_fig = create_empty_figure("Top Token Log Probabilities (Chunk %d)" % (chunk + 1)) if not paginated_logprobs or not paginated_alternatives else go.Figure() if paginated_logprobs and paginated_alternatives: for i, (token, probs) in enumerate(zip(paginated_tokens, paginated_alternatives)): for j, (alt_tok, prob) in enumerate(probs): alt_viz_fig.add_trace(go.Bar(x=[f"{token} (Pos {i+start_idx})"], y=[prob], name=f"{alt_tok}", marker_color=['blue', 'green', 'red', 'purple', 'orange'][:len(probs)])) alt_viz_fig.update_layout( title="Top Token Log Probabilities (Chunk %d)" % (chunk + 1), xaxis_title="Token (Position)", yaxis_title="Log Probability", barmode='stack', hovermode="closest", clickmode='event+select' ) alt_viz_fig.update_traces( customdata=[f"Token: {tok}, Alt: {alt}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, alts) in enumerate(zip(paginated_tokens, paginated_alternatives)) for alt, prob in alts], hovertemplate='%{customdata}