import gradio as gr
import json
import matplotlib.pyplot as plt
import pandas as pd
import io
import base64
import math
import ast
import logging
from matplotlib.widgets import Cursor
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Function to safely parse JSON or Python dictionary input
def parse_input(json_input):
logger.debug("Attempting to parse input: %s", json_input)
try:
# Try to parse as JSON first
data = json.loads(json_input)
logger.debug("Successfully parsed as JSON")
return data
except json.JSONDecodeError as e:
logger.error("JSON parsing failed: %s", str(e))
try:
# If JSON fails, try to parse as Python literal (e.g., with single quotes)
data = ast.literal_eval(json_input)
logger.debug("Successfully parsed as Python literal")
# Convert Python dictionary to JSON-compatible format (replace single quotes with double quotes)
def dict_to_json(obj):
if isinstance(obj, dict):
return {str(k): dict_to_json(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [dict_to_json(item) for item in obj]
else:
return obj
converted_data = dict_to_json(data)
logger.debug("Converted to JSON-compatible format")
return converted_data
except (SyntaxError, ValueError) as e:
logger.error("Python literal parsing failed: %s", str(e))
raise ValueError(f"Malformed input: {str(e)}. Ensure property names are in double quotes (e.g., \"content\") or correct Python dictionary format.")
# Function to ensure a value is a float, converting from string if necessary
def ensure_float(value):
if value is None:
return None
if isinstance(value, str):
try:
return float(value)
except ValueError:
logger.error("Failed to convert string '%s' to float", value)
return None
if isinstance(value, (int, float)):
return float(value)
return None
# Function to process and visualize log probs with hover and alternatives
def visualize_logprobs(json_input):
try:
# Parse the input (handles both JSON and Python dictionaries)
data = parse_input(json_input)
# Ensure data is a list or dictionary with 'content'
if isinstance(data, dict) and "content" in data:
content = data["content"]
elif isinstance(data, list):
content = data
else:
raise ValueError("Input must be a list or dictionary with 'content' key")
# Extract tokens, log probs, and top alternatives, skipping None or non-finite values
tokens = []
logprobs = []
top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
for entry in content:
logprob = ensure_float(entry.get("logprob", None))
if logprob is not None and math.isfinite(logprob):
tokens.append(entry["token"])
logprobs.append(logprob)
# Get top_logprobs, default to empty dict if None
top_probs = entry.get("top_logprobs", {})
# Ensure all values in top_logprobs are floats
finite_top_probs = {}
for key, value in top_probs.items():
float_value = ensure_float(value)
if float_value is not None and math.isfinite(float_value):
finite_top_probs[key] = float_value
# Get the top 3 log probs (including the selected token)
all_probs = {entry["token"]: logprob} # Add the selected token's logprob
all_probs.update(finite_top_probs) # Add alternatives
sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
top_3 = sorted_probs[:3] # Top 3 log probs (highest to lowest)
top_alternatives.append(top_3)
else:
logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
# Create the plot with hover functionality
if logprobs:
fig, ax = plt.subplots(figsize=(10, 5))
scatter = ax.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Selected Token")[0]
ax.set_title("Log Probabilities of Generated Tokens")
ax.set_xlabel("Token Position")
ax.set_ylabel("Log Probability")
ax.grid(True)
ax.set_xticks([]) # Hide X-axis labels by default
# Add hover functionality using Matplotlib's Cursor for tooltips
cursor = Cursor(ax, useblit=True, color='red', linewidth=1)
token_annotations = []
for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
annotation = ax.annotate('', (x, y), xytext=(10, 10), textcoords='offset points', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), visible=False)
token_annotations.append(annotation)
def on_hover(event):
if event.inaxes == ax:
for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)):
contains, _ = scatter.contains(event)
if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5:
token_annotations[i].set_text(tokens[i])
token_annotations[i].set_visible(True)
fig.canvas.draw_idle()
else:
token_annotations[i].set_visible(False)
fig.canvas.draw_idle()
fig.canvas.mpl_connect('motion_notify_event', on_hover)
# Save plot to a bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
buf.seek(0)
plt.close()
# Convert to base64 for Gradio
img_bytes = buf.getvalue()
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
img_html = f''
else:
img_html = "No finite log probabilities to plot."
# Create DataFrame for the table
table_data = []
for i, entry in enumerate(content):
logprob = ensure_float(entry.get("logprob", None))
if logprob is not None and math.isfinite(logprob) and "top_logprobs" in entry and entry["top_logprobs"] is not None:
token = entry["token"]
top_logprobs = entry["top_logprobs"]
# Ensure all values in top_logprobs are floats
finite_top_logprobs = {}
for key, value in top_logprobs.items():
float_value = ensure_float(value)
if float_value is not None and math.isfinite(float_value):
finite_top_logprobs[key] = float_value
# Extract top 3 alternatives from top_logprobs
top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
row = [token, f"{logprob:.4f}"]
for alt_token, alt_logprob in top_3:
row.append(f"{alt_token}: {alt_logprob:.4f}")
while len(row) < 5:
row.append("")
table_data.append(row)
df = (
pd.DataFrame(
table_data,
columns=[
"Token",
"Log Prob",
"Top 1 Alternative",
"Top 2 Alternative",
"Top 3 Alternative",
],
)
if table_data
else None
)
# Generate colored text
if logprobs:
min_logprob = min(logprobs)
max_logprob = max(logprobs)
if max_logprob == min_logprob:
normalized_probs = [0.5] * len(logprobs)
else:
normalized_probs = [
(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs
]
colored_text = ""
for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
r = int(255 * (1 - norm_prob)) # Red for low confidence
g = int(255 * norm_prob) # Green for high confidence
b = 0
color = f"rgb({r}, {g}, {b})"
colored_text += f'{token}'
if i < len(tokens) - 1:
colored_text += " "
colored_text_html = f"
{colored_text}
" else: colored_text_html = "No finite log probabilities to display." # Create an alternative visualization for top 3 tokens alt_viz_html = "" if logprobs and top_alternatives: alt_viz_html = "