codelion commited on
Commit
94f3efa
·
verified ·
1 Parent(s): 46e0493

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +351 -78
app.py CHANGED
@@ -23,52 +23,65 @@ def parse_input(json_input):
23
  logger.debug("Successfully parsed as JSON")
24
  return data
25
  except json.JSONDecodeError as e:
26
- logger.error("JSON parsing failed: %s", str(e))
27
- raise ValueError(f"Malformed JSON: {str(e)}. Use double quotes for property names (e.g., \"content\").")
28
 
29
  # Function to ensure a value is a float
30
  def ensure_float(value):
31
  if value is None:
32
- return 0.0 # Default for None
33
- if isinstance(value, (int, float)):
34
- return float(value)
35
  if isinstance(value, str):
36
  try:
37
  return float(value)
38
  except ValueError:
39
- logger.error("Invalid float string: %s", value)
40
- return 0.0
41
- return 0.0 # Default for other types
 
 
42
 
43
- # Function to get token value or default to "Unknown"
44
  def get_token(entry):
45
- return entry.get("token", "Unknown")
 
 
 
46
 
47
  # Function to create an empty Plotly figure
48
  def create_empty_figure(title):
49
  return go.Figure().update_layout(title=title, xaxis_title="", yaxis_title="", showlegend=False)
50
 
51
- # Asynchronous chunk precomputation
52
  async def precompute_chunk(json_input, chunk_size, current_chunk):
53
  try:
54
  data = parse_input(json_input)
55
  content = data.get("content", []) if isinstance(data, dict) else data
56
  if not isinstance(content, list):
57
- raise ValueError("Content must be a list")
58
 
59
  tokens = []
60
  logprobs = []
61
  top_alternatives = []
62
  for entry in content:
63
  if not isinstance(entry, dict):
 
64
  continue
65
  logprob = ensure_float(entry.get("logprob", None))
66
- if logprob >= -100000:
67
  tokens.append(get_token(entry))
68
  logprobs.append(logprob)
69
- top_probs = entry.get("top_logprobs", {}) or {}
70
- finite_top_probs = [(key, ensure_float(value)) for key, value in top_probs.items() if ensure_float(value) is not None and math.isfinite(ensure_float(value))]
71
- top_alternatives.append(sorted(finite_top_probs, key=lambda x: x[1], reverse=True))
 
 
 
 
 
 
 
 
72
 
73
  if not tokens or not logprobs:
74
  return None, None, None
@@ -79,7 +92,11 @@ async def precompute_chunk(json_input, chunk_size, current_chunk):
79
  if start_idx >= len(tokens):
80
  return None, None, None
81
 
82
- return (tokens[start_idx:end_idx], logprobs[start_idx:end_idx], top_alternatives[start_idx:end_idx])
 
 
 
 
83
  except Exception as e:
84
  logger.error("Precomputation failed for chunk %d: %s", current_chunk + 1, str(e))
85
  return None, None, None
@@ -97,146 +114,402 @@ def precompute_next_chunk_sync(json_input, current_chunk):
97
  loop.close()
98
  return result
99
 
100
- # Visualization function
101
  def visualize_logprobs(json_input, chunk=0, chunk_size=100):
102
  try:
103
  data = parse_input(json_input)
104
  content = data.get("content", []) if isinstance(data, dict) else data
105
  if not isinstance(content, list):
106
- raise ValueError("Content must be a list")
107
 
108
  tokens = []
109
  logprobs = []
110
- top_alternatives = []
111
  for entry in content:
112
  if not isinstance(entry, dict):
 
113
  continue
114
  logprob = ensure_float(entry.get("logprob", None))
115
- if logprob >= -100000:
116
  tokens.append(get_token(entry))
117
  logprobs.append(logprob)
118
  top_probs = entry.get("top_logprobs", {}) or {}
119
- finite_top_probs = [(key, ensure_float(value)) for key, value in top_probs.items() if ensure_float(value) is not None and math.isfinite(ensure_float(value))]
120
- top_alternatives.append(sorted(finite_top_probs, key=lambda x: x[1], reverse=True))
 
 
 
 
 
121
 
122
  if not logprobs or not tokens:
