codelion commited on
Commit
6b2ca38
·
verified ·
1 Parent(s): b766b6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -82
app.py CHANGED
@@ -9,7 +9,6 @@ import ast
9
  import logging
10
  import numpy as np
11
  import plotly.graph_objects as go
12
- from plotly.subplots import make_subplots
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.DEBUG)
@@ -58,8 +57,12 @@ def ensure_float(value):
58
  return float(value)
59
  return None
60
 
 
 
 
 
61
  # Function to process and visualize log probs with interactive Plotly plots
62
- def visualize_logprobs(json_input, page=0):
63
  try:
64
  # Parse the input (handles both JSON and Python dictionaries)
65
  data = parse_input(json_input)
@@ -72,7 +75,7 @@ def visualize_logprobs(json_input, page=0):
72
  else:
73
  raise ValueError("Input must be a list or dictionary with 'content' key")
74
 
75
- # Extract tokens, log probs, and top alternatives, skipping None or non-finite values with fixed filter of -100000
76
  tokens = []
77
  logprobs = []
78
  top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
@@ -100,20 +103,11 @@ def visualize_logprobs(json_input, page=0):
100
 
101
  # Check if there's valid data after filtering
102
  if not logprobs or not tokens:
103
- return (gr.update(value="No finite log probabilities or tokens to visualize after filtering"), None, None, None, 1, 0)
104
-
105
- # Paginate data for large inputs (fixed page size of 1000)
106
- page_size = 1000
107
- total_pages = max(1, (len(logprobs) + page_size - 1) // page_size)
108
- start_idx = page * page_size
109
- end_idx = min((page + 1) * page_size, len(logprobs))
110
- paginated_tokens = tokens[start_idx:end_idx]
111
- paginated_logprobs = logprobs[start_idx:end_idx]
112
- paginated_alternatives = top_alternatives[start_idx:end_idx] if top_alternatives else []
113
 
114
  # 1. Main Log Probability Plot (Interactive Plotly)
115
  main_fig = go.Figure()
116
- main_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
117
  main_fig.update_layout(
118
  title="Log Probabilities of Generated Tokens",
119
  xaxis_title="Token Position",
@@ -122,16 +116,15 @@ def visualize_logprobs(json_input, page=0):
122
  clickmode='event+select'
123
  )
124
  main_fig.update_traces(
125
- customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, prob) in enumerate(zip(paginated_tokens, paginated_logprobs))],
126
  hovertemplate='<b>%{customdata}</b><extra></extra>'
127
  )
128
 
129
  # 2. Probability Drop Analysis (Interactive Plotly)
130
- if len(paginated_logprobs) < 2:
131
- drops_fig = go.Figure()
132
- drops_fig.add_trace(go.Bar(x=list(range(len(paginated_logprobs)-1)), y=[0], name='Drop', marker_color='red'))
133
  else:
134
- drops = [paginated_logprobs[i+1] - paginated_logprobs[i] for i in range(len(paginated_logprobs)-1)]
135
  drops_fig = go.Figure()
136
  drops_fig.add_trace(go.Bar(x=list(range(len(drops))), y=drops, name='Drop', marker_color='red'))
137
  drops_fig.update_layout(
@@ -142,13 +135,13 @@ def visualize_logprobs(json_input, page=0):
142
  clickmode='event+select'
143
  )
144
  drops_fig.update_traces(
145
- customdata=[f"Drop: {drop:.4f}, From: {paginated_tokens[i]} to {paginated_tokens[i+1]}, Position: {i+start_idx}" for i, drop in enumerate(drops)],
146
  hovertemplate='<b>%{customdata}</b><extra></extra>'
147
  )
148
 
149
- # Create DataFrame for the table (paginated)
150
  table_data = []
151
- for i, entry in enumerate(content[start_idx:end_idx]):
152
  logprob = ensure_float(entry.get("logprob", None))
153
  if logprob is not None and math.isfinite(logprob) and logprob >= -100000 and "top_logprobs" in entry and entry["top_logprobs"] is not None:
