hiyata commited on
Commit
6d0235b
·
verified ·
1 Parent(s): 56468ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -28
app.py CHANGED
@@ -150,6 +150,7 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
150
  Plots a 1D heatmap of per-base SHAP contributions.
151
  Negative = push toward Non-Human, Positive = push toward Human.
152
  Optionally can show only a subrange (start:end).
 
153
  """
154
  if start is not None and end is not None:
155
  shap_means = shap_means[start:end]
@@ -161,13 +162,18 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
161
 
162
  fig, ax = plt.subplots(figsize=(12, 2))
163
  cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
164
- cbar = plt.colorbar(cax, orientation='horizontal', pad=0.2)
165
  cbar.set_label('SHAP Contribution')
166
 
167
  ax.set_yticks([])
168
  ax.set_xlabel('Position in Sequence')
169
  ax.set_title(f"{title}{subtitle}")
 
 
170
  plt.tight_layout()
 
 
 
171
  return fig
172
 
173
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
@@ -187,6 +193,22 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
187
  plt.xlabel('SHAP Value (impact on model output)')
188
  plt.title(f'Top {top_k} Most Influential k-mers')
189
  plt.gca().invert_yaxis()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  return fig
191
 
192
  def compute_gc_content(sequence):
@@ -281,19 +303,22 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text=""):
281
  def analyze_subregion(state, header, region_start, region_end):
282
  """
283
  Takes stored data from step 1 and a user-chosen region.
284
- Returns a subregion heatmap and some stats (like GC content, average SHAP).
285
  """
286
  if not state or "seq" not in state or "shap_means" not in state:
287
- return ("No sequence data found. Please run Step 1 first.", None)
288
 
289
  seq = state["seq"]
290
  shap_means = state["shap_means"]
291
 
292
  # Validate bounds
 
 
 
293
  region_start = max(0, min(region_start, len(seq)))
294
  region_end = max(0, min(region_end, len(seq)))
295
  if region_end <= region_start:
296
- return ("Invalid region range. End must be > Start.", None)
297
 
298
  # Subsequence
299
  region_seq = seq[region_start:region_end]
@@ -302,23 +327,44 @@ def analyze_subregion(state, header, region_start, region_end):
302
  # Some stats
303
  gc_percent = compute_gc_content(region_seq)
304
  avg_shap = float(np.mean(region_shap))
305
-
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  region_info = (
307
  f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
308
  f"Region length: {len(region_seq)} bases\n"
309
  f"GC content: {gc_percent:.2f}%\n"
310
- f"Average SHAP in region: {avg_shap:.4f} "
311
- f"({'toward human' if avg_shap > 0 else 'toward non-human' if avg_shap < 0 else 'neutral'})"
 
 
312
  )
313
 
314
  # Plot region as small heatmap
315
- fig = plot_linear_heatmap(shap_means,
316
- title="Subregion SHAP",
317
- start=region_start,
318
- end=region_end)
319
- heatmap_img = fig_to_image(fig)
 
 
320
 
321
- return (region_info, heatmap_img)
 
 
 
 
322
 
323
 
324
  ###############################################################################
@@ -335,7 +381,7 @@ with gr.Blocks(css=css) as iface:
335
  gr.Markdown("""
336
  # Virus Host Classifier (with Interactive Region Viewer)
337
  **Step 1**: Predict overall viral sequence origin (human vs non-human)
338
- **Step 2**: Explore subregions to see local SHAP signals and GC content
339
  """)
340
 
341
  with gr.Tab("1) Full-Sequence Analysis"):
@@ -368,8 +414,8 @@ with gr.Blocks(css=css) as iface:
368
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
369
 
370
  # Hidden states that store data for step 2
371
- # "state" will hold (sequence, shap_means).
372
- # "header" is optional meta info
373
  seq_state = gr.State()
374
  header_state = gr.State()
375
 
@@ -382,7 +428,8 @@ with gr.Blocks(css=css) as iface:
382
 
383
  with gr.Tab("2) Subregion Exploration"):
384
  gr.Markdown("""
