hiyata commited on
Commit
552aec4
·
verified ·
1 Parent(s): 6d0235b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -34
app.py CHANGED
@@ -133,7 +133,52 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
133
  return shap_means
134
 
135
  ###############################################################################
136
- # 5. PLOTTING / UTILITIES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  ###############################################################################
138
 
139
  def fig_to_image(fig):
@@ -150,7 +195,7 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
150
  Plots a 1D heatmap of per-base SHAP contributions.
151
  Negative = push toward Non-Human, Positive = push toward Human.
152
  Optionally can show only a subrange (start:end).
153
- We'll add extra bottom margin to avoid x-axis overlap.
154
  """
155
  if start is not None and end is not None:
156
  shap_means = shap_means[start:end]
@@ -162,17 +207,17 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
162
 
163
  fig, ax = plt.subplots(figsize=(12, 2))
164
  cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
165
- cbar = plt.colorbar(cax, orientation='horizontal', pad=0.3)
 
 
 
166
  cbar.set_label('SHAP Contribution')
167
 
168
  ax.set_yticks([])
169
  ax.set_xlabel('Position in Sequence')
170
  ax.set_title(f"{title}{subtitle}")
171
-
172
- # Extra spacing for x-axis labels
173
- plt.tight_layout()
174
- # Or you can do something like:
175
- # plt.subplots_adjust(bottom=0.2)
176
 
177
  return fig
178
 
@@ -219,11 +264,14 @@ def compute_gc_content(sequence):
219
  return (gc_count / len(sequence)) * 100.0
220
 
221
  ###############################################################################
222
- # 6. MAIN ANALYSIS STEP (Gradio Step 1)
223
  ###############################################################################
224
 
225
- def analyze_sequence(file_obj, top_kmers=10, fasta_text=""):
226
- """Analyzes the entire genome, returning classification and a heatmap."""
 
 
 
227
  # Handle input
228
  if fasta_text.strip():
229
  text = fasta_text.strip()
@@ -232,14 +280,14 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text=""):
232
  with open(file_obj, 'r') as f:
233
  text = f.read()
234
  except Exception as e:
235
- return (f"Error reading file: {str(e)}", None, None, None, None)
236
  else:
237
- return ("Please provide a FASTA sequence.", None, None, None, None)
238
 
239
  # Parse FASTA
240
  sequences = parse_fasta(text)
241
  if not sequences:
242
- return ("No valid FASTA sequences found.", None, None, None, None)
243
 
244
  header, seq = sequences[0]
245
 
@@ -250,7 +298,7 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text=""):
250
  model.load_state_dict(torch.load('model.pt', map_location=device))
251
  scaler = joblib.load('scaler.pkl')
252
  except Exception as e:
253
- return (f"Error loading model: {str(e)}", None, None, None, None)
254
 
255
  # Vectorize + scale
256
  freq_vector = sequence_to_kmer_vector(seq)
@@ -264,13 +312,26 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text=""):
264
  classification = "Human" if prob_human > 0.5 else "Non-human"
265
  confidence = max(prob_human, prob_nonhuman)
266
 
 
 
 
 
 
 
 
 
267
  # Build results text
268
  results_text = (
269
  f"Sequence: {header}\n"
270
  f"Length: {len(seq):,} bases\n"
271
  f"Classification: {classification}\n"
272
  f"Confidence: {confidence:.3f}\n"
273
- f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})"
 
 
 
 
 
274
  )
275
 
276
  # K-mer importance plot
@@ -278,26 +339,27 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text=""):
278
  bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
279
  bar_img = fig_to_image(bar_fig)
280
 
281
- # Per-base SHAP for entire genome
282
- shap_means = compute_positionwise_scores(seq, shap_values, k=4)
283
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
284
  heatmap_img = fig_to_image(heatmap_fig)
285
 
286
  # Return:
287
- # 1) results text
288
- # 2) k-mer bar image
289
- # 3) full-genome heatmap
290
- # 4) the "state" we need for step 2: (sequence, shap_means)
291
- # We'll store these in a dictionary so we can pass it around in Gradio.
 
 
292
  state_dict = {
293
  "seq": seq,
294
  "shap_means": shap_means
295
  }
296
 
297
- return (results_text, bar_img, heatmap_img, state_dict, header)
298
 
299
  ###############################################################################
300
- # 7. SUBREGION ANALYSIS (Gradio Step 2)
301
  ###############################################################################
302
 
303
  def analyze_subregion(state, header, region_start, region_end):
@@ -333,7 +395,6 @@ def analyze_subregion(state, header, region_start, region_end):
333
  negative_fraction = np.mean(region_shap < 0)
334
 
335
  # Simple logic-based interpretation
336
- # Adjust thresholds as needed
337
  if avg_shap > 0.05:
338
  region_classification = "Likely pushing toward human"
339
  elif avg_shap < -0.05:
@@ -368,7 +429,7 @@ def analyze_subregion(state, header, region_start, region_end):
368
 
369
 
370
  ###############################################################################
371
- # 8. BUILD GRADIO INTERFACE
372
  ###############################################################################
373
 
374
  css = """