154
  token = entry["token"]
@@ -183,38 +176,38 @@ def visualize_logprobs(json_input, page=0):
183
  else None
184
  )
185
 
186
- # Generate colored text (paginated)
187
- if paginated_logprobs:
188
- min_logprob = min(paginated_logprobs)
189
- max_logprob = max(paginated_logprobs)
190
  if max_logprob == min_logprob:
191
- normalized_probs = [0.5] * len(paginated_logprobs)
192
  else:
193
  normalized_probs = [
194
- (lp - min_logprob) / (max_logprob - min_logprob) for lp in paginated_logprobs
195
  ]
196
 
197
  colored_text = ""
198
- for i, (token, norm_prob) in enumerate(zip(paginated_tokens, normalized_probs)):
199
  r = int(255 * (1 - norm_prob)) # Red for low confidence
200
  g = int(255 * norm_prob) # Green for high confidence
201
  b = 0
202
  color = f"rgb({r}, {g}, {b})"
203
  colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
204
- if i < len(paginated_tokens) - 1:
205
  colored_text += " "
206
  colored_text_html = f"<p>{colored_text}</p>"
207
  else:
208
  colored_text_html = "No finite log probabilities to display."
209
 
210
- # Top 3 Token Log Probabilities (paginated)
211
- alt_viz_fig = go.Figure()
212
- if paginated_logprobs and paginated_alternatives:
213
- for i, (token, probs) in enumerate(zip(paginated_tokens, paginated_alternatives)):
214
  for j, (alt_tok, prob) in enumerate(probs):
215
- alt_viz_fig.add_trace(go.Bar(x=[f"{token} (Pos {i+start_idx})"], y=[prob], name=f"{alt_tok}", marker_color=['blue', 'green', 'red'][j]))
216
  alt_viz_fig.update_layout(
217
- title="Top 3 Token Log Probabilities (Paginated)",
218
  xaxis_title="Token (Position)",
219
  yaxis_title="Log Probability",
220
  barmode='stack',
@@ -222,35 +215,29 @@ def visualize_logprobs(json_input, page=0):
222
  clickmode='event+select'
223
  )
224
  alt_viz_fig.update_traces(
225
- customdata=[f"Token: {tok}, Alt: {alt}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, alts) in enumerate(zip(paginated_tokens, paginated_alternatives)) for alt, prob in alts],
226
  hovertemplate='<b>%{customdata}</b><extra></extra>'
227
  )
228
- alt_viz_html = alt_viz_fig.to_html(include_plotlyjs='cdn', full_html=False)
229
- else:
230
- alt_viz_html = "No finite log probabilities to display."
231
 
232
- return (main_fig, df, colored_text_html, alt_viz_html, drops_fig, total_pages, page)
233
 
234
  except Exception as e:
235
  logger.error("Visualization failed: %s", str(e))
236
- return (gr.update(value=f"Error: {str(e)}"), None, "No finite log probabilities to display.", None, gr.update(value="No data for probability drops."), 1, 0)
237
 
238
- # Gradio interface with interactive layout and pagination
239
  with gr.Blocks(title="Log Probability Visualizer") as app:
240
  gr.Markdown("# Log Probability Visualizer")
241
  gr.Markdown(
242
- "Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Use pagination to navigate large inputs (fixed filter ≥ -100000, 1000 tokens per page)."
243
  )
244
 
245
  with gr.Row():
246
- with gr.Column(scale=1):
247
- json_input = gr.Textbox(
248
- label="JSON Input",
249
- lines=10,
250
- placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
251
- )
252
- with gr.Column(scale=1):
253
- page = gr.Number(value=0, label="Page Number", precision=0, minimum=0)
254
 
255
  with gr.Row():
256
  plot_output = gr.Plot(label="Log Probability Plot (Click for Tokens)")
@@ -266,36 +253,8 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
266
  btn = gr.Button("Visualize")
267
  btn.click(
268
  fn=visualize_logprobs,
269
- inputs=[json_input, page],
270
- outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, gr.State(), gr.State()],
271
- )
272
-
273
- # Pagination controls
274
- with gr.Row():
275
- prev_btn = gr.Button("Previous Page")
276
- next_btn = gr.Button("Next Page")
277
- total_pages_output = gr.Number(label="Total Pages", interactive=False)
278
- current_page_output = gr.Number(label="Current Page", interactive=False)
279
-
280
- def update_page(json_input, current_page, action):
281
- if action == "prev" and current_page > 0:
282
- current_page -= 1
283
- elif action == "next":
284
- total_pages = visualize_logprobs(json_input, 0)[5] # Get total pages
285
- if current_page < total_pages - 1:
286
- current_page += 1
287
- return gr.update(value=current_page), gr.update(value=total_pages)
288
-
289
- prev_btn.click(
290
- fn=update_page,
291
- inputs=[json_input, page, gr.State()],
292
- outputs=[page, total_pages_output]
293
- )
294
-
295
- next_btn.click(
296
- fn=update_page,
297
- inputs=[json_input, page, gr.State()],
298
- outputs=[page, total_pages_output]
299
  )
