codelion commited on
Commit
4615d41
·
verified ·
1 Parent(s): 374a5a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -105
app.py CHANGED
@@ -281,7 +281,7 @@ def visualize_logprobs(json_input, chunk=0, chunk_size=100):
281
  logger.error("Visualization failed: %s (Input: %s)", str(e), json_input[:100] + "..." if len(json_input) > 100 else json_input)
282
  return (create_empty_figure("Log Probabilities of Generated Tokens"), None, "No finite log probabilities to display.", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Significant Probability Drops"), 1, 0)
283
 
284
- # Analysis functions for detecting correct vs. incorrect traces (unchanged from previous)
285
  def analyze_confidence_signature(logprobs, tokens):
286
  if not logprobs or not tokens:
287
  return "No data for confidence signature analysis.", None
@@ -476,114 +476,120 @@ def analyze_full_trace(json_input):
476
  return analysis_html, None, None, None, None, None
477
 
478
  # Gradio interface with two tabs: Trace Analysis and Visualization
479
- with gr.Blocks(title="Log Probability Visualizer") as app:
480
- gr.Markdown("# Log Probability Visualizer")
481
- gr.Markdown(
482
- "Paste your JSON log prob data below to analyze reasoning traces and visualize tokens in chunks of 100. Fixed filter ≥ -100000, dynamic number of top_logprobs, handles missing or null fields. Next chunk is precomputed proactively."
483
- )
 
484
 
485
- with gr.Tabs():
486
- with gr.Tab("Trace Analysis"):
487
- with gr.Row():
488
- json_input_analysis = gr.Textbox(
489
- label="JSON Input for Trace Analysis",
490
- lines=10,
491
- placeholder="Paste your JSON (e.g., {\"content\": [{\"bytes\": [44], \"logprob\": 0.0, \"token\": \",\", \"top_logprobs\": {\" so\": -13.8046875, \".\": -13.8046875, \",\": -13.640625}}]}).",
 
 
 
 
 
 
 
 
 
492
  )
493
- with gr.Row():
494
- analysis_output = gr.HTML(label="Trace Analysis Results")
495
-
496
- btn_analyze = gr.Button("Analyze Trace")
497
- btn_analyze.click(
498
- fn=analyze_full_trace,
499
- inputs=[json_input_analysis],
500
- outputs=[analysis_output, gr.State(), gr.State(), gr.State(), gr.State(), gr.State()],
501
- )
502
 
503
- with gr.Tab("Visualization"):
504
- with gr.Row():
505
- json_input_viz = gr.Textbox(
506
- label="JSON Input for Visualization",
507
- lines=10,
508
- placeholder="Paste your JSON (e.g., {\"content\": [{\"bytes\": [44], \"logprob\": 0.0, \"token\": \",\", \"top_logprobs\": {\" so\": -13.8046875, \".\": -13.8046875, \",\": -13.640625}}]}).",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  )
510
- chunk = gr.Number(value=0, label="Current Chunk", precision=0, minimum=0)
511
-
512
- with gr.Row():
513
- plot_output = gr.Plot(label="Log Probability Plot (Click for Tokens)")
514
- drops_output = gr.Plot(label="Probability Drops (Click for Details)")
515
-
516
- with gr.Row():
517
- table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
518
- alt_viz_output = gr.Plot(label="Top Token Log Probabilities (Click for Details)")
519
 
520
- with gr.Row():
521
- text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
522
-
523
- with gr.Row():
524
- prev_btn = gr.Button("Previous Chunk")
525
- next_btn = gr.Button("Next Chunk")
526
- total_chunks_output = gr.Number(label="Total Chunks", interactive=False)
527
-
528
- # Precomputed next chunk state (hidden)
529
- precomputed_next = gr.State(value=None)
530
-
531
- btn_viz = gr.Button("Visualize")
532
- btn_viz.click(
533
- fn=visualize_logprobs,
534
- inputs=[json_input_viz, chunk],
535
- outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
536
- )
537
-
538
- # Precompute next chunk proactively when on current chunk
539
- async def precompute_next_chunk(json_input, current_chunk, precomputed_next):
540
- if precomputed_next is not None:
541
- return precomputed_next # Use cached precomputed chunk if available
542
- next_tokens, next_logprobs, next_alternatives = await precompute_chunk(json_input, 100, current_chunk)
543
- if next_tokens is None or next_logprobs is None or next_alternatives is None:
544
- return None
545
- return (next_tokens, next_logprobs, next_alternatives)
546
-
547
- # Update chunk on button clicks
548
- def update_chunk(json_input, current_chunk, action, precomputed_next=None):
549
- total_chunks = visualize_logprobs(json_input, 0)[5] # Get total chunks
550
- if action == "prev" and current_chunk > 0:
551
- current_chunk -= 1
552
- elif action == "next" and current_chunk < total_chunks - 1:
553
- current_chunk += 1
554
- # If precomputed next chunk exists, use it; otherwise, compute it
555
- if precomputed_next:
556
- next_tokens, next_logprobs, next_alternatives = precomputed_next
557
- if next_tokens and next_logprobs and next_alternatives:
558
- logger.debug("Using precomputed next chunk for chunk %d", current_chunk)
559
- return visualize_logprobs(json_input, current_chunk)
560
- return visualize_logprobs(json_input, current_chunk)
561
-
562
- prev_btn.click(
563
- fn=update_chunk,
564
- inputs=[json_input_viz, chunk, gr.State(value="prev"), precomputed_next],
565
- outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
566
- )
567
-
568
- next_btn.click(
569
- fn=update_chunk,
570
- inputs=[json_input_viz, chunk, gr.State(value="next"), precomputed_next],
571
- outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
572
- )
573
 
574
- # Trigger precomputation when chunk changes (via button clicks or initial load)
575
- def trigger_precomputation(json_input, current_chunk):
576
- try:
577
- asyncio.create_task(precompute_next_chunk(json_input, current_chunk, None))
578
- except Exception as e:
579
- logger.error("Precomputation trigger failed: %s", str(e))
580
- return gr.update(value=current_chunk)
581
-
582
- # Use a dummy event to trigger precomputation on chunk change (simplified for Gradio)
583
- chunk.change(
584
- fn=trigger_precomputation,
585
- inputs=[json_input_viz, chunk],
586
- outputs=[chunk],
587
- )
588
 
589
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  logger.error("Visualization failed: %s (Input: %s)", str(e), json_input[:100] + "..." if len(json_input) > 100 else json_input)
282
  return (create_empty_figure("Log Probabilities of Generated Tokens"), None, "No finite log probabilities to display.", create_empty_figure("Top Token Log Probabilities"), create_empty_figure("Significant Probability Drops"), 1, 0)
283
 
284
+ # Analysis functions for detecting correct vs. incorrect traces (unchanged)
285
  def analyze_confidence_signature(logprobs, tokens):
286
  if not logprobs or not tokens:
287
  return "No data for confidence signature analysis.", None
 
476
  return analysis_html, None, None, None, None, None
477
 
478
  # Gradio interface with two tabs: Trace Analysis and Visualization
479
+ try:
480
+ with gr.Blocks(title="Log Probability Visualizer") as app:
481
+ gr.Markdown("# Log Probability Visualizer")
482
+ gr.Markdown(
483
+ "Paste your JSON log prob data below to analyze reasoning traces and visualize tokens in chunks of 100. Fixed filter ≥ -100000, dynamic number of top_logprobs, handles missing or null fields. Next chunk is precomputed proactively."
484
+ )
485
 
486
+ with gr.Tabs():
487
+ with gr.Tab("Trace Analysis"):
488
+ with gr.Row():
489
+ json_input_analysis = gr.Textbox(
490
+ label="JSON Input for Trace Analysis",
491
+ lines=10,
492
+ placeholder="Paste your JSON (e.g., {\"content\": [{\"bytes\": [44], \"logprob\": 0.0, \"token\": \",\", \"top_logprobs\": {\" so\": -13.8046875, \".\": -13.8046875, \",\": -13.640625}}]}).",
493
+ )
494
+ with gr.Row():
495
+ analysis_output = gr.HTML(label="Trace Analysis Results")
496
+
497
+ btn_analyze = gr.Button("Analyze Trace")
498
+ btn_analyze.click(
499
+ fn=analyze_full_trace,
500
+ inputs=[json_input_analysis],
501
+ outputs=[analysis_output, gr.State(), gr.State(), gr.State(), gr.State(), gr.State()],
502
  )
 
 
 
 
 
 
 
 
 
503
 
504
+ with gr.Tab("Visualization"):
505
+ with gr.Row():
506
+ json_input_viz = gr.Textbox(
507
+ label="JSON Input for Visualization",
508
+ lines=10,
509
+ placeholder="Paste your JSON (e.g., {\"content\": [{\"bytes\": [44], \"logprob\": 0.0, \"token\": \",\", \"top_logprobs\": {\" so\": -13.8046875, \".\": -13.8046875, \",\": -13.640625}}]}).",
510
+ )
511
+ chunk = gr.Number(value=0, label="Current Chunk", precision=0, minimum=0)
512
+
513
+ with gr.Row():
514
+ plot_output = gr.Plot(label="Log Probability Plot (Click for Tokens)")
515
+ drops_output = gr.Plot(label="Probability Drops (Click for Details)")
516
+
517
+ with gr.Row():
518
+ table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives")
519
+ alt_viz_output = gr.Plot(label="Top Token Log Probabilities (Click for Details)")
520
+
521
+ with gr.Row():
522
+ text_output = gr.HTML(label="Colored Text (Confidence Visualization)")
523
+
524
+ with gr.Row():
525
+ prev_btn = gr.Button("Previous Chunk")
526
+ next_btn = gr.Button("Next Chunk")
527
+ total_chunks_output = gr.Number(label="Total Chunks", interactive=False)
528
+
529
+ # Precomputed next chunk state (hidden)
530
+ precomputed_next = gr.State(value=None)
531
+
532
+ btn_viz = gr.Button("Visualize")
533
+ btn_viz.click(
534
+ fn=visualize_logprobs,
535
+ inputs=[json_input_viz, chunk],
536
+ outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
537
  )
 
 
 
 
 
 
 
 
 
538
 
539
+ # Precompute next chunk proactively when on current chunk
540
+ async def precompute_next_chunk(json_input, current_chunk, precomputed_next):
541
+ if precomputed_next is not None:
542
+ return precomputed_next # Use cached precomputed chunk if available
543
+ try:
544
+ next_tokens, next_logprobs, next_alternatives = await precompute_chunk(json_input, 100, current_chunk)
545
+ if next_tokens is None or next_logprobs is None or next_alternatives is None:
546
+ return None
547
+ return (next_tokens, next_logprobs, next_alternatives)
548
+ except Exception as e:
549
+ logger.error("Precomputation failed for chunk %d: %s", current_chunk + 1, str(e))
550
+ return None
551
+
552
+ # Update chunk on button clicks
553
+ def update_chunk(json_input, current_chunk, action, precomputed_next=None):
554
+ total_chunks = visualize_logprobs(json_input, 0)[5] # Get total chunks
555
+ if action == "prev" and current_chunk > 0:
556
+ current_chunk -= 1
557
+ elif action == "next" and current_chunk < total_chunks - 1:
558
+ current_chunk += 1
559
+ # If precomputed next chunk exists, use it; otherwise, compute it
560
+ if precomputed_next:
561
+ next_tokens, next_logprobs, next_alternatives = precomputed_next
562
+ if next_tokens and next_logprobs and next_alternatives:
563
+ logger.debug("Using precomputed next chunk for chunk %d", current_chunk)
564
+ return visualize_logprobs(json_input, current_chunk)
565
+ return visualize_logprobs(json_input, current_chunk)
566
+
567
+ prev_btn.click(
568
+ fn=update_chunk,
569
+ inputs=[json_input_viz, chunk, gr.State(value="prev"), precomputed_next],
570
+ outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
571
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
+ next_btn.click(
574
+ fn=update_chunk,
575
+ inputs=[json_input_viz, chunk, gr.State(value="next"), precomputed_next],
576
+ outputs=[plot_output, table_output, text_output, alt_viz_output, drops_output, total_chunks_output, chunk],
577
+ )
 
 
 
 
 
 
 
 
 
578
 
579
+ # Trigger precomputation when chunk changes (via button clicks or initial load)
580
+ def trigger_precomputation(json_input, current_chunk):
581
+ try:
582
+ asyncio.create_task(precompute_next_chunk(json_input, current_chunk, None))
583
+ except Exception as e:
584
+ logger.error("Precomputation trigger failed: %s", str(e))
585
+ return gr.update(value=current_chunk)
586
+
587
+ # Use a dummy event to trigger precomputation on chunk change (simplified for Gradio)
588
+ chunk.change(
589
+ fn=trigger_precomputation,
590
+ inputs=[json_input_viz, chunk],
591
+ outputs=[chunk],
592
+ )
593
+ except Exception as e:
594
+ logger.error("Application startup failed: %s", str(e))
595
+ raise