codelion commited on
Commit
9ba1537
·
verified ·
1 Parent(s): 0d41503

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -62
app.py CHANGED
@@ -9,6 +9,8 @@ import ast
9
  import logging
10
  import numpy as np
11
  import plotly.graph_objects as go
 
 
12
 
13
  # Set up logging
14
  logging.basicConfig(level=logging.DEBUG)
@@ -24,24 +26,7 @@ def parse_input(json_input):
24
  return data
25
  except json.JSONDecodeError as e:
26
  logger.error("JSON parsing failed: %s (Input: %s)", str(e), json_input[:100] + "..." if len(json_input) > 100 else json_input)
27
- try:
28
- # If JSON fails, try to parse as Python literal (e.g., with single quotes), but only for JSON-like strings
29
- data = ast.literal_eval(json_input)
30
- logger.debug("Successfully parsed as Python literal")
31
- # Convert Python dictionary to JSON-compatible format (replace single quotes with double quotes)
32
- def dict_to_json(obj):
33
- if isinstance(obj, dict):
34
- return {str(k): dict_to_json(v) for k, v in obj.items()}
35
- elif isinstance(obj, list):
36
- return [dict_to_json(item) for item in obj]
37
- else:
38
- return obj
39
- converted_data = dict_to_json(data)
40
- logger.debug("Converted to JSON-compatible format")
41
- return converted_data
42
- except (SyntaxError, ValueError) as e:
43
- logger.error("Python literal parsing failed: %s (Input: %s)", str(e), json_input[:100] + "..." if len(json_input) > 100 else json_input)
44
- raise ValueError(f"Malformed input: {str(e)}. Ensure property names are in double quotes (e.g., \"content\") and the format matches JSON (e.g., {{\"content\": [...]}}).")
45
 
46
  # Function to ensure a value is a float, converting from string if necessary
47
  def ensure_float(value):
@@ -69,10 +54,59 @@ def get_token(entry):
69
  def create_empty_figure(title):
70
  return go.Figure().update_layout(title=title, xaxis_title="", yaxis_title="", showlegend=False)
71
 
