hiyata commited on
Commit
78f8b3b
·
verified ·
1 Parent(s): bc5e648

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -66
app.py CHANGED
@@ -82,7 +82,7 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
82
  return vec
83
 
84
  ###############################################################################
85
- # 3. SHAP-VALUE (ABLATION) CALCULATION
86
  ###############################################################################
87
 
88
  def calculate_shap_values(model, x_tensor):
@@ -105,7 +105,7 @@ def calculate_shap_values(model, x_tensor):
105
 
106
 
107
  ###############################################################################
108
- # 4. PER-BASE SHAP AGGREGATION
109
  ###############################################################################
110
 
111
  def compute_positionwise_scores(sequence, shap_values, k=4):
@@ -125,7 +125,7 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
125
  return shap_means
126
 
127
  ###############################################################################
128
- # 5. FIND EXTREME SHAP REGIONS
129
  ###############################################################################
130
 
131
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
@@ -166,7 +166,7 @@ def get_zero_centered_cmap():
166
  colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
167
  return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
168
 
169
- def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
170
  if start is not None and end is not None:
171
  local_shap = shap_means[start:end]
172
  subtitle = f" (positions {start}-{end})"
@@ -184,7 +184,7 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
184
  cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
185
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
186
  cbar.ax.tick_params(labelsize=8)
187
- cbar.set_label('SHAP Contribution', fontsize=9, labelpad=5)
188
  ax.set_yticks([])
189
  ax.set_xlabel('Position in Sequence', fontsize=10)
190
  ax.set_title(f"{title}{subtitle}", pad=10)
@@ -200,17 +200,17 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
200
  colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
201
  plt.barh(range(len(values)), values, color=colors)
202
  plt.yticks(range(len(values)), features)
203
- plt.xlabel('SHAP Value (impact on model output)')
204
  plt.title(f'Top {top_k} Most Influential k-mers')
205
  plt.gca().invert_yaxis()
206
  plt.tight_layout()
207
  return fig
208
 
209
- def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bins=30):
210
  fig, ax = plt.subplots(figsize=(6, 4))
211
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
212
  ax.axvline(0, color='red', linestyle='--', label='0.0')
213
- ax.set_xlabel("SHAP Value")
214
  ax.set_ylabel("Count")
215
  ax.set_title(title)
216
  ax.legend()
@@ -227,23 +227,23 @@ def compute_gc_content(sequence):
227
  # 7. MAIN ANALYSIS STEP (Gradio Step 1)
228
  ###############################################################################
229
  def create_kmer_shap_csv(kmers, shap_values):
230
- """Create a CSV file with k-mer SHAP values and return the filepath"""
231
- # Create DataFrame with k-mers and SHAP values
232
  kmer_df = pd.DataFrame({
233
  'kmer': kmers,
234
- 'shap_value': shap_values,
235
- 'abs_shap': np.abs(shap_values)
236
  })
237
 
238
- # Sort by absolute SHAP value (most influential first)
239
- kmer_df = kmer_df.sort_values('abs_shap', ascending=False)
240
 
241
- # Drop the abs_shap column used for sorting
242
- kmer_df = kmer_df[['kmer', 'shap_value']]
243
 
244
  # Save to temporary file
245
  temp_dir = tempfile.gettempdir()
246
- temp_path = os.path.join(temp_dir, f"kmer_shap_values_{os.urandom(4).hex()}.csv")
247
  kmer_df.to_csv(temp_path, index=False)
248
 
249
  return temp_path # Return only the file path, not a tuple
@@ -296,19 +296,19 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
296
  f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
297
  f"---\n"
298
  f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
299
- f"Start: {max_start}, End: {max_end}, Avg SHAP: {max_avg:.4f}\n\n"
300
  f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
301
- f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
302
  )
303
 
304
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
305
  bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
306
  bar_img = fig_to_image(bar_fig)
307
 
308
- heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
309
  heatmap_img = fig_to_image(heatmap_fig)
310
 
311
- # Create CSV with k-mer SHAP values and return the file path
312
  kmer_shap_csv = create_kmer_shap_csv(kmers, shap_values)
313
 
314
  # State dictionary for subregion analysis
@@ -347,14 +347,14 @@ def analyze_subregion(state, header, region_start, region_end):
347
  f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
