hiyata commited on
Commit
ac53617
·
verified ·
1 Parent(s): ae32958

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -95
app.py CHANGED
@@ -645,36 +645,20 @@ def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
645
 
646
  def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
647
  """Create a simple genome diagram using PIL"""
648
- # Validate inputs and convert to proper types
649
  if not gene_results or genome_length <= 0:
650
  img = Image.new('RGB', (800, 100), color='white')
651
  draw = ImageDraw.Draw(img)
652
  draw.text((10, 40), "Error: Invalid input data", fill='black')
653
  return img
654
-
655
- # Ensure all gene coordinates are valid integers and within bounds
656
- valid_genes = []
657
  for gene in gene_results:
658
- try:
659
- start = max(0, int(float(gene['start'])))
660
- end = min(genome_length, int(float(gene['end'])))
661
- if start < end:
662
- gene_copy = gene.copy()
663
- gene_copy['start'] = start
664
- gene_copy['end'] = end
665
- valid_genes.append(gene_copy)
666
- else:
667
- print(f"Warning: Skipping gene {gene.get('gene_name', 'unknown')} due to invalid coordinates: {start}-{end}")
668
- except (ValueError, TypeError) as e:
669
- print(f"Warning: Skipping gene due to coordinate conversion error: {str(e)}")
670
- continue
671
-
672
- if not valid_genes:
673
- img = Image.new('RGB', (800, 100), color='white')
674
- draw = ImageDraw.Draw(img)
675
- draw.text((10, 40), "Error: No valid genes to display", fill='black')
676
- return img
677
-
678
  # Image dimensions
679
  width = 1500
680
  height = 600
@@ -685,104 +669,131 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
685
  img = Image.new('RGB', (width, height), 'white')
686
  draw = ImageDraw.Draw(img)
687
 
688
- # Use default font if custom font not available
689
  try:
690
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
691
  title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
692
  except:
693
- font = ImageFont.load_default()
694
- title_font = ImageFont.load_default()
695
 
696
  # Draw title