72
- # Function to process and visualize the full log probs with dynamic top_logprobs, handling missing tokens and JSON structure
73
- def visualize_logprobs(json_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  try:
75
- # Parse the input (handles JSON only, as specified)
76
  data = parse_input(json_input)
77
 
78
  # Ensure data is a dictionary with 'content' key containing a list
@@ -94,14 +128,13 @@ def visualize_logprobs(json_input):
94
  logger.warning("Skipping non-dictionary entry: %s", entry)
95
  continue
96
  logprob = ensure_float(entry.get("logprob", None))
97
- if logprob >= -100000: # Include all entries with default 0.0, removing math.isfinite check
98
- token = get_token(entry) # Safely get token, defaulting to "Unknown" if missing
99
- tokens.append(token)
100
  logprobs.append(logprob)
101
  # Get top_logprobs, default to empty dict if None
102
  top_probs = entry.get("top_logprobs", {})
103
  if top_probs is None:
104
- logger.debug("top_logprobs is None for token: %s, using empty dict", token)
105
  top_probs = {} # Default to empty dict for None
106
  # Ensure all values in top_logprobs are floats and create a list of tuples
107
  finite_top_probs = []
@@ -115,53 +148,61 @@ def visualize_logprobs(json_input):
115
  else:
116
  logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
117
 
118
- # Check if there's valid data after filtering (including default 0.0)
119
  if not logprobs or not tokens:
120
- return (create_empty_figure("Log Probabilities of Generated Tokens"), None, "No tokens to display.", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Significant Probability Drops"))
 
 
 
 
 
 
 
 
121
 
122
  # 1. Main Log Probability Plot (Interactive Plotly)
123
  main_fig = go.Figure()
124
- main_fig.add_trace(go.Scatter(x=list(range(len(logprobs))), y=logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
125
  main_fig.update_layout(
126
- title="Log Probabilities of Generated Tokens",
127
- xaxis_title="Token Position",
128
  yaxis_title="Log Probability",
129
  hovermode="closest",
130
  clickmode='event+select'
131
  )
132
  main_fig.update_traces(
133
- customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i}" for i, (tok, prob) in enumerate(zip(tokens, logprobs))],
134
  hovertemplate='<b>%{customdata}</b><extra></extra>'
135
  )
136
 
137
  # 2. Probability Drop Analysis (Interactive Plotly)
138
- if len(logprobs) < 2:
139
- drops_fig = create_empty_figure("Significant Probability Drops")
140
  else:
141
- drops = [logprobs[i+1] - logprobs[i] for i in range(len(logprobs)-1)]
142
  drops_fig = go.Figure()
143
  drops_fig.add_trace(go.Bar(x=list(range(len(drops))), y=drops, name='Drop', marker_color='red'))
144
  drops_fig.update_layout(
145
- title="Significant Probability Drops",
146
- xaxis_title="Token Position",
147
  yaxis_title="Log Probability Drop",
148
  hovermode="closest",
149
  clickmode='event+select'
150
  )
151
  drops_fig.update_traces(
152
- customdata=[f"Drop: {drop:.4f}, From: {tokens[i]} to {tokens[i+1]}, Position: {i}" for i, drop in enumerate(drops)],
153
  hovertemplate='<b>%{customdata}</b><extra></extra>'
154
  )
155
 
156
  # Create DataFrame for the table with dynamic top_logprobs
157
  table_data = []
158
- max_alternatives = max(len(alts) for alts in top_alternatives) if top_alternatives else 0
159
- for i, entry in enumerate(content):
160
  if not isinstance(entry, dict):
161
  continue
162
  logprob = ensure_float(entry.get("logprob", None))
163
  if logprob >= -100000 and "top_logprobs" in entry: # Include all entries with default 0.0
164
- token = get_token(entry) # Safely get token, defaulting to "Unknown" if missing
165
  top_logprobs = entry.get("top_logprobs", {})
166
  if top_logprobs is None:
167
  logger.debug("top_logprobs is None for token: %s, using empty dict", token)
@@ -191,38 +232,38 @@ def visualize_logprobs(json_input):
191
  else None
192
  )
193
 
194
- # Generate colored text
195
- if logprobs:
196
- min_logprob = min(logprobs)
197
- max_logprob = max(logprobs)
198
  if max_logprob == min_logprob:
199
- normalized_probs = [0.5] * len(logprobs)
200
  else:
201
  normalized_probs = [
202
- (lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs
203
  ]
204
 
205
  colored_text = ""
206
- for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)):
207
  r = int(255 * (1 - norm_prob)) # Red for low confidence
208
  g = int(255 * norm_prob) # Green for high confidence
209
  b = 0
210
  color = f"rgb({r}, {g}, {b})"
211
  colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
212
- if i < len(tokens) - 1:
213
  colored_text += " "
214
  colored_text_html = f"<p>{colored_text}</p>"
215
  else:
216
- colored_text_html = "No tokens to display."
217
 
218
- # Top Token Log Probabilities (Interactive Plotly, dynamic length)
219
- alt_viz_fig = create_empty_figure("Top Token Log Probabilities") if not logprobs or not top_alternatives else go.Figure()
220
- if logprobs and top_alternatives:
221
- for i, (token, probs) in enumerate(zip(tokens, top_alternatives)):
222
  for j, (alt_tok, prob) in enumerate(probs):
223
- alt_viz_fig.add_trace(go.Bar(x=[f"{token} (Pos {i})"], y=[prob], name=f"{alt_tok}", marker_color=['blue', 'green', 'red', 'purple', 'orange'][:len(probs)]))
224
  alt_viz_fig.update_layout(
225
- title="Top Token Log Probabilities",
226
  xaxis_title="Token (Position)",
227
  yaxis_title="Log Probability",
228
  barmode='stack',
@@ -230,21 +271,21 @@ def visualize_logprobs(json_input):
230
  clickmode='event+select'
231
  )
232
  alt_viz_fig.update_traces(
233
- customdata=[f"Token: {tok}, Alt: {alt}, Log Prob: {prob:.4f}, Position: {i}" for i, (tok, alts) in enumerate(zip(tokens, top_alternatives)) for alt, prob in alts],
234
  hovertemplate='<b>%{customdata}</b><extra></extra>'
235
  )
236
 
237
- return (main_fig, df, colored_text_html, alt_viz_fig, drops_fig)
238
 
239
  except Exception as e:
240
  logger.error("Visualization failed: %s (Input: %s)", str(e), json_input[:100] + "..." if len(json_input) > 100 else json_input)
241
- return (create_empty_figure("Log Probabilities of Generated Tokens"), None, "No finite log probabilities to display.", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Significant Probability Drops"))
242
 
243
- # Gradio interface with full dataset visualization, dynamic top_logprobs, and robust JSON handling
244
  with gr.Blocks(title="Log Probability Visualizer") as app:
245
  gr.Markdown("# Log Probability Visualizer")
246
  gr.Markdown(
247
- "Paste your JSON log prob data below to visualize all tokens at once. Fixed filter ≥ -100000, dynamic number of top_logprobs, handles missing or null fields."
248
  )
249
 
250
  with gr.Row():
@@ -253,6 +294,7 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
253
  lines=10,
254
  placeholder="Paste your JSON (e.g., {\"content\": [{\"bytes\": [44], \"logprob\": 0.0, \"token\": \",\", \"top_logprobs\": {\" so\": -13.8046875, \".\": -13.8046875, \",\": -13.640625}}]}).",
255
  )
 
256
 
257
  with gr.Row():
258
  plot_output = gr.Plot(label="Log Probability Plot (Click for Tokens)")
@@ -265,11 +307,67 @@ with gr.Blocks(title="Log Probability Visualizer") as app:
265
  with gr.Row():
266
  text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
267
 
 
 
 
 
 
 
 
 
268
  btn = gr.Button("Visualize")
269
  btn.click(
270
  fn=visualize_logprobs,
271
- inputs=[json_input],
272
- outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  )
274
 
275
  app.launch()
 
9
  import logging
10
  import numpy as np
11
  import plotly.graph_objects as go
12
+ import asyncio
13
+ import anyio
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.DEBUG)
 