@@ -380,7 +441,7 @@ css = """
380
  with gr.Blocks(css=css) as iface:
381
  gr.Markdown("""
382
  # Virus Host Classifier (with Interactive Region Viewer)
383
- **Step 1**: Predict overall viral sequence origin (human vs non-human)
384
  **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
385
  """)
386
 
@@ -404,26 +465,37 @@ with gr.Blocks(css=css) as iface:
404
  step=1,
405
  label="Number of top k-mers to display"
406
  )
 
 
 
 
 
 
 
407
  analyze_btn = gr.Button("Analyze Sequence", variant="primary")
408
 
409
  with gr.Column(scale=2):
410
  results_box = gr.Textbox(
411
- label="Classification Results", lines=7, interactive=False
412
  )
413
  kmer_img = gr.Image(label="Top k-mer SHAP")
414
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
415
 
416
  # Hidden states that store data for step 2
417
- # "seq_state" will hold { seq, shap_means }.
418
- # "header_state" is optional meta info
419
  seq_state = gr.State()
420
  header_state = gr.State()
421
 
422
- # The "analyze_sequence" function returns 5 values, which we map here:
 
 
 
 
 
 
423
  analyze_btn.click(
424
  analyze_sequence,
425
- inputs=[file_input, top_k, text_input],
426
- outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
427
  )
428
 
429
  with gr.Tab("2) Subregion Exploration"):
@@ -460,6 +532,9 @@ with gr.Blocks(css=css) as iface:
460
  - Local SHAP signals (heatmap & histogram)
461
  - GC content, fraction of bases pushing "human" vs "non-human"
462
  - Simple logic-based interpretation based on average SHAP
 
 
 
463
  """)
464
 
465
  if __name__ == "__main__":
 
133
  return shap_means
134
 
135
  ###############################################################################
136
+ # 5. FIND EXTREME SHAP REGIONS
137
+ ###############################################################################
138
+
139
+ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
140
+ """
141
+ Finds the subregion of length `window_size` that has the maximum
142
+ (mode="max") or minimum (mode="min") average SHAP.
143
+ Returns (best_start, best_end, avg_shap).
144
+ """
145
+ n = len(shap_means)
146
+ if window_size >= n:
147
+ # If the window is bigger than the entire sequence, return the whole seq
148
+ avg_val = np.mean(shap_means) if n > 0 else 0.0
149
+ return (0, n, avg_val)
150
+
151
+ # Rolling sum approach
152
+ csum = np.cumsum(shap_means) # csum[i] = sum of shap_means[0..i-1]
153
+ # function to compute sum in [start, start+window_size)
154
+ def window_sum(start):
155
+ end = start + window_size
156
+ return csum[end] - csum[start]
157
+
158
+ best_start = 0
159
+ best_avg = None
160
+
161
+ # Initialize the best with the first window
162
+ best_sum = window_sum(0)
163
+ best_avg = best_sum / window_size
164
+ best_start = 0
165
+
166
+ for start in range(1, n - window_size + 1):
167
+ wsum = window_sum(start)
168
+ wavg = wsum / window_size
169
+ if mode == "max":
170
+ if wavg > best_avg:
171
+ best_avg = wavg
172
+ best_start = start
173
+ else: # mode == "min"
174
+ if wavg < best_avg:
175
+ best_avg = wavg
176
+ best_start = start
177
+
178
+ return (best_start, best_start + window_size, best_avg)
179
+
180
+ ###############################################################################
181
+ # 6. PLOTTING / UTILITIES
182
  ###############################################################################