697
- draw.text((margin, margin//2), "Genome SHAP Analysis", fill='black', font=title_font)
 
698
 
699
  # Draw genome line
700
  line_y = height // 2
701
- draw.line([(margin, line_y), (width - margin, line_y)], fill='black', width=2)
702
 
703
  # Calculate scale factor
704
- scale = (width - 2 * margin) / genome_length
 
 
 
 
 
 
 
705
 
706
  # Draw scale markers
707
- for i in range(0, genome_length + 1, genome_length // 10):
708
- x = margin + int(i * scale)
709
- draw.line([(x, line_y - 5), (x, line_y + 5)], fill='black', width=1)
710
- draw.text((x - 20, line_y + 10), f"{i:,}", fill='black', font=font)
 
 
 
 
711
 
712
- # Sort genes by absolute SHAP value
713
- sorted_genes = sorted(valid_genes, key=lambda x: abs(float(x['avg_shap'])))
714
 
715
  # Draw genes
716
  for idx, gene in enumerate(sorted_genes):
717
- # Calculate position
718
- start_x = margin + int(float(gene['start']) * scale)
719
- end_x = margin + int(float(gene['end']) * scale)
720
 
721
  # Calculate color based on SHAP value
722
- avg_shap = float(gene['avg_shap'])
723
- if avg_shap > 0:
724
- intensity = min(255, int(abs(avg_shap * 500)))
725
- color = (255, 255 - intensity, 255 - intensity) # Red
726
  else:
727
- intensity = min(255, int(abs(avg_shap * 500)))
728
- color = (255 - intensity, 255 - intensity, 255) # Blue
729
 
730
  # Draw gene box
731
- y_top = line_y - track_height // 2
732
- y_bottom = line_y + track_height // 2
733
- draw.rectangle([(start_x, y_top), (end_x, y_bottom)],
734
- fill=color, outline='black')
735
 
736
- # Draw gene name
737
- label = str(gene['gene_name'])
738
- # Get text size for positioning
739
- if hasattr(font, 'getsize'):
740
- label_width, label_height = font.getsize(label)
741
- else:
742
- label_width = len(label) * 6 # Approximate width
743
- label_height = 12
744
 
745
- # Alternate label position above/below
746
  if idx % 2 == 0:
747
- text_y = y_top - label_height - 5
748
  else:
749
- text_y = y_bottom + 5
750
 
751
- # Draw label
752
  gene_width = end_x - start_x
753
  if gene_width > label_width:
754
- # Horizontal label
755
  text_x = start_x + (gene_width - label_width) // 2
756
- draw.text((text_x, text_y), label, fill='black', font=font)
757
  elif gene_width > 20:
758
- # Vertical label
759
  txt_img = Image.new('RGBA', (label_width, label_height), (255, 255, 255, 0))
760
  txt_draw = ImageDraw.Draw(txt_img)
761
  txt_draw.text((0, 0), label, font=font, fill='black')
762
- txt_img = txt_img.rotate(90, expand=True)
763
- img.paste(txt_img, (start_x, text_y), txt_img)
 
764
 
765
  # Draw legend
766
  legend_x = margin
767
  legend_y = height - margin
768
- draw.text((legend_x, legend_y - 60), "SHAP Values:", fill='black', font=font)
769
 
770
  # Draw legend boxes
771
  box_width = 20
772
  box_height = 20
773
  spacing = 15
774
 
775
- legend_items = [
776
- ((255, 0, 0), "Strong human-like signal", (legend_x, legend_y - 45)),
777
- ((255, 200, 200), "Weak human-like signal", (legend_x, legend_y - 20)),
778
- ((200, 200, 255), "Weak non-human-like signal", (legend_x + 250, legend_y - 45)),
779
- ((0, 0, 255), "Strong non-human-like signal", (legend_x + 250, legend_y - 20))
780
- ]
 
 
 
 
 
 
 
 
 
781
 
782
- for color, label, (x, y) in legend_items:
783
- draw.rectangle([(x, y, x + box_width, y + box_height)],
784
- fill=color, outline='black')
785
- draw.text((x + box_width + spacing, y), label, fill='black', font=font)
 
 
 
 
 
 
 
 
 
 
 
786
 
787
  return img
788
 
@@ -790,7 +801,11 @@ def analyze_gene_features(sequence_file: str,
790
  features_file: str,
791
  fasta_text: str = "",
792
  features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
793
- """Analyze SHAP values for each gene feature"""
 
 
 
 
794
  # First analyze whole sequence
795
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
796
  if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
@@ -882,23 +897,17 @@ def analyze_gene_features(sequence_file: str,
882
  # Save CSV to temp file
883
  try:
884
  temp_dir = tempfile.gettempdir()
 
 
 
 
 
 
885
  temp_path = None
886
 
887
- # Create visualization with robust error handling
888
  try:
889
- # Ensure all gene coordinates are numeric and valid
890
- for gene in gene_results:
891
- try:
892
- gene['start'] = int(float(gene['start']))
893
- gene['end'] = int(float(gene['end']))
894
- if gene['start'] >= gene['end']:
895
- raise ValueError(f"Invalid coordinates for gene {gene['gene_name']}: {gene['start']}-{gene['end']}")
896
- except (ValueError, TypeError) as e:
897
- print(f"Warning: Invalid coordinates for gene {gene['gene_name']}: {str(e)}")
898
- continue
899
-
900
  diagram_img = create_simple_genome_diagram(gene_results, len(shap_means))
901
-
902
  except Exception as e:
903
  print(f"Error creating visualization: {str(e)}")
904
  # Create error image
@@ -906,10 +915,8 @@ def analyze_gene_features(sequence_file: str,
906
  draw = ImageDraw.Draw(diagram_img)
907
  draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black')
908
 
909
- return results_text, temp_path, diagram_img os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv")
910
-
911
- with open(temp_path, 'w') as f:
912
- f.write(csv_content)
913
 
914
  ###############################################################################
915
  # 12. DOWNLOAD FUNCTIONS
 
645
 
646
  def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
647
  """Create a simple genome diagram using PIL"""
648
+ # Validate inputs
649
  if not gene_results or genome_length <= 0:
650
  img = Image.new('RGB', (800, 100), color='white')
651
  draw = ImageDraw.Draw(img)
652
  draw.text((10, 40), "Error: Invalid input data", fill='black')
653
  return img
654
+
655
+ # Ensure all gene coordinates are valid integers
 
656
  for gene in gene_results:
657
+ gene['start'] = max(0, int(gene['start']))
658
+ gene['end'] = min(genome_length, int(gene['end']))
659
+ if gene['start'] >= gene['end']:
660
+ print(f"Warning: Invalid coordinates for gene {gene['gene_name']}: {gene['start']}-{gene['end']}")
661
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  # Image dimensions
663
  width = 1500
664
  height = 600
 
669
  img = Image.new('RGB', (width, height), 'white')
670
  draw = ImageDraw.Draw(img)
671
 
672
+ # Try to load font, fall back to default if unavailable
673
  try:
674
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
675
  title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
676
  except:
677
+ font = None
678
+ title_font = None
679
 
680
  # Draw title
681
+ draw.text((margin, margin // 2), "Genome SHAP Analysis",
682
+ fill='black', font=title_font or font)
683
 
684
  # Draw genome line
685
  line_y = height // 2
686
+ draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2)
687
 
688
  # Calculate scale factor
689
+ scale = float(width - 2 * margin) / float(genome_length)
690
+
691
+ # Determine a reasonable step for scale markers (avoid zero step if genome_length<10)
692
+ num_ticks = 10
693
+ if genome_length < num_ticks:
694
+ step = 1
695
+ else:
696
+ step = genome_length // num_ticks
697
 
698
  # Draw scale markers
699
+ # Use min(genome_length, step * num_ticks) in range to avoid overshooting
700
+ for i in range(0, genome_length + 1, step):
701
+ x_coord = margin + i * scale
702
+ draw.line([
703
+ (int(x_coord), int(line_y - 5)),
704
+ (int(x_coord), int(line_y + 5))
705
+ ], fill='black', width=1)
706
+ draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
707
 
708
+ # Sort genes by absolute SHAP value for drawing
709
+ sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
710
 
711
  # Draw genes
712
  for idx, gene in enumerate(sorted_genes):
713
+ # Calculate position and ensure integers
714
+ start_x = margin + int(gene['start'] * scale)
715
+ end_x = margin + int(gene['end'] * scale)
716
 
717
  # Calculate color based on SHAP value
718
+ if gene['avg_shap'] > 0:
719
+ intensity = min(255, int(abs(gene['avg_shap'] * 500)))
720
+ color = (255, max(0, 255 - intensity), max(0, 255 - intensity)) # Red-ish
 
721
  else:
722
+ intensity = min(255, int(abs(gene['avg_shap'] * 500)))
723
+ color = (max(0, 255 - intensity), max(0, 255 - intensity), 255) # Blue-ish
724
 
725
  # Draw gene box
726
+ draw.rectangle([
727
+ (int(start_x), int(line_y - track_height // 2)),
728
+ (int(end_x), int(line_y + track_height // 2))
729
+ ], fill=color, outline='black')
730
 
731
+ # Prepare gene name label
732
+ label = f"{gene['gene_name']}"
733
+ label_width, label_height = draw.textsize(label, font=font)
 
 
 
 
 
734
 
735
+ # Alternate label positions above/below
736
  if idx % 2 == 0:
737
+ text_y = line_y - track_height - 15
738
  else:
739
+ text_y = line_y + track_height + 5
740
 
741
+ # Decide to rotate text or not based on available box width
742
  gene_width = end_x - start_x
743
  if gene_width > label_width:
744
+ # Draw horizontally
745
  text_x = start_x + (gene_width - label_width) // 2
746
+ draw.text((int(text_x), int(text_y)), label, fill='black', font=font)
747
  elif gene_width > 20:
748
+ # Create rotated text
749
  txt_img = Image.new('RGBA', (label_width, label_height), (255, 255, 255, 0))
750
  txt_draw = ImageDraw.Draw(txt_img)
751
  txt_draw.text((0, 0), label, font=font, fill='black')
752
+ rotated_img = txt_img.rotate(90, expand=True)
753
+ # Paste at (start_x, text_y) casted to int
754
+ img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img)
755
 
756
  # Draw legend
757
  legend_x = margin
758
  legend_y = height - margin
759
+ draw.text((int(legend_x), int(legend_y - 60)), "SHAP Values:", fill='black', font=font)
760
 
761
  # Draw legend boxes
762
  box_width = 20
763
  box_height = 20
764
  spacing = 15
765
 
766
+ # Strong human-like
767
+ draw.rectangle([
768
+ (int(legend_x), int(legend_y - 45)),
769
+ (int(legend_x + box_width), int(legend_y - 45 + box_height))
770
+ ], fill=(255, 0, 0), outline='black')
771
+ draw.text((int(legend_x + box_width + spacing), int(legend_y - 45)),
772
+ "Strong human-like signal", fill='black', font=font)
773
+
774
+ # Weak human-like
775
+ draw.rectangle([
776
+ (int(legend_x), int(legend_y - 20)),
777
+ (int(legend_x + box_width), int(legend_y - 20 + box_height))
778
+ ], fill=(255, 200, 200), outline='black')
779
+ draw.text((int(legend_x + box_width + spacing), int(legend_y - 20)),
780
+ "Weak human-like signal", fill='black', font=font)
781
 
782
+ # Weak non-human-like
783
+ draw.rectangle([
784
+ (int(legend_x + 250), int(legend_y - 45)),
785
+ (int(legend_x + 250 + box_width), int(legend_y - 45 + box_height))
786
+ ], fill=(200, 200, 255), outline='black')
787
+ draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 45)),
788
+ "Weak non-human-like signal", fill='black', font=font)
789
+
790
+ # Strong non-human-like
791
+ draw.rectangle([
792
+ (int(legend_x + 250), int(legend_y - 20)),
793
+ (int(legend_x + 250 + box_width), int(legend_y - 20 + box_height))
794
+ ], fill=(0, 0, 255), outline='black')
795
+ draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 20)),
796
+ "Strong non-human-like signal", fill='black', font=font)
797
 
798
  return img
799
 
 
801
  features_file: str,
802
  fasta_text: str = "",
803
  features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
804
+ """
805
+ Analyze SHAP values for each gene feature.
806
+ NOTE: This function assumes there's an `analyze_sequence(...)` function
807
+ defined elsewhere that returns the needed SHAP information.
808
+ """
809
  # First analyze whole sequence
810
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
811
  if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
 
897
  # Save CSV to temp file
898
  try:
899
  temp_dir = tempfile.gettempdir()
900
+ temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv")
901
+
902
+ with open(temp_path, 'w') as f:
903
+ f.write(csv_content)
904
+ except Exception as e:
905
+ print(f"Error saving CSV: {str(e)}")
906
  temp_path = None
907
 
908
+ # Create visualization
909
  try:
 
 
 
 
 
 
 
 
 
 
 
910
  diagram_img = create_simple_genome_diagram(gene_results, len(shap_means))
 
911
  except Exception as e:
912
  print(f"Error creating visualization: {str(e)}")
913
  # Create error image
 
915
  draw = ImageDraw.Draw(diagram_img)
916
  draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black')
917
 
918
+ return results_text, temp_path, diagram_img
919
+
 
 
920
 
921
  ###############################################################################
922
  # 12. DOWNLOAD FUNCTIONS