300
 
301
  app.launch()
 
9
  import logging
10
  import numpy as np
11
  import plotly.graph_objects as go
 
12
 
13
  # Set up logging
14
  logging.basicConfig(level=logging.DEBUG)
 
57
  return float(value)
58
  return None
59
 
60
+ # Function to create an empty Plotly figure
61
+ def create_empty_figure(title):
62
+ return go.Figure().update_layout(title=title, xaxis_title="", yaxis_title="", showlegend=False)
63
+
64
  # Function to process and visualize log probs with interactive Plotly plots
65
+ def visualize_logprobs(json_input):
66
  try:
67
  # Parse the input (handles both JSON and Python dictionaries)
68
  data = parse_input(json_input)
 
75
  else:
76
  raise ValueError("Input must be a list or dictionary with 'content' key")
77
 
78
+ # Extract tokens and log probs, skipping None or non-finite values with fixed filter of -100000
79
  tokens = []
80
  logprobs = []
81
  top_alternatives = [] # List to store top 3 log probs (selected token + 2 alternatives)
 
103
 
104
  # Check if there's valid data after filtering
105
  if not logprobs or not tokens:
106
+ return (create_empty_figure("Log Probabilities of Generated Tokens"), None, "No finite log probabilities to display.", create_empty_figure("Top 3 Token Log Probabilities"), create_empty_figure("Significant Probability Drops"))
 
 
 
 
 
 
 
 
 
107
 
108
  # 1. Main Log Probability Plot (Interactive Plotly)
109
  main_fig = go.Figure()
110
+ main_fig.add_trace(go.Scatter(x=list(range(len(logprobs))), y=logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
111
  main_fig.update_layout(
112
  title="Log Probabilities of Generated Tokens",
113
  xaxis_title="Token Position",
 
116
  clickmode='event+select'
117
  )
118
  main_fig.update_traces(
119
+ customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i}" for i, (tok, prob) in enumerate(zip(tokens, logprobs))],
120
  hovertemplate='<b>%{customdata}</b><extra></extra>'
121
  )
122
 
123
  # 2. Probability Drop Analysis (Interactive Plotly)
124
+ if len(logprobs) < 2:
125
+ drops_fig = create_empty_figure("Significant Probability Drops")
 
126
  else:
127
+ drops = [logprobs[i+1] - logprobs[i] for i in range(len(logprobs)-1)]
128
  drops_fig = go.Figure()
129
  drops_fig.add_trace(go.Bar(x=list(range(len(drops))), y=drops, name='Drop', marker_color='red'))
130
  drops_fig.update_layout(
 
135
  clickmode='event+select'
136
  )
137
  drops_fig.update_traces(
138
+ customdata=[f"Drop: {drop:.4f}, From: {tokens[i]} to {tokens[i+1]}, Position: {i}" for i, drop in enumerate(drops)],
139
  hovertemplate='<b>%{customdata}</b><extra></extra>'
140
  )
141
 
142
+ # Create DataFrame for the table
143
  table_data = []
144
+ for i, entry in enumerate(content):
145
  logprob = ensure_float(entry.get("logprob", None))
146
  if logprob is not None and math.isfinite(logprob) and logprob >= -100000 and "top_logprobs" in entry and entry["top_logprobs"] is not None:
147
  token = entry["token"]
 
176
  else None
177
  )
