File size: 2,972 Bytes
0244d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import json
import matplotlib.pyplot as plt
import pandas as pd
from io import StringIO

# Function to process and visualize log probs
def visualize_logprobs(json_input):
    try:
        # Parse the JSON input
        data = json.loads(json_input)
        
        # Extract tokens and log probs, skipping None values
        tokens = [entry['token'] for entry in data['content'] if entry['logprob'] is not None]
        logprobs = [entry['logprob'] for entry in data['content'] if entry['logprob'] is not None]
        
        # Prepare data for the table
        table_data = []
        for entry in data['content']:
            if entry['logprob'] is not None:
                token = entry['token']
                logprob = entry['logprob']
                top_logprobs = entry['top_logprobs']
                # Extract top 3 alternatives
                top_3 = sorted(top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
                row = [token, f"{logprob:.4f}"]
                for alt_token, alt_logprob in top_3:
                    row.append(f"{alt_token}: {alt_logprob:.4f}")
                # Pad with empty strings if fewer than 3 alternatives
                while len(row) < 5:
                    row.append("")
                table_data.append(row)
        
        # Create the plot
        plt.figure(figsize=(10, 5))
        plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b')
        plt.title("Log Probabilities of Generated Tokens")
        plt.xlabel("Token Position")
        plt.ylabel("Log Probability")
        plt.grid(True)
        plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right')
        plt.tight_layout()
        
        # Save plot to a buffer
        img_buffer = StringIO()
        plt.savefig(img_buffer, format='png', bbox_inches='tight')
        img_buffer.seek(0)
        plt.close()
        
        # Create a DataFrame for the table
        df = pd.DataFrame(
            table_data,
            columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
        )
        
        return img_buffer, df
    
    except Exception as e:
        return f"Error: {str(e)}", None

# Gradio interface
with gr.Blocks(title="Log Probability Visualizer") as app:
    gr.Markdown("# Log Probability Visualizer")
    gr.Markdown("Paste your JSON log prob data below to visualize the tokens and their probabilities.")
    
    # Input
    json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON here...")
    
    # Outputs
    plot_output = gr.Image(label="Log Probability Plot")
    table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
    
    # Button to trigger visualization
    btn = gr.Button("Visualize")
    btn.click(
        fn=visualize_logprobs,
        inputs=json_input,
        outputs=[plot_output, table_output]
    )

# Launch the app
app.launch()