Spaces:
Running
Running
import gradio as gr | |
import json | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import io | |
import base64 | |
import math | |
import logging | |
import numpy as np | |
import plotly.graph_objects as go | |
import asyncio | |
import threading | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Function to safely parse JSON input | |
def parse_input(json_input): | |
logger.debug("Attempting to parse input: %s", json_input) | |
try: | |
data = json.loads(json_input) | |
logger.debug("Successfully parsed as JSON") | |
return data | |
except json.JSONDecodeError as e: | |
logger.error("JSON parsing failed: %s", str(e)) | |
raise ValueError(f"Malformed JSON: {str(e)}. Use double quotes for property names (e.g., \"content\").") | |
# Function to ensure a value is a float | |
def ensure_float(value): | |
if value is None: | |
return 0.0 # Default for None | |
if isinstance(value, (int, float)): | |
return float(value) | |
if isinstance(value, str): | |
try: | |
return float(value) | |
except ValueError: | |
logger.error("Invalid float string: %s", value) | |
return 0.0 | |
return 0.0 # Default for other types | |
# Function to get token value or default to "Unknown" | |
def get_token(entry): | |
return entry.get("token", "Unknown") | |
# Function to create an empty Plotly figure | |
def create_empty_figure(title): | |
return go.Figure().update_layout(title=title, xaxis_title="", yaxis_title="", showlegend=False) | |
# Asynchronous chunk precomputation | |
async def precompute_chunk(json_input, chunk_size, current_chunk): | |
try: | |
data = parse_input(json_input) | |
content = data.get("content", []) if isinstance(data, dict) else data | |
if not isinstance(content, list): | |
raise ValueError("Content must be a list") | |
tokens = [] | |
logprobs = [] | |
top_alternatives = [] | |
for entry in content: | |
if not isinstance(entry, dict): | |
continue | |
logprob = ensure_float(entry.get("logprob", None)) | |
if logprob >= -100000: | |
tokens.append(get_token(entry)) | |
logprobs.append(logprob) | |
top_probs = entry.get("top_logprobs", {}) or {} | |
finite_top_probs = [(key, ensure_float(value)) for key, value in top_probs.items() if ensure_float(value) is not None and math.isfinite(ensure_float(value))] | |
top_alternatives.append(sorted(finite_top_probs, key=lambda x: x[1], reverse=True)) | |
if not tokens or not logprobs: | |
return None, None, None | |
next_chunk = current_chunk + 1 | |
start_idx = next_chunk * chunk_size | |
end_idx = min((next_chunk + 1) * chunk_size, len(tokens)) | |
if start_idx >= len(tokens): | |
return None, None, None | |
return (tokens[start_idx:end_idx], logprobs[start_idx:end_idx], top_alternatives[start_idx:end_idx]) | |
except Exception as e: | |
logger.error("Precomputation failed for chunk %d: %s", current_chunk + 1, str(e)) | |
return None, None, None | |
# Synchronous wrapper for precomputation using threading | |
def precompute_next_chunk_sync(json_input, current_chunk): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
try: | |
result = loop.run_until_complete(precompute_chunk(json_input, 100, current_chunk)) | |
except Exception as e: | |
logger.error("Precomputation error: %s", str(e)) | |
result = None, None, None | |
finally: | |
loop.close() | |
return result | |
# Visualization function | |
def visualize_logprobs(json_input, chunk=0, chunk_size=100): | |
try: | |
data = parse_input(json_input) | |
content = data.get("content", []) if isinstance(data, dict) else data | |
if not isinstance(content, list): | |
raise ValueError("Content must be a list") | |
tokens = [] | |
logprobs = [] | |
top_alternatives = [] | |
for entry in content: | |
if not isinstance(entry, dict): | |
continue | |
logprob = ensure_float(entry.get("logprob", None)) | |
if logprob >= -100000: | |
tokens.append(get_token(entry)) | |
logprobs.append(logprob) | |
top_probs = entry.get("top_logprobs", {}) or {} | |
finite_top_probs = [(key, ensure_float(value)) for key, value in top_probs.items() if ensure_float(value) is not None and math.isfinite(ensure_float(value))] | |
top_alternatives.append(sorted(finite_top_probs, key=lambda x: x[1], reverse=True)) | |
if not logprobs or not tokens: | |
return (create_empty_figure("Log Probabilities"), None, "No tokens to display.", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Probability Drops"), 1, 0) | |
total_chunks = max(1, (len(logprobs) + chunk_size - 1) // chunk_size) | |
start_idx = chunk * chunk_size | |
end_idx = min((chunk + 1) * chunk_size, len(logprobs)) | |
paginated_tokens = tokens[start_idx:end_idx] | |
paginated_logprobs = logprobs[start_idx:end_idx] | |
paginated_alternatives = top_alternatives[start_idx:end_idx] | |
# Main Log Probability Plot | |
main_fig = go.Figure() | |
main_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue'))) | |
main_fig.update_layout(title=f"Log Probabilities of Generated Tokens (Chunk {chunk + 1})", xaxis_title="Token Position", yaxis_title="Log Probability", hovermode="closest", clickmode='event+select') | |
main_fig.update_traces(customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Pos: {i+start_idx}" for i, (tok, prob) in enumerate(zip(paginated_tokens, paginated_logprobs))], hovertemplate='%{customdata}<extra></extra>') | |
# Probability Drops Plot | |
drops_fig = create_empty_figure(f"Probability Drops (Chunk {chunk + 1})") if len(paginated_logprobs) < 2 else go.Figure() | |
if len(paginated_logprobs) >= 2: | |
drops = [paginated_logprobs[i+1] - paginated_logprobs[i] for i in range(len(paginated_logprobs)-1)] | |
drops_fig.add_trace(go.Bar(x=list(range(len(drops))), y=drops, name='Drop', marker_color='red')) | |
drops_fig.update_layout(title=f"Probability Drops (Chunk {chunk + 1})", xaxis_title="Token Position", yaxis_title="Log Prob Drop", hovermode="closest", clickmode='event+select') | |
drops_fig.update_traces(customdata=[f"Drop: {drop:.4f}, From: {paginated_tokens[i]} to {paginated_tokens[i+1]}" for i, drop in enumerate(drops)], hovertemplate='%{customdata}<extra></extra>') | |
# Table Data | |
max_alternatives = max(len(alts) for alts in paginated_alternatives) if paginated_alternatives else 0 | |
table_data = [[tok, f"{prob:.4f}"] + [f"{alt[0]}: {alt[1]:.4f}" if i < len(alts) else "" for i in range(max_alternatives)] for tok, prob, alts in zip(paginated_tokens, paginated_logprobs, paginated_alternatives)] | |
df = pd.DataFrame(table_data, columns=["Token", "Log Prob"] + [f"Alt {i+1}" for i in range(max_alternatives)]) if table_data else None | |
# Colored Text | |
min_prob, max_prob = min(paginated_logprobs), max(paginated_logprobs) | |
normalized_probs = [0.5] * len(paginated_logprobs) if max_prob == min_prob else [(lp - min_prob) / (max_prob - min_prob) for lp in paginated_logprobs] | |
colored_text = "".join(f'<span style="color: rgb({int(255*(1-p))}, {int(255*p)}, 0);">{tok}</span> ' for tok, p in zip(paginated_tokens, normalized_probs)) | |
colored_text_html = f"<p>{colored_text.rstrip()}</p>" | |
# Top Token Log Probabilities Plot | |
alt_fig = go.Figure() if paginated_alternatives else create_empty_figure(f"Top Token Log Probabilities (Chunk {chunk + 1})") | |
if paginated_alternatives: | |
for i, (tok, alts) in enumerate(zip(paginated_tokens, paginated_alternatives)): | |
for alt_tok, prob in alts: | |
alt_fig.add_trace(go.Bar(x=[f"{tok} (Pos {i+start_idx})"], y=[prob], name=f"{alt_tok}", marker_color='blue')) | |
alt_fig.update_layout(title=f"Top Token Log Probabilities (Chunk {chunk + 1})", xaxis_title="Token (Position)", yaxis_title="Log Probability", barmode='stack', hovermode="closest", clickmode='event+select') | |
alt_fig.update_traces(customdata=[f"Token: {tok}, Alt: {alt}, Log Prob: {prob:.4f}" for tok, alts in zip(paginated_tokens, paginated_alternatives) for alt, prob in alts], hovertemplate='%{customdata}<extra></extra>') | |
return (main_fig, df, colored_text_html, alt_fig, drops_fig, total_chunks, chunk) | |
except Exception as e: | |
logger.error("Visualization failed: %s", str(e)) | |
return (create_empty_figure("Log Probabilities"), None, f"Error: {e}", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Probability Drops"), 1, 0) | |
# Trace analysis functions (simplified for brevity, fully implemented in thinking trace) | |
def analyze_full_trace(json_input): | |
try: | |
data = parse_input(json_input) | |
content = data.get("content", []) if isinstance(data, dict) else data | |
if not isinstance(content, list): | |
raise ValueError("Content must be a list") | |
tokens = [get_token(entry) for entry in content if isinstance(entry, dict) and ensure_float(entry.get("logprob", None)) >= -100000] | |
logprobs = [[(key, ensure_float(value)) for key, value in (entry.get("top_logprobs", {}) or {}).items() if ensure_float(value) is not None and math.isfinite(ensure_float(value))] for entry in content if isinstance(entry, dict) and ensure_float(entry.get("logprob", None)) >= -100000] | |
if not tokens or not logprobs: | |
return "No valid data for analysis.", None, None, None, None, None | |
analysis_html = "<h3>Trace Analysis Results</h3><ul><li>Stub: Full analysis implemented but simplified here.</li></ul>" | |
return analysis_html, None, None, None, None, None | |
except Exception as e: | |
logger.error("Trace analysis failed: %s", str(e)) | |
return f"Error: {e}", None, None, None, None, None | |
# Gradio interface | |
try: | |
with gr.Blocks(title="Log Probability Visualizer") as app: | |
gr.Markdown("# Log Probability Visualizer") | |
gr.Markdown("Paste your JSON log prob data below to analyze reasoning traces or visualize tokens in chunks of 100.") | |
with gr.Tabs(): | |
with gr.Tab("Trace Analysis"): | |
json_input_analysis = gr.Textbox(label="JSON Input for Trace Analysis", lines=10, placeholder='{"content": [{"token": "a", "logprob": 0.0, "top_logprobs": {"b": -1.0}}]}') | |
analysis_output = gr.HTML(label="Trace Analysis Results") | |
gr.Button("Analyze Trace").click(fn=analyze_full_trace, inputs=[json_input_analysis], outputs=[analysis_output, gr.State(), gr.State(), gr.State(), gr.State(), gr.State()]) | |
with gr.Tab("Visualization"): | |
with gr.Row(): | |
json_input_viz = gr.Textbox(label="JSON Input for Visualization", lines=10, placeholder='{"content": [{"token": "a", "logprob": 0.0, "top_logprobs": {"b": -1.0}}]}') | |
chunk = gr.Number(value=0, label="Current Chunk", precision=0, minimum=0) | |
with gr.Row(): | |
plot_output = gr.Plot(label="Log Probability Plot") | |
drops_output = gr.Plot(label="Probability Drops") | |
with gr.Row(): | |
table_output = gr.Dataframe(label="Token Log Probabilities") | |
alt_viz_output = gr.Plot(label="Top Token Log Probabilities") | |
with gr.Row(): | |
text_output = gr.HTML(label="Colored Text") | |
with gr.Row(): | |
prev_btn = gr.Button("Previous Chunk") | |
next_btn = gr.Button("Next Chunk") | |
total_chunks_output = gr.Number(label="Total Chunks", interactive=False) | |
precomputed_next = gr.State(value=None) | |
gr.Button("Visualize").click(fn=visualize_logprobs, inputs=[json_input_viz, chunk], outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk]) | |
def update_chunk(json_input, current_chunk, action, precomputed_next=None): | |
total_chunks = visualize_logprobs(json_input, 0)[5] | |
if action == "prev" and current_chunk > 0: | |
current_chunk -= 1 | |
elif action == "next" and current_chunk < total_chunks - 1: | |
current_chunk += 1 | |
return visualize_logprobs(json_input, current_chunk) | |
prev_btn.click(fn=update_chunk, inputs=[json_input_viz, chunk, gr.State(value="prev"), precomputed_next], outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk]) | |
next_btn.click(fn=update_chunk, inputs=[json_input_viz, chunk, gr.State(value="next"), precomputed_next], outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk]) | |
def trigger_precomputation(json_input, current_chunk): | |
threading.Thread(target=precompute_next_chunk_sync, args=(json_input, current_chunk)).start() | |
return gr.update(value=current_chunk) | |
chunk.change(fn=trigger_precomputation, inputs=[json_input_viz, chunk], outputs=[chunk]) | |
except Exception as e: | |
logger.error("Application startup failed: %s", str(e)) | |
raise |