codelion commited on
Commit
181b7be
·
verified ·
1 Parent(s): f2687d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -113
app.py CHANGED
@@ -4,91 +4,51 @@ import matplotlib.pyplot as plt
4
  import pandas as pd
5
  import io
6
  import base64
7
- import ast
8
  import math
9
 
10
- # Function to safely convert string representations of infinity
11
- def parse_infinity(value):
12
- if isinstance(value, str):
13
- if value.lower() == '-infinity' or value.lower() == '-inf':
14
- return float('-inf')
15
- elif value.lower() == 'infinity' or value.lower() == 'inf':
16
- return float('inf')
17
- return value
18
-
19
  # Function to process and visualize log probs
20
  def visualize_logprobs(json_input):
21
  try:
22
- # Try to parse as JSON first, handling string representations of infinity
23
- try:
24
- # Attempt to load JSON, replacing -inf with "-Infinity" if needed
25
- def replace_inf(s):
26
- import re
27
- return re.sub(r'-inf', '"-Infinity"', re.sub(r'inf', '"Infinity"', s))
28
-
29
- data = json.loads(replace_inf(json_input))
30
- # Convert string "Infinity" or "-Infinity" back to float if needed
31
- if isinstance(data, dict) and 'content' in data:
32
- for entry in data['content']:
33
- if 'logprob' in entry:
34
- entry['logprob'] = parse_infinity(entry['logprob'])
35
- if 'top_logprobs' in entry:
36
- entry['top_logprobs'] = {k: parse_infinity(v) for k, v in entry['top_logprobs'].items()}
37
- elif isinstance(data, list):
38
- for entry in data:
39
- if 'logprob' in entry:
40
- entry['logprob'] = parse_infinity(entry['logprob'])
41
- if 'top_logprobs' in entry:
42
- entry['top_logprobs'] = {k: parse_infinity(v) for k, v in entry['top_logprobs'].items()}
43
-
44
- except json.JSONDecodeError:
45
- # If JSON fails, try to parse as Python literal (e.g., with single quotes)
46
- try:
47
- data = ast.literal_eval(json_input)
48
- # Ensure -inf is handled as float('-inf')
49
- if isinstance(data, dict) and 'content' in data:
50
- for entry in data['content']:
51
- if 'logprob' in entry and isinstance(entry['logprob'], str):
52
- entry['logprob'] = parse_infinity(entry['logprob'])
53
- if 'top_logprobs' in entry:
54
- entry['top_logprobs'] = {k: parse_infinity(v) for k, v in entry['top_logprobs'].items()}
55
- elif isinstance(data, list):
56
- for entry in data:
57
- if 'logprob' in entry and isinstance(entry['logprob'], str):
58
- entry['logprob'] = parse_infinity(entry['logprob'])
59
- if 'top_logprobs' in entry:
60
- entry['top_logprobs'] = {k: parse_infinity(v) for k, v in entry['top_logprobs'].items()}
61
- except (SyntaxError, ValueError) as e:
62
- raise ValueError(f"Malformed input: {str(e)}")
63
-
64
- # Ensure data is a list or dictionary with 'content'
65
- if isinstance(data, dict) and 'content' in data:
66
- content = data['content']
67
  elif isinstance(data, list):
68
  content = data
69
  else:
70
  raise ValueError("Input must be a list or dictionary with 'content' key")
71
-
72
- # Extract tokens and log probs, skipping None values and handling non-finite values
73
  tokens = []
74
  logprobs = []
75
  for entry in content:
76
- if 'logprob' in entry and entry['logprob'] is not None and math.isfinite(entry['logprob']):
77
- tokens.append(entry['token'])
78
- logprobs.append(entry['logprob'])
79
-
80
- # Prepare data for the table
 
 
 
 
81
  table_data = []
82
  for entry in content:
83
- if 'logprob' in entry and entry['logprob'] is not None and math.isfinite(entry['logprob']):
84
- token = entry['token']
85
- logprob = entry['logprob']
86
- top_logprobs = entry.get('top_logprobs', {})
87
-
88
- # Filter out non-finite (e.g., -inf, inf, nan) log probs from top_logprobs
89
- finite_top_logprobs = {k: v for k, v in top_logprobs.items() if math.isfinite(v)}
90
- # Extract top 3 finite alternatives, sorted by log prob (most probable first)
91
- top_3 = sorted(finite_top_logprobs.items(), key=lambda x: x[1], reverse=True)[:3]
 
 
 
 
 
 
 
