hiyata commited on
Commit
2fd86ff
·
verified ·
1 Parent(s): 1869cbd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -18
app.py CHANGED
@@ -555,22 +555,209 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
555
  except Exception as e:
556
  error_msg = f"Error during sequence comparison: {str(e)}"
557
  return error_msg, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
 
559
  ###############################################################################
560
- # 10. BUILD GRADIO INTERFACE
561
  ###############################################################################
562
 
563
  css = """
564
  .gradio-container {
565
  font-family: 'IBM Plex Sans', sans-serif;
566
  }
 
 
 
567
  """
568
 
569
  with gr.Blocks(css=css) as iface:
570
  gr.Markdown("""
571
  # Virus Host Classifier
572
  **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
573
- **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
 
 
574
 
575
  **Color Scale**: Negative SHAP = Blue, Zero = White, Positive = Red.
576
  """)
@@ -587,19 +774,20 @@ with gr.Blocks(css=css) as iface:
587
  results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
588
  kmer_img = gr.Image(label="Top k-mer SHAP")
589
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
 
590
  seq_state = gr.State()
591
  header_state = gr.State()
592
  analyze_btn.click(
593
  analyze_sequence,
594
  inputs=[file_input, top_k, text_input, win_size],
595
- outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
596
  )
597
 
598
  with gr.Tab("2) Subregion Exploration"):
599
  gr.Markdown("""
600
  **Subregion Analysis**
601
  Select start/end positions to view local SHAP signals, distribution, GC content, etc.
602
- The heatmap also uses the same Blue-White-Red scale.
603
  """)
604
  with gr.Row():
605
  region_start = gr.Number(label="Region Start", value=0)
@@ -609,13 +797,47 @@ with gr.Blocks(css=css) as iface:
609
  with gr.Row():
610
  subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
611
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
 
612
  region_btn.click(
613
  analyze_subregion,
614
  inputs=[seq_state, header_state, region_start, region_end],
615
- outputs=[subregion_info, subregion_img, subregion_hist_img]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  )
617
 
