File size: 4,967 Bytes
0244d3c
 
 
 
6934db6
 
7e141c2
0244d3c
 
 
 
9825333
 
 
 
 
 
0244d3c
 
 
 
 
 
 
 
 
 
 
 
6934db6
0244d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6934db6
 
 
 
0244d3c
 
6934db6
 
 
 
 
0244d3c
 
 
 
 
 
7e141c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0244d3c
 
7e141c2
0244d3c
 
 
 
9825333
0244d3c
 
9825333
0244d3c
 
6934db6
0244d3c
7e141c2
0244d3c
 
 
 
 
 
7e141c2
0244d3c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import json
import matplotlib.pyplot as plt
import pandas as pd
import io
import base64
import ast

# Function to process and visualize log probs
def visualize_logprobs(json_input):
    try:
        # Try to parse as JSON first
        try:
            data = json.loads(json_input)
        except json.JSONDecodeError:
            # If JSON fails, try to parse as Python literal (e.g., with single quotes)
            data = ast.literal_eval(json_input)
        
        # Extract tokens and log probs, skipping None values
        tokens = [entry['token'] for entry in data['content'] if entry['logprob'] is not None]
        logprobs = [entry['logprob'] for entry in data['content'] if entry['logprob'] is not None]
        
        # Prepare data for the table
        table_data = []
        for entry in data['content']:
            if entry['logprob'] is not None:
                token = entry['token']
                logprob = entry['logprob']
                top_logprobs = entry['top_logprobs']
                # Extract top 3 alternatives, sorted by log prob (most probable first)
                top_3 = sorted(top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
                row = [token, f"{logprob:.4f}"]
                for alt_token, alt_logprob in top_3:
                    row.append(f"{alt_token}: {alt_logprob:.4f}")
                # Pad with empty strings if fewer than 3 alternatives
                while len(row) < 5:
                    row.append("")
                table_data.append(row)
        
        # Create the plot
        plt.figure(figsize=(10, 5))
        plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b')
        plt.title("Log Probabilities of Generated Tokens")
        plt.xlabel("Token Position")
        plt.ylabel("Log Probability")
        plt.grid(True)
        plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right')
        plt.tight_layout()
        
        # Save plot to a bytes buffer
        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight')
        buf.seek(0)
        plt.close()
        
        # Convert buffer to base64 for Gradio
        img_bytes = buf.getvalue()
        img_base64 = base64.b64encode(img_bytes).decode('utf-8')
        img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
        
        # Create a DataFrame for the table
        df = pd.DataFrame(
            table_data,
            columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
        )
        
        # Generate colored text based on log probabilities
        # Normalize log probs to [0, 1] for color scaling (0 = most uncertain, 1 = most confident)
        min_logprob = min(logprobs)
        max_logprob = max(logprobs)
        if max_logprob == min_logprob:
            normalized_probs = [0.5] * len(logprobs)  # Avoid division by zero
        else:
            normalized_probs = [(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs]
        
        # Create HTML for colored text
        colored_text = ""
        for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
            # Map normalized probability to RGB color (green for high confidence, red for low)
            # Use a simple linear interpolation: green (0, 255, 0) to red (255, 0, 0)
            r = int(255 * (1 - norm_prob))  # Red increases as uncertainty increases
            g = int(255 * norm_prob)        # Green decreases as uncertainty increases
            b = 0                           # Blue stays 0 for simplicity
            color = f'rgb({r}, {g}, {b})'
            colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
            if i < len(tokens) - 1:
                colored_text += " "  # Add space between tokens
        
        colored_text_html = f'<p>{colored_text}</p>'
        
        return img_html, df, colored_text_html
    
    except Exception as e:
        return f"Error: {str(e)}", None, None

# Gradio interface
with gr.Blocks(title="Log Probability Visualizer") as app:
    gr.Markdown("# Log Probability Visualizer")
    gr.Markdown("Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities.")
    
    # Input
    json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON or Python dict here...")
    
    # Outputs
    plot_output = gr.HTML(label="Log Probability Plot")
    table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
    text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
    
    # Button to trigger visualization
    btn = gr.Button("Visualize")
    btn.click(
        fn=visualize_logprobs,
        inputs=json_input,
        outputs=[plot_output, table_output, text_output]
    )

# Launch the app
app.launch()