92
  row = [token, f"{logprob:.4f}"]
93
  for alt_token, alt_logprob in top_3:
94
  row.append(f"{alt_token}: {alt_logprob:.4f}")
@@ -96,88 +56,98 @@ def visualize_logprobs(json_input):
96
  while len(row) < 5:
97
  row.append("")
98
  table_data.append(row)
99
-
100
- # Create the plot (only for finite log probs)
101
  if logprobs:
102
  plt.figure(figsize=(10, 5))
103
- plt.plot(range(len(logprobs)), logprobs, marker='o', linestyle='-', color='b')
104
  plt.title("Log Probabilities of Generated Tokens")
105
  plt.xlabel("Token Position")
106
  plt.ylabel("Log Probability")
107
  plt.grid(True)
108
- plt.xticks(range(len(logprobs)), tokens, rotation=45, ha='right')
109
  plt.tight_layout()
110
-
111
  # Save plot to a bytes buffer
112
  buf = io.BytesIO()
113
- plt.savefig(buf, format='png', bbox_inches='tight')
114
  buf.seek(0)
115
  plt.close()
116
-
117
- # Convert buffer to base64 for Gradio
118
  img_bytes = buf.getvalue()
119
- img_base64 = base64.b64encode(img_bytes).decode('utf-8')
120
  img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
121
  else:
122
  img_html = "No finite log probabilities to plot."
123
-
124
- # Create a DataFrame for the table
125
- df = pd.DataFrame(
126
- table_data,
127
- columns=["Token", "Log Prob", "Top 1 Alternative", "Top 2 Alternative", "Top 3 Alternative"]
128
- ) if table_data else None
129
-
130
- # Generate colored text based on log probabilities
 
 
 
 
 
 
 
 
 
 
131
  if logprobs:
132
- # Normalize log probs to [0, 1] for color scaling (0 = most uncertain, 1 = most confident)
133
  min_logprob = min(logprobs)
134
  max_logprob = max(logprobs)
135
  if max_logprob == min_logprob:
136
- normalized_probs = [0.5] * len(logprobs) # Avoid division by zero
137
  else:
138
- normalized_probs = [(lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs]
139
-
140
- # Create HTML for colored text
 
141
  colored_text = ""
142
  for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
143
- # Map normalized probability to RGB color (green for high confidence, red for low)
144
- r = int(255 * (1 - norm_prob)) # Red increases as uncertainty increases
145
- g = int(255 * norm_prob) # Green decreases as uncertainty increases
146
- b = 0 # Blue stays 0 for simplicity
147
- color = f'rgb({r}, {g}, {b})'
148
  colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
149
  if i < len(tokens) - 1:
150
- colored_text += " " # Add space between tokens
151
-
152
- colored_text_html = f'<p>{colored_text}</p>'
153
  else:
154
  colored_text_html = "No finite log probabilities to display."
155
-
156
  return img_html, df, colored_text_html
157
-
158
  except Exception as e:
159
  return f"Error: {str(e)}", None, None
160
 
161
  # Gradio interface
162
  with gr.Blocks(title="Log Probability Visualizer") as app:
163
  gr.Markdown("# Log Probability Visualizer")
164
- gr.Markdown("Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities.")
165
-
166
- # Input
167
- json_input = gr.Textbox(label="JSON Input", lines=10, placeholder="Paste your JSON or Python dict here...")
168
-
169
- # Outputs
 
 
 
 
170
  plot_output = gr.HTML(label="Log Probability Plot")
171
  table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
172
  text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
173
-
174
- # Button to trigger visualization
175
  btn = gr.Button("Visualize")
176
  btn.click(
177
  fn=visualize_logprobs,
178
  inputs=json_input,
179
- outputs=[plot_output, table_output, text_output]
180
  )
181
 
182
- # Launch the app
183
  app.launch()
 
4
  import pandas as pd
5
  import io
6
  import base64
 
7
  import math
8
 
 
 
 
 
 
 
 
 
 
9
  # Function to process and visualize log probs
10
  def visualize_logprobs(json_input):
11
  try:
12
+ # Parse the JSON input
13
+ data = json.loads(json_input)
14
+ if isinstance(data, dict) and "content" in data:
15
+ content = data["content"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  elif isinstance(data, list):
17
  content = data
18
  else:
19
  raise ValueError("Input must be a list or dictionary with 'content' key")
20
+
21
+ # Extract tokens and log probs, skipping None or non-finite values
22
  tokens = []
23
  logprobs = []
24
  for entry in content:
25
+ if (
26
+ "logprob" in entry
27
+ and entry["logprob"] is not None
28
+ and math.isfinite(entry["logprob"])
29
+ ):
30
+ tokens.append(entry["token"])
31
+ logprobs.append(entry["logprob"])
32
+
33
+ # Prepare table data, handling None in top_logprobs
34
  table_data = []
35
  for entry in content:
36
+ # Only include entries with finite logprob and non-None top_logprobs
37
+ if (
38
+ "logprob" in entry
39
+ and entry["logprob"] is not None
40
+ and math.isfinite(entry["logprob"])
41
+ and "top_logprobs" in entry
42
+ and entry["top_logprobs"] is not None
43
+ ):
44
+ token = entry["token"]
45
+ logprob = entry["logprob"]
46
+ top_logprobs = entry["top_logprobs"]
47
+
48
+ # Extract top 3 alternatives from top_logprobs
49
+ top_3 = sorted(
50
+ top_logprobs.items(), key=lambda x: x[1], reverse=True
51
+ )[:3]
52
  row = [token, f"{logprob:.4f}"]
53
  for alt_token, alt_logprob in top_3:
54
  row.append(f"{alt_token}: {alt_logprob:.4f}")
 
56
  while len(row) < 5:
57
  row.append("")
58
  table_data.append(row)
59
+
60
+ # Create the plot
61
  if logprobs:
62
  plt.figure(figsize=(10, 5))
63
+ plt.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b")
64
  plt.title("Log Probabilities of Generated Tokens")
65
  plt.xlabel("Token Position")
66
  plt.ylabel("Log Probability")
67
  plt.grid(True)
68
+ plt.xticks(range(len(logprobs)), tokens, rotation=45, ha="right")
69
  plt.tight_layout()
70
+
71
  # Save plot to a bytes buffer
72
  buf = io.BytesIO()
73
+ plt.savefig(buf, format="png", bbox_inches="tight")
74
  buf.seek(0)
75
  plt.close()
76
+
77
+ # Convert to base64 for Gradio
78
  img_bytes = buf.getvalue()
79
+ img_base64 = base64.b64encode(img_bytes).decode("utf-8")
80
  img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">'
81
  else:
82
  img_html = "No finite log probabilities to plot."
83
+
84
+ # Create DataFrame for the table
85
+ df = (
86
+ pd.DataFrame(
87
+ table_data,
88
+ columns=[
89
+ "Token",
90
+ "Log Prob",
91
+ "Top 1 Alternative",
92
+ "Top 2 Alternative",
93
+ "Top 3 Alternative",
94
+ ],
95
+ )
96
+ if table_data
97
+ else None
98
+ )
99
+
100
+ # Generate colored text
101
  if logprobs:
 
102
  min_logprob = min(logprobs)
103
  max_logprob = max(logprobs)
104
  if max_logprob == min_logprob:
105
+ normalized_probs = [0.5] * len(logprobs)
106
  else:
107
+ normalized_probs = [
108
+ (lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs
109
+ ]
110
+
111
  colored_text = ""
112
  for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
113
+ r = int(255 * (1 - norm_prob)) # Red for low confidence
114
+ g = int(255 * norm_prob) # Green for high confidence
115
+ b = 0
116
+ color = f"rgb({r}, {g}, {b})"
 
117
  colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
118
  if i < len(tokens) - 1:
119
+ colored_text += " "
120
+ colored_text_html = f"<p>{colored_text}</p>"
 
121
  else:
122
  colored_text_html = "No finite log probabilities to display."
123
+
124
  return img_html, df, colored_text_html
125
+
126
  except Exception as e:
127
  return f"Error: {str(e)}", None, None
128
 
129
  # Gradio interface
130
  with gr.Blocks(title="Log Probability Visualizer") as app:
131
  gr.Markdown("# Log Probability Visualizer")
132
+ gr.Markdown(
133
+ "Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities."
134
+ )
135
+
136
+ json_input = gr.Textbox(
137
+ label="JSON Input",
138
+ lines=10,
139
+ placeholder="Paste your JSON or Python dict here...",
140
+ )
141
+
142
  plot_output = gr.HTML(label="Log Probability Plot")
143
  table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
144
  text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
145
+
 
146
  btn = gr.Button("Visualize")
147
  btn.click(
148
  fn=visualize_logprobs,
149
  inputs=json_input,
150
+ outputs=[plot_output, table_output, text_output],
151
  )
152
 
 
153
  app.launch()