385
- Select start/end positions to view local SHAP signals.
 
386
  """)
387
  with gr.Row():
388
  region_start = gr.Number(label="Region Start", value=0)
@@ -391,15 +438,17 @@ with gr.Blocks(css=css) as iface:
391
 
392
  subregion_info = gr.Textbox(
393
  label="Subregion Analysis",
394
- lines=4,
395
  interactive=False
396
  )
397
- subregion_img = gr.Image(label="Subregion SHAP Heatmap")
398
-
 
 
399
  region_btn.click(
400
  analyze_subregion,
401
  inputs=[seq_state, header_state, region_start, region_end],
402
- outputs=[subregion_info, subregion_img]
403
  )
404
 
405
  gr.Markdown("""
@@ -407,13 +456,10 @@ with gr.Blocks(css=css) as iface:
407
  1. **Overall Classification** (human vs non-human), using a learned model on k-mer frequencies.
408
  2. **SHAP Analysis** (ablation-based) to see which k-mer features push classification toward or away from "human".
409
  3. **Genome-Wide SHAP Heatmap**: Each base's average SHAP across overlapping k-mers.
410
- 4. **Subregion Exploration**:
411
- - View SHAP signals in a user-chosen region.
412
- - Calculate local GC content, average SHAP, etc.
413
-
414
- ### Tips
415
- - For very large sequences (e.g., >100k bases), the full heatmap might be large; consider downsampling if needed.
416
- - Adjust *Region Start* and *End* to explore different parts of the genome.
417
  """)
418
 
419
  if __name__ == "__main__":
 
150
  Plots a 1D heatmap of per-base SHAP contributions.
151
  Negative = push toward Non-Human, Positive = push toward Human.
152
  Optionally can show only a subrange (start:end).
153
+ We'll add extra bottom margin to avoid x-axis overlap.
154
  """
155
  if start is not None and end is not None:
156
  shap_means = shap_means[start:end]
 
162
 
163
  fig, ax = plt.subplots(figsize=(12, 2))
164
  cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
165
+ cbar = plt.colorbar(cax, orientation='horizontal', pad=0.3)
166
  cbar.set_label('SHAP Contribution')
167
 
168
  ax.set_yticks([])
169
  ax.set_xlabel('Position in Sequence')
170
  ax.set_title(f"{title}{subtitle}")
171
+
172
+ # Extra spacing for x-axis labels
173
  plt.tight_layout()
174
+ # Or you can do something like:
175
+ # plt.subplots_adjust(bottom=0.2)
176
+
177
  return fig
178
 
179
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
 
193
  plt.xlabel('SHAP Value (impact on model output)')
194
  plt.title(f'Top {top_k} Most Influential k-mers')
195
  plt.gca().invert_yaxis()
196
+ plt.tight_layout()
197
+ return fig
198
+
199
+ def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
200
+ """
201
+ Simple histogram of SHAP values in the subregion.
202
+ Helps see how many positions push human vs non-human.
203
+ """
204
+ fig, ax = plt.subplots(figsize=(6, 4))
205
+ ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
206
+ ax.axvline(0, color='red', linestyle='--', label='0.0')
207
+ ax.set_xlabel("SHAP Value")
208
+ ax.set_ylabel("Count")
209
+ ax.set_title(title)
210
+ ax.legend()
211
+ plt.tight_layout()
212
  return fig
213
 
214
  def compute_gc_content(sequence):
 
303
  def analyze_subregion(state, header, region_start, region_end):
304
  """
305
  Takes stored data from step 1 and a user-chosen region.
306
+ Returns a subregion heatmap, histogram, and some stats (GC, average SHAP).
307
  """
308
  if not state or "seq" not in state or "shap_means" not in state:
309
+ return ("No sequence data found. Please run Step 1 first.", None, None)
310
 
311
  seq = state["seq"]
312
  shap_means = state["shap_means"]
313
 
314
  # Validate bounds
315
+ region_start = int(region_start)
316
+ region_end = int(region_end)
317
+
318
  region_start = max(0, min(region_start, len(seq)))