618
- with gr.Tab("3) Comparative Analysis"):
619
  gr.Markdown("""
620
  **Compare Two Sequences**
621
  Upload or paste two FASTA sequences to compare their SHAP patterns.
@@ -638,29 +860,33 @@ with gr.Blocks(css=css) as iface:
638
  with gr.Row():
639
  diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
640
  diff_hist = gr.Image(label="Distribution of SHAP Differences")
 
641
  compare_btn.click(
642
  analyze_sequence_comparison,
643
  inputs=[file_input1, file_input2, text_input1, text_input2],
644
- outputs=[comparison_text, diff_heatmap, diff_hist]
645
  )
646
 
647
  gr.Markdown("""
648
  ### Interface Features
649
- - **Overall Classification** (human vs non-human) using k-mer frequencies.
650
- - **SHAP Analysis** to see which k-mers push classification toward or away from human.
651
  - **White-Centered SHAP Gradient**:
652
- - Negative (blue), 0 (white), Positive (red), with symmetrical color range around 0.
653
- - **Identify Subregions** with the strongest push for human or non-human.
654
- - **Subregion Exploration**:
655
- - Local SHAP heatmap & histogram
656
- - GC content
657
- - Fraction of positions pushing human vs. non-human
658
- - Simple logic-based classification
659
  - **Sequence Comparison**:
660
  - Compare two sequences to identify regions of difference
661
- - Normalized comparison to handle different sequence lengths
662
  - Statistical summary of differences
 
 
 
663
  """)
664
 
665
  if __name__ == "__main__":
666
- iface.launch()
 
555
  except Exception as e:
556
  error_msg = f"Error during sequence comparison: {str(e)}"
557
  return error_msg, None, None
558
+
559
+ ###############################################################################
560
+ # 11. GENE FEATURE ANALYSIS
561
+ ###############################################################################
562
+
563
+ def parse_gene_features(text):
564
+ """Parse gene features from text file in FASTA-like format"""
565
+ genes = []
566
+ current_header = None
567
+ current_sequence = []
568
+
569
+ for line in text.strip().split('\n'):
570
+ line = line.strip()
571
+ if not line:
572
+ continue
573
+ if line.startswith('>'):
574
+ if current_header:
575
+ genes.append({
576
+ 'header': current_header,
577
+ 'sequence': ''.join(current_sequence),
578
+ 'metadata': parse_gene_metadata(current_header)
579
+ })
580
+ current_header = line[1:]
581
+ current_sequence = []
582
+ else:
583
+ current_sequence.append(line.upper())
584
+
585
+ if current_header:
586
+ genes.append({
587
+ 'header': current_header,
588
+ 'sequence': ''.join(current_sequence),
589
+ 'metadata': parse_gene_metadata(current_header)
590
+ })
591
+
592
+ return genes
593
+
594
+ def parse_gene_metadata(header):
595
+ """Extract metadata from gene header"""
596
+ metadata = {}
597
+ parts = header.split()
598
+
599
+ for part in parts:
600
+ if '[' in part and ']' in part:
601
+ key_value = part[1:-1].split('=', 1)
602
+ if len(key_value) == 2:
603
+ metadata[key_value[0]] = key_value[1]
604
+
605
+ return metadata
606
+
607
+ def analyze_gene_features(sequence_file, features_file, fasta_text="", features_text=""):
608
+ """Analyze SHAP values for each gene feature"""
609
+ # First analyze whole sequence
610
+ sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
611
+ if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
612
+ return f"Error in sequence analysis: {sequence_results[0]}", None, None
613
+
614
+ # Get SHAP values
615
+ shap_means = sequence_results[3]["shap_means"]
616
+
617
+ # Parse gene features
618
+ if features_text.strip():
619
+ genes = parse_gene_features(features_text)
620
+ else:
621
+ try:
622
+ with open(features_file, 'r') as f:
623
+ genes = parse_gene_features(f.read())
624
+ except Exception as e:
625
+ return f"Error reading features file: {str(e)}", None, None
626
+
627
+ # Analyze each gene
628
+ gene_results = []
629
+ for gene in genes:
630
+ try:
631
+ location = gene['metadata'].get('location', '')
632
+ if not location:
633
+ continue
634
+
635
+ # Parse location (assuming format like "21729..22861")
636
+ start, end = map(int, location.split('..'))
637
+
638
+ # Get SHAP values for this region
639
+ gene_shap = shap_means[start:end]
640
+ avg_shap = float(np.mean(gene_shap))
641
+
642
+ gene_results.append({
643
+ 'gene_name': gene['metadata'].get('gene', 'Unknown'),
644
+ 'location': location,
645
+ 'avg_shap': avg_shap,
646
+ 'start': start,
647
+ 'end': end,
648
+ 'locus_tag': gene['metadata'].get('locus_tag', ''),
649
+ 'classification': 'Human' if avg_shap > 0 else 'Non-human',
650
+ 'confidence': abs(avg_shap)
651
+ })
652
+
653
+ except Exception as e:
654
+ print(f"Error processing gene {gene['metadata'].get('gene', 'Unknown')}: {str(e)}")
655
+ continue
656
+
657
+ # Create CSV output
658
+ csv_output = "gene_name,location,avg_shap,classification,confidence,locus_tag\n"
659
+ for result in gene_results:
660
+ csv_output += f"{result['gene_name']},{result['location']},{result['avg_shap']:.4f},"
661
+ csv_output += f"{result['classification']},{result['confidence']:.4f},{result['locus_tag']}\n"
662
+
663
+ # Create genome diagram
664
+ diagram_img = create_genome_diagram(gene_results, len(shap_means))
665
+
666
+ return gene_results, csv_output, diagram_img
667
+
668
+ def create_genome_diagram(gene_results, genome_length):
669
+ """Create genome diagram using BioPython"""
670
+ from Bio.Graphics import GenomeDiagram
671
+ from Bio.SeqFeature import SeqFeature, FeatureLocation
672
+ from reportlab.lib import colors
673
+ from io import BytesIO
674
+ from PIL import Image
675
+
676
+ # Create diagram
677
+ gd_diagram = GenomeDiagram.Diagram("Genome SHAP Analysis")
678
+ gd_track = gd_diagram.new_track(1, name="Genes")
679
+ gd_feature_set = gd_track.new_set()
680
+
681
+ # Add features
682
+ for gene in gene_results:
683
+ # Create feature
684
+ feature = SeqFeature(
685
+ FeatureLocation(gene['start'], gene['end']),
686
+ type="gene"
687
+ )
688
+
689
+ # Calculate color based on SHAP value
690
+ if gene['avg_shap'] > 0:
691
+ intensity = min(1.0, abs(gene['avg_shap']) * 2)
692
+ color = colors.Color(1-intensity, 1-intensity, 1) # Red
693
+ else:
694
+ intensity = min(1.0, abs(gene['avg_shap']) * 2)
695
+ color = colors.Color(1-intensity, 1-intensity, 1) # Blue
696
+
697
+ # Add to diagram
698
+ gd_feature_set.add_feature(
699
+ feature,
700
+ color=color,
701
+ label=True,
702
+ name=f"{gene['gene_name']}\n(SHAP: {gene['avg_shap']:.3f})"
703
+ )
704
+
705
+ # Draw diagram
706
+ gd_diagram.draw(
707
+ format="linear",
708
+ orientation="landscape",
709
+ pagesize=(15, 5),
710
+ start=0,
711
+ end=genome_length,
712
+ fragments=1
713
+ )
714
+
715
+ # Save to BytesIO and convert to PIL Image
716
+ buffer = BytesIO()
717
+ gd_diagram.write(buffer, "PNG")
718
+ buffer.seek(0)
719
+ return Image.open(buffer)
720
+
721
+ ###############################################################################
722
+ # 12. DOWNLOAD FUNCTIONS
723
+ ###############################################################################
724
+
725
+ def prepare_csv_download(data, filename="analysis_results.csv"):
726
+ """Prepare CSV data for download"""
727
+ if isinstance(data, str):
728
+ return data.encode(), filename
729
+ elif isinstance(data, (list, dict)):
730
+ import csv
731
+ from io import StringIO
732
+
733
+ output = StringIO()
734
+ writer = csv.DictWriter(output, fieldnames=data[0].keys())
735
+ writer.writeheader()
736
+ writer.writerows(data)
737
+ return output.getvalue().encode(), filename
738
+ else:
739
+ raise ValueError("Unsupported data type for CSV download")
740
 
741
  ###############################################################################
742
+ # 13. BUILD GRADIO INTERFACE
743
  ###############################################################################
744
 
745
  css = """
746
  .gradio-container {
747
  font-family: 'IBM Plex Sans', sans-serif;
748
  }
749
+ .download-button {
750
+ margin-top: 10px;
751
+ }
752
  """
753
 
754
  with gr.Blocks(css=css) as iface:
755
  gr.Markdown("""
756
  # Virus Host Classifier
757
  **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
758
+ **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
759
+ **Step 3**: Analyze gene features and their contributions.
760
+ **Step 4**: Compare sequences and analyze differences.
761
 
762
  **Color Scale**: Negative SHAP = Blue, Zero = White, Positive = Red.
763
  """)
 
774
  results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
775
  kmer_img = gr.Image(label="Top k-mer SHAP")
776
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
777
+ download_results = gr.File(label="Download Results", visible=False, elem_classes="download-button")
778
  seq_state = gr.State()
779
  header_state = gr.State()
780
  analyze_btn.click(
781
  analyze_sequence,
782
  inputs=[file_input, top_k, text_input, win_size],
783
+ outputs=[results_box, kmer_img, genome_img, seq_state, header_state, download_results]
784
  )
785
 
786
  with gr.Tab("2) Subregion Exploration"):
787
  gr.Markdown("""
788
  **Subregion Analysis**
789
  Select start/end positions to view local SHAP signals, distribution, GC content, etc.
790
+ The heatmap uses the same Blue-White-Red scale.
791
  """)
792
  with gr.Row():
793
  region_start = gr.Number(label="Region Start", value=0)
 
797
  with gr.Row():
798
  subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
799
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
800
+ download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button")
801
  region_btn.click(
802
  analyze_subregion,
803
  inputs=[seq_state, header_state, region_start, region_end],
804
+ outputs=[subregion_info, subregion_img, subregion_hist_img, download_subregion]
805
+ )
806
+
807
+ with gr.Tab("3) Gene Features Analysis"):
808
+ gr.Markdown("""
809
+ **Analyze Gene Features**
810
+ Upload a FASTA file and corresponding gene features file to analyze SHAP values per gene.
811
+ Gene features should be in the format:
812
+ ```
813
+ >gene_name [gene=X] [locus_tag=Y] [location=start..end]
814
+ SEQUENCE
815
+ ```
816
+ The genome viewer will show genes color-coded by their contribution:
817
+ - Red: Genes pushing toward human origin
818
+ - Blue: Genes pushing toward non-human origin
819
+ - Color intensity indicates strength of signal
820
+ """)
821
+ with gr.Row():
822
+ with gr.Column(scale=1):
823
+ gene_fasta_file = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
824
+ gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5)
825
+ with gr.Column(scale=1):
826
+ features_file = gr.File(label="Upload gene features file", file_types=[".txt"], type="filepath")
827
+ features_text = gr.Textbox(label="Or paste gene features", placeholder=">gene_1 [gene=U12]...\nACGT...", lines=5)
828
+
829
+ analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary")
830
+ gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False)
831
+ gene_diagram = gr.Image(label="Genome Diagram with Gene Features")
832
+ download_gene_results = gr.File(label="Download Gene Analysis", visible=False, elem_classes="download-button")
833
+
834
+ analyze_genes_btn.click(
835
+ analyze_gene_features,
836
+ inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text],
837
+ outputs=[gene_results, download_gene_results, gene_diagram]
838
  )
839
 
840
+ with gr.Tab("4) Comparative Analysis"):
841
  gr.Markdown("""
842
  **Compare Two Sequences**
843
  Upload or paste two FASTA sequences to compare their SHAP patterns.
 
860
  with gr.Row():
861
  diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
862
  diff_hist = gr.Image(label="Distribution of SHAP Differences")
863
+ download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button")
864
  compare_btn.click(
865
  analyze_sequence_comparison,
866
  inputs=[file_input1, file_input2, text_input1, text_input2],
867
+ outputs=[comparison_text, diff_heatmap, diff_hist, download_comparison]
868
  )
869
 
870
  gr.Markdown("""
871
  ### Interface Features
872
+ - **Overall Classification** (human vs non-human) using k-mer frequencies
873
+ - **SHAP Analysis** shows which k-mers push classification toward or away from human
874
  - **White-Centered SHAP Gradient**:
875
+ - Negative (blue), 0 (white), Positive (red)
876
+ - Symmetrical color range around 0
877
+ - **Identify Subregions** with strongest push for human or non-human
878
+ - **Gene Feature Analysis**:
879
+ - Analyze individual genes' contributions
880
+ - Interactive genome viewer
881
+ - Gene-level statistics and classification
882
  - **Sequence Comparison**:
883
  - Compare two sequences to identify regions of difference
884
+ - Normalized comparison to handle different lengths
885
  - Statistical summary of differences
886
+ - **Data Export**:
887
+ - Download results as CSV files
888
+ - Save analysis outputs for further processing
889
  """)
890
 
891
  if __name__ == "__main__":
892
+ iface.launch()