hiyata commited on
Commit
f5ea8d6
·
verified ·
1 Parent(s): 455bf4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -273
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  import numpy as np
4
  from itertools import product
5
  import torch.nn as nn
@@ -7,7 +8,6 @@ import matplotlib.pyplot as plt
7
  import matplotlib.colors as mcolors
8
  import io
9
  from PIL import Image
10
- from scipy.interpolate import interp1d
11
 
12
  ###############################################################################
13
  # 1. MODEL DEFINITION
@@ -71,7 +71,7 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
71
 
72
  total_kmers = len(sequence) - k + 1
73
  if total_kmers > 0:
74
- vec /= total_kmers
75
 
76
  return vec
77
 
@@ -86,10 +86,12 @@ def calculate_shap_values(model, x_tensor):
86
  """
87
  model.eval()
88
  with torch.no_grad():
 
89
  baseline_output = model(x_tensor)
90
  baseline_probs = torch.softmax(baseline_output, dim=1)
91
- baseline_prob = baseline_probs[0, 1].item() # Probability of 'human'
92
 
 
93
  shap_values = []
94
  x_zeroed = x_tensor.clone()
95
  for i in range(x_tensor.shape[1]):
@@ -100,7 +102,7 @@ def calculate_shap_values(model, x_tensor):
100
  prob = probs[0, 1].item()
101
  impact = baseline_prob - prob
102
  shap_values.append(impact)
103
- x_zeroed[0, i] = original_val
104
  return np.array(shap_values), baseline_prob
105
 
106
  ###############################################################################
@@ -108,6 +110,10 @@ def calculate_shap_values(model, x_tensor):
108
  ###############################################################################
109
 
110
  def compute_positionwise_scores(sequence, shap_values, k=4):
 
 
 
 
111
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
112
  kmer_dict = {km: i for i, km in enumerate(kmers)}
113
 
@@ -132,13 +138,20 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
132
  ###############################################################################
133
 
134
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
 
 
 
 
 
135
  n = len(shap_means)
136
  if n == 0:
137
  return (0, 0, 0.0)
138
  if window_size >= n:
 
139
  avg_val = float(np.mean(shap_means))
140
  return (0, n, avg_val)
141
 
 
142
  csum = np.zeros(n + 1, dtype=np.float32)
143
  csum[1:] = np.cumsum(shap_means)
144
 
@@ -165,6 +178,7 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
165
  ###############################################################################
166
 
167
  def fig_to_image(fig):
 
168
  buf = io.BytesIO()
169
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
170
  buf.seek(0)
@@ -173,14 +187,27 @@ def fig_to_image(fig):
173
  return img
174
 
175
  def get_zero_centered_cmap():
 
 
 
 
 
 
176
  colors = [
177
- (0.0, 'blue'),
178
- (0.5, 'white'),
179
- (1.0, 'red')
180
  ]
181
- return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
 
182
 
183
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
 
 
 
 
 
 
184
  if start is not None and end is not None:
185
  local_shap = shap_means[start:end]
186
  subtitle = f" (positions {start}-{end})"
@@ -191,46 +218,73 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
191
  if len(local_shap) == 0:
192
  local_shap = np.array([0.0])
193
 
 
194
  heatmap_data = local_shap.reshape(1, -1)
 
 
195
  min_val = np.min(local_shap)
196
  max_val = np.max(local_shap)
197
  extent = max(abs(min_val), abs(max_val))
198
- cmap = get_zero_centered_cmap()
199
 
200
- fig, ax = plt.subplots(figsize=(12, 1.8))
 
 
 
 
 
 
201
  cax = ax.imshow(
202
  heatmap_data,
203
  aspect='auto',
204
- cmap=cmap,
205
  vmin=-extent,
206
- vmax=extent
207
  )
 
 
208
  cbar = plt.colorbar(
209
  cax,
210
  orientation='horizontal',
211
- pad=0.25,
212
- aspect=40,
213
- shrink=0.8
214
  )
215
- cbar.ax.tick_params(labelsize=8)
216
- cbar.set_label('SHAP Contribution', fontsize=9, labelpad=5)
217
 
 
 
 
 
 
 
 
 
 
218
  ax.set_yticks([])
219
  ax.set_xlabel('Position in Sequence', fontsize=10)
220
  ax.set_title(f"{title}{subtitle}", pad=10)
221
- plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
 
 
 
 
 
 
222
 
223
  return fig
224
 
225
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
 
226
  plt.rcParams.update({'font.size': 10})
227
  fig = plt.figure(figsize=(10, 5))
228
 
 
229
  indices = np.argsort(np.abs(shap_values))[-top_k:]
230
  values = shap_values[indices]
231
  features = [kmers[i] for i in indices]
232
 
 
233
  colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
 
234
  plt.barh(range(len(values)), values, color=colors)
235
  plt.yticks(range(len(values)), features)
236
  plt.xlabel('SHAP Value (impact on model output)')
@@ -240,6 +294,9 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
240
  return fig
241
 
242
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
 
 
 
243
  fig, ax = plt.subplots(figsize=(6, 4))
244
  ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
245
  ax.axvline(0, color='red', linestyle='--', label='0.0')
@@ -251,102 +308,119 @@ def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
251
  return fig
252
 
253
  def compute_gc_content(sequence):
 
254
  if not sequence:
255
  return 0
256
  gc_count = sequence.count('G') + sequence.count('C')
257
  return (gc_count / len(sequence)) * 100.0
258
 
259
  ###############################################################################
260
- # 7. SEQUENCE ANALYSIS FUNCTIONS
261
  ###############################################################################
262
 
263
- # Set up device and load the model once globally
264
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
265
- model = VirusClassifier(256)
266
- model.load_state_dict(torch.load("model.pt", map_location=device))
267
- model.to(device)
268
- model.eval()
269
-
270
- KMERS_4 = [''.join(p) for p in product("ACGT", repeat=4)]
271
-
272
- def analyze_sequence(file_path, top_k=10, fasta_text="", window_size=500):
273
  """