319
  region_end = max(0, min(region_end, len(seq)))
320
  if region_end <= region_start:
321
+ return ("Invalid region range. End must be > Start.", None, None)
322
 
323
  # Subsequence
324
  region_seq = seq[region_start:region_end]
 
327
  # Some stats
328
  gc_percent = compute_gc_content(region_seq)
329
  avg_shap = float(np.mean(region_shap))
330
+
331
+ # Fraction pushing toward human vs. non-human
332
+ positive_fraction = np.mean(region_shap > 0)
333
+ negative_fraction = np.mean(region_shap < 0)
334
+
335
+ # Simple logic-based interpretation
336
+ # Adjust thresholds as needed
337
+ if avg_shap > 0.05:
338
+ region_classification = "Likely pushing toward human"
339
+ elif avg_shap < -0.05:
340
+ region_classification = "Likely pushing toward non-human"
341
+ else:
342
+ region_classification = "Near neutral (no strong push)"
343
+
344
  region_info = (
345
  f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
346
  f"Region length: {len(region_seq)} bases\n"
347
  f"GC content: {gc_percent:.2f}%\n"
348
+ f"Average SHAP in region: {avg_shap:.4f}\n"
349
+ f"Fraction with SHAP > 0 (toward human): {positive_fraction:.2f}\n"
350
+ f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
351
+ f"Subregion interpretation: {region_classification}\n"
352
  )
353
 
354
  # Plot region as small heatmap
355
+ heatmap_fig = plot_linear_heatmap(
356
+ shap_means,
357
+ title="Subregion SHAP",
358
+ start=region_start,
359
+ end=region_end
360
+ )
361
+ heatmap_img = fig_to_image(heatmap_fig)
362
 
363
+ # Plot histogram of SHAP in region
364
+ hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
365
+ hist_img = fig_to_image(hist_fig)
366
+
367
+ return (region_info, heatmap_img, hist_img)
368
 
369
 
370
  ###############################################################################
 
381
  gr.Markdown("""
382
  # Virus Host Classifier (with Interactive Region Viewer)
383
  **Step 1**: Predict overall viral sequence origin (human vs non-human)
384
+ **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
385
  """)
386
 
387
  with gr.Tab("1) Full-Sequence Analysis"):
 
414
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
415
 
416
  # Hidden states that store data for step 2
417
+ # "seq_state" will hold { seq, shap_means }.
418
+ # "header_state" is optional meta info
419
  seq_state = gr.State()
420
  header_state = gr.State()
421
 
 
428
 
429
  with gr.Tab("2) Subregion Exploration"):
430
  gr.Markdown("""
431
+ **Subregion Analysis**
432
+ Select start/end positions to view local SHAP signals, distribution, and GC content.
433
  """)
434
  with gr.Row():
435
  region_start = gr.Number(label="Region Start", value=0)
 
438
 
439
  subregion_info = gr.Textbox(
440
  label="Subregion Analysis",
441
+ lines=7,
442
  interactive=False
443
  )
444
+ with gr.Row():
445
+ subregion_img = gr.Image(label="Subregion SHAP Heatmap")
446
+ subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
447
+
448
  region_btn.click(
449
  analyze_subregion,
450
  inputs=[seq_state, header_state, region_start, region_end],
451
+ outputs=[subregion_info, subregion_img, subregion_hist_img]
452
  )
453
 
454
  gr.Markdown("""
 
456
  1. **Overall Classification** (human vs non-human), using a learned model on k-mer frequencies.
457
  2. **SHAP Analysis** (ablation-based) to see which k-mer features push classification toward or away from "human".
458
  3. **Genome-Wide SHAP Heatmap**: Each base's average SHAP across overlapping k-mers.
459
+ 4. **Subregion Exploration**:
460
+ - Local SHAP signals (heatmap & histogram)
461
+ - GC content, fraction of bases pushing "human" vs "non-human"
462
+ - Simple logic-based interpretation based on average SHAP
 
 
 
463
  """)
464
 
465
  if __name__ == "__main__":