26
  return data
27
  except json.JSONDecodeError as e:
28
  logger.error("JSON parsing failed: %s (Input: %s)", str(e), json_input[:100] + "..." if len(json_input) > 100 else json_input)
29
+ raise ValueError(f"Malformed input: {str(e)}. Ensure property names are in double quotes (e.g., \"content\") and the format matches JSON (e.g., {{\"content\": [...]}}).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Function to ensure a value is a float, converting from string if necessary
32
  def ensure_float(value):
 
54
  def create_empty_figure(title):
55
  return go.Figure().update_layout(title=title, xaxis_title="", yaxis_title="", showlegend=False)
56
 
57
+ # Precompute the next chunk asynchronously
58
+ async def precompute_chunk(json_input, chunk_size, current_chunk):
59
+ try:
60
+ data = parse_input(json_input)
61
+ content = data.get("content", []) if isinstance(data, dict) else data
62
+ if not isinstance(content, list):
63
+ raise ValueError("Content must be a list of entries")
64
+
65
+ tokens = []
66
+ logprobs = []
67
+ top_alternatives = []
68
+ for entry in content:
69
+ if not isinstance(entry, dict):
70
+ logger.warning("Skipping non-dictionary entry: %s", entry)
71
+ continue
72
+ logprob = ensure_float(entry.get("logprob", None))
73
+ if logprob >= -100000: # Include all entries with default 0.0
74
+ tokens.append(get_token(entry))
75
+ logprobs.append(logprob)
76
+ top_probs = entry.get("top_logprobs", {})
77
+ if top_probs is None:
78
+ logger.debug("top_logprobs is None for token: %s, using empty dict", get_token(entry))
79
+ top_probs = {}
80
+ finite_top_probs = []
81
+ for key, value in top_probs.items():
82
+ float_value = ensure_float(value)
83
+ if float_value is not None and math.isfinite(float_value):
84
+ finite_top_probs.append((key, float_value))
85
+ sorted_probs = sorted(finite_top_probs, key=lambda x: x[1], reverse=True)
86
+ top_alternatives.append(sorted_probs)
87
+
88
+ if not tokens or not logprobs:
89
+ return None, None, None
90
+
91
+ next_chunk = current_chunk + 1
92
+ start_idx = next_chunk * chunk_size
93
+ end_idx = min((next_chunk + 1) * chunk_size, len(tokens))
94
+ if start_idx >= len(tokens):
95
+ return None, None, None
96
+
97
+ paginated_tokens = tokens[start_idx:end_idx]
98
+ paginated_logprobs = logprobs[start_idx:end_idx]
99
+ paginated_alternatives = top_alternatives[start_idx:end_idx]
100
+
101
+ return paginated_tokens, paginated_logprobs, paginated_alternatives
102
+ except Exception as e:
103
+ logger.error("Precomputation failed for chunk %d: %s", current_chunk + 1, str(e))
104
+ return None, None, None
105
+
106
+ # Function to process and visualize a chunk of log probs with dynamic top_logprobs
107
+ def visualize_logprobs(json_input, chunk=0, chunk_size=1000):
108
  try:
109
+ # Parse the input (handles JSON only)
110
  data = parse_input(json_input)
111
 
112
  # Ensure data is a dictionary with 'content' key containing a list
 
128
  logger.warning("Skipping non-dictionary entry: %s", entry)
129
  continue
130
  logprob = ensure_float(entry.get("logprob", None))
131
+ if logprob >= -100000: # Include all entries with default 0.0
132
+ tokens.append(get_token(entry))
 
133
  logprobs.append(logprob)
134
  # Get top_logprobs, default to empty dict if None
135
  top_probs = entry.get("top_logprobs", {})
136
  if top_probs is None:
137
+ logger.debug("top_logprobs is None for token: %s, using empty dict", get_token(entry))
138
  top_probs = {} # Default to empty dict for None
139
  # Ensure all values in top_logprobs are floats and create a list of tuples
140
  finite_top_probs = []
 
148
  else:
149
  logger.debug("Skipping entry with logprob: %s (type: %s)", entry.get("logprob"), type(entry.get("logprob", None)))
150
 
151
+ # Check if there's valid data after filtering
152
  if not logprobs or not tokens:
153
+ return (create_empty_figure("Log Probabilities of Generated Tokens"), None, "No tokens to display.", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Significant Probability Drops"), 1, 0)
154
+
155
+ # Paginate data for chunks of 1,000 tokens
156
+ total_chunks = max(1, (len(logprobs) + chunk_size - 1) // chunk_size)
157
+ start_idx = chunk * chunk_size
158
+ end_idx = min((chunk + 1) * chunk_size, len(logprobs))
159
+ paginated_tokens = tokens[start_idx:end_idx]
160
+ paginated_logprobs = logprobs[start_idx:end_idx]
161
+ paginated_alternatives = top_alternatives[start_idx:end_idx] if top_alternatives else []
162
 
163
  # 1. Main Log Probability Plot (Interactive Plotly)
164
  main_fig = go.Figure()
165
+ main_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
166
  main_fig.update_layout(
167
+ title="Log Probabilities of Generated Tokens (Chunk %d)" % (chunk + 1),
168
+ xaxis_title="Token Position (within chunk)",
169
  yaxis_title="Log Probability",
170
  hovermode="closest",
171
  clickmode='event+select'
172
  )
173
  main_fig.update_traces(
174
+ customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, prob) in enumerate(zip(paginated_tokens, paginated_logprobs))],
175
  hovertemplate='<b>%{customdata}</b><extra></extra>'
176
  )
177
 
178
  # 2. Probability Drop Analysis (Interactive Plotly)
179
+ if len(paginated_logprobs) < 2:
180
+ drops_fig = create_empty_figure("Significant Probability Drops (Chunk %d)" % (chunk + 1))
181
  else:
182
+ drops = [paginated_logprobs[i+1] - paginated_logprobs[i] for i in range(len(paginated_logprobs)-1)]
183
  drops_fig = go.Figure()
184
  drops_fig.add_trace(go.Bar(x=list(range(len(drops))), y=drops, name='Drop', marker_color='red'))
185
  drops_fig.update_layout(
186
+ title="Significant Probability Drops (Chunk %d)" % (chunk + 1),
187
+ xaxis_title="Token Position (within chunk)",
188
  yaxis_title="Log Probability Drop",
189
  hovermode="closest",
190
  clickmode='event+select'
191
  )
192
  drops_fig.update_traces(
193
+ customdata=[f"Drop: {drop:.4f}, From: {paginated_tokens[i]} to {paginated_tokens[i+1]}, Position: {i+start_idx}" for i, drop in enumerate(drops)],
194
  hovertemplate='<b>%{customdata}</b><extra></extra>'
195
  )
196
 
197
  # Create DataFrame for the table with dynamic top_logprobs
198
  table_data = []
199
+ max_alternatives = max(len(alts) for alts in paginated_alternatives) if paginated_alternatives else 0
200
+ for i, entry in enumerate(content[start_idx:end_idx]):
201
  if not isinstance(entry, dict):
202
  continue
203
  logprob = ensure_float(entry.get("logprob", None))
204
  if logprob >= -100000 and "top_logprobs" in entry: # Include all entries with default 0.0
205
+ token = get_token(entry)
206
  top_logprobs = entry.get("top_logprobs", {})
207
  if top_logprobs is None:
208
  logger.debug("top_logprobs is None for token: %s, using empty dict", token)
 
232
  else None
233
  )