274
- Analyze a virus sequence from a FASTA file or text input.
275
- Returns (results_text, kmer_plot, heatmap_plot, state_dict, header)
276
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  try:
278
- if file_path:
279
- with open(file_path, 'r') as f:
280
- fasta_text = f.read()
281
-
282
- if not fasta_text.strip():
283
- return ("Error: No sequence provided", None, None, {}, "")
284
-
285
- sequences = parse_fasta(fasta_text)
286
- if not sequences:
287
- return ("Error: No valid FASTA sequences found", None, None, {}, "")
288
-
289
- header, sequence = sequences[0]
290
-
291
- x = sequence_to_kmer_vector(sequence, k=4)
292
- x_tensor = torch.tensor(x).float().unsqueeze(0).to(device)
293
-
294
- with torch.no_grad():
295
- output = model(x_tensor)
296
- probs = torch.softmax(output, dim=1)
297
- pred_human = probs[0, 1].item()
298
-
299
- classification = "Human" if pred_human > 0.5 else "Non-human"
300
-
301
- shap_values, baseline_prob = calculate_shap_values(model, x_tensor)
302
-
303
- shap_means = compute_positionwise_scores(sequence, shap_values, k=4)
304
-
305
- start_max, end_max, avg_max = find_extreme_subregion(shap_means, window_size, mode="max")
306
- start_min, end_min, avg_min = find_extreme_subregion(shap_means, window_size, mode="min")
307
-
308
- results = (
309
- f"Classification: {classification} "
310
- f"(probability of human = {pred_human:.3f})\n\n"
311
- f"Sequence length: {len(sequence):,} bases\n"
312
- f"Overall GC content: {compute_gc_content(sequence):.1f}%\n\n"
313
- f"Most human-like {window_size} bp region:\n"
314
- f"Position {start_max:,} to {end_max:,}\n"
315
- f"Average SHAP: {avg_max:.4f}\n"
316
- f"GC content: {compute_gc_content(sequence[start_max:end_max]):.1f}%\n\n"
317
- f"Least human-like {window_size} bp region:\n"
318
- f"Position {start_min:,} to {end_min:,}\n"
319
- f"Average SHAP: {avg_min:.4f}\n"
320
- f"GC content: {compute_gc_content(sequence[start_min:end_min]):.1f}%"
321
- )
322
-
323
- kmer_fig = create_importance_bar_plot(shap_values, KMERS_4, top_k=top_k)
324
- kmer_img = fig_to_image(kmer_fig)
325
-
326
- heatmap_fig = plot_linear_heatmap(shap_means)
327
- heatmap_img = fig_to_image(heatmap_fig)
328
-
329
- state = {
330
- "seq": sequence,
331
- "shap_means": shap_means
332
- }
333
-
334
- return results, kmer_img, heatmap_img, state, header
335
 
 
336
  except Exception as e:
337
- return (f"Error analyzing sequence: {str(e)}", None, None, {}, "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  ###############################################################################
340
- # 8. SUBREGION ANALYSIS FUNCTION
341
  ###############################################################################
342
 
343
  def analyze_subregion(state, header, region_start, region_end):
 
 
 
 
344
  if not state or "seq" not in state or "shap_means" not in state:
345
  return ("No sequence data found. Please run Step 1 first.", None, None)
346
 
347
  seq = state["seq"]
348
  shap_means = state["shap_means"]
349
 
 
350
  region_start = int(region_start)
351
  region_end = int(region_end)
352
 
@@ -355,15 +429,19 @@ def analyze_subregion(state, header, region_start, region_end):
355
  if region_end <= region_start:
356
  return ("Invalid region range. End must be > Start.", None, None)
357
 
 
358
  region_seq = seq[region_start:region_end]
359
  region_shap = shap_means[region_start:region_end]
360
 
 
361
  gc_percent = compute_gc_content(region_seq)
362
  avg_shap = float(np.mean(region_shap))
363
 
 
364
  positive_fraction = np.mean(region_shap > 0)
365
  negative_fraction = np.mean(region_shap < 0)
366
 
 
367
  if avg_shap > 0.05:
368
  region_classification = "Likely pushing toward human"
369
  elif avg_shap < -0.05:
@@ -381,6 +459,7 @@ def analyze_subregion(state, header, region_start, region_end):
381
  f"Subregion interpretation: {region_classification}\n"
382
  )
383
 
 
384
  heatmap_fig = plot_linear_heatmap(
385
  shap_means,
386
  title="Subregion SHAP",
@@ -389,122 +468,15 @@ def analyze_subregion(state, header, region_start, region_end):
389
  )
390
  heatmap_img = fig_to_image(heatmap_fig)
391
 
 
392
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
393
  hist_img = fig_to_image(hist_fig)
394
 
395
  return (region_info, heatmap_img, hist_img)
396
 
