hiyata commited on
Commit
d1cde92
·
verified ·
1 Parent(s): 2e254a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -0
app.py CHANGED
@@ -474,7 +474,230 @@ def analyze_subregion(state, header, region_start, region_end):
474
 
475
  return (region_info, heatmap_img, hist_img)
476
 
 
 
 
 
 
 
 
477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  ###############################################################################
479
  # 9. BUILD GRADIO INTERFACE
480
  ###############################################################################
 
474
 
475
  return (region_info, heatmap_img, hist_img)
476
 
477
+ # Add these imports at the top of the file, after existing imports
478
+ from scipy.interpolate import interp1d
479
+ import numpy as np
480
+
481
+ ###############################################################################
482
+ # NEW SECTION: COMPARATIVE ANALYSIS FUNCTIONS
483
+ ###############################################################################
484
 
485
+ def normalize_shap_lengths(shap1, shap2, num_points=1000):
486
+ """
487
+ Normalize two SHAP arrays to the same length using interpolation.
488
+ Returns (normalized_shap1, normalized_shap2)
489
+ """
490
+ # Create x coordinates for both sequences
491
+ x1 = np.linspace(0, 1, len(shap1))
492
+ x2 = np.linspace(0, 1, len(shap2))
493
+
494
+ # Create interpolation functions
495
+ f1 = interp1d(x1, shap1, kind='linear')
496
+ f2 = interp1d(x2, shap2, kind='linear')
497
+
498
+ # Create new x coordinates for interpolation
499
+ x_new = np.linspace(0, 1, num_points)
500
+
501
+ # Interpolate both sequences to new length
502
+ shap1_norm = f1(x_new)
503
+ shap2_norm = f2(x_new)
504
+
505
+ return shap1_norm, shap2_norm
506
+
507
+ def compute_shap_difference(shap1_norm, shap2_norm):
508
+ """
509
+ Compute the difference between two normalized SHAP arrays.
510
+ Positive values indicate seq2 is more "human-like" than seq1.
511
+ """
512
+ return shap2_norm - shap1_norm
513
+
514
+ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
515
+ """
516
+ Plot the difference between two sequences' SHAP values.
517
+ Red indicates seq2 is more human-like, blue indicates seq1 is more human-like.
518
+ """
519
+ # Build 2D array for imshow
520
+ heatmap_data = shap_diff.reshape(1, -1)
521
+
522
+ # Force symmetrical range
523
+ extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
524
+
525
+ # Create figure with adjusted height ratio
526
+ fig, ax = plt.subplots(figsize=(12, 1.8))
527
+
528
+ # Create custom colormap
529
+ custom_cmap = get_zero_centered_cmap()
530
+
531
+ # Plot heatmap
532
+ cax = ax.imshow(
533
+ heatmap_data,
534
+ aspect='auto',
535
+ cmap=custom_cmap,
536
+ vmin=-extent,
537
+ vmax=+extent
538
+ )
539
+
540
+ # Configure colorbar
541
+ cbar = plt.colorbar(
542
+ cax,
543
+ orientation='horizontal',
544
+ pad=0.25,
545
+ aspect=40,
546
+ shrink=0.8
547
+ )
548
+
549
+ # Style the colorbar
550
+ cbar.ax.tick_params(labelsize=8)
551
+ cbar.set_label(
552
+ 'SHAP Difference (Seq2 - Seq1)',
553
+ fontsize=9,
554
+ labelpad=5
555
+ )
556
+
557
+ # Configure main plot
558
+ ax.set_yticks([])
559
+ ax.set_xlabel('Normalized Position (0-100%)', fontsize=10)
560
+ ax.set_title(title, pad=10)
561
+
562
+ plt.subplots_adjust(
563
+ bottom=0.25,
564
+ left=0.05,
565
+ right=0.95
566
+ )
567
+
568
+ return fig
569
+
570
+ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
571
+ """
572
+ Compare two sequences by analyzing their SHAP differences.
573
+ Returns comparison text and visualizations.
574
+ """
575
+ # Process first sequence
576
+ results1 = analyze_sequence(file1, fasta_text=fasta1)
577
+ if isinstance(results1[0], str) and "Error" in results1[0]:
578
+ return (f"Error in sequence 1: {results1[0]}", None, None)
579
+
580
+ # Process second sequence
581
+ results2 = analyze_sequence(file2, fasta_text=fasta2)
582
+ if isinstance(results2[0], str) and "Error" in results2[0]:
583
+ return (f"Error in sequence 2: {results2[0]}", None, None)
584
+
585
+ # Get SHAP means from state dictionaries
586
+ shap1 = results1[3]["shap_means"]
587
+ shap2 = results2[3]["shap_means"]
588
+
589
+ # Normalize lengths
590
+ shap1_norm, shap2_norm = normalize_shap_lengths(shap1, shap2)
591
+
592
+ # Compute difference (positive = seq2 more human-like)
593
+ shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
594
+
595
+ # Calculate some statistics
596
+ avg_diff = np.mean(shap_diff)
597
+ std_diff = np.std(shap_diff)
598
+ max_diff = np.max(shap_diff)
599
+ min_diff = np.min(shap_diff)
600
+
601
+ # Calculate what fraction of positions show substantial differences
602
+ threshold = 0.05 # Arbitrary threshold for "substantial" difference
603
+ substantial_diffs = np.abs(shap_diff) > threshold
604
+ frac_different = np.mean(substantial_diffs)
605
+
606
+ # Generate comparison text
607
+ comparison_text = f"""Sequence Comparison Results:
608
+ Sequence 1: {results1[4]}
609
+ Length: {len(shap1):,} bases
610
+ Classification: {results1[0].split('Classification: ')[1].split('\n')[0]}
611
+
612
+ Sequence 2: {results2[4]}
613
+ Length: {len(shap2):,} bases
614
+ Classification: {results2[0].split('Classification: ')[1].split('\n')[0]}
615
+
616
+ Comparison Statistics:
617
+ Average SHAP difference: {avg_diff:.4f}
618
+ Standard deviation: {std_diff:.4f}
619
+ Max difference: {max_diff:.4f} (Seq2 more human-like)
620
+ Min difference: {min_diff:.4f} (Seq1 more human-like)
621
+ Fraction of positions with substantial differences: {frac_different:.2%}
622
+
623
+ Interpretation:
624
+ Positive values (red) indicate regions where Sequence 2 is more "human-like"
625
+ Negative values (blue) indicate regions where Sequence 1 is more "human-like"
626
+ """
627
+
628
+ # Create comparison heatmap
629
+ heatmap_fig = plot_comparative_heatmap(shap_diff)
630
+ heatmap_img = fig_to_image(heatmap_fig)
631
+
632
+ # Create histogram of differences
633
+ hist_fig = plot_shap_histogram(
634
+ shap_diff,
635
+ title="Distribution of SHAP Differences"
636
+ )
637
+ hist_img = fig_to_image(hist_fig)
638
+
639
+ return comparison_text, heatmap_img, hist_img
640
+
641
+ ###############################################################################
642
+ # NEW TAB TO GRADIO
643
+ ###############################################################################
644
+
645
+ # Inside the Gradio interface definition, add this new tab:
646
+ with gr.Tab("3) Comparative Analysis"):
647
+ gr.Markdown("""
648
+ **Compare Two Sequences**
649
+ Upload or paste two FASTA sequences to compare their SHAP patterns.
650
+ The sequences will be normalized to the same length for comparison.
651
+
652
+ **Color Scale**:
653
+ - Red: Sequence 2 is more human-like in this region
654
+ - Blue: Sequence 1 is more human-like in this region
655
+ - White: No substantial difference
656
+ """)
657
+
658
+ with gr.Row():
659
+ with gr.Column(scale=1):
660
+ file_input1 = gr.File(
661
+ label="Upload first FASTA file",
662
+ file_types=[".fasta", ".fa", ".txt"],
663
+ type="filepath"
664
+ )
665
+ text_input1 = gr.Textbox(
666
+ label="Or paste first FASTA sequence",
667
+ placeholder=">sequence1\nACGTACGT...",
668
+ lines=5
669
+ )
670
+
671
+ with gr.Column(scale=1):
672
+ file_input2 = gr.File(
673
+ label="Upload second FASTA file",
674
+ file_types=[".fasta", ".fa", ".txt"],
675
+ type="filepath"
676
+ )
677
+ text_input2 = gr.Textbox(
678
+ label="Or paste second FASTA sequence",
679
+ placeholder=">sequence2\nACGTACGT...",
680
+ lines=5
681
+ )
682
+
683
+ compare_btn = gr.Button("Compare Sequences", variant="primary")
684
+
685
+ comparison_text = gr.Textbox(
686
+ label="Comparison Results",
687
+ lines=12,
688
+ interactive=False
689
+ )
690
+
691
+ with gr.Row():
692
+ diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
693
+ diff_hist = gr.Image(label="Distribution of SHAP Differences")
694
+
695
+ compare_btn.click(
696
+ analyze_sequence_comparison,
697
+ inputs=[file_input1, file_input2, text_input1, text_input2],
698
+ outputs=[comparison_text, diff_heatmap, diff_hist]
699
+ )
700
+
701
  ###############################################################################
702
  # 9. BUILD GRADIO INTERFACE
703
  ###############################################################################