183
 
184
  def fig_to_image(fig):
 
195
  Plots a 1D heatmap of per-base SHAP contributions.
196
  Negative = push toward Non-Human, Positive = push toward Human.
197
  Optionally can show only a subrange (start:end).
198
+ We'll adjust layout so that the colorbar is below the x-axis and doesn't overlap.
199
  """
200
  if start is not None and end is not None:
201
  shap_means = shap_means[start:end]
 
207
 
208
  fig, ax = plt.subplots(figsize=(12, 2))
209
  cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
210
+
211
+ # Adjust colorbar with some extra margin
212
+ # We'll place the colorbar horizontally below
213
+ cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25)
214
  cbar.set_label('SHAP Contribution')
215
 
216
  ax.set_yticks([])
217
  ax.set_xlabel('Position in Sequence')
218
  ax.set_title(f"{title}{subtitle}")
219
+ # Additional spacing at bottom to avoid overlap
220
+ plt.subplots_adjust(bottom=0.3)
 
 
 
221
 
222
  return fig
223
 
 
264
  return (gc_count / len(sequence)) * 100.0
265
 
266
  ###############################################################################
267
+ # 7. MAIN ANALYSIS STEP (Gradio Step 1)
268
  ###############################################################################
269
 
270
+ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
271
+ """
272
+ Analyzes the entire genome, returning classification, full-genome heatmap,
273
+ top k-mer bar plot, and identifies subregions with strongest positive/negative push.
274
+ """
275
  # Handle input
276
  if fasta_text.strip():
277
  text = fasta_text.strip()
 
280
  with open(file_obj, 'r') as f:
281
  text = f.read()
282
  except Exception as e:
283
+ return (f"Error reading file: {str(e)}", None, None, None, None, None)
284
  else:
285
+ return ("Please provide a FASTA sequence.", None, None, None, None, None)
286
 
287
  # Parse FASTA
288
  sequences = parse_fasta(text)
289
  if not sequences:
290
+ return ("No valid FASTA sequences found.", None, None, None, None, None)
291
 
292
  header, seq = sequences[0]
293
 
 
298
  model.load_state_dict(torch.load('model.pt', map_location=device))
299
  scaler = joblib.load('scaler.pkl')
300
  except Exception as e:
301
+ return (f"Error loading model: {str(e)}", None, None, None, None, None)
302
 
303
  # Vectorize + scale
304
  freq_vector = sequence_to_kmer_vector(seq)
 
312
  classification = "Human" if prob_human > 0.5 else "Non-human"
313
  confidence = max(prob_human, prob_nonhuman)
314
 
315
+ # Per-base SHAP
316
+ shap_means = compute_positionwise_scores(seq, shap_values, k=4)
317
+
318
+ # Find the most "human-pushing" region
319
+ (max_start, max_end, max_avg) = find_extreme_subregion(shap_means, window_size, mode="max")
320
+ # Find the most "non-human–pushing" region
321
+ (min_start, min_end, min_avg) = find_extreme_subregion(shap_means, window_size, mode="min")
322
+
323
  # Build results text
324
  results_text = (
325
  f"Sequence: {header}\n"
326
  f"Length: {len(seq):,} bases\n"
327
  f"Classification: {classification}\n"
328
  f"Confidence: {confidence:.3f}\n"
329
+ f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
330
+ f"---\n"
331
+ f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
332
+ f"Start: {max_start}, End: {max_end}, Avg SHAP: {max_avg:.4f}\n\n"
333
+ f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
334
+ f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
335
  )
336
 
337
  # K-mer importance plot
 
339
  bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
340
  bar_img = fig_to_image(bar_fig)
341
 
342
+ # Full-genome SHAP heatmap
 
343
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
344
  heatmap_img = fig_to_image(heatmap_fig)
345
 
346
  # Return:
347
+ # 1) results text
348
+ # 2) k-mer bar image
349
+ # 3) full-genome heatmap
350
+ # 4) "state" with { seq, shap_means, header }, for subregion analysis
351
+ # 5) we also return "most pushing" subregion info if we want
352
+ # but for simplicity, we can just keep them in the text.
353
+ # 6) the sequence header
354
  state_dict = {
355
  "seq": seq,
356
  "shap_means": shap_means
357
  }
358
 
359
+ return (results_text, bar_img, heatmap_img, state_dict, header, None)
360
 
361
  ###############################################################################
362
+ # 8. SUBREGION ANALYSIS (Gradio Step 2)
363
  ###############################################################################
364
 
365
  def analyze_subregion(state, header, region_start, region_end):
 
395
  negative_fraction = np.mean(region_shap < 0)
396
 
397
  # Simple logic-based interpretation
 
398
  if avg_shap > 0.05:
399
  region_classification = "Likely pushing toward human"
400
  elif avg_shap < -0.05:
 
429
 
430
 
431
  ###############################################################################
432
+ # 9. BUILD GRADIO INTERFACE
433
  ###############################################################################
434
 
435
  css = """
 
