codelion commited on
Commit
0244d3c
·
verified ·
1 Parent(s): b9aefbe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import matplotlib.pyplot as plt
4
+ import pandas as pd
5
+ from io import StringIO
6
+
7
+ # Function to process and visualize log probs
8
+ def visualize_logprobs(json_input):
9
+ try:
10
+ # Parse the JSON input
11
+ data = json.loads(json_input)
12
+
13
+ # Extract tokens and log probs, skipping None values
14
+ tokens = [entry['token'] for entry in data['content'] if entry['logprob'] is not None]
15
+ logprobs = [entry['logprob'] for entry in data['content'] if entry['logprob'] is not None]
16
+
17
+ # Prepare data for the table
18
+ table_data = []
19
+ for entry in data['content']:
20
+ if entry['logprob'] is not None:
21
+ token = entry['token']
22
+ logprob = entry['logprob']
23
+ top_logprobs = entry['top_logprobs']
24
+ # Extract top 3 alternatives
25
+ top_3 = sorted(top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
26
+ row = [token, f"{logprob:.4f}"]
27
+ for alt_token, alt_logprob in top_3:
28
+ row.append(f"{alt_token}: {alt_logprob:.4f}")
29
+ # Pad with empty strings if fewer than 3 alternatives
30
+ while len(row) < 5:
31
+ row.append("")
32
+ table_data.append(row)
33
+
34
+ # Create the plot
35
+ plt.figure(figsize=(10, 5))
36
+ plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b')
37
+ plt.title("Log Probabilities of Generated Tokens")
38
+ plt.xlabel("Token Position")
39
+ plt.ylabel("Log Probability")
40
+ plt.grid(True)
41
+ plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right')
42
+ plt.tight_layout()
43
+
44
+ # Save plot to a buffer
45
+ img_buffer = StringIO()
46
+ plt.savefig(img_buffer, format='png', bbox_inches='tight')
47
+ img_buffer.seek(0)
48
+ plt.close()
49
+
50
+ # Create a DataFrame for the table
51
+ df = pd.DataFrame(
52
+ table_data,
53
+ columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
54
+ )
55
+
56
+ return img_buffer, df
57
+
58
+ except Exception as e:
59
+ return f"Error: {str(e)}", None
60
+
61
+ # Gradio interface
62
+ with gr.Blocks(title="Log Probability Visualizer") as app:
63
+ gr.Markdown("# Log Probability Visualizer")
64
+ gr.Markdown("Paste your JSON log prob data below to visualize the tokens and their probabilities.")
65
+
66
+ # Input
67
+ json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON here...")
68
+
69
+ # Outputs
70
+ plot_output = gr.Image(label="Log Probability Plot")
71
+ table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
72
+
73
+ # Button to trigger visualization
74
+ btn = gr.Button("Visualize")
75
+ btn.click(
76
+ fn=visualize_logprobs,
77
+ inputs=json_input,
78
+ outputs=[plot_output, table_output]
79
+ )
80
+
81
+ # Launch the app
82
+ app.launch()