123
- return (create_empty_figure("Log Probabilities"), None, "No tokens to display.", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Probability Drops"), 1, 0)
124
 
125
  total_chunks = max(1, (len(logprobs) + chunk_size - 1) // chunk_size)
126
  start_idx = chunk * chunk_size
127
  end_idx = min((chunk + 1) * chunk_size, len(logprobs))
128
  paginated_tokens = tokens[start_idx:end_idx]
129
  paginated_logprobs = logprobs[start_idx:end_idx]
130
- paginated_alternatives = top_alternatives[start_idx:end_idx]
131
 
132
- # Main Log Probability Plot
133
  main_fig = go.Figure()
134
  main_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
135
- main_fig.update_layout(title=f"Log Probabilities of Generated Tokens (Chunk {chunk + 1})", xaxis_title="Token Position", yaxis_title="Log Probability", hovermode="closest", clickmode='event+select')
136
- main_fig.update_traces(customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Pos: {i+start_idx}" for i, (tok, prob) in enumerate(zip(paginated_tokens, paginated_logprobs))], hovertemplate='%{customdata}<extra></extra>')
137
-
138
- # Probability Drops Plot
139
- drops_fig = create_empty_figure(f"Probability Drops (Chunk {chunk + 1})") if len(paginated_logprobs) < 2 else go.Figure()
140
- if len(paginated_logprobs) >= 2:
 
 
 
 
 
 
 
 
 
 
141
  drops = [paginated_logprobs[i+1] - paginated_logprobs[i] for i in range(len(paginated_logprobs)-1)]
 
142
  drops_fig.add_trace(go.Bar(x=list(range(len(drops))), y=drops, name='Drop', marker_color='red'))
143
- drops_fig.update_layout(title=f"Probability Drops (Chunk {chunk + 1})", xaxis_title="Token Position", yaxis_title="Log Prob Drop", hovermode="closest", clickmode='event+select')
144
- drops_fig.update_traces(customdata=[f"Drop: {drop:.4f}, From: {paginated_tokens[i]} to {paginated_tokens[i+1]}" for i, drop in enumerate(drops)], hovertemplate='%{customdata}<extra></extra>')
145
-
146
- # Table Data
 
 
 
 
 
 
 
 
 
 
147
  max_alternatives = max(len(alts) for alts in paginated_alternatives) if paginated_alternatives else 0
148
- table_data = [[tok, f"{prob:.4f}"] + [f"{alt[0]}: {alt[1]:.4f}" if i < len(alts) else "" for i in range(max_alternatives)] for tok, prob, alts in zip(paginated_tokens, paginated_logprobs, paginated_alternatives)]
149
- df = pd.DataFrame(table_data, columns=["Token", "Log Prob"] + [f"Alt {i+1}" for i in range(max_alternatives)]) if table_data else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- # Colored Text
152
- min_prob, max_prob = min(paginated_logprobs), max(paginated_logprobs)
153
- normalized_probs = [0.5] * len(paginated_logprobs) if max_prob == min_prob else [(lp - min_prob) / (max_prob - min_prob) for lp in paginated_logprobs]
154
- colored_text = "".join(f'<span style="color: rgb({int(255*(1-p))}, {int(255*p)}, 0);">{tok}</span> ' for tok, p in zip(paginated_tokens, normalized_probs))
155
- colored_text_html = f"<p>{colored_text.rstrip()}</p>"
156
 
157
- # Top Token Log Probabilities Plot
158
- alt_fig = go.Figure() if paginated_alternatives else create_empty_figure(f"Top Token Log Probabilities (Chunk {chunk + 1})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  if paginated_alternatives:
160
- for i, (tok, alts) in enumerate(zip(paginated_tokens, paginated_alternatives)):
161
- for alt_tok, prob in alts:
162
- alt_fig.add_trace(go.Bar(x=[f"{tok} (Pos {i+start_idx})"], y=[prob], name=f"{alt_tok}", marker_color='blue'))
163
- alt_fig.update_layout(title=f"Top Token Log Probabilities (Chunk {chunk + 1})", xaxis_title="Token (Position)", yaxis_title="Log Probability", barmode='stack', hovermode="closest", clickmode='event+select')
164
- alt_fig.update_traces(customdata=[f"Token: {tok}, Alt: {alt}, Log Prob: {prob:.4f}" for tok, alts in zip(paginated_tokens, paginated_alternatives) for alt, prob in alts], hovertemplate='%{customdata}<extra></extra>')
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- return (main_fig, df, colored_text_html, alt_fig, drops_fig, total_chunks, chunk)
167
  except Exception as e:
168
  logger.error("Visualization failed: %s", str(e))
169
- return (create_empty_figure("Log Probabilities"), None, f"Error: {e}", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Probability Drops"), 1, 0)
170
-
171
- # Trace analysis functions (simplified for brevity, fully implemented in thinking trace)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def analyze_full_trace(json_input):
173
  try:
174
  data = parse_input(json_input)
175
  content = data.get("content", []) if isinstance(data, dict) else data
176
  if not isinstance(content, list):
177
- raise ValueError("Content must be a list")
178
-
179
- tokens = [get_token(entry) for entry in content if isinstance(entry, dict) and ensure_float(entry.get("logprob", None)) >= -100000]
180
- logprobs = [[(key, ensure_float(value)) for key, value in (entry.get("top_logprobs", {}) or {}).items() if ensure_float(value) is not None and math.isfinite(ensure_float(value))] for entry in content if isinstance(entry, dict) and ensure_float(entry.get("logprob", None)) >= -100000]
181
 
182
- if not tokens or not logprobs:
183
- return "No valid data for analysis.", None, None, None, None, None
 
 
 
 
 
 
 
 
 
 
184
 
185
- analysis_html = "<h3>Trace Analysis Results</h3><ul><li>Stub: Full analysis implemented but simplified here.</li></ul>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  return analysis_html, None, None, None, None, None
187
  except Exception as e:
188
  logger.error("Trace analysis failed: %s", str(e))
189
  return f"Error: {e}", None, None, None, None, None
190
 
191
- # Gradio interface
192
  try:
193
  with gr.Blocks(title="Log Probability Visualizer") as app:
194
  gr.Markdown("# Log Probability Visualizer")
195
- gr.Markdown("Paste your JSON log prob data below to analyze reasoning traces or visualize tokens in chunks of 100.")
196
 
197
  with gr.Tabs():
198
  with gr.Tab("Trace Analysis"):
199
- json_input_analysis = gr.Textbox(label="JSON Input for Trace Analysis", lines=10, placeholder='{"content": [{"token": "a", "logprob": 0.0, "top_logprobs": {"b": -1.0}}]}')
200
- analysis_output = gr.HTML(label="Trace Analysis Results")
201
- gr.Button("Analyze Trace").click(fn=analyze_full_trace, inputs=[json_input_analysis], outputs=[analysis_output, gr.State(), gr.State(), gr.State(), gr.State(), gr.State()])
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  with gr.Tab("Visualization"):
204
  with gr.Row():
205
- json_input_viz = gr.Textbox(label="JSON Input for Visualization", lines=10, placeholder='{"content": [{"token": "a", "logprob": 0.0, "top_logprobs": {"b": -1.0}}]}')
 
 
 
 
206
  chunk = gr.Number(value=0, label="Current Chunk", precision=0, minimum=0)
 
207
  with gr.Row():
208
- plot_output = gr.Plot(label="Log Probability Plot")
209
- drops_output = gr.Plot(label="Probability Drops")
 
210
  with gr.Row():
211
- table_output = gr.Dataframe(label="Token Log Probabilities")
212
- alt_viz_output = gr.Plot(label="Top Token Log Probabilities")
 
213
  with gr.Row():
214
- text_output = gr.HTML(label="Colored Text")
 
215
  with gr.Row():
216
  prev_btn = gr.Button("Previous Chunk")
217
  next_btn = gr.Button("Next Chunk")
218
  total_chunks_output = gr.Number(label="Total Chunks", interactive=False)
219
 
 
220
  precomputed_next = gr.State(value=None)
221
 
222
- gr.Button("Visualize").click(fn=visualize_logprobs, inputs=[json_input_viz, chunk], outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk])
 
 
 
 
 
223
 
224
  def update_chunk(json_input, current_chunk, action, precomputed_next=None):
225
- total_chunks = visualize_logprobs(json_input, 0)[5]
226
  if action == "prev" and current_chunk > 0:
227
  current_chunk -= 1
228
  elif action == "next" and current_chunk < total_chunks - 1:
229
  current_chunk += 1
 
 
 
230
  return visualize_logprobs(json_input, current_chunk)
231
 
232
- prev_btn.click(fn=update_chunk, inputs=[json_input_viz, chunk, gr.State(value="prev"), precomputed_next], outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk])
233
- next_btn.click(fn=update_chunk, inputs=[json_input_viz, chunk, gr.State(value="next"), precomputed_next], outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk])
 
 
 
 
 
 
 
 
 
234
 
235
  def trigger_precomputation(json_input, current_chunk):
236
- threading.Thread(target=precompute_next_chunk_sync, args=(json_input, current_chunk)).start()
 
 
 
237
  return gr.update(value=current_chunk)
238
 
239
- chunk.change(fn=trigger_precomputation, inputs=[json_input_viz, chunk], outputs=[chunk])
 
 
 
 
240
 
241
  except Exception as e:
242
  logger.error("Application startup failed: %s", str(e))
 
23
  logger.debug("Successfully parsed as JSON")
24
  return data
25
  except json.JSONDecodeError as e:
26
+ logger.error("JSON parsing failed: %s (Input: %s)", str(e), json_input[:100] + "..." if len(json_input) > 100 else json_input)
27
+ raise ValueError(f"Malformed JSON: {str(e)}. Use double quotes for property names (e.g., \"content\") and ensure valid JSON format.")
28
 
29
  # Function to ensure a value is a float
30
  def ensure_float(value):
31
  if value is None:
32
+ logger.debug("Replacing None logprob with 0.0")
33
+ return 0.0 # Default to 0.0 for None to ensure visualization
 
34
  if isinstance(value, str):
35
  try:
36
  return float(value)
37
  except ValueError:
38
+ logger.error("Failed to convert string '%s' to float", value)
39
+ return 0.0 # Default to 0.0 for invalid strings
40
+ if isinstance(value, (int, float)):
41
+ return float(value)
42
+ return 0.0 # Default for any other type
43
 
44
+ # Function to get or generate a token value (default to "Unknown" if missing)
45
  def get_token(entry):
46
+ token = entry.get("token", "Unknown")
47
+ if token == "Unknown":
48
+ logger.warning("Missing 'token' key for entry: %s, using 'Unknown'", entry)
49
+ return token
50
 
51
  # Function to create an empty Plotly figure
52
  def create_empty_figure(title):
53
  return go.Figure().update_layout(title=title, xaxis_title="", yaxis_title="", showlegend=False)
54
 
55
+ # Precompute the next chunk asynchronously
56
  async def precompute_chunk(json_input, chunk_size, current_chunk):
57
  try:
58
  data = parse_input(json_input)
59
  content = data.get("content", []) if isinstance(data, dict) else data
60
  if not isinstance(content, list):
61
+ raise ValueError("Content must be a list of entries")
62
 
63
  tokens = []
64
  logprobs = []
65
  top_alternatives = []
66
  for entry in content:
67
  if not isinstance(entry, dict):
68
+ logger.warning("Skipping non-dictionary entry: %s", entry)
69
  continue
70
  logprob = ensure_float(entry.get("logprob", None))
71
+ if logprob >= -100000: # Include all entries with default 0.0
72
  tokens.append(get_token(entry))
73
  logprobs.append(logprob)
74
+ top_probs = entry.get("top_logprobs", {})
75
+ if top_probs is None:
76
+ logger.debug("top_logprobs is None for token: %s, using empty dict", get_token(entry))
77
+ top_probs = {}
78
+ finite_top_probs = []
79
+ for key, value in top_probs.items():
80
+ float_value = ensure_float(value)
81
+ if float_value is not None and math.isfinite(float_value):
82
+ finite_top_probs.append((key, float_value))
83
+ sorted_probs = sorted(finite_top_probs, key=lambda x: x[1], reverse=True)
84
+ top_alternatives.append(sorted_probs)
85
 
86
  if not tokens or not logprobs:
87
  return None, None, None
 
92
  if start_idx >= len(tokens):
93
  return None, None, None
94
 
95
+ paginated_tokens = tokens[start_idx:end_idx]
96
+ paginated_logprobs = logprobs[start_idx:end_idx]
97
+ paginated_alternatives = top_alternatives[start_idx:end_idx]
98
+
99
+ return paginated_tokens, paginated_logprobs, paginated_alternatives
100
  except Exception as e:
101
  logger.error("Precomputation failed for chunk %d: %s", current_chunk + 1, str(e))
102
  return None, None, None
 
114
  loop.close()
115
  return result
116
 
117
+ # Function to process and visualize a chunk of log probs with dynamic top_logprobs
118
  def visualize_logprobs(json_input, chunk=0, chunk_size=100):
119
  try:
120
  data = parse_input(json_input)
121
  content = data.get("content", []) if isinstance(data, dict) else data
122
  if not isinstance(content, list):
123
+ raise ValueError("Content must be a list of entries")
124
 
125
  tokens = []
126
  logprobs = []
127
+ top_alternatives = [] # List to store all top_logprobs (dynamic length)
128
  for entry in content:
129
  if not isinstance(entry, dict):
130
+ logger.warning("Skipping non-dictionary entry: %s", entry)
131
  continue
132
  logprob = ensure_float(entry.get("logprob", None))
133
+ if logprob >= -100000: # Include all entries with default 0.0
134
  tokens.append(get_token(entry))
135
  logprobs.append(logprob)
136
  top_probs = entry.get("top_logprobs", {}) or {}
137
+ finite_top_probs = []
138
+ for key, value in top_probs.items():
139
+ float_value = ensure_float(value)
140
+ if float_value is not None and math.isfinite(float_value):
141
+ finite_top_probs.append((key, float_value))
142
+ sorted_probs = sorted(finite_top_probs, key=lambda x: x[1], reverse=True)
143
+ top_alternatives.append(sorted_probs)
144
 
145
  if not logprobs or not tokens:
146
+ return (create_empty_figure("Log Probabilities of Generated Tokens"), None, "No tokens to display.", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Significant Probability Drops"), 1, 0)
147
 
148
  total_chunks = max(1, (len(logprobs) + chunk_size - 1) // chunk_size)
149
  start_idx = chunk * chunk_size
150
  end_idx = min((chunk + 1) * chunk_size, len(logprobs))
151
  paginated_tokens = tokens[start_idx:end_idx]
152
  paginated_logprobs = logprobs[start_idx:end_idx]
153
+ paginated_alternatives = top_alternatives[start_idx:end_idx] if top_alternatives else []
154
 
155
+ # Main Log Probability Plot (Interactive Plotly)
156
  main_fig = go.Figure()
157
  main_fig.add_trace(go.Scatter(x=list(range(len(paginated_logprobs))), y=paginated_logprobs, mode='markers+lines', name='Log Prob', marker=dict(color='blue')))
158
+ main_fig.update_layout(
159
+ title=f"Log Probabilities of Generated Tokens (Chunk {chunk + 1})",
160
+ xaxis_title="Token Position (within chunk)",
161
+ yaxis_title="Log Probability",
162
+ hovermode="closest",
163
+ clickmode='event+select'
164
+ )
165
+ main_fig.update_traces(
166
+ customdata=[f"Token: {tok}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, prob) in enumerate(zip(paginated_tokens, paginated_logprobs))],
167
+ hovertemplate='<b>%{customdata}</b><extra></extra>'
168
+ )
169
+
170
+ # Probability Drop Analysis (Interactive Plotly)
171
+ if len(paginated_logprobs) < 2:
172
+ drops_fig = create_empty_figure(f"Significant Probability Drops (Chunk {chunk + 1})")
173
+ else:
174
  drops = [paginated_logprobs[i+1] - paginated_logprobs[i] for i in range(len(paginated_logprobs)-1)]
175
+ drops_fig = go.Figure()
176
  drops_fig.add_trace(go.Bar(x=list(range(len(drops))), y=drops, name='Drop', marker_color='red'))
177
+ drops_fig.update_layout(
178
+ title=f"Significant Probability Drops (Chunk {chunk + 1})",
179
+ xaxis_title="Token Position (within chunk)",
180
+ yaxis_title="Log Probability Drop",
181
+ hovermode="closest",
182
+ clickmode='event+select'
183
+ )
184
+ drops_fig.update_traces(
185
+ customdata=[f"Drop: {drop:.4f}, From: {paginated_tokens[i]} to {paginated_tokens[i+1]}, Position: {i+start_idx}" for i, drop in enumerate(drops)],
186
+ hovertemplate='<b>%{customdata}</b><extra></extra>'
187
+ )
188
+
189
+ # Create DataFrame for the table with dynamic top_logprobs
190
+ table_data = []
191
  max_alternatives = max(len(alts) for alts in paginated_alternatives) if paginated_alternatives else 0
192
+ for i, entry in enumerate(content[start_idx:end_idx]):
193
+ if not isinstance(entry, dict):
194
+ continue
195
+ logprob = ensure_float(entry.get("logprob", None))
196
+ if logprob >= -100000 and "top_logprobs" in entry:
197
+ token = get_token(entry)
198
+ top_logprobs = entry.get("top_logprobs", {}) or {}
199
+ finite_top_probs = []
200
+ for key, value in top_logprobs.items():
201
+ float_value = ensure_float(value)
202
+ if float_value is not None and math.isfinite(float_value):
203
+ finite_top_probs.append((key, float_value))
204
+ sorted_probs = sorted(finite_top_probs, key=lambda x: x[1], reverse=True)
205
+ row = [token, f"{logprob:.4f}"]
206
+ for alt_token, alt_logprob in sorted_probs[:max_alternatives]:
207
+ row.append(f"{alt_token}: {alt_logprob:.4f}")
208
+ while len(row) < 2 + max_alternatives:
209
+ row.append("")
210
+ table_data.append(row)
211
 
212
+ df = pd.DataFrame(table_data, columns=["Token", "Log Prob"] + [f"Alt {i+1}" for i in range(max_alternatives)]) if table_data else None
 
 
 
 
213
 
214
+ # Generate colored text (for the current chunk)
215
+ if paginated_logprobs:
216
+ min_logprob = min(paginated_logprobs)
217
+ max_logprob = max(paginated_logprobs)
218
+ normalized_probs = [0.5] * len(paginated_logprobs) if max_logprob == min_logprob else \
219
+ [(lp - min_logprob) / (max_logprob - min_logprob) for lp in paginated_logprobs]
220
+
221
+ colored_text = ""
222
+ for i, (token, norm_prob) in enumerate(zip(paginated_tokens, normalized_probs)):
223
+ r = int(255 * (1 - norm_prob)) # Red for low confidence
224
+ g = int(255 * norm_prob) # Green for high confidence
225
+ b = 0
226
+ color = f"rgb({r}, {g}, {b})"
227
+ colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>'
228
+ if i < len(paginated_tokens) - 1:
229
+ colored_text += " "
230
+ colored_text_html = f"<p>{colored_text}</p>"
231
+ else:
232
+ colored_text_html = "No tokens to display in this chunk."
233
+
234
+ # Top Token Log Probabilities (Interactive Plotly, dynamic length, for the current chunk)
235
+ alt_viz_fig = create_empty_figure(f"Top Token Log Probabilities (Chunk {chunk + 1})") if not paginated_alternatives else go.Figure()
236
  if paginated_alternatives:
237
+ for i, (token, probs) in enumerate(zip(paginated_tokens, paginated_alternatives)):
238
+ for j, (alt_tok, prob) in enumerate(probs):
239
+ alt_viz_fig.add_trace(go.Bar(x=[f"{token} (Pos {i+start_idx})"], y=[prob], name=f"{alt_tok}", marker_color=['blue', 'green', 'red', 'purple', 'orange'][:len(probs)]))
240
+ alt_viz_fig.update_layout(
241
+ title=f"Top Token Log Probabilities (Chunk {chunk + 1})",
242
+ xaxis_title="Token (Position)",
243
+ yaxis_title="Log Probability",
244
+ barmode='stack',
245
+ hovermode="closest",
246
+ clickmode='event+select'
247
+ )
248
+ alt_viz_fig.update_traces(
249
+ customdata=[f"Token: {tok}, Alt: {alt}, Log Prob: {prob:.4f}, Position: {i+start_idx}" for i, (tok, alts) in enumerate(zip(paginated_tokens, paginated_alternatives)) for alt, prob in alts],
250
+ hovertemplate='<b>%{customdata}</b><extra></extra>'
251
+ )
252
+
253
+ return (main_fig, df, colored_text_html, alt_viz_fig, drops_fig, total_chunks, chunk)
254
 
 
255
  except Exception as e:
256
  logger.error("Visualization failed: %s", str(e))
257
+ return (create_empty_figure("Log Probabilities of Generated Tokens"), None, f"Error: {e}", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Significant Probability Drops"), 1, 0)
258
+
259
+ # Analysis functions for detecting correct vs. incorrect traces
260
+ def analyze_confidence_signature(logprobs, tokens):
261
+ if not logprobs or not tokens:
262
+ return "No data for confidence signature analysis.", None
263
+ top_probs = [lps[0][1] if lps and lps[0][1] is not None else -float('inf') for lps in logprobs] # Handle empty or None
264
+ if not any(p != -float('inf') for p in top_probs):
265
+ return "No valid log probabilities for confidence analysis.", None
266
+ moving_avg = np.convolve(top_probs, np.ones(20) / 20, mode='valid') # 20-token window
267
+ drops = np.where(np.diff(moving_avg) < -0.15)[0] # Significant drops
268
+ if not drops.size:
269
+ return "No significant confidence drops detected.", None
270
+ drop_positions = [(i, tokens[i + 19] if i + 19 < len(tokens) else "End of trace") for i in drops]
271
+ return "Significant confidence drops detected at positions:", drop_positions
272
+
273
+ def detect_interpretation_pivots(logprobs, tokens):
274
+ if not logprobs or not tokens:
275
+ return "No data for interpretation pivot detection.", None
276
+ pivots = []
277
+ reconsideration_tokens = ["wait", "but", "actually", "however", "hmm"]
278
+ for i, (token, lps) in enumerate(zip(tokens, logprobs)):
279
+ if not lps:
280
+ continue
281
+ for rt in reconsideration_tokens:
282
+ for t, p in lps:
283
+ if t.lower() == rt and p > -2.5: # High probability
284
+ context = tokens[max(0, i-50):i]
285
+ pivots.append((i, rt, context))
286
+ if not pivots:
287
+ return "No interpretation pivots detected.", None
288
+ return "Interpretation pivots detected:", pivots
289
+
290
+ def calculate_decision_entropy(logprobs):
291
+ if not logprobs:
292
+ return "No data for entropy spike detection.", None
293
+ entropies = []
294
+ for lps in logprobs:
295
+ if not lps:
296
+ entropies.append(0.0)
297
+ continue
298
+ probs = [math.exp(p) for _, p in lps if p is not None] # Convert log probs to probabilities, handle None
299
+ if not probs or sum(probs) == 0:
300
+ entropies.append(0.0)
301
+ continue
302
+ entropy = -sum(p * math.log(p) for p in probs if p > 0)
303
+ entropies.append(entropy)
304
+ baseline = np.percentile(entropies, 75) if entropies else 0.0
305
+ spikes = [i for i, e in enumerate(entropies) if e > baseline * 1.5 and baseline > 0]
306
+ if not spikes:
307
+ return "No entropy spikes detected at decision points.", None
308
+ return "Entropy spikes detected at positions:", spikes
309
+
310
+ def analyze_conclusion_competition(logprobs, tokens):
311
+ if not logprobs or not tokens:
312
+ return "No data for conclusion competition analysis.", None
313
+ conclusion_indices = [i for i, t in enumerate(tokens) if any(marker in t.lower() for marker in ["therefore", "thus", "boxed", "answer"])]
314
+ if not conclusion_indices:
315
+ return "No conclusion markers found in trace.", None
316
+ gaps = []
317
+ conclusion_idx = conclusion_indices[-1]
318
+ end_range = min(conclusion_idx + 50, len(logprobs))
319
+ for idx in range(conclusion_idx, end_range):
320
+ if idx < len(logprobs) and len(logprobs[idx]) >= 2 and logprobs[idx][0][1] is not None and logprobs[idx][1][1] is not None:
321
+ gap = logprobs[idx][0][1] - logprobs[idx][1][1]
322
+ gaps.append(gap)
323
+ if not gaps:
324
+ return "No conclusion competition data available.", None
325
+ mean_gap = np.mean(gaps)
326
+ return f"Mean probability gap at conclusion: {mean_gap:.4f} (higher indicates more confident conclusion)", None
327
+
328
+ def analyze_verification_signals(logprobs, tokens):
329
+ if not logprobs or not tokens:
330
+ return "No data for verification signal analysis.", None
331
+ verification_terms = ["verify", "check", "confirm", "ensure", "double"]
332
+ verification_probs = []
333
+ for lps in logprobs:
334
+ if not lps:
335
+ continue
336
+ max_v_prob = -float('inf')
337
+ for token, prob in lps:
338
+ if any(v_term in token.lower() for v_term in verification_terms) and prob is not None:
339
+ max_v_prob = max(max_v_prob, prob)
340
+ if max_v_prob > -float('inf'):
341
+ verification_probs.append(max_v_prob)
342
+ if not verification_probs:
343
+ return "No verification signals detected.", None
344
+ count, mean_prob = len(verification_probs), np.mean(verification_probs)
345
+ return f"Verification signals found: {count} instances, mean probability: {mean_prob:.4f}", None
346
+
347
+ def detect_semantic_inversions(logprobs, tokens):
348
+ if not logprobs or not tokens:
349
+ return "No data for semantic inversion detection.", None
350
+ inversion_pairs = [("more", "less"), ("larger", "smaller"), ("winning", "losing"), ("increase", "decrease"), ("greater", "lesser"), ("positive", "negative")]
351
+ inversions = []
352
+ for i, (token, lps) in enumerate(zip(tokens, logprobs)):
353
+ if not lps:
354
+ continue
355
+ for pos, neg in inversion_pairs:
356
+ if token.lower() == pos:
357
+ for t, p in lps:
358
+ if t.lower() == neg and p > -3.0 and p is not None:
359
+ inversions.append((i, pos, neg, p))
360
+ elif token.lower() == neg:
361
+ for t, p in lps:
362
+ if t.lower() == pos and p > -3.0 and p is not None:
363
+ inversions.append((i, neg, pos, p))
364
+ if not inversions:
365
+ return "No semantic inversions detected.", None
366
+ return "Semantic inversions detected:", inversions
367
+
368
+ # Function to perform full trace analysis
369
  def analyze_full_trace(json_input):
370
  try:
371
  data = parse_input(json_input)
372
  content = data.get("content", []) if isinstance(data, dict) else data
373
  if not isinstance(content, list):
374
+ raise ValueError("Content must be a list of entries")
 
 
 
375
 
376
+ tokens = []
377
+ logprobs = []
378
+ for entry in content:
379
+ if not isinstance(entry, dict):
380
+ logger.warning("Skipping non-dictionary entry: %s", entry)
381
+ continue
382
+ logprob = ensure_float(entry.get("logprob", None))
383
+ if logprob >= -100000:
384
+ tokens.append(get_token(entry))
385
+ top_probs = entry.get("top_logprobs", {}) or {}
386
+ finite_top_probs = [(key, ensure_float(value)) for key, value in top_probs.items() if ensure_float(value) is not None and math.isfinite(ensure_float(value))]
387
+ logprobs.append(finite_top_probs)
388
 
389
+ if not logprobs or not tokens:
390
+ return "No valid data for trace analysis.", None, None, None, None, None
391
+
392
+ confidence_result, confidence_data = analyze_confidence_signature(logprobs, tokens)
393
+ pivot_result, pivot_data = detect_interpretation_pivots(logprobs, tokens)
394
+ entropy_result, entropy_data = calculate_decision_entropy(logprobs)
395
+ conclusion_result, conclusion_data = analyze_conclusion_competition(logprobs, tokens)
396
+ verification_result, verification_data = analyze_verification_signals(logprobs, tokens)
397
+ inversion_result, inversion_data = detect_semantic_inversions(logprobs, tokens)
398
+
399
+ analysis_html = f"""
400
+ <h3>Trace Analysis Results</h3>
401
+ <ul>
402
+ <li><strong>Confidence Signature:</strong> {confidence_result}</li>
403
+ {f"<ul><li>Positions: {', '.join(str(pos) for pos, tok in confidence_data)}</li></ul>" if confidence_data else ""}
404
+ <li><strong>Interpretation Pivots:</strong> {pivot_result}</li>
405
+ {f"<ul><li>Positions: {', '.join(str(pos) for pos, _, _ in pivot_data)}</li></ul>" if pivot_data else ""}
406
+ <li><strong>Decision Entropy Spikes:</strong> {entropy_result}</li>
407
+ {f"<ul><li>Positions: {', '.join(str(pos) for pos in entropy_data)}</li></ul>" if entropy_data else ""}
408
+ <li><strong>Conclusion Competition:</strong> {conclusion_result}</li>
409
+ <li><strong>Verification Signals:</strong> {verification_result}</li>
410
+ <li><strong>Semantic Inversions:</strong> {inversion_result}</li>
411
+ {f"<ul><li>Positions: {', '.join(str(pos) for pos, _, _, _ in inversion_data)}</li></ul>" if inversion_data else ""}
412
+ </ul>
413
+ """
414
  return analysis_html, None, None, None, None, None
415
  except Exception as e:
416
  logger.error("Trace analysis failed: %s", str(e))
417
  return f"Error: {e}", None, None, None, None, None
418
 
419
+ # Gradio interface with two tabs
420
  try:
421
  with gr.Blocks(title="Log Probability Visualizer") as app:
422
  gr.Markdown("# Log Probability Visualizer")
423
+ gr.Markdown("Paste your JSON log prob data below to analyze reasoning traces or visualize tokens in chunks of 100. Fixed filter ≥ -100000, dynamic number of top_logprobs, handles missing or null fields. Next chunk is precomputed proactively.")
424
 
425
  with gr.Tabs():
426
  with gr.Tab("Trace Analysis"):
427
+ with gr.Row():
428
+ json_input_analysis = gr.Textbox(
429
+ label="JSON Input for Trace Analysis",
430
+ lines=10,
431
+ placeholder='{"content": [{"bytes": [44], "logprob": 0.0, "token": ",", "top_logprobs": {" so": -13.8046875, ".": -13.8046875, ",": -13.640625}}]}'
432
+ )
433
+ with gr.Row():
434
+ analysis_output = gr.HTML(label="Trace Analysis Results")
435
+
436
+ btn_analyze = gr.Button("Analyze Trace")
437
+ btn_analyze.click(
438
+ fn=analyze_full_trace,
439
+ inputs=[json_input_analysis],
440
+ outputs=[analysis_output, gr.State(), gr.State(), gr.State(), gr.State(), gr.State()],
441
+ )
442
 
443
  with gr.Tab("Visualization"):
444
  with gr.Row():
445
+ json_input_viz = gr.Textbox(
446
+ label="JSON Input for Visualization",
447
+ lines=10,
448
+ placeholder='{"content": [{"bytes": [44], "logprob": 0.0, "token": ",", "top_logprobs": {" so": -13.8046875, ".": -13.8046875, ",": -13.640625}}]}'
449
+ )
450
  chunk = gr.Number(value=0, label="Current Chunk", precision=0, minimum=0)
451
+
452
  with gr.Row():
453
+ plot_output = gr.Plot(label="Log Probability Plot (Click for Tokens)")
454
+ drops_output = gr.Plot(label="Probability Drops (Click for Details)")
455
+
456
  with gr.Row():
457
+ table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
458
+ alt_viz_output = gr.Plot(label="Top Token Log Probabilities (Click for Details)")
459
+
460
  with gr.Row():
461
+ text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
462
+
463
  with gr.Row():
464
  prev_btn = gr.Button("Previous Chunk")
465
  next_btn = gr.Button("Next Chunk")
466
  total_chunks_output = gr.Number(label="Total Chunks", interactive=False)
467
 
468
+ # Precomputed next chunk state (hidden)
469
  precomputed_next = gr.State(value=None)
470
 
471
+ btn_viz = gr.Button("Visualize")
472
+ btn_viz.click(
473
+ fn=visualize_logprobs,
474
+ inputs=[json_input_viz, chunk],
475
+ outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
476
+ )
477
 
478
  def update_chunk(json_input, current_chunk, action, precomputed_next=None):
479
+ total_chunks = visualize_logprobs(json_input, 0)[5] # Get total chunks
480
  if action == "prev" and current_chunk > 0:
481
  current_chunk -= 1
482
  elif action == "next" and current_chunk < total_chunks - 1:
483
  current_chunk += 1
484
+ if precomputed_next and all(precomputed_next):
485
+ logger.debug("Using precomputed next chunk for chunk %d", current_chunk)
486
+ return visualize_logprobs(json_input, current_chunk)
487
  return visualize_logprobs(json_input, current_chunk)
488
 
489
+ prev_btn.click(
490
+ fn=update_chunk,
491
+ inputs=[json_input_viz, chunk, gr.State(value="prev"), precomputed_next],
492
+ outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
493
+ )
494
+
495
+ next_btn.click(
496
+ fn=update_chunk,
497
+ inputs=[json_input_viz, chunk, gr.State(value="next"), precomputed_next],
498
+ outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
499
+ )
500
 
501
  def trigger_precomputation(json_input, current_chunk):
502
+ try:
503
+ threading.Thread(target=precompute_next_chunk_sync, args=(json_input, current_chunk), daemon=True).start()
504
+ except Exception as e:
505
+ logger.error("Precomputation trigger failed: %s", str(e))
506
  return gr.update(value=current_chunk)
507
 
508
+ chunk.change(
509
+ fn=trigger_precomputation,
510
+ inputs=[json_input_viz, chunk],
511
+ outputs=[chunk],
512
+ )
513
 
514
  except Exception as e:
515
  logger.error("Application startup failed: %s", str(e))