348
  f"Region length: {len(region_seq)} bases\n"
349
  f"GC content: {gc_percent:.2f}%\n"
350
- f"Average SHAP in region: {avg_shap:.4f}\n"
351
- f"Fraction with SHAP > 0 (toward human): {positive_fraction:.2f}\n"
352
- f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
353
  f"Subregion interpretation: {region_classification}\n"
354
  )
355
- heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end)
356
  heatmap_img = fig_to_image(heatmap_fig)
357
- hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
358
  hist_img = fig_to_image(hist_fig)
359
 
360
  # For demonstration, returning None for the file download as well
@@ -370,10 +370,10 @@ def get_zero_centered_cmap():
370
  return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
371
 
372
  def compute_shap_difference(shap1_norm, shap2_norm):
373
- """Compute the SHAP difference between normalized sequences"""
374
  return shap2_norm - shap1_norm
375
 
376
- def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
377
  """
378
  Plot heatmap using relative positions (0-100%)
379
  """
@@ -393,7 +393,7 @@ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
393
 
394
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
395
  cbar.ax.tick_params(labelsize=8)
396
- cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
397
 
398
  ax.set_yticks([])
399
  ax.set_xlabel('Relative Position in Sequence', fontsize=10)
@@ -402,14 +402,14 @@ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
402
 
403
  return fig
404
 
405
- def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
406
  """
407
- Plot histogram of SHAP values with configurable number of bins
408
  """
409
  fig, ax = plt.subplots(figsize=(6, 4))
410
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
411
  ax.axvline(0, color='red', linestyle='--', label='0.0')
412
- ax.set_xlabel("SHAP Value")
413
  ax.set_ylabel("Count")
414
  ax.set_title(title)
415
  ax.legend()
@@ -483,7 +483,7 @@ def sliding_window_smooth(values, window_size=50):
483
 
484
  def normalize_shap_lengths(shap1, shap2):
485
  """
486
- Normalize and smooth SHAP values with dynamic adaptation
487
  """
488
  # Calculate adaptive parameters
489
  num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
@@ -517,7 +517,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
517
  if isinstance(res2[0], str) and "Error" in res2[0]:
518
  return (f"Error in sequence 2: {res2[0]}", None, None, None)
519
 
520
- # Extract SHAP values and sequence info
521
  shap1 = res1[3]["shap_means"]
522
  shap2 = res2[3]["shap_means"]
523
 
@@ -567,7 +567,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
567
  f"Smoothing Window: {smooth_window} points\n"
568
  f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n"
569
  "Statistics:\n"
570
- f"Average SHAP difference: {avg_diff:.4f}\n"
571
  f"Standard deviation: {std_diff:.4f}\n"
572
  f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
573
  f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
@@ -582,7 +582,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
582
  # Generate visualizations
583
  heatmap_fig = plot_comparative_heatmap(
584
  shap_diff,
585
- title=f"SHAP Difference Heatmap (window: {smooth_window})"
586
  )
587
  heatmap_img = fig_to_image(heatmap_fig)
588
 
@@ -590,7 +590,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
590
  num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
591
  hist_fig = plot_shap_histogram(
592
  shap_diff,
593
- title="Distribution of SHAP Differences",
594
  num_bins=num_bins
595
  )
596
  hist_img = fig_to_image(hist_fig)
@@ -680,7 +680,7 @@ def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
680
  return None, None
681
 
682
  def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
