hiyata commited on
Commit
dbad921
·
verified ·
1 Parent(s): 9a5c352

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -72
app.py CHANGED
@@ -8,10 +8,14 @@ import matplotlib.pyplot as plt
8
  import matplotlib.colors as mcolors
9
  import io
10
  from io import BytesIO # Import io then BytesIO
11
- from PIL import Image, ImageDraw
12
  from Bio.Graphics import GenomeDiagram
13
  from Bio.SeqFeature import SeqFeature, FeatureLocation
14
  from reportlab.lib import colors
 
 
 
 
15
 
16
  ###############################################################################
17
  # 1. MODEL DEFINITION
@@ -563,8 +567,16 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
563
  # 11. GENE FEATURE ANALYSIS
564
  ###############################################################################
565
 
566
- def parse_gene_features(text):
567
- """Parse gene features from text file in FASTA-like format"""
 
 
 
 
 
 
 
 
568
  genes = []
569
  current_header = None
570
  current_sequence = []
@@ -573,6 +585,7 @@ def parse_gene_features(text):
573
  line = line.strip()
574
  if not line:
575
  continue
 
576
  if line.startswith('>'):
577
  if current_header:
578
  genes.append({
@@ -594,8 +607,16 @@ def parse_gene_features(text):
594
 
595
  return genes
596
 
597
- def parse_gene_metadata(header):
598
- """Extract metadata from gene header"""
 
 
 
 
 
 
 
 
599
  metadata = {}
600
  parts = header.split()
601
 
@@ -607,8 +628,92 @@ def parse_gene_metadata(header):
607
 
608
  return metadata
609
 
610
- def analyze_gene_features(sequence_file, features_file, fasta_text="", features_text=""):
611
- """Analyze SHAP values for each gene feature"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  # First analyze whole sequence
613
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
614
  if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
@@ -618,14 +723,14 @@ def analyze_gene_features(sequence_file, features_file, fasta_text="", features_
618
  shap_means = sequence_results[3]["shap_means"]
619
 
620
  # Parse gene features
621
- if features_text.strip():
622
- genes = parse_gene_features(features_text)
623
- else:
624
- try:
625
  with open(features_file, 'r') as f:
626
  genes = parse_gene_features(f.read())
627
- except Exception as e:
628
- return f"Error reading features file: {str(e)}", None, None
629
 
630
  # Analyze each gene
631
  gene_results = []
@@ -635,78 +740,142 @@ def analyze_gene_features(sequence_file, features_file, fasta_text="", features_
635
  if not location:
636
  continue
637
 
638
- # Parse location (assuming format like "21729..22861")
639
- start, end = map(int, location.split('..'))
640
-
 
641
  # Get SHAP values for this region
642
  gene_shap = shap_means[start:end]
643
- avg_shap = float(np.mean(gene_shap))
644
 
645
  gene_results.append({
646
  'gene_name': gene['metadata'].get('gene', 'Unknown'),
647
  'location': location,
648
- 'avg_shap': avg_shap,
649
  'start': start,
650
  'end': end,
651
  'locus_tag': gene['metadata'].get('locus_tag', ''),
652
- 'classification': 'Human' if avg_shap > 0 else 'Non-human',
653
- 'confidence': abs(avg_shap)
 
 
 
 
 
 
654
  })
655
 
656
  except Exception as e:
657
  print(f"Error processing gene {gene['metadata'].get('gene', 'Unknown')}: {str(e)}")
658
  continue
 
 
 
659
 
660
- # Create CSV output
661
- csv_output = "gene_name,location,avg_shap,classification,confidence,locus_tag\n"
662
- for result in gene_results:
663
- csv_output += f"{result['gene_name']},{result['location']},{result['avg_shap']:.4f},"
664
- csv_output += f"{result['classification']},{result['confidence']:.4f},{result['locus_tag']}\n"
665
-
666
- # Create genome diagram
667
- diagram_img = create_genome_diagram(gene_results, len(shap_means))
668
-
669
- return gene_results, csv_output, diagram_img
670
-
671
- def create_genome_diagram(gene_results, genome_length):
672
- """Create genome diagram using BioPython"""
673
- from Bio.Graphics import GenomeDiagram
674
- from Bio.SeqFeature import SeqFeature, FeatureLocation
675
- from reportlab.lib import colors
676
- from io import BytesIO
677
- from PIL import Image
678
- import io # Ensure io is imported at the top level
679
-
680
- # Create diagram
681
- gd_diagram = GenomeDiagram.Diagram("Genome SHAP Analysis")
682
- gd_track = gd_diagram.new_track(1, name="Genes")
683
- gd_feature_set = gd_track.new_set()
684
-
685
- # Add features
686
  for gene in gene_results:
687
- # Create feature
688
- feature = SeqFeature(
689
- FeatureLocation(gene['start'], gene['end']),
690
- type="gene"
 
691
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
 
693
- # Calculate color based on SHAP value
694
- if gene['avg_shap'] > 0:
695
- intensity = min(1.0, abs(gene['avg_shap']) * 2)
696
- color = colors.Color(1-intensity, 1-intensity, 1) # Red
697
- else:
698
- intensity = min(1.0, abs(gene['avg_shap']) * 2)
699
- color = colors.Color(1-intensity, 1-intensity, 1) # Blue
700
-
701
- # Add to diagram
702
- gd_feature_set.add_feature(
703
- feature,
704
- color=color,
705
- label=True,
706
- name=f"{gene['gene_name']}\n(SHAP: {gene['avg_shap']:.3f})"
707
- )
 
 
 
 
 
 
708
 
 
 
 
709
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
  # Draw diagram
711
  gd_diagram.draw(
712
  format="linear",
@@ -722,13 +891,10 @@ def create_genome_diagram(gene_results, genome_length):
722
  gd_diagram.write(buffer, "PNG")
723
  buffer.seek(0)
724
  return Image.open(buffer)
 
725
  except Exception as e:
726
  print(f"Error creating genome diagram: {str(e)}")
727
- # Create a simple error image
728
- error_img = Image.new('RGB', (800, 100), color='white')
729
- draw = ImageDraw.Draw(error_img)
730
- draw.text((10, 40), f"Error creating genome diagram: {str(e)}", fill='black')
731
- return error_img
732
 
733
  ###############################################################################
734
  # 12. DOWNLOAD FUNCTIONS
@@ -822,7 +988,7 @@ with gr.Blocks(css=css) as iface:
822
  Upload a FASTA file and corresponding gene features file to analyze SHAP values per gene.
823
  Gene features should be in the format:
824
  ```
825
- >gene_name [gene=X] [locus_tag=Y] [location=start..end]
826
  SEQUENCE
827
  ```
828
  The genome viewer will show genes color-coded by their contribution:
@@ -841,7 +1007,7 @@ with gr.Blocks(css=css) as iface:
841
  analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary")
842
  gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False)
843
  gene_diagram = gr.Image(label="Genome Diagram with Gene Features")
844
- download_gene_results = gr.File(label="Download Gene Analysis", visible=False, elem_classes="download-button")
845
 
846
  analyze_genes_btn.click(
847
  analyze_gene_features,
 
8
  import matplotlib.colors as mcolors
9
  import io
10
  from io import BytesIO # Import io then BytesIO
11
+ from PIL import Image, ImageDraw, ImageFont
12
  from Bio.Graphics import GenomeDiagram
13
  from Bio.SeqFeature import SeqFeature, FeatureLocation
14
  from reportlab.lib import colors
15
+ import pandas as pd
16
+ import tempfile
17
+ import os
18
+ from typing import List, Dict, Tuple, Optional, Any
19
 
20
  ###############################################################################
21
  # 1. MODEL DEFINITION
 
567
  # 11. GENE FEATURE ANALYSIS
568
  ###############################################################################
569
 
570
+ def parse_gene_features(text: str) -> List[Dict[str, Any]]:
571
+ """
572
+ Parse gene features from text file in FASTA-like format
573
+
574
+ Args:
575
+ text (str): Input text in FASTA format with gene metadata
576
+
577
+ Returns:
578
+ List[Dict]: List of gene dictionaries containing sequence and metadata
579
+ """
580
  genes = []
581
  current_header = None
582
  current_sequence = []
 
585
  line = line.strip()
586
  if not line:
587
  continue
588
+
589
  if line.startswith('>'):
590
  if current_header:
591
  genes.append({
 
607
 
608
  return genes
609
 
610
+ def parse_gene_metadata(header: str) -> Dict[str, str]:
611
+ """
612
+ Extract metadata from gene header
613
+
614
+ Args:
615
+ header (str): Gene header line starting with '>'
616
+
617
+ Returns:
618
+ Dict[str, str]: Dictionary of metadata key-value pairs
619
+ """
620
  metadata = {}
621
  parts = header.split()
622
 
 
628
 
629
  return metadata
630
 
631
+ def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
632
+ """
633
+ Parse gene location string, handling both forward and complement strands
634
+
635
+ Args:
636
+ location_str (str): Location string (e.g., "1234..5678" or "complement(1234..5678)")
637
+
638
+ Returns:
639
+ Tuple[Optional[int], Optional[int]]: Start and end positions, or (None, None) if parsing fails
640
+ """
641
+ try:
642
+ # Handle complement strand
643
+ is_complement = location_str.startswith('complement(')
644
+ clean_loc = location_str.replace('complement(', '').replace(')', '')
645
+
646
+ # Split on '..' and convert to integers
647
+ if '..' in clean_loc:
648
+ start, end = map(int, clean_loc.split('..'))
649
+ return start, end
650
+ else:
651
+ return None, None
652
+
653
+ except Exception as e:
654
+ print(f"Error parsing location {location_str}: {str(e)}")
655
+ return None, None
656
+
657
+ def save_results_to_temp(results: str, prefix: str = "analysis") -> Optional[str]:
658
+ """
659
+ Save results to a temporary file
660
+
661
+ Args:
662
+ results (str): Content to save
663
+ prefix (str): Prefix for the temporary file name
664
+
665
+ Returns:
666
+ Optional[str]: Path to temporary file, or None if save fails
667
+ """
668
+ try:
669
+ temp_dir = tempfile.gettempdir()
670
+ temp_path = os.path.join(temp_dir, f"{prefix}_{os.urandom(4).hex()}.csv")
671
+
672
+ with open(temp_path, 'w') as f:
673
+ f.write(results)
674
+ return temp_path
675
+ except Exception as e:
676
+ print(f"Error saving results: {str(e)}")
677
+ return None
678
+
679
+ def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
680
+ """
681
+ Compute statistical measures for gene SHAP values
682
+
683
+ Args:
684
+ gene_shap (np.ndarray): Array of SHAP values for a gene
685
+
686
+ Returns:
687
+ Dict[str, float]: Dictionary of statistical measures
688
+ """
689
+ return {
690
+ 'avg_shap': float(np.mean(gene_shap)),
691
+ 'median_shap': float(np.median(gene_shap)),
692
+ 'std_shap': float(np.std(gene_shap)),
693
+ 'max_shap': float(np.max(gene_shap)),
694
+ 'min_shap': float(np.min(gene_shap)),
695
+ 'pos_fraction': float(np.mean(gene_shap > 0))
696
+ }
697
+
698
+ def analyze_gene_features(sequence_file: str,
699
+ features_file: str,
700
+ fasta_text: str = "",
701
+ features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
702
+ """
703
+ Analyze SHAP values for each gene feature
704
+
705
+ Args:
706
+ sequence_file (str): Path to FASTA file
707
+ features_file (str): Path to features file
708
+ fasta_text (str): FASTA content if provided as text
709
+ features_text (str): Features content if provided as text
710
+
711
+ Returns:
712
+ Tuple[str, Optional[str], Optional[Image.Image]]:
713
+ - Analysis results text
714
+ - Path to CSV file
715
+ - Genome diagram image
716
+ """
717
  # First analyze whole sequence
718
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
719
  if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
 
723
  shap_means = sequence_results[3]["shap_means"]
724
 
725
  # Parse gene features
726
+ try:
727
+ if features_text.strip():
728
+ genes = parse_gene_features(features_text)
729
+ else:
730
  with open(features_file, 'r') as f:
731
  genes = parse_gene_features(f.read())
732
+ except Exception as e:
733
+ return f"Error reading features file: {str(e)}", None, None
734
 
735
  # Analyze each gene
736
  gene_results = []
 
740
  if not location:
741
  continue
742
 
743
+ start, end = parse_location(location)
744
+ if start is None or end is None:
745
+ continue
746
+
747
  # Get SHAP values for this region
748
  gene_shap = shap_means[start:end]
749
+ stats = compute_gene_statistics(gene_shap)
750
 
751
  gene_results.append({
752
  'gene_name': gene['metadata'].get('gene', 'Unknown'),
753
  'location': location,
 
754
  'start': start,
755
  'end': end,
756
  'locus_tag': gene['metadata'].get('locus_tag', ''),
757
+ 'avg_shap': stats['avg_shap'],
758
+ 'median_shap': stats['median_shap'],
759
+ 'std_shap': stats['std_shap'],
760
+ 'max_shap': stats['max_shap'],
761
+ 'min_shap': stats['min_shap'],
762
+ 'pos_fraction': stats['pos_fraction'],
763
+ 'classification': 'Human' if stats['avg_shap'] > 0 else 'Non-human',
764
+ 'confidence': abs(stats['avg_shap'])
765
  })
766
 
767
  except Exception as e:
768
  print(f"Error processing gene {gene['metadata'].get('gene', 'Unknown')}: {str(e)}")
769
  continue
770
+
771
+ if not gene_results:
772
+ return "No valid genes could be processed", None, None
773
 
774
+ # Create results text
775
+ results_text = "Gene Analysis Results:\n\n"
776
+ results_text += f"Total genes analyzed: {len(gene_results)}\n"
777
+ results_text += f"Human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Human')}\n"
778
+ results_text += f"Non-human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Non-human')}\n\n"
779
+
780
+ # Sort genes by absolute SHAP value for reporting
781
+ sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
782
+
783
+ results_text += "Top 10 genes by signal strength:\n"
784
+ for gene in sorted_genes[:10]:
785
+ results_text += (
786
+ f"Gene: {gene['gene_name']}\n"
787
+ f"Location: {gene['location']}\n"
788
+ f"Classification: {gene['classification']} "
789
+ f"(confidence: {gene['confidence']:.4f})\n"
790
+ f"Average SHAP: {gene['avg_shap']:.4f}\n\n"
791
+ )
792
+
793
+ # Create CSV content
794
+ csv_content = "gene_name,location,avg_shap,median_shap,std_shap,max_shap,min_shap,"
795
+ csv_content += "pos_fraction,classification,confidence,locus_tag\n"
796
+
 
 
 
797
  for gene in gene_results:
798
+ csv_content += (
799
+ f"{gene['gene_name']},{gene['location']},{gene['avg_shap']:.4f},"
800
+ f"{gene['median_shap']:.4f},{gene['std_shap']:.4f},{gene['max_shap']:.4f},"
801
+ f"{gene['min_shap']:.4f},{gene['pos_fraction']:.4f},{gene['classification']},"
802
+ f"{gene['confidence']:.4f},{gene['locus_tag']}\n"
803
  )
804
+
805
+ # Save CSV to temp file
806
+ csv_path = save_results_to_temp(csv_content, "gene_analysis")
807
+
808
+ try:
809
+ # Create genome diagram
810
+ diagram_img = create_genome_diagram(gene_results, len(shap_means))
811
+ except Exception as e:
812
+ print(f"Error creating genome diagram: {str(e)}")
813
+ diagram_img = create_error_image(str(e))
814
+
815
+ return results_text, csv_path, diagram_img
816
+
817
+ def create_error_image(error_message: str) -> Image.Image:
818
+ """
819
+ Create an error image with message
820
+
821
+ Args:
822
+ error_message (str): Error message to display
823
 
824
+ Returns:
825
+ Image.Image: Error image
826
+ """
827
+ img = Image.new('RGB', (800, 100), color='white')
828
+ draw = ImageDraw.Draw(img)
829
+ try:
830
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
831
+ except:
832
+ font = None
833
+ draw.text((10, 40), f"Error creating genome diagram: {error_message}",
834
+ fill='black', font=font)
835
+ return img
836
+
837
+ def create_genome_diagram(gene_results: List[Dict[str, Any]],
838
+ genome_length: int) -> Image.Image:
839
+ """
840
+ Create genome diagram using BioPython
841
+
842
+ Args:
843
+ gene_results (List[Dict]): List of gene analysis results
844
+ genome_length (int): Total length of the genome
845
 
846
+ Returns:
847
+ Image.Image: Genome diagram image
848
+ """
849
  try:
850
+ # Create diagram
851
+ gd_diagram = GenomeDiagram.Diagram("Genome SHAP Analysis")
852
+ gd_track = gd_diagram.new_track(1, name="Genes")
853
+ gd_feature_set = gd_track.new_set()
854
+
855
+ # Add features
856
+ for gene in gene_results:
857
+ # Create feature
858
+ feature = SeqFeature(
859
+ FeatureLocation(gene['start'], gene['end']),
860
+ type="gene"
861
+ )
862
+
863
+ # Calculate color based on SHAP value
864
+ if gene['avg_shap'] > 0:
865
+ intensity = min(1.0, abs(gene['avg_shap']) * 2)
866
+ color = colors.Color(1-intensity, 1-intensity, 1) # Red
867
+ else:
868
+ intensity = min(1.0, abs(gene['avg_shap']) * 2)
869
+ color = colors.Color(1-intensity, 1-intensity, 1) # Blue
870
+
871
+ # Add to diagram
872
+ gd_feature_set.add_feature(
873
+ feature,
874
+ color=color,
875
+ label=True,
876
+ name=f"{gene['gene_name']}\n(SHAP: {gene['avg_shap']:.3f})"
877
+ )
878
+
879
  # Draw diagram
880
  gd_diagram.draw(
881
  format="linear",
 
891
  gd_diagram.write(buffer, "PNG")
892
  buffer.seek(0)
893
  return Image.open(buffer)
894
+
895
  except Exception as e:
896
  print(f"Error creating genome diagram: {str(e)}")
897
+ return create_error_image(str(e))
 
 
 
 
898
 
899
  ###############################################################################
900
  # 12. DOWNLOAD FUNCTIONS
 
988
  Upload a FASTA file and corresponding gene features file to analyze SHAP values per gene.
989
  Gene features should be in the format:
990
  ```
991
+ >gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)]
992
  SEQUENCE
993
  ```
994
  The genome viewer will show genes color-coded by their contribution:
 
1007
  analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary")
1008
  gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False)
1009
  gene_diagram = gr.Image(label="Genome Diagram with Gene Features")
1010
+ download_gene_results = gr.File(label="Download Gene Analysis (CSV)", visible=True)
1011
 
1012
  analyze_genes_btn.click(
1013
  analyze_gene_features,