397
- ###############################################################################
398
- # 9. COMPARISON ANALYSIS FUNCTIONS
399
- ###############################################################################
400
-
401
- def normalize_shap_lengths(shap1, shap2, num_points=1000):
402
- x1 = np.linspace(0, 1, len(shap1))
403
- x2 = np.linspace(0, 1, len(shap2))
404
-
405
- f1 = interp1d(x1, shap1, kind='linear')
406
- f2 = interp1d(x2, shap2, kind='linear')
407
-
408
- x_new = np.linspace(0, 1, num_points)
409
-
410
- shap1_norm = f1(x_new)
411
- shap2_norm = f2(x_new)
412
-
413
- return shap1_norm, shap2_norm
414
-
415
- def compute_shap_difference(shap1_norm, shap2_norm):
416
- return shap2_norm - shap1_norm
417
-
418
- def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
419
- heatmap_data = shap_diff.reshape(1, -1)
420
- extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
421
- cmap = get_zero_centered_cmap()
422
-
423
- fig, ax = plt.subplots(figsize=(12, 1.8))
424
- cax = ax.imshow(
425
- heatmap_data,
426
- aspect='auto',
427
- cmap=cmap,
428
- vmin=-extent,
429
- vmax=extent
430
- )
431
- cbar = plt.colorbar(
432
- cax,
433
- orientation='horizontal',
434
- pad=0.25,
435
- aspect=40,
436
- shrink=0.8
437
- )
438
- cbar.ax.tick_params(labelsize=8)
439
- cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
440
-
441
- ax.set_yticks([])
442
- ax.set_xlabel('Normalized Position (0-100%)', fontsize=10)
443
- ax.set_title(title, pad=10)
444
- plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
445
-
446
- return fig
447
-
448
- def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
449
- results1 = analyze_sequence(file1, top_k=10, fasta_text=fasta1, window_size=500)
450
- if isinstance(results1[0], str) and "Error" in results1[0]:
451
- return (f"Error in sequence 1: {results1[0]}", None, None)
452
-
453
- results2 = analyze_sequence(file2, top_k=10, fasta_text=fasta2, window_size=500)
454
- if isinstance(results2[0], str) and "Error" in results2[0]:
455
- return (f"Error in sequence 2: {results2[0]}", None, None)
456
-
457
- shap1 = results1[3]["shap_means"]
458
- shap2 = results2[3]["shap_means"]
459
-
460
- shap1_norm, shap2_norm = normalize_shap_lengths(shap1, shap2)
461
- shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
462
-
463
- avg_diff = np.mean(shap_diff)
464
- std_diff = np.std(shap_diff)
465
- max_diff = np.max(shap_diff)
466
- min_diff = np.min(shap_diff)
467
-
468
- threshold = 0.05
469
- substantial_diffs = np.abs(shap_diff) > threshold
470
- frac_different = np.mean(substantial_diffs)
471
-
472
- classification1 = results1[0].split('Classification: ')[1].split('\n')[0].strip()
473
- classification2 = results2[0].split('Classification: ')[1].split('\n')[0].strip()
474
-
475
- len1_formatted = "{:,}".format(len(shap1))
476
- len2_formatted = "{:,}".format(len(shap2))
477
- frac_formatted = "{:.2%}".format(frac_different)
478
-
479
- comparison_text = (
480
- "Sequence Comparison Results:\n"
481
- f"Sequence 1: {results1[4]}\n"
482
- f"Length: {len1_formatted} bases\n"
483
- f"Classification: {classification1}\n\n"
484
- f"Sequence 2: {results2[4]}\n"
485
- f"Length: {len2_formatted} bases\n"
486
- f"Classification: {classification2}\n\n"
487
- "Comparison Statistics:\n"
488
- f"Average SHAP difference: {avg_diff:.4f}\n"
489
- f"Standard deviation: {std_diff:.4f}\n"
490
- f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
491
- f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
492
- f"Fraction of positions with substantial differences: {frac_formatted}\n\n"
493
- "Interpretation:\n"
494
- "Positive values (red) indicate regions where Sequence 2 is more 'human-like'\n"
495
- "Negative values (blue) indicate regions where Sequence 1 is more 'human-like'"
496
- )
497
-
498
- heatmap_fig = plot_comparative_heatmap(shap_diff)
499
- heatmap_img = fig_to_image(heatmap_fig)
500
-
501
- hist_fig = plot_shap_histogram(shap_diff, title="Distribution of SHAP Differences")
502
- hist_img = fig_to_image(hist_fig)
503
-
504
- return comparison_text, heatmap_img, hist_img
505
 
506
  ###############################################################################
507
- # 10. BUILD GRADIO INTERFACE
508
  ###############################################################################
509
 