441
  with gr.Blocks(css=css) as iface:
442
  gr.Markdown("""
443
  # Virus Host Classifier (with Interactive Region Viewer)
444
+ **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
445
  **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
446
  """)
447
 
 
465
  step=1,
466
  label="Number of top k-mers to display"
467
  )
468
+ win_size = gr.Slider(
469
+ minimum=100,
470
+ maximum=5000,
471
+ value=500,
472
+ step=100,
473
+ label="Window size for 'most pushing' subregions"
474
+ )
475
  analyze_btn = gr.Button("Analyze Sequence", variant="primary")
476
 
477
  with gr.Column(scale=2):
478
  results_box = gr.Textbox(
479
+ label="Classification Results", lines=12, interactive=False
480
  )
481
  kmer_img = gr.Image(label="Top k-mer SHAP")
482
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
483
 
484
  # Hidden states that store data for step 2
 
 
485
  seq_state = gr.State()
486
  header_state = gr.State()
487
 
488
+ # The "analyze_sequence" function returns 6 values, which we map here:
489
+ # 1) results_text
490
+ # 2) bar_img
491
+ # 3) heatmap_img
492
+ # 4) state_dict
493
+ # 5) header
494
+ # 6) None placeholder
495
  analyze_btn.click(
496
  analyze_sequence,
497
+ inputs=[file_input, top_k, text_input, win_size],
498
+ outputs=[results_box, kmer_img, genome_img, seq_state, header_state, None]
499
  )
500
 
501
  with gr.Tab("2) Subregion Exploration"):
 
532
  - Local SHAP signals (heatmap & histogram)
533
  - GC content, fraction of bases pushing "human" vs "non-human"
534
  - Simple logic-based interpretation based on average SHAP
535
+ 5. **Identification of the most 'human-pushing' subregion** (max average SHAP)
536
+ and the most 'non-human–pushing' subregion (min average SHAP),
537
+ each of a chosen window size.
538
  """)
539
 
540
  if __name__ == "__main__":