683
- """Compute statistical measures for gene SHAP values"""
684
  return {
685
  'avg_shap': float(np.mean(gene_shap)),
686
  'median_shap': float(np.median(gene_shap)),
@@ -693,7 +693,7 @@ def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
693
  def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
694
  """
695
  Create a simple genome diagram using PIL, forcing a minimum color intensity
696
- so that small SHAP values don't appear white.
697
  """
698
  from PIL import Image, ImageDraw, ImageFont
699
 
@@ -730,7 +730,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
730
  title_font = ImageFont.load_default()
731
 
732
  # Draw title
733
- draw.text((margin, margin // 2), "Genome SHAP Analysis", fill='black', font=title_font or font)
734
 
735
  # Draw genome line
736
  line_y = height // 2
@@ -755,7 +755,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
755
  ], fill='black', width=1)
756
  draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
757
 
758
- # Sort genes by absolute SHAP value for drawing
759
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
760
 
761
  # Draw genes
@@ -764,10 +764,10 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
764
  start_x = margin + int(gene['start'] * scale)
765
  end_x = margin + int(gene['end'] * scale)
766
 
767
- # Calculate color based on SHAP value
768
  avg_shap = gene['avg_shap']
769
 
770
- # Convert shap -> color intensity (0 to 255)
771
  # Then clamp to a minimum intensity so it never ends up plain white
772
  intensity = int(abs(avg_shap) * 500)
773
  intensity = max(50, min(255, intensity)) # clamp between 50 and 255
@@ -813,7 +813,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
813
  # Draw legend
814
  legend_x = margin
815
  legend_y = height - margin
816
- draw.text((int(legend_x), int(legend_y - 60)), "SHAP Values:", fill='black', font=font)
817
 
818
  # Draw legend boxes
819
  box_width = 20
@@ -858,13 +858,13 @@ def analyze_gene_features(sequence_file: str,
858
  features_file: str,
859
  fasta_text: str = "",
860
  features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
861
- """Analyze SHAP values for each gene feature"""
862
  # First analyze whole sequence
863
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
864
  if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
865
  return f"Error in sequence analysis: {sequence_results[0]}", None, None
866
 
867
- # Get SHAP values
868
  shap_means = sequence_results[3]["shap_means"]
869
 
870
  # Parse gene features
@@ -889,7 +889,7 @@ def analyze_gene_features(sequence_file: str,
889
  if start is None or end is None:
890
  continue
891
 
892
- # Get SHAP values for this region
893
  gene_shap = shap_means[start:end]
894
  stats = compute_gene_statistics(gene_shap)
895
 
@@ -916,7 +916,7 @@ def analyze_gene_features(sequence_file: str,
916
  if not gene_results:
917
  return "No valid genes could be processed", None, None
918
 
919
- # Sort genes by absolute SHAP value
920
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
921
 
922
  # Create results text
@@ -932,11 +932,11 @@ def analyze_gene_features(sequence_file: str,
932
  f"Location: {gene['location']}\n"
933
  f"Classification: {gene['classification']} "
934
  f"(confidence: {gene['confidence']:.4f})\n"
935
- f"Average SHAP: {gene['avg_shap']:.4f}\n\n"
936
  )
937
 
938
  # Create CSV content
939
- csv_content = "gene_name,location,avg_shap,median_shap,std_shap,max_shap,min_shap,"
940
  csv_content += "pos_fraction,classification,confidence,locus_tag\n"
941
 
942
  for gene in gene_results:
@@ -1020,11 +1020,11 @@ with gr.Blocks(css=css) as iface:
1020
  gr.Markdown("""
1021
  # Virus Host Classifier
1022
  **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
1023
- **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
1024
  **Step 3**: Analyze gene features and their contributions.
1025
  **Step 4**: Compare sequences and analyze differences.
1026
 
1027
- **Color Scale**: Negative SHAP = Blue, Zero = White, Positive SHAP = Red.
1028
  """)
1029
 
1030
  with gr.Tab("1) Full-Sequence Analysis"):
@@ -1043,11 +1043,11 @@ with gr.Blocks(css=css) as iface:
1043
 
1044
  with gr.Column(scale=2):
1045
  results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
1046
- kmer_img = gr.Image(label="Top k-mer SHAP")
1047
- genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
1048
 
1049
  # File components with the correct type parameter
1050
- download_kmer_shap = gr.File(label="Download k-mer SHAP Values (CSV)", visible=True, type="filepath")
1051
  download_results = gr.File(label="Download Results", visible=True, elem_classes="download-button")
1052
 
1053
  seq_state = gr.State()
@@ -1071,7 +1071,7 @@ with gr.Blocks(css=css) as iface:
1071
  with gr.Tab("2) Subregion Exploration"):
1072
  gr.Markdown("""
1073
  **Subregion Analysis**
1074
- Select start/end positions to view local SHAP signals, distribution, GC content, etc.
1075
  The heatmap uses the same Blue-White-Red scale.
1076
  """)
1077
  with gr.Row():
@@ -1080,8 +1080,8 @@ with gr.Blocks(css=css) as iface:
1080
  region_btn = gr.Button("Analyze Subregion")
1081
  subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False)
1082
  with gr.Row():
1083
- subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
1084
- subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
1085
  download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button")
1086
 
1087
  region_btn.click(
@@ -1093,12 +1093,11 @@ with gr.Blocks(css=css) as iface:
1093
  with gr.Tab("3) Gene Features Analysis"):
1094
  gr.Markdown("""
1095
  **Analyze Gene Features**
1096
- Upload a FASTA file and corresponding gene features file to analyze SHAP values per gene.
1097
  Gene features should be in the format:
1098
 
1099
  >gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)]
1100
  SEQUENCE
1101
-
1102
  The genome viewer will show genes color-coded by their contribution:
1103
  - Red: Genes pushing toward human origin
1104
  - Blue: Genes pushing toward non-human origin
@@ -1126,7 +1125,7 @@ with gr.Blocks(css=css) as iface:
1126
  with gr.Tab("4) Comparative Analysis"):
1127
  gr.Markdown("""
1128
  **Compare Two Sequences**
1129
- Upload or paste two FASTA sequences to compare their SHAP patterns.
1130
  The sequences will be normalized to the same length for comparison.
1131
 
1132
  **Color Scale**:
@@ -1144,8 +1143,8 @@ with gr.Blocks(css=css) as iface:
1144
  compare_btn = gr.Button("Compare Sequences", variant="primary")
1145
  comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
1146
  with gr.Row():
1147
- diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
1148
- diff_hist = gr.Image(label="Distribution of SHAP Differences")
1149
  download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button")
1150
 
1151
  compare_btn.click(
@@ -1157,8 +1156,8 @@ with gr.Blocks(css=css) as iface:
1157
  gr.Markdown("""
1158
  ### Interface Features
1159
  - **Overall Classification** (human vs non-human) using k-mer frequencies
1160
- - **SHAP Analysis** shows which k-mers push classification toward or away from human
1161
- - **White-Centered SHAP Gradient**:
1162
  - Negative (blue), 0 (white), Positive (red)
1163
  - Symmetrical color range around 0
1164
  - **Identify Subregions** with strongest push for human or non-human
@@ -1172,7 +1171,7 @@ with gr.Blocks(css=css) as iface:
1172
  - Statistical summary of differences
1173
  - **Data Export**:
1174
  - Download results as CSV files
1175
- - Download k-mer SHAP values
1176
  - Save analysis outputs for further processing
1177
  """)
1178
 
 
82
  return vec
83
 
84
  ###############################################################################
85
+ # 3. FEATURE IMPORTANCE (ABLATION) CALCULATION
86
  ###############################################################################
87
 
88
  def calculate_shap_values(model, x_tensor):
 
105
 
106
 
107
  ###############################################################################
108
+ # 4. PER-BASE FEATURE IMPORTANCE AGGREGATION
109
  ###############################################################################
110
 
111
  def compute_positionwise_scores(sequence, shap_values, k=4):
 
125
  return shap_means
126
 
127
  ###############################################################################
128
+ # 5. FIND EXTREME IMPORTANCE REGIONS
129
  ###############################################################################
130
 
131
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
 
166
  colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
167
  return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
168
 
169
+ def plot_linear_heatmap(shap_means, title="Per-base Feature Importance Heatmap", start=None, end=None):
170
  if start is not None and end is not None:
171
  local_shap = shap_means[start:end]
172
  subtitle = f" (positions {start}-{end})"
 
184
  cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
185
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
186
  cbar.ax.tick_params(labelsize=8)
187
+ cbar.set_label('Feature Importance', fontsize=9, labelpad=5)
188
  ax.set_yticks([])
189
  ax.set_xlabel('Position in Sequence', fontsize=10)
190
  ax.set_title(f"{title}{subtitle}", pad=10)
 
200
  colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
201
  plt.barh(range(len(values)), values, color=colors)
202
  plt.yticks(range(len(values)), features)
203
+ plt.xlabel('Feature Importance (impact on model output)')
204
  plt.title(f'Top {top_k} Most Influential k-mers')
205
  plt.gca().invert_yaxis()
206
  plt.tight_layout()
207
  return fig
208
 
209
+ def plot_shap_histogram(shap_array, title="Feature Importance Distribution in Region", num_bins=30):
210
  fig, ax = plt.subplots(figsize=(6, 4))
211
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
212
  ax.axvline(0, color='red', linestyle='--', label='0.0')
213
+ ax.set_xlabel("Feature Importance Value")
214
  ax.set_ylabel("Count")
215
  ax.set_title(title)
216
  ax.legend()
 
227
  # 7. MAIN ANALYSIS STEP (Gradio Step 1)
228
  ###############################################################################
229
  def create_kmer_shap_csv(kmers, shap_values):
230
+ """Create a CSV file with k-mer importance values and return the filepath"""
231
+ # Create DataFrame with k-mers and importance values
232
  kmer_df = pd.DataFrame({
233
  'kmer': kmers,
234
+ 'importance_value': shap_values,
235
+ 'abs_importance': np.abs(shap_values)
236
  })
237
 
238
+ # Sort by absolute importance value (most influential first)
239
+ kmer_df = kmer_df.sort_values('abs_importance', ascending=False)
240
 
241
+ # Drop the abs_importance column used for sorting
242
+ kmer_df = kmer_df[['kmer', 'importance_value']]
243
 
244
  # Save to temporary file
245
  temp_dir = tempfile.gettempdir()
246
+ temp_path = os.path.join(temp_dir, f"kmer_importance_values_{os.urandom(4).hex()}.csv")
247
  kmer_df.to_csv(temp_path, index=False)
248
 
249
  return temp_path # Return only the file path, not a tuple
 
296
  f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
297
  f"---\n"
298
  f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
299
+ f"Start: {max_start}, End: {max_end}, Avg Importance: {max_avg:.4f}\n\n"
300
  f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
301
+ f"Start: {min_start}, End: {min_end}, Avg Importance: {min_avg:.4f}"
302
  )
303
 
304
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
305
  bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
306
  bar_img = fig_to_image(bar_fig)
307
 
308
+ heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide Feature Importance")
309
  heatmap_img = fig_to_image(heatmap_fig)
310
 
311
+ # Create CSV with k-mer importance values and return the file path
312
  kmer_shap_csv = create_kmer_shap_csv(kmers, shap_values)
313
 
314
  # State dictionary for subregion analysis
 
347
  f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
348
  f"Region length: {len(region_seq)} bases\n"
349
  f"GC content: {gc_percent:.2f}%\n"
350
+ f"Average importance in region: {avg_shap:.4f}\n"
351
+ f"Fraction with importance > 0 (toward human): {positive_fraction:.2f}\n"
352
+ f"Fraction with importance < 0 (toward non-human): {negative_fraction:.2f}\n"
353
  f"Subregion interpretation: {region_classification}\n"
354
  )
355
+ heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion Feature Importance", start=region_start, end=region_end)
356
  heatmap_img = fig_to_image(heatmap_fig)
357
+ hist_fig = plot_shap_histogram(region_shap, title="Feature Importance Distribution in Subregion")
358
  hist_img = fig_to_image(hist_fig)
359
 
360
  # For demonstration, returning None for the file download as well
 
370
  return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
371
 
372
  def compute_shap_difference(shap1_norm, shap2_norm):
373
+ """Compute the feature importance difference between normalized sequences"""
374
  return shap2_norm - shap1_norm
375
 
376
+ def plot_comparative_heatmap(shap_diff, title="Feature Importance Difference Heatmap"):
377
  """
378
  Plot heatmap using relative positions (0-100%)
379
  """
 
393
 
394
  cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
395
  cbar.ax.tick_params(labelsize=8)
396
+ cbar.set_label('Feature Importance Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
397
 
398
  ax.set_yticks([])
399
  ax.set_xlabel('Relative Position in Sequence', fontsize=10)
 
402
 
403
  return fig
404
 
405
+ def plot_shap_histogram(shap_array, title="Feature Importance Distribution", num_bins=30):
406
  """
407
+ Plot histogram of feature importance values with configurable number of bins
408
  """
409
  fig, ax = plt.subplots(figsize=(6, 4))
410
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
411
  ax.axvline(0, color='red', linestyle='--', label='0.0')
412
+ ax.set_xlabel("Feature Importance Value")
413
  ax.set_ylabel("Count")
414
  ax.set_title(title)
415
  ax.legend()
 
483
 
484
  def normalize_shap_lengths(shap1, shap2):
485
  """
486
+ Normalize and smooth feature importance values with dynamic adaptation
487
  """
488
  # Calculate adaptive parameters
489
  num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
 
517
  if isinstance(res2[0], str) and "Error" in res2[0]:
518
  return (f"Error in sequence 2: {res2[0]}", None, None, None)
519
 
520
+ # Extract feature importance values and sequence info
521
  shap1 = res1[3]["shap_means"]
522
  shap2 = res2[3]["shap_means"]
523
 
 
567
  f"Smoothing Window: {smooth_window} points\n"
568
  f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n"
569
  "Statistics:\n"
570
+ f"Average feature importance difference: {avg_diff:.4f}\n"
571
  f"Standard deviation: {std_diff:.4f}\n"
572
  f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
573
  f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
 
582
  # Generate visualizations
583
  heatmap_fig = plot_comparative_heatmap(
584
  shap_diff,
585
+ title=f"Feature Importance Difference Heatmap (window: {smooth_window})"
586
  )
587
  heatmap_img = fig_to_image(heatmap_fig)
588
 
 
590
  num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
591
  hist_fig = plot_shap_histogram(
592
  shap_diff,
593
+ title="Distribution of Feature Importance Differences",
594
  num_bins=num_bins
595
  )
596
  hist_img = fig_to_image(hist_fig)
 
680
  return None, None
681
 
682
  def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
683
+ """Compute statistical measures for gene feature importance values"""
684
  return {
685
  'avg_shap': float(np.mean(gene_shap)),
686
  'median_shap': float(np.median(gene_shap)),
 
693
  def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
694
  """
695
  Create a simple genome diagram using PIL, forcing a minimum color intensity
696
+ so that small feature importance values don't appear white.
697
  """
698
  from PIL import Image, ImageDraw, ImageFont
699
 
 
730
  title_font = ImageFont.load_default()
731
 
732
  # Draw title
733
+ draw.text((margin, margin // 2), "Genome Feature Importance Analysis", fill='black', font=title_font or font)
734
 
735
  # Draw genome line
736
  line_y = height // 2
 
755
  ], fill='black', width=1)
756
  draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
757
 
758
+ # Sort genes by absolute feature importance value for drawing
759
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
760
 
761
  # Draw genes
 
764
  start_x = margin + int(gene['start'] * scale)
765
  end_x = margin + int(gene['end'] * scale)
766
 
767
+ # Calculate color based on feature importance value
768
  avg_shap = gene['avg_shap']
769
 
770
+ # Convert importance -> color intensity (0 to 255)
771
  # Then clamp to a minimum intensity so it never ends up plain white
772
  intensity = int(abs(avg_shap) * 500)
773
  intensity = max(50, min(255, intensity)) # clamp between 50 and 255
 
813
  # Draw legend
814
  legend_x = margin
815
  legend_y = height - margin
816
+ draw.text((int(legend_x), int(legend_y - 60)), "Feature Importance Values:", fill='black', font=font)
817
 
818
  # Draw legend boxes
819
  box_width = 20
 
858
  features_file: str,
859
  fasta_text: str = "",
860
  features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
861
+ """Analyze feature importance values for each gene feature"""
862
  # First analyze whole sequence
863
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
864
  if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
865
  return f"Error in sequence analysis: {sequence_results[0]}", None, None
866
 
867
+ # Get feature importance values
868
  shap_means = sequence_results[3]["shap_means"]
869
 
870
  # Parse gene features
 
889
  if start is None or end is None:
890
  continue
891
 
892
+ # Get feature importance values for this region
893
  gene_shap = shap_means[start:end]
894
  stats = compute_gene_statistics(gene_shap)
895
 
 
916
  if not gene_results:
917
  return "No valid genes could be processed", None, None
918
 
919
+ # Sort genes by absolute feature importance value
920
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
921
 
922
  # Create results text
 
932
  f"Location: {gene['location']}\n"
933
  f"Classification: {gene['classification']} "
934
  f"(confidence: {gene['confidence']:.4f})\n"
935
+ f"Average Feature Importance: {gene['avg_shap']:.4f}\n\n"
936
  )
937
 
938
  # Create CSV content
939
+ csv_content = "gene_name,location,avg_importance,median_importance,std_importance,max_importance,min_importance,"
940
  csv_content += "pos_fraction,classification,confidence,locus_tag\n"
941
 
942
  for gene in gene_results:
 
1020
  gr.Markdown("""
1021
  # Virus Host Classifier
1022
  **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
1023
+ **Step 2**: Explore subregions to see local feature influence, distribution, GC content, etc.
1024
  **Step 3**: Analyze gene features and their contributions.
1025
  **Step 4**: Compare sequences and analyze differences.
1026
 
1027
+ **Color Scale**: Negative values = Blue, Zero = White, Positive values = Red.
1028
  """)
1029
 
1030
  with gr.Tab("1) Full-Sequence Analysis"):
 
1043
 
1044
  with gr.Column(scale=2):
1045
  results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
1046
+ kmer_img = gr.Image(label="Top k-mer Importance")
1047
+ genome_img = gr.Image(label="Genome-wide Feature Importance Heatmap (Blue=neg, White=0, Red=pos)")
1048
 
1049
  # File components with the correct type parameter
1050
+ download_kmer_shap = gr.File(label="Download k-mer Importance Values (CSV)", visible=True, type="filepath")
1051
  download_results = gr.File(label="Download Results", visible=True, elem_classes="download-button")
1052
 
1053
  seq_state = gr.State()
 
1071
  with gr.Tab("2) Subregion Exploration"):
1072
  gr.Markdown("""
1073
  **Subregion Analysis**
1074
+ Select start/end positions to view local feature importance, distribution, GC content, etc.
1075
  The heatmap uses the same Blue-White-Red scale.
1076
  """)
1077
  with gr.Row():
 
1080
  region_btn = gr.Button("Analyze Subregion")
1081
  subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False)
1082
  with gr.Row():
1083
+ subregion_img = gr.Image(label="Subregion Feature Importance Heatmap (B-W-R)")
1084
+ subregion_hist_img = gr.Image(label="Feature Importance Distribution (Histogram)")
1085
  download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button")
1086
 
1087
  region_btn.click(
 
1093
  with gr.Tab("3) Gene Features Analysis"):
1094
  gr.Markdown("""
1095
  **Analyze Gene Features**
1096
+ Upload a FASTA file and corresponding gene features file to analyze feature importance values per gene.
1097
  Gene features should be in the format:
1098
 
1099
  >gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)]
1100
  SEQUENCE
 
1101
  The genome viewer will show genes color-coded by their contribution:
1102
  - Red: Genes pushing toward human origin
1103
  - Blue: Genes pushing toward non-human origin
 
1125
  with gr.Tab("4) Comparative Analysis"):
1126
  gr.Markdown("""
1127
  **Compare Two Sequences**
1128
+ Upload or paste two FASTA sequences to compare their feature importance patterns.
1129
  The sequences will be normalized to the same length for comparison.
1130
 
1131
  **Color Scale**:
 
1143
  compare_btn = gr.Button("Compare Sequences", variant="primary")
1144
  comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
1145
  with gr.Row():
1146
+ diff_heatmap = gr.Image(label="Feature Importance Difference Heatmap")
1147
+ diff_hist = gr.Image(label="Distribution of Feature Importance Differences")
1148
  download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button")
1149
 
1150
  compare_btn.click(
 
1156
  gr.Markdown("""
1157
  ### Interface Features
1158
  - **Overall Classification** (human vs non-human) using k-mer frequencies
1159
+ - **Feature Importance Analysis** shows which k-mers push classification toward or away from human
1160
+ - **White-Centered Gradient**:
1161
  - Negative (blue), 0 (white), Positive (red)
1162
  - Symmetrical color range around 0
1163
  - **Identify Subregions** with strongest push for human or non-human
 
1171
  - Statistical summary of differences
1172
  - **Data Export**:
1173
  - Download results as CSV files
1174
+ - Download k-mer importance values
1175
  - Save analysis outputs for further processing
1176
  """)
1177