234
 
235
+ # Generate colored text (for the current chunk)
236
+ if paginated_logprobs:
237
+ min_logprob = min(paginated_logprobs)
238
+ max_logprob = max(paginated_logprobs)
239
  if max_logprob == min_logprob:
240
+ normalized_probs = [0.5] * len(paginated_logprobs)
241
  else:
242
  normalized_probs = [
243
+ (lp - min_logprob) / (max_logprob - min_logprob) for lp in paginated_logprobs
244
  ]
245
 
246
  colored_text = ""
247
+ for i, (token, norm_prob) in enumerate(zip(paginated_tokens, normalized_probs)):
248
  r = int(255 * (1 - norm_prob)) # Red for low confidence
249
  g = int(255 * norm_prob) # Green for high confidence
250
  b = 0
251
  color = f"rgb({r}, {g}, {b})"
252
  colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
253
+ if i < len(paginated_tokens) - 1:
254
  colored_text += " "
255
  colored_text_html = f"<p>{colored_text}</p>"
256
  else:
257
+ colored_text_html = "No tokens to display in this chunk."
258
 
259
+ # Top Token Log Probabilities (Interactive Plotly, dynamic length, for the current chunk)
260
+ alt_viz_fig = create_empty_figure("Top Token Log Probabilities (Chunk %d)" % (chunk + 1)) if not paginated_logprobs or not paginated_alternatives else go.Figure()
261
+ if paginated_logprobs and paginated_alternatives:
262
+ for i, (token, probs) in enumerate(zip(paginated_tokens, paginated_alternatives)):
263
  for j, (alt_tok, prob) in enumerate(probs):