178
 
179
+ # Generate colored text
180
+ if logprobs:
181
+ min_logprob = min(logprobs)
182
+ max_logprob = max(logprobs)
183
  if max_logprob == min_logprob:
184
+ normalized_probs = [0.5] * len(logprobs)
185
  else:
186
  normalized_probs = [
187
+ (lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs
188
  ]
189
 
190
  colored_text = ""
191
+ for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
192
  r = int(255 * (1 - norm_prob)) # Red for low confidence
193
  g = int(255 * norm_prob) # Green for high confidence
194
  b = 0
195
  color = f"rgb({r}, {g}, {b})"
196
  colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
197
+ if i < len(tokens) - 1:
198
  colored_text += " "
199
  colored_text_html = f"<p>{colored_text}</p>"
200
  else:
201
  colored_text_html = "No finite log probabilities to display."
202
 
203
+ # Top 3 Token Log Probabilities (Interactive Plotly)
204
+ alt_viz_fig = create_empty_figure("Top 3 Token Log Probabilities") if not logprobs or not top_alternatives else go.Figure()
205
+ if logprobs and top_alternatives:
206
+ for i, (token, probs) in enumerate(zip(tokens, top_alternatives)):
207
  for j, (alt_tok, prob) in enumerate(probs):
208
+ alt_viz_fig.add_trace(go.Bar(x=[f"{token} (Pos {i})"], y=[prob], name=f"{alt_tok}", marker_color=['blue', 'green', 'red'][j]))
209
  alt_viz_fig.update_layout(
210
+ title="Top 3 Token Log Probabilities",
211
  xaxis_title="Token (Position)",
212
  yaxis_title="Log Probability",
213
  barmode='stack',
 
215
  clickmode='event+select'
216
  )
217
  alt_viz_fig.update_traces(
218
+ customdata=[f"Token: {tok}, Alt: {alt}, Log Prob: {prob:.4f}, Position: {i}" for i, (tok, alts) in enumerate(zip(tokens, top_alternatives)) for alt, prob in alts],
219
  hovertemplate='<b>%{customdata}</b><extra></extra>'
220
  )
 
 
 
221
 
222
+ return (main_fig, df, colored_text_html, alt_viz_fig, drops_fig)
223
 
224
  except Exception as e:
225
  logger.error("Visualization failed: %s", str(e))
226
+ return (create_empty_figure("Log Probabilities of Generated Tokens"), None, "No finite log probabilities to display.", create_empty_figure("Top 3 Token Log Probabilities"), create_empty_figure("Significant Probability Drops"))
227
 
228
+ # Gradio interface with improved layout
229
  with gr.Blocks(title="Log Probability Visualizer") as app:
230
  gr.Markdown("# Log Probability Visualizer")
231
  gr.Markdown(
232
+ "Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities. Fixed filter ≥ -100000, 1000 tokens per page."
233
  )
234
 
235
  with gr.Row():
236
+ json_input = gr.Textbox(
237
+ label="JSON Input",
238
+ lines=10,
239
+ placeholder="Paste your JSON (e.g., {\"content\": [...]}) or Python dict (e.g., {'content': [...]}) here...",
240
+ )
 
 
 
241
 
242
  with gr.Row():
243
  plot_output = gr.Plot(label="Log Probability Plot (Click for Tokens)")
 
253
  btn = gr.Button("Visualize")
254
  btn.click(
255
  fn=visualize_logprobs,
256
+ inputs=[json_input],
257
+ outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  )
259
 
260
  app.launch()