Spaces:
Running
Running
import gradio as gr | |
import json | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import io | |
import base64 | |
import math | |
import ast | |
import logging | |
import numpy as np | |
from sklearn.cluster import KMeans | |
from sklearn.decomposition import PCA | |
from sklearn.manifold import TSNE | |
from scipy import stats | |
from scipy.stats import entropy | |
from scipy.signal import correlate | |
import networkx as nx | |
from matplotlib.widgets import Cursor | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Function to safely parse JSON or Python dictionary input | |
def parse_input(json_input): | |
logger.debug("Attempting to parse input: %s", json_input) | |
try: | |
# Try to parse as JSON first | |
data = json.loads(json_input) | |
logger.debug("Successfully parsed as JSON") | |
return data | |
except json.JSONDecodeError as e: | |
logger.error("JSON parsing failed: %s", str(e)) | |
try: | |
# If JSON fails, try to parse as Python literal (e.g., with single quotes) | |
data = ast.literal_eval(json_input) | |
logger.debug("Successfully parsed as Python literal") | |
# Convert Python dictionary to JSON-compatible format (replace single quotes with double quotes) | |
def dict_to_json(obj): | |
if isinstance(obj, dict): | |
return {str(k): dict_to_json(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [dict_to_json(item) for item in obj] | |
else: | |
return obj | |
converted_data = dict_to_json(data) | |
logger.debug("Converted to JSON-compatible format") | |
return converted_data | |
except (SyntaxError, ValueError) as e: | |
logger.error("Python literal parsing failed: %s", str(e)) | |
raise ValueError(f"Malformed input: {str(e)}. Ensure property names are in double quotes (e.g., \"content\") or correct Python dictionary format.") | |
# Function to ensure a value is a float, converting from string if necessary | |
def ensure_float(value): | |
if value is None: | |
return None | |
if isinstance(value, str): | |
try: | |
return float(value) | |
except ValueError: | |
logger.error("Failed to convert string '%s' to float", value) | |
return None | |
if isinstance(value, (int, float)): | |
return float(value) | |
return None | |
# Function to process and visualize log probs with multiple analyses | |
def visualize_logprobs(json_input, prob_filter=-float('inf')): | |
try: | |
# Parse the input (handles both JSON and Python dictionaries) | |
data = parse_input(json_input) | |
# Ensure data is a list or dictionary with 'content' | |
if isinstance(data, dict) and "content" in data: | |
content = data["content"] | |
elif isinstance(data, list): | |
content = data | |
else: | |
raise ValueError("Input must be a list or dictionary with 'content' key") | |
# Extract tokens, log probs, and top alternatives, skipping None or non-finite values | |
tokens = [] | |
logprobs = [] | |
top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives) | |
token_types = [] # Simplified token type categorization | |
for entry in content: | |
logprob = ensure_float(entry.get("logprob", None)) | |
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter: | |
tokens.append(entry["token"]) | |
logprobs.append(logprob) | |
# Categorize token type (simple heuristic) | |
token = entry["token"].lower().strip() | |
if token in ["the", "a", "an"]: token_types.append("article") | |
elif token in ["is", "are", "was", "were"]: token_types.append("verb") | |
elif token in ["top", "so", "need", "figure"]: token_types.append("noun") | |
else: token_types.append("other") | |
# Get top_logprobs, default to empty dict if None | |
top_probs = entry.get("top_logprobs", {}) | |
# Ensure all values in top_logprobs are floats | |
finite_top_probs = {} | |
for key, value in top_probs.items(): | |
float_value = ensure_float(value) | |
if float_value is not None and math.isfinite(float_value): | |
finite_top_probs[key] = float_value | |
# Get the top 3 log probs (including the selected token) | |
all_probs = {entry["token"]: logprob} # Add the selected token's logprob | |
all_probs.update(finite_top_probs) # Add alternatives | |
sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True) | |
top_3 = sorted_probs[:3] # Top 3 log probs (highest to lowest) | |
top_alternatives.append(top_3) | |
else: | |
logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None))) | |
# If no valid data after filtering, return error messages | |
if not logprobs: | |
return "No finite log probabilities to visualize after filtering.", None, None, None, None, None, None, None, None, None, None | |
# 1. Main Log Probability Plot (with click for tokens) | |
fig_main, ax_main = plt.subplots(figsize=(10, 5)) | |
scatter = ax_main.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Selected Token")[0] | |
ax_main.set_title("Log Probabilities of Generated Tokens") | |
ax_main.set_xlabel("Token Position") | |
ax_main.set_ylabel("Log Probability") | |
ax_main.grid(True) | |
ax_main.set_xticks([]) # Hide X-axis labels by default | |
# Add click functionality to show token | |
token_annotations = [] | |
for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)): | |
annotation = ax_main.annotate('', (x, y), xytext=(10, 10), textcoords='offset points', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), visible=False) | |
token_annotations.append(annotation) | |
def on_click(event): | |
if event.inaxes == ax_main: | |
for i, (x, y) in enumerate(zip(range(len(logprobs)), logprobs)): | |
contains, _ = scatter.contains(event) | |
if contains and abs(event.xdata - x) < 0.5 and abs(event.ydata - y) < 0.5: | |
token_annotations[i].set_text(tokens[i]) | |
token_annotations[i].set_visible(True) | |
fig_main.canvas.draw_idle() | |
else: | |
token_annotations[i].set_visible(False) | |
fig_main.canvas.draw_idle() | |
fig_main.canvas.mpl_connect('button_press_event', on_click) | |
# Save main plot | |
buf_main = io.BytesIO() | |
plt.savefig(buf_main, format="png", bbox_inches="tight", dpi=100) | |
buf_main.seek(0) | |
plt.close(fig_main) | |
img_main_bytes = buf_main.getvalue() | |
img_main_base64 = base64.b64encode(img_main_bytes).decode("utf-8") | |
img_main_html = f'<img src="data:image/png;base64,{img_main_base64}" style="max-width: 100%; height: auto;">' | |
# 2. K-Means Clustering of Log Probabilities | |
kmeans = KMeans(n_clusters=3, random_state=42) | |
cluster_labels = kmeans.fit_predict(np.array(logprobs).reshape(-1, 1)) | |
fig_cluster, ax_cluster = plt.subplots(figsize=(10, 5)) | |
scatter = ax_cluster.scatter(range(len(logprobs)), logprobs, c=cluster_labels, cmap='viridis') | |
ax_cluster.set_title("K-Means Clustering of Log Probabilities") | |
ax_cluster.set_xlabel("Token Position") | |
ax_cluster.set_ylabel("Log Probability") | |
ax_cluster.grid(True) | |
plt.colorbar(scatter, ax=ax_cluster, label="Cluster") | |
buf_cluster = io.BytesIO() | |
plt.savefig(buf_cluster, format="png", bbox_inches="tight", dpi=100) | |
buf_cluster.seek(0) | |
plt.close(fig_cluster) | |
img_cluster_bytes = buf_cluster.getvalue() | |
img_cluster_base64 = base64.b64encode(img_cluster_bytes).decode("utf-8") | |
img_cluster_html = f'<img src="data:image/png;base64,{img_cluster_base64}" style="max-width: 100%; height: auto;">' | |
# 3. Probability Drop Analysis | |
drops = [logprobs[i+1] - logprobs[i] if i < len(logprobs)-1 else 0 for i in range(len(logprobs))] | |
fig_drops, ax_drops = plt.subplots(figsize=(10, 5)) | |
ax_drops.bar(range(len(drops)), drops, color='red', alpha=0.5) | |
ax_drops.set_title("Significant Probability Drops") | |
ax_drops.set_xlabel("Token Position") | |
ax_drops.set_ylabel("Log Probability Drop") | |
ax_drops.grid(True) | |
buf_drops = io.BytesIO() | |
plt.savefig(buf_drops, format="png", bbox_inches="tight", dpi=100) | |
buf_drops.seek(0) | |
plt.close(fig_drops) | |
img_drops_bytes = buf_drops.getvalue() | |
img_drops_base64 = base64.b64encode(img_drops_bytes).decode("utf-8") | |
img_drops_html = f'<img src="data:image/png;base64,{img_drops_base64}" style="max-width: 100%; height: auto;">' | |
# 4. N-Gram Analysis (Bigrams for simplicity) | |
bigrams = [(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)] | |
bigram_probs = [logprobs[i] + logprobs[i+1] for i in range(len(tokens)-1)] | |
fig_ngram, ax_ngram = plt.subplots(figsize=(10, 5)) | |
ax_ngram.bar(range(len(bigrams)), bigram_probs, color='green') | |
ax_ngram.set_title("N-Gram (Bigrams) Probability Sum") | |
ax_ngram.set_xlabel("Bigram Position") | |
ax_ngram.set_ylabel("Sum of Log Probabilities") | |
ax_ngram.set_xticks(range(len(bigrams))) | |
ax_ngram.set_xticklabels([f"{b[0]}->{b[1]}" for b in bigrams], rotation=45, ha="right") | |
ax_ngram.grid(True) | |
buf_ngram = io.BytesIO() | |
plt.savefig(buf_ngram, format="png", bbox_inches="tight", dpi=100) | |
buf_ngram.seek(0) | |
plt.close(fig_ngram) | |
img_ngram_bytes = buf_ngram.getvalue() | |
img_ngram_base64 = base64.b64encode(img_ngram_bytes).decode("utf-8") | |
img_ngram_html = f'<img src="data:image/png;base64,{img_ngram_base64}" style="max-width: 100%; height: auto;">' | |
# 5. Markov Chain Modeling (Simple Graph) | |
G = nx.DiGraph() | |
for i in range(len(tokens)-1): | |
G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i]) | |
fig_markov, ax_markov = plt.subplots(figsize=(10, 5)) | |
pos = nx.spring_layout(G) | |
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_markov) | |
ax_markov.set_title("Markov Chain of Token Transitions") | |
buf_markov = io.BytesIO() | |
plt.savefig(buf_markov, format="png", bbox_inches="tight", dpi=100) | |
buf_markov.seek(0) | |
plt.close(fig_markov) | |
img_markov_bytes = buf_markov.getvalue() | |
img_markov_base64 = base64.b64encode(img_markov_bytes).decode("utf-8") | |
img_markov_html = f'<img src="data:image/png;base64,{img_markov_base64}" style="max-width: 100%; height: auto;">' | |
# 6. Anomaly Detection (Outlier Detection with Z-Score) | |
z_scores = np.abs(stats.zscore(logprobs)) | |
outliers = z_scores > 2 # Threshold for outliers | |
fig_anomaly, ax_anomaly = plt.subplots(figsize=(10, 5)) | |
ax_anomaly.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b") | |
ax_anomaly.plot(np.where(outliers)[0], [logprobs[i] for i in np.where(outliers)[0]], "ro", label="Outliers") | |
ax_anomaly.set_title("Log Probabilities with Outliers") | |
ax_anomaly.set_xlabel("Token Position") | |
ax_anomaly.set_ylabel("Log Probability") | |
ax_anomaly.grid(True) | |
ax_anomaly.legend() | |
ax_anomaly.set_xticks([]) # Hide X-axis labels | |
buf_anomaly = io.BytesIO() | |
plt.savefig(buf_anomaly, format="png", bbox_inches="tight", dpi=100) | |
buf_anomaly.seek(0) | |
plt.close(fig_anomaly) | |
img_anomaly_bytes = buf_anomaly.getvalue() | |
img_anomaly_base64 = base64.b64encode(img_anomaly_bytes).decode("utf-8") | |
img_anomaly_html = f'<img src="data:image/png;base64,{img_anomaly_base64}" style="max-width: 100%; height: auto;">' | |
# 7. Autocorrelation | |
autocorr = correlate(logprobs, logprobs, mode='full') | |
autocorr = autocorr[len(autocorr)//2:] / len(logprobs) # Normalize | |
fig_autocorr, ax_autocorr = plt.subplots(figsize=(10, 5)) | |
ax_autocorr.plot(range(len(autocorr)), autocorr, color='purple') | |
ax_autocorr.set_title("Autocorrelation of Log Probabilities") | |
ax_autocorr.set_xlabel("Lag") | |
ax_autocorr.set_ylabel("Autocorrelation") | |
ax_autocorr.grid(True) | |
buf_autocorr = io.BytesIO() | |
plt.savefig(buf_autocorr, format="png", bbox_inches="tight", dpi=100) | |
buf_autocorr.seek(0) | |
plt.close(fig_autocorr) | |
img_autocorr_bytes = buf_autocorr.getvalue() | |
img_autocorr_base64 = base64.b64encode(img_autocorr_bytes).decode("utf-8") | |
img_autocorr_html = f'<img src="data:image/png;base64,{img_autocorr_base64}" style="max-width: 100%; height: auto;">' | |
# 8. Smoothing (Moving Average) | |
window_size = 3 | |
moving_avg = np.convolve(logprobs, np.ones(window_size)/window_size, mode='valid') | |
fig_smoothing, ax_smoothing = plt.subplots(figsize=(10, 5)) | |
ax_smoothing.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Original") | |
ax_smoothing.plot(range(window_size-1, len(logprobs)), moving_avg, color="orange", label="Moving Average") | |
ax_smoothing.set_title("Log Probabilities with Moving Average") | |
ax_smoothing.set_xlabel("Token Position") | |
ax_smoothing.set_ylabel("Log Probability") | |
ax_smoothing.grid(True) | |
ax_smoothing.legend() | |
ax_smoothing.set_xticks([]) # Hide X-axis labels | |
buf_smoothing = io.BytesIO() | |
plt.savefig(buf_smoothing, format="png", bbox_inches="tight", dpi=100) | |
buf_smoothing.seek(0) | |
plt.close(fig_smoothing) | |
img_smoothing_bytes = buf_smoothing.getvalue() | |
img_smoothing_base64 = base64.b64encode(img_smoothing_bytes).decode("utf-8") | |
img_smoothing_html = f'<img src="data:image/png;base64,{img_smoothing_base64}" style="max-width: 100%; height: auto;">' | |
# 9. Uncertainty Propagation (Variance of Top Logprobs) | |
variances = [] | |
for probs in top_alternatives: | |
if len(probs) > 1: | |
values = [p[1] for p in probs] | |
variances.append(np.var(values)) | |
else: | |
variances.append(0) | |
fig_uncertainty, ax_uncertainty = plt.subplots(figsize=(10, 5)) | |
ax_uncertainty.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b", label="Log Prob") | |
ax_uncertainty.fill_between(range(len(logprobs)), [lp - v for lp, v in zip(logprobs, variances)], | |
[lp + v for lp, v in zip(logprobs, variances)], color='gray', alpha=0.3, label="Uncertainty") | |
ax_uncertainty.set_title("Log Probabilities with Uncertainty Propagation") | |
ax_uncertainty.set_xlabel("Token Position") | |
ax_uncertainty.set_ylabel("Log Probability") | |
ax_uncertainty.grid(True) | |
ax_uncertainty.legend() | |
ax_uncertainty.set_xticks([]) # Hide X-axis labels | |
buf_uncertainty = io.BytesIO() | |
plt.savefig(buf_uncertainty, format="png", bbox_inches="tight", dpi=100) | |
buf_uncertainty.seek(0) | |
plt.close(fig_uncertainty) | |
img_uncertainty_bytes = buf_uncertainty.getvalue() | |
img_uncertainty_base64 = base64.b64encode(img_uncertainty_bytes).decode("utf-8") | |
img_uncertainty_html = f'<img src="data:image/png;base64,{img_uncertainty_base64}" style="max-width: 100%; height: auto;">' | |
# 10. Correlation Heatmap | |
corr_matrix = np.corrcoef(logprobs, rowvar=False) | |
fig_corr, ax_corr = plt.subplots(figsize=(10, 5)) | |
im = ax_corr.imshow(corr_matrix, cmap='coolwarm', interpolation='nearest') | |
ax_corr.set_title("Correlation of Log Probabilities Across Positions") | |
ax_corr.set_xlabel("Token Position") | |
ax_corr.set_ylabel("Token Position") | |
plt.colorbar(im, ax=ax_corr, label="Correlation") | |
buf_corr = io.BytesIO() | |
plt.savefig(buf_corr, format="png", bbox_inches="tight", dpi=100) | |
buf_corr.seek(0) | |
plt.close(fig_corr) | |
img_corr_bytes = buf_corr.getvalue() | |
img_corr_base64 = base64.b64encode(img_corr_bytes).decode("utf-8") | |
img_corr_html = f'<img src="data:image/png;base64,{img_corr_base64}" style="max-width: 100%; height: auto;">' | |
# 11. Token Type Correlation | |
type_probs = {t: [] for t in set(token_types)} | |
for t, p in zip(token_types, logprobs): | |
type_probs[t].append(p) | |
fig_type, ax_type = plt.subplots(figsize=(10, 5)) | |
for t in type_probs: | |
ax_type.bar(t, np.mean(type_probs[t]), yerr=np.std(type_probs[t]), capsize=5, label=t) | |
ax_type.set_title("Average Log Probability by Token Type") | |
ax_type.set_xlabel("Token Type") | |
ax_type.set_ylabel("Average Log Probability") | |
ax_type.grid(True) | |
ax_type.legend() | |
buf_type = io.BytesIO() | |
plt.savefig(buf_type, format="png", bbox_inches="tight", dpi=100) | |
buf_type.seek(0) | |
plt.close(fig_type) | |
img_type_bytes = buf_type.getvalue() | |
img_type_base64 = base64.b64encode(img_type_bytes).decode("utf-8") | |
img_type_html = f'<img src="data:image/png;base64,{img_type_base64}" style="max-width: 100%; height: auto;">' | |
# 12. Token Embedding Similarity vs. Probability (Simulated) | |
# Simulate embedding distances (e.g., cosine similarity) as random values for demonstration | |
simulated_embeddings = np.random.rand(len(tokens), 2) # 2D embeddings | |
fig_embed, ax_embed = plt.subplots(figsize=(10, 5)) | |
ax_embed.scatter(simulated_embeddings[:, 0], simulated_embeddings[:, 1], c=logprobs, cmap='viridis') | |
ax_embed.set_title("Token Embedding Similarity vs. Log Probability") | |
ax_embed.set_xlabel("Embedding Dimension 1") | |
ax_embed.set_ylabel("Embedding Dimension 2") | |
plt.colorbar(ax_embed.collections[0], ax=ax_embed, label="Log Probability") | |
buf_embed = io.BytesIO() | |
plt.savefig(buf_embed, format="png", bbox_inches="tight", dpi=100) | |
buf_embed.seek(0) | |
plt.close(fig_embed) | |
img_embed_bytes = buf_embed.getvalue() | |
img_embed_base64 = base64.b64encode(img_embed_bytes).decode("utf-8") | |
img_embed_html = f'<img src="data:image/png;base64,{img_embed_base64}" style="max-width: 100%; height: auto;">' | |
# 13. Bayesian Inference (Simplified as Inferred Probabilities) | |
# Simulate inferred probabilities based on top_logprobs entropy | |
entropies = [entropy([p[1] for p in probs], base=2) for probs in top_alternatives if len(probs) > 1] | |
fig_bayesian, ax_bayesian = plt.subplots(figsize=(10, 5)) | |
ax_bayesian.bar(range(len(entropies)), entropies, color='orange') | |
ax_bayesian.set_title("Bayesian Inferred Uncertainty (Entropy)") | |
ax_bayesian.set_xlabel("Token Position") | |
ax_bayesian.set_ylabel("Entropy") | |
ax_bayesian.grid(True) | |
buf_bayesian = io.BytesIO() | |
plt.savefig(buf_bayesian, format="png", bbox_inches="tight", dpi=100) | |
buf_bayesian.seek(0) | |
plt.close(fig_bayesian) | |
img_bayesian_bytes = buf_bayesian.getvalue() | |
img_bayesian_base64 = base64.b64encode(img_bayesian_bytes).decode("utf-8") | |
img_bayesian_html = f'<img src="data:image/png;base64,{img_bayesian_base64}" style="max-width: 100%; height: auto;">' | |
# 14. Graph-Based Analysis | |
G = nx.DiGraph() | |
for i in range(len(tokens)-1): | |
G.add_edge(tokens[i], tokens[i+1], weight=logprobs[i+1] - logprobs[i]) | |
fig_graph, ax_graph = plt.subplots(figsize=(10, 5)) | |
pos = nx.spring_layout(G) | |
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500, edge_color='gray', width=1, ax=ax_graph) | |
ax_graph.set_title("Graph of Token Transitions") | |
buf_graph = io.BytesIO() | |
plt.savefig(buf_graph, format="png", bbox_inches="tight", dpi=100) | |
buf_graph.seek(0) | |
plt.close(fig_graph) | |
img_graph_bytes = buf_graph.getvalue() | |
img_graph_base64 = base64.b64encode(img_graph_bytes).decode("utf-8") | |
img_graph_html = f'<img src="data:image/png;base64,{img_graph_base64}" style="max-width: 100%; height: auto;">' | |
# 15. Dimensionality Reduction (t-SNE) | |
features = np.array([logprobs + [p[1] for p in alts[:2]] for logprobs, alts in zip([logprobs], top_alternatives)]) | |
tsne = TSNE(n_components=2, random_state=42) | |
tsne_result = tsne.fit_transform(features.T) | |
fig_tsne, ax_tsne = plt.subplots(figsize=(10, 5)) | |
scatter = ax_tsne.scatter(tsne_result[:, 0], tsne_result[:, 1], c=logprobs, cmap='viridis') | |
ax_tsne.set_title("t-SNE of Log Probabilities and Top Alternatives") | |
ax_tsne.set_xlabel("t-SNE Dimension 1") | |
ax_tsne.set_ylabel("t-SNE Dimension 2") | |
plt.colorbar(scatter, ax=ax_tsne, label="Log Probability") | |
buf_tsne = io.BytesIO() | |
plt.savefig(buf_tsne, format="png", bbox_inches="tight", dpi=100) | |
buf_tsne.seek(0) | |
plt.close(fig_tsne) | |
img_tsne_bytes = buf_tsne.getvalue() | |
img_tsne_base64 = base64.b64encode(img_tsne_bytes).decode("utf-8") | |
img_tsne_html = f'<img src="data:image/png;base64,{img_tsne_base64}" style="max-width: 100%; height: auto;">' | |
# 16. Interactive Heatmap | |
fig_heatmap, ax_heatmap = plt.subplots(figsize=(10, 5)) | |
im = ax_heatmap.imshow([logprobs], cmap='viridis', aspect='auto') | |
ax_heatmap.set_title("Interactive Heatmap of Log Probabilities") | |
ax_heatmap.set_xlabel("Token Position") | |
ax_heatmap.set_ylabel("Probability Level") | |
plt.colorbar(im, ax=ax_heatmap, label="Log Probability") | |
buf_heatmap = io.BytesIO() | |
plt.savefig(buf_heatmap, format="png", bbox_inches="tight", dpi=100) | |
buf_heatmap.seek(0) | |
plt.close(fig_heatmap) | |
img_heatmap_bytes = buf_heatmap.getvalue() | |
img_heatmap_base64 = base64.b64encode(img_heatmap_bytes).decode("utf-8") | |
img_heatmap_html = f'<img src="data:image/png;base64,{img_heatmap_base64}" style="max-width: 100%; height: auto;">' | |
# 17. Probability Distribution Plots (Box Plots for Top Logprobs) | |
all_top_probs = [p[1] for alts in top_alternatives for p in alts] | |
fig_dist, ax_dist = plt.subplots(figsize=(10, 5)) | |
ax_dist.boxplot([logprobs] + [p[1] for alts in top_alternatives for p in alts[:2]], labels=["Selected"] + ["Alt1", "Alt2"]) | |
ax_dist.set_title("Probability Distribution of Top Tokens") | |
ax_dist.set_xlabel("Token Type") | |
ax_dist.set_ylabel("Log Probability") | |
ax_dist.grid(True) | |
buf_dist = io.BytesIO() | |
plt.savefig(buf_dist, format="png", bbox_inches="tight", dpi=100) | |
buf_dist.seek(0) | |
plt.close(fig_dist) | |
img_dist_bytes = buf_dist.getvalue() | |
img_dist_base64 = base64.b64encode(img_dist_bytes).decode("utf-8") | |
img_dist_html = f'<img src="data:image/png;base64,{img_dist_base64}" style="max-width: 100%; height: auto;">' | |
# Create DataFrame for the table | |
table_data = [] | |
for i, entry in enumerate(content): | |
logprob = ensure_float(entry.get("logprob", None)) | |
if logprob is not None and math.isfinite(logprob) and logprob >= prob_filter and "top_logprobs" in entry and entry["top_logprobs"] is not None: | |
token = entry["token"] | |
top_logprobs = entry["top_logprobs"] | |
# Ensure all values in top_logprobs are floats | |
finite_top_logprobs = {} | |
for key, value in top_logprobs.items(): | |
float_value = ensure_float(value) | |
if float_value is not None and math.isfinite(float_value): | |
finite_top_logprobs[key] = float_value | |
# Extract top 3 alternatives from top_logprobs | |
top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3] | |
row = [token, f"{logprob:.4f}"] | |
for alt_token, alt_logprob in top_3: | |
row.append(f"{alt_token}: {alt_logprob:.4f}") | |
while len(row) < 5: | |
row.append("") | |
table_data.append(row) | |
df = ( | |
pd.DataFrame( | |
table_data, | |
columns=[ | |
"Token", | |
"Log Prob", | |
"Top 1 Alternative", | |
"Top 2 Alternative", | |
"Top 3 Alternative", | |
], | |
) | |
if table_data | |
else None | |
) | |
# Generate colored text | |
if logprobs: | |
min_logprob = min(logprobs) | |
max_logprob = max(logprobs) | |
if max_logprob == min_logprob: | |
normalized_probs = [0.5] * len(logprobs) | |
else: | |
normalized_probs = [ | |
(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs | |
] | |
colored_text = "" | |
for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)): | |
r = int(255 * (1 - norm_prob)) # Red for low confidence | |
g = int(255 * norm_prob) # Green for high confidence | |
b = 0 | |
color = f"rgb({r}, {g}, {b})" | |
colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>' | |
if i < len(tokens) - 1: | |
colored_text += " " | |
colored_text_html = f"<p>{colored_text}</p>" | |
else: | |
colored_text_html = "No finite log probabilities to display." | |
# Top 3 Token Log Probabilities | |
alt_viz_html = "" | |
if logprobs and top_alternatives: | |
alt_viz_html = "<h3>Top 3 Token Log Probabilities</h3><ul>" | |
for i, (token, probs) in enumerate(zip(tokens, top_alternatives)): | |
alt_viz_html += f"<li>Position {i} (Token: {token}):<br>" | |
for tok, prob in probs: | |
alt_viz_html += f"{tok}: {prob:.4f}<br>" | |
alt_viz_html += "</li>" | |
alt_viz_html += "</ul>" | |
return (img_main_html, df, colored_text_html, alt_viz_html, img_cluster_html, img_drops_html, | |
img_ngram_html, img_markov_html, img_anomaly_html, img_autocorr_html, img_smoothing_html, | |
img_uncertainty_html, img_corr_html, img_type_html, img_embed_html, img_bayesian_html, | |
img_graph_html, img_tsne_html, img_heatmap_html, img_dist_html) | |
except Exception as e: | |
logger.error("Visualization failed: %s", str(e)) | |
return (f"Error: {str(e)}", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) | |
# Gradio interface with dynamic filtering | |
with gr.Blocks(title="Log Probability Visualizer") as app: | |
gr.Markdown("# Log Probability Visualizer") | |
gr.Markdown( | |
"Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use the filter to focus on specific log probability ranges." | |
) | |
with gr.Row(): | |
json_input = gr.Textbox( | |
label="JSON Input", | |
lines=10, | |
placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...", | |
) | |
prob_filter = gr.Slider(minimum=-float('inf'), maximum=0, value=-float('inf'), label="Log Probability Filter (≥)") | |
with gr.Row(): | |
plot_output = gr.HTML(label="Log Probability Plot (Click for Tokens)") | |
cluster_output = gr.HTML(label="K-Means Clustering") | |
drops_output = gr.HTML(label="Probability Drops") | |
with gr.Row(): | |
ngram_output = gr.HTML(label="N-Gram Analysis") | |
markov_output = gr.HTML(label="Markov Chain") | |
with gr.Row(): | |
anomaly_output = gr.HTML(label="Anomaly Detection") | |
autocorr_output = gr.HTML(label="Autocorrelation") | |
with gr.Row(): | |
smoothing_output = gr.HTML(label="Smoothing (Moving Average)") | |
uncertainty_output = gr.HTML(label="Uncertainty Propagation") | |
with gr.Row(): | |
corr_output = gr.HTML(label="Correlation Heatmap") | |
type_output = gr.HTML(label="Token Type Correlation") | |
with gr.Row(): | |
embed_output = gr.HTML(label="Embedding Similarity vs. Probability") | |
bayesian_output = gr.HTML(label="Bayesian Inference (Entropy)") | |
with gr.Row(): | |
graph_output = gr.HTML(label="Graph of Token Transitions") | |
tsne_output = gr.HTML(label="t-SNE of Log Probabilities") | |
with gr.Row(): | |
heatmap_output = gr.HTML(label="Interactive Heatmap") | |
dist_output = gr.HTML(label="Probability Distribution") | |
table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives") | |
text_output = gr.HTML(label="Colored Text (Confidence Visualization)") | |
alt_viz_output = gr.HTML(label="Top 3 Token Log Probabilities") | |
btn = gr.Button("Visualize") | |
btn.click( | |
fn=visualize_logprobs, | |
inputs=[json_input, prob_filter], | |
outputs=[ | |
plot_output, table_output, text_output, alt_viz_output, | |
cluster_output, drops_output, ngram_output, markov_output, | |
anomaly_output, autocorr_output, smoothing_output, uncertainty_output, | |
corr_output, type_output, embed_output, bayesian_output, | |
graph_output, tsne_output, heatmap_output, dist_output | |
], | |
) | |
app.launch() |