264
+ alt_viz_fig.add_trace(go.Bar(x=[f"{token} (Pos {i+start_idx})"], y=[prob], name=f"{alt_tok}", marker_color=['blue', 'green', 'red', 'purple', 'orange'][:len(probs)]))
265
  alt_viz_fig.update_layout(
266
+ title="Top Token Log Probabilities (Chunk %d)" % (chunk + 1),
267
  xaxis_title="Token (Position)",
268
  yaxis_title="Log Probability",
269
  barmode='stack',
 
271
  clickmode='event+select'
272
  )
273
  alt_viz_fig.update_traces(
274
+ customdata=[f"Token: {tok}, Alt: {alt}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, alts) in enumerate(zip(paginated_tokens, paginated_alternatives)) for alt, prob in alts],
275
  hovertemplate='<b>%{customdata}</b><extra></extra>'
276
  )
277
 
278
+ return (main_fig, df, colored_text_html, alt_viz_fig, drops_fig, total_chunks, chunk)
279
 
280
  except Exception as e:
281
  logger.error("Visualization failed: %s (Input: %s)", str(e), json_input[:100] + "..." if len(json_input) > 100 else json_input)
282
+ return (create_empty_figure("Log Probabilities of Generated Tokens"), None, "No finite log probabilities to display.", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Significant Probability Drops"), 1, 0)
283
 
284
+ # Gradio interface with chunked visualization and proactive precomputation
285
  with gr.Blocks(title="Log Probability Visualizer") as app:
286
  gr.Markdown("# Log Probability Visualizer")
287
  gr.Markdown(
288
+ "Paste your JSON log prob data below to visualize tokens in chunks of 1,000. Fixed filter ≥ -100000, dynamic number of top_logprobs, handles missing or null fields. Next chunk is precomputed proactively."
289
  )
290
 
291
  with gr.Row():
 
294
  lines=10,
295
  placeholder="Paste your JSON (e.g., {\"content\": [{\"bytes\": [44], \"logprob\": 0.0, \"token\": \",\", \"top_logprobs\": {\" so\": -13.8046875, \".\": -13.8046875, \",\": -13.640625}}]}).",
296
  )
297
+ chunk = gr.Number(value=0, label="Current Chunk", precision=0, minimum=0)
298
 
299
  with gr.Row():
300
  plot_output = gr.Plot(label="Log Probability Plot (Click for Tokens)")
 
307
  with gr.Row():
308
  text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
309
 
310
+ with gr.Row():
311
+ prev_btn = gr.Button("Previous Chunk")
312
+ next_btn = gr.Button("Next Chunk")
313
+ total_chunks_output = gr.Number(label="Total Chunks", interactive=False)
314
+
315
+ # Precomputed next chunk state (hidden)
316
+ precomputed_next = gr.State(value=None)
317
+
318
  btn = gr.Button("Visualize")
319
  btn.click(
320
  fn=visualize_logprobs,
321
+ inputs=[json_input, chunk],
322
+ outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
323
+ )
324
+
325
+ # Precompute next chunk proactively when on current chunk
326
+ async def precompute_next_chunk(json_input, current_chunk, precomputed_next):
327
+ if precomputed_next is not None:
328
+ return precomputed_next # Use cached precomputed chunk if available
329
+ next_tokens, next_logprobs, next_alternatives = await precompute_chunk(json_input, 1000, current_chunk)
330
+ if next_tokens is None or next_logprobs is None or next_alternatives is None:
331
+ return None
332
+ return (next_tokens, next_logprobs, next_alternatives)
333
+
334
+ # Update chunk on button clicks
335
+ def update_chunk(json_input, current_chunk, action, precomputed_next=None):
336
+ total_chunks = visualize_logprobs(json_input, 0)[5] # Get total chunks
337
+ if action == "prev" and current_chunk > 0:
338
+ current_chunk -= 1
339
+ elif action == "next" and current_chunk < total_chunks - 1:
340
+ current_chunk += 1
341
+ # If precomputed next chunk exists, use it; otherwise, compute it
342
+ if precomputed_next:
343
+ next_tokens, next_logprobs, next_alternatives = precomputed_next
344
+ if next_tokens and next_logprobs and next_alternatives:
345
+ logger.debug("Using precomputed next chunk for chunk %d", current_chunk)
346
+ return visualize_logprobs(json_input, current_chunk)
347
+ return visualize_logprobs(json_input, current_chunk)
348
+
349
+ prev_btn.click(
350
+ fn=update_chunk,
351
+ inputs=[json_input, chunk, gr.State(value="prev"), precomputed_next],
352
+ outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
353
+ )
354
+
355
+ next_btn.click(
356
+ fn=update_chunk,
357
+ inputs=[json_input, chunk, gr.State(value="next"), precomputed_next],
358
+ outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
359
+ )
360
+
361
+ # Trigger precomputation when chunk changes (via button clicks or initial load)
362
+ def trigger_precomputation(json_input, current_chunk):
363
+ asyncio.create_task(precompute_next_chunk(json_input, current_chunk, None))
364
+ return gr.update(value=current_chunk)
365
+
366
+ # Use a dummy event to trigger precomputation on chunk change (simplified for Gradio)
367
+ chunk.change(
368
+ fn=trigger_precomputation,
369
+ inputs=[json_input, chunk],
370
+ outputs=[chunk],
371
  )
372
 
373
  app.launch()