import gradio as gr
import json
import matplotlib.pyplot as plt
import pandas as pd
import io
import base64
import math
import logging
import numpy as np
import plotly.graph_objects as go
import asyncio
import threading
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Function to safely parse JSON input
def parse_input(json_input):
logger.debug("Attempting to parse input: %s", json_input)
try:
data = json.loads(json_input)
logger.debug("Successfully parsed as JSON")
return data
except json.JSONDecodeError as e:
logger.error("JSON parsing failed: %s", str(e))
raise ValueError(f"Malformed JSON: {str(e)}. Use double quotes for property names (e.g., \"content\").")
# Function to ensure a value is a float
def ensure_float(value):
if value is None:
return 0.0 # Default for None
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
try:
return float(value)
except ValueError:
logger.error("Invalid float string: %s", value)
return 0.0
return 0.0 # Default for other types
# Function to get token value or default to "Unknown"
def get_token(entry):
return entry.get("token", "Unknown")
# Function to create an empty Plotly figure
def create_empty_figure(title):
return go.Figure().update_layout(title=title, xaxis_title="", yaxis_title="", showlegend=False)
# Asynchronous chunk precomputation
async def precompute_chunk(json_input, chunk_size, current_chunk):
try:
data = parse_input(json_input)
content = data.get("content", []) if isinstance(data, dict) else data
if not isinstance(content, list):
raise ValueError("Content must be a list")
tokens = []
logprobs = []
top_alternatives = []
for entry in content:
if not isinstance(entry, dict):
continue
logprob = ensure_float(entry.get("logprob", None))
if logprob >= -100000:
tokens.append(get_token(entry))
logprobs.append(logprob)
top_probs = entry.get("top_logprobs", {}) or {}
finite_top_probs = [(key, ensure_float(value)) for key, value in top_probs.items() if ensure_float(value) is not None and math.isfinite(ensure_float(value))]
top_alternatives.append(sorted(finite_top_probs, key=lambda x: x[1], reverse=True))
if not tokens or not logprobs:
return None, None, None
next_chunk = current_chunk + 1
start_idx = next_chunk * chunk_size
end_idx = min((next_chunk + 1) * chunk_size, len(tokens))
if start_idx >= len(tokens):
return None, None, None
return (tokens[start_idx:end_idx], logprobs[start_idx:end_idx], top_alternatives[start_idx:end_idx])
except Exception as e:
logger.error("Precomputation failed for chunk %d: %s", current_chunk + 1, str(e))
return None, None, None
# Synchronous wrapper for precomputation using threading
def precompute_next_chunk_sync(json_input, current_chunk):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(precompute_chunk(json_input, 100, current_chunk))
except Exception as e:
logger.error("Precomputation error: %s", str(e))
result = None, None, None
finally:
loop.close()
return result
# Visualization function
def visualize_logprobs(json_input, chunk=0, chunk_size=100):
try:
data = parse_input(json_input)
content = data.get("content", []) if isinstance(data, dict) else data
if not isinstance(content, list):
raise ValueError("Content must be a list")
tokens = []
logprobs = []
top_alternatives = []
for entry in content:
if not isinstance(entry, dict):
continue
logprob = ensure_float(entry.get("logprob", None))
if logprob >= -100000:
tokens.append(get_token(entry))
logprobs.append(logprob)
top_probs = entry.get("top_logprobs", {}) or {}
finite_top_probs = [(key, ensure_float(value)) for key, value in top_probs.items() if ensure_float(value) is not None and math.isfinite(ensure_float(value))]
top_alternatives.append(sorted(finite_top_probs, key=lambda x: x[1], reverse=True))
if not logprobs or not tokens:
return (create_empty_figure("Log Probabilities"), None, "No tokens to display.", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Probability Drops"), 1, 0)
total_chunks = max(1, (len(logprobs) + chunk_size - 1) // chunk_size)
start_idx = chunk * chunk_size
end_idx = min((chunk + 1) * chunk_size, len(logprobs))
paginated_tokens = tokens[start_idx:end_idx]
paginated_logprobs = logprobs[start_idx:end_idx]
paginated_alternatives = top_alternatives[start_idx:end_idx]
# Main Log Probability Plot
main_fig = go.Figure()
main_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
main_fig.update_layout(title=f"Log Probabilities of Generated Tokens (Chunk {chunk + 1})", xaxis_title="Token Position", yaxis_title="Log Probability", hovermode="closest", clickmode='event+select')
main_fig.update_traces(customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Pos: {i+start_idx}" for i, (tok, prob) in enumerate(zip(paginated_tokens, paginated_logprobs))], hovertemplate='%{customdata}
{colored_text.rstrip()}
" # Top Token Log Probabilities Plot alt_fig = go.Figure() if paginated_alternatives else create_empty_figure(f"Top Token Log Probabilities (Chunk {chunk + 1})") if paginated_alternatives: for i, (tok, alts) in enumerate(zip(paginated_tokens, paginated_alternatives)): for alt_tok, prob in alts: alt_fig.add_trace(go.Bar(x=[f"{tok} (Pos {i+start_idx})"], y=[prob], name=f"{alt_tok}", marker_color='blue')) alt_fig.update_layout(title=f"Top Token Log Probabilities (Chunk {chunk + 1})", xaxis_title="Token (Position)", yaxis_title="Log Probability", barmode='stack', hovermode="closest", clickmode='event+select') alt_fig.update_traces(customdata=[f"Token: {tok}, Alt: {alt}, Log Prob: {prob:.4f}" for tok, alts in zip(paginated_tokens, paginated_alternatives) for alt, prob in alts], hovertemplate='%{customdata}