510
  css = """
@@ -535,14 +507,14 @@ with gr.Blocks(css=css) as iface:
535
  placeholder=">sequence_name\nACGTACGT...",
536
  lines=5
537
  )
538
- top_k_slider = gr.Slider(
539
  minimum=5,
540
  maximum=30,
541
  value=10,
542
  step=1,
543
  label="Number of top k-mers to display"
544
  )
545
- win_size_slider = gr.Slider(
546
  minimum=100,
547
  maximum=5000,
548
  value=500,
@@ -561,9 +533,10 @@ with gr.Blocks(css=css) as iface:
561
  seq_state = gr.State()
562
  header_state = gr.State()
563
 
 
564
  analyze_btn.click(
565
  analyze_sequence,
566
- inputs=[file_input, top_k_slider, text_input, win_size_slider],
567
  outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
568
  )
569
 
@@ -592,61 +565,6 @@ with gr.Blocks(css=css) as iface:
592
  inputs=[seq_state, header_state, region_start, region_end],
593
  outputs=[subregion_info, subregion_img, subregion_hist_img]
594
  )
595
-
596
- with gr.Tab("3) Comparative Analysis"):
597
- gr.Markdown("""
598
- **Compare Two Sequences**
599
- Upload or paste two FASTA sequences to compare their SHAP patterns.
600
- The sequences will be normalized to the same length for comparison.
601
-
602
- **Color Scale**:
603
- - Red: Sequence 2 is more human-like in this region
604
- - Blue: Sequence 1 is more human-like in this region
605
- - White: No substantial difference
606
- """)
607
-
608
- with gr.Row():
609
- with gr.Column(scale=1):
610
- file_input1 = gr.File(
611
- label="Upload first FASTA file",
612
- file_types=[".fasta", ".fa", ".txt"],
613
- type="filepath"
614
- )
615
- text_input1 = gr.Textbox(
616
- label="Or paste first FASTA sequence",
617
- placeholder=">sequence1\nACGTACGT...",
618
- lines=5
619
- )
620
-
621
- with gr.Column(scale=1):
622
- file_input2 = gr.File(
623
- label="Upload second FASTA file",
624
- file_types=[".fasta", ".fa", ".txt"],
625
- type="filepath"
626
- )
627
- text_input2 = gr.Textbox(
628
- label="Or paste second FASTA sequence",
629
- placeholder=">sequence2\nACGTACGT...",
630
- lines=5
631
- )
632
-
633
- compare_btn = gr.Button("Compare Sequences", variant="primary")
634
-
635
- comparison_text = gr.Textbox(
636
- label="Comparison Results",
637
- lines=12,
638
- interactive=False
639
- )
640
-
641
- with gr.Row():
642
- diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
643
- diff_hist = gr.Image(label="Distribution of SHAP Differences")
644
-
645
- compare_btn.click(
646
- analyze_sequence_comparison,
647
- inputs=[file_input1, file_input2, text_input1, text_input2],
648
- outputs=[comparison_text, diff_heatmap, diff_hist]
649
- )
650
 
651
  gr.Markdown("""
652
  ### Interface Features
@@ -660,22 +578,7 @@ with gr.Blocks(css=css) as iface:
660
  - GC content
661
  - Fraction of positions pushing human vs. non-human
662
  - Simple logic-based classification
663
- - **Sequence Comparison**:
664
- - Compare two sequences to identify regions of difference
665
- - Normalized comparison to handle different sequence lengths
666
- - Statistical summary of differences
667
  """)
668
 
669
  if __name__ == "__main__":
670
- plt.style.use('default')
671
- plt.rcParams['figure.figsize'] = [10, 6]
672
- plt.rcParams['figure.dpi'] = 100
673
- plt.rcParams['font.size'] = 10
674
-
675
- iface.launch(
676
- share=False,
677
- server_name="0.0.0.0",
678
- server_port=7860,
679
- show_api=False,
680
- debug=False
681
- )
 
1
  import gradio as gr
2
  import torch
3
+ import joblib
4
  import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
 
8
  import matplotlib.colors as mcolors
9
  import io
10
  from PIL import Image
 
11
 
12
  ###############################################################################
13
  # 1. MODEL DEFINITION
 
71
 
72
  total_kmers = len(sequence) - k + 1
73
  if total_kmers > 0:
74
+ vec = vec / total_kmers
75
 
76
  return vec
77
 
 
86
  """
87
  model.eval()
88
  with torch.no_grad():
89
+ # Baseline
90
  baseline_output = model(x_tensor)
91
  baseline_probs = torch.softmax(baseline_output, dim=1)
92
+ baseline_prob = baseline_probs[0, 1].item() # Probability of 'human' class
93
 
94
+ # Zeroing each feature to measure impact
95
  shap_values = []
96
  x_zeroed = x_tensor.clone()
97
  for i in range(x_tensor.shape[1]):
 
102
  prob = probs[0, 1].item()
103
  impact = baseline_prob - prob
104
  shap_values.append(impact)
105
+ x_zeroed[0, i] = original_val # restore
106
  return np.array(shap_values), baseline_prob
107
 
108
  ###############################################################################
 
110
  ###############################################################################
111
 
112
  def compute_positionwise_scores(sequence, shap_values, k=4):
113
+ """
114
+ Returns an array of per-base SHAP contributions by averaging
115
+ the k-mer SHAP values of all k-mers covering that base.
116
+ """
117
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
118
  kmer_dict = {km: i for i, km in enumerate(kmers)}
119
 
 
138
  ###############################################################################
139
 
140
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
141
+ """
142
+ Finds the subregion of length `window_size` that has the maximum
143
+ (mode="max") or minimum (mode="min") average SHAP.
144
+ Returns (best_start, best_end, best_avg).
145
+ """
146
  n = len(shap_means)
147
  if n == 0:
148
  return (0, 0, 0.0)
149
  if window_size >= n:
150
+ # entire sequence
151
  avg_val = float(np.mean(shap_means))
152
  return (0, n, avg_val)
153
 
154
+ # We'll build csum of length n+1
155
  csum = np.zeros(n + 1, dtype=np.float32)
156
  csum[1:] = np.cumsum(shap_means)
157
 
 
178
  ###############################################################################
179
 
180
  def fig_to_image(fig):
181
+ """Convert a Matplotlib figure to a PIL Image for Gradio."""
182
  buf = io.BytesIO()
183
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
184
  buf.seek(0)
 
187
  return img
188
 
189
  def get_zero_centered_cmap():
190
+ """
191
+ Creates a custom diverging colormap that is:
192
+ - Blue for negative
193
+ - White for zero
194
+ - Red for positive
195
+ """
196
  colors = [
197
+ (0.0, 'blue'), # negative
198
+ (0.5, 'white'), # zero
199
+ (1.0, 'red') # positive
200
  ]
201
+ cmap = mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
202
+ return cmap
203
 
204
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
205
+ """
206
+ Plots a 1D heatmap of per-base SHAP contributions with a custom colormap:
207
+ - Negative = blue
208
+ - 0 = white
209
+ - Positive = red
210
+ """
211
  if start is not None and end is not None:
212
  local_shap = shap_means[start:end]
213
  subtitle = f" (positions {start}-{end})"
 
218
  if len(local_shap) == 0:
219
  local_shap = np.array([0.0])
220
 
221
+ # Build 2D array for imshow
222
  heatmap_data = local_shap.reshape(1, -1)
223
+
224
+ # Force symmetrical range
225
  min_val = np.min(local_shap)
226
  max_val = np.max(local_shap)
227
  extent = max(abs(min_val), abs(max_val))
 
228
 
229
+ # Create custom colormap
230
+ custom_cmap = get_zero_centered_cmap()
231
+
232
+ # Create figure with adjusted height ratio
233
+ fig, ax = plt.subplots(figsize=(12, 1.8)) # Reduced height
234
+
235
+ # Plot heatmap
236
  cax = ax.imshow(
237
  heatmap_data,
238
  aspect='auto',
239
+ cmap=custom_cmap,
240
  vmin=-extent,
241
+ vmax=+extent
242
  )
243
+
244
+ # Configure colorbar with more subtle positioning
245
  cbar = plt.colorbar(
246
  cax,
247
  orientation='horizontal',
248
+ pad=0.25, # Reduced padding
249
+ aspect=40, # Make colorbar thinner
250
+ shrink=0.8 # Make colorbar shorter than plot width
251
  )
 
 
252
 
253
+ # Style the colorbar
254
+ cbar.ax.tick_params(labelsize=8) # Smaller tick labels
255
+ cbar.set_label(
256
+ 'SHAP Contribution',
257
+ fontsize=9,
258
+ labelpad=5
259
+ )
260
+
261
+ # Configure main plot
262
  ax.set_yticks([])
263
  ax.set_xlabel('Position in Sequence', fontsize=10)
264
  ax.set_title(f"{title}{subtitle}", pad=10)
265
+
266
+ # Fine-tune layout
267
+ plt.subplots_adjust(
268
+ bottom=0.25, # Reduced bottom margin
269
+ left=0.05, # Tighter left margin
270
+ right=0.95 # Tighter right margin
271
+ )
272
 
273
  return fig
274
 
275
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
276
+ """Create a bar plot of the most important k-mers."""
277
  plt.rcParams.update({'font.size': 10})
278
  fig = plt.figure(figsize=(10, 5))
279
 
280
+ # Sort by absolute importance
281
  indices = np.argsort(np.abs(shap_values))[-top_k:]
282
  values = shap_values[indices]
283
  features = [kmers[i] for i in indices]
284
 
285
+ # negative -> blue, positive -> red
286
  colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
287
+
288
  plt.barh(range(len(values)), values, color=colors)
289
  plt.yticks(range(len(values)), features)
290
  plt.xlabel('SHAP Value (impact on model output)')
 
294
  return fig
295
 
296
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
297
+ """
298
+ Simple histogram of SHAP values in the subregion.
299
+ """
300
  fig, ax = plt.subplots(figsize=(6, 4))
301
  ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
302
  ax.axvline(0, color='red', linestyle='--', label='0.0')
 
308
  return fig
309
 
310
  def compute_gc_content(sequence):
311
+ """Compute %GC in the sequence (A, C, G, T)."""
312
  if not sequence:
313
  return 0
314
  gc_count = sequence.count('G') + sequence.count('C')
315
  return (gc_count / len(sequence)) * 100.0
316
 
317
  ###############################################################################
318
+ # 7. MAIN ANALYSIS STEP (Gradio Step 1)
319
  ###############################################################################
320
 
321
+ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
 
 
 
 
 
 
 
 
 
322
  """
323
+ Analyzes the entire genome, returning classification, full-genome heatmap,
324
+ top k-mer bar plot, and identifies subregions with strongest positive/negative push.
325
  """
326
+ # Handle input
327
+ if fasta_text.strip():
328
+ text = fasta_text.strip()
329
+ elif file_obj is not None:
330
+ try:
331
+ with open(file_obj, 'r') as f:
332
+ text = f.read()
333
+ except Exception as e:
334
+ return (f"Error reading file: {str(e)}", None, None, None, None)
335
+ else:
336
+ return ("Please provide a FASTA sequence.", None, None, None, None)
337
+
338
+ # Parse FASTA
339
+ sequences = parse_fasta(text)
340
+ if not sequences:
341
+ return ("No valid FASTA sequences found.", None, None, None, None)
342
+
343
+ header, seq = sequences[0]
344
+
345
+ # Load model and scaler
346
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
347
  try:
348
+ # Use weights_only=True for safer loading
349
+ state_dict = torch.load('model.pt', map_location=device, weights_only=True)
350
+ model = VirusClassifier(256).to(device)
351
+ model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
+ scaler = joblib.load('scaler.pkl')
354
  except Exception as e:
355
+ return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
356
+
357
+ # Vectorize + scale
358
+ freq_vector = sequence_to_kmer_vector(seq)
359
+ scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
360
+ x_tensor = torch.FloatTensor(scaled_vector).to(device)
361
+
362
+ # SHAP + classification
363
+ shap_values, prob_human = calculate_shap_values(model, x_tensor)
364
+ prob_nonhuman = 1.0 - prob_human
365
+
366
+ classification = "Human" if prob_human > 0.5 else "Non-human"
367
+ confidence = max(prob_human, prob_nonhuman)
368
+
369
+ # Per-base SHAP
370
+ shap_means = compute_positionwise_scores(seq, shap_values, k=4)
371
+
372
+ # Find the most "human-pushing" region
373
+ (max_start, max_end, max_avg) = find_extreme_subregion(shap_means, window_size, mode="max")
374
+ # Find the most "non-human–pushing" region
375
+ (min_start, min_end, min_avg) = find_extreme_subregion(shap_means, window_size, mode="min")
376
+
377
+ # Build results text
378
+ results_text = (
379
+ f"Sequence: {header}\n"
380
+ f"Length: {len(seq):,} bases\n"
381
+ f"Classification: {classification}\n"
382
+ f"Confidence: {confidence:.3f}\n"
383
+ f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
384
+ f"---\n"
385
+ f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
386
+ f"Start: {max_start}, End: {max_end}, Avg SHAP: {max_avg:.4f}\n\n"
387
+ f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
388
+ f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
389
+ )
390
+
391
+ # K-mer importance plot
392
+ kmers = [''.join(p) for p in product("ACGT", repeat=4)]
393
+ bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
394
+ bar_img = fig_to_image(bar_fig)
395
+
396
+ # Full-genome SHAP heatmap
397
+ heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
398
+ heatmap_img = fig_to_image(heatmap_fig)
399
+
400
+ # Store data for subregion analysis
401
+ state_dict_out = {
402
+ "seq": seq,
403
+ "shap_means": shap_means
404
+ }
405
+
406
+ return (results_text, bar_img, heatmap_img, state_dict_out, header)
407
 
408
  ###############################################################################
409
+ # 8. SUBREGION ANALYSIS (Gradio Step 2)
410
  ###############################################################################
411
 
412
  def analyze_subregion(state, header, region_start, region_end):
413
+ """
414
+ Takes stored data from step 1 and a user-chosen region.
415
+ Returns a subregion heatmap, histogram, and some stats (GC, average SHAP).
416
+ """
417
  if not state or "seq" not in state or "shap_means" not in state:
418
  return ("No sequence data found. Please run Step 1 first.", None, None)
419
 
420
  seq = state["seq"]
421
  shap_means = state["shap_means"]
422
 
423
+ # Validate bounds
424
  region_start = int(region_start)
425
  region_end = int(region_end)
426
 
 
429
  if region_end <= region_start:
430
  return ("Invalid region range. End must be > Start.", None, None)
431
 
432
+ # Subsequence
433
  region_seq = seq[region_start:region_end]
434
  region_shap = shap_means[region_start:region_end]
435
 
436
+ # Some stats
437
  gc_percent = compute_gc_content(region_seq)
438
  avg_shap = float(np.mean(region_shap))
439
 
440
+ # Fraction pushing toward human vs. non-human
441
  positive_fraction = np.mean(region_shap > 0)
442
  negative_fraction = np.mean(region_shap < 0)
443
 
444
+ # Simple logic-based interpretation
445
  if avg_shap > 0.05:
446
  region_classification = "Likely pushing toward human"
447
  elif avg_shap < -0.05:
 
459
  f"Subregion interpretation: {region_classification}\n"
460
  )
461
 
462
+ # Plot region as small heatmap
463
  heatmap_fig = plot_linear_heatmap(
464
  shap_means,
465
  title="Subregion SHAP",
 
468
  )
469
  heatmap_img = fig_to_image(heatmap_fig)
470
 
471
+ # Plot histogram of SHAP in region
472
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
473
  hist_img = fig_to_image(hist_fig)
474
 
475
  return (region_info, heatmap_img, hist_img)
476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
  ###############################################################################
479
+ # 9. BUILD GRADIO INTERFACE
480
  ###############################################################################
481
 
482
  css = """
 
507
  placeholder=">sequence_name\nACGTACGT...",
508
  lines=5
509
  )
510
+ top_k = gr.Slider(
511
  minimum=5,
512
  maximum=30,
513
  value=10,
514
  step=1,
515
  label="Number of top k-mers to display"
516
  )
517
+ win_size = gr.Slider(
518
  minimum=100,
519
  maximum=5000,
520
  value=500,
 
533
  seq_state = gr.State()
534
  header_state = gr.State()
535
 
536
+ # analyze_sequence(...) returns 5 items
537
  analyze_btn.click(
538
  analyze_sequence,
539
+ inputs=[file_input, top_k, text_input, win_size],
540
  outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
541
  )
542
 
 
565
  inputs=[seq_state, header_state, region_start, region_end],
566
  outputs=[subregion_info, subregion_img, subregion_hist_img]
567
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
  gr.Markdown("""
570
  ### Interface Features
 
578
  - GC content
579
  - Fraction of positions pushing human vs. non-human
580
  - Simple logic-based classification
 
 
 
 
581
  """)
582
 
583
  if __name__ == "__main__":
584
+ iface.launch()