hiyata commited on
Commit
455bf4d
·
verified ·
1 Parent(s): 0d6258f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -201
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import torch
3
- import joblib
4
  import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
@@ -72,7 +71,7 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
72
 
73
  total_kmers = len(sequence) - k + 1
74
  if total_kmers > 0:
75
- vec = vec / total_kmers
76
 
77
  return vec
78
 
@@ -87,12 +86,10 @@ def calculate_shap_values(model, x_tensor):
87
  """
88
  model.eval()
89
  with torch.no_grad():
90
- # Baseline
91
  baseline_output = model(x_tensor)
92
  baseline_probs = torch.softmax(baseline_output, dim=1)
93
  baseline_prob = baseline_probs[0, 1].item() # Probability of 'human'
94
 
95
- # Zeroing each feature to measure impact
96
  shap_values = []
97
  x_zeroed = x_tensor.clone()
98
  for i in range(x_tensor.shape[1]):
@@ -100,10 +97,10 @@ def calculate_shap_values(model, x_tensor):
100
  x_zeroed[0, i] = 0.0
101
  output = model(x_zeroed)
102
  probs = torch.softmax(output, dim=1)
103
- prob = probs[0, 1].item() # Probability of 'human'
104
  impact = baseline_prob - prob
105
  shap_values.append(impact)
106
- x_zeroed[0, i] = original_val # restore
107
  return np.array(shap_values), baseline_prob
108
 
109
  ###############################################################################
@@ -111,10 +108,6 @@ def calculate_shap_values(model, x_tensor):
111
  ###############################################################################
112
 
113
  def compute_positionwise_scores(sequence, shap_values, k=4):
114
- """
115
- Returns an array of per-base SHAP contributions by averaging
116
- the k-mer SHAP values of all k-mers covering that base.
117
- """
118
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
119
  kmer_dict = {km: i for i, km in enumerate(kmers)}
120
 
@@ -139,20 +132,13 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
139
  ###############################################################################
140
 
141
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
142
- """
143
- Finds the subregion of length `window_size` that has the maximum
144
- (mode="max") or minimum (mode="min") average SHAP.
145
- Returns (best_start, best_end, best_avg).
146
- """
147
  n = len(shap_means)
148
  if n == 0:
149
  return (0, 0, 0.0)
150
  if window_size >= n:
151
- # entire sequence
152
  avg_val = float(np.mean(shap_means))
153
  return (0, n, avg_val)
154
 
155
- # We'll build csum of length n+1
156
  csum = np.zeros(n + 1, dtype=np.float32)
157
  csum[1:] = np.cumsum(shap_means)
158
 
@@ -179,7 +165,6 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
179
  ###############################################################################
180
 
181
  def fig_to_image(fig):
182
- """Convert a Matplotlib figure to a PIL Image for Gradio."""
183
  buf = io.BytesIO()
184
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
185
  buf.seek(0)
@@ -188,27 +173,14 @@ def fig_to_image(fig):
188
  return img
189
 
190
  def get_zero_centered_cmap():
191
- """
192
- Creates a custom diverging colormap that is:
193
- - Blue for negative
194
- - White for zero
195
- - Red for positive
196
- """
197
  colors = [
198
- (0.0, 'blue'), # negative
199
- (0.5, 'white'), # zero
200
- (1.0, 'red') # positive
201
  ]
202
- cmap = mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
203
- return cmap
204
 
205
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
206
- """
207
- Plots a 1D heatmap of per-base SHAP contributions with a custom colormap:
208
- - Negative = blue
209
- - 0 = white
210
- - Positive = red
211
- """
212
  if start is not None and end is not None:
213
  local_shap = shap_means[start:end]
214
  subtitle = f" (positions {start}-{end})"
@@ -219,73 +191,46 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
219
  if len(local_shap) == 0:
220
  local_shap = np.array([0.0])
221
 
222
- # Build 2D array for imshow
223
  heatmap_data = local_shap.reshape(1, -1)
224
-
225
- # Force symmetrical range
226
  min_val = np.min(local_shap)
227
  max_val = np.max(local_shap)
228
  extent = max(abs(min_val), abs(max_val))
 
229
 
230
- # Create custom colormap
231
- custom_cmap = get_zero_centered_cmap()
232
-
233
- # Create figure with adjusted height ratio
234
- fig, ax = plt.subplots(figsize=(12, 1.8)) # Reduced height
235
-
236
- # Plot heatmap
237
  cax = ax.imshow(
238
  heatmap_data,
239
  aspect='auto',
240
- cmap=custom_cmap,
241
  vmin=-extent,
242
- vmax=+extent
243
  )
244
-
245
- # Configure colorbar with more subtle positioning
246
  cbar = plt.colorbar(
247
  cax,
248
  orientation='horizontal',
249
- pad=0.25, # Reduced padding
250
- aspect=40, # Make colorbar thinner
251
- shrink=0.8 # Make colorbar shorter than plot width
252
- )
253
-
254
- # Style the colorbar
255
- cbar.ax.tick_params(labelsize=8) # Smaller tick labels
256
- cbar.set_label(
257
- 'SHAP Contribution',
258
- fontsize=9,
259
- labelpad=5
260
  )
 
 
261
 
262
- # Configure main plot
263
  ax.set_yticks([])
264
  ax.set_xlabel('Position in Sequence', fontsize=10)
265
  ax.set_title(f"{title}{subtitle}", pad=10)
266
-
267
- # Fine-tune layout
268
- plt.subplots_adjust(
269
- bottom=0.25, # Reduced bottom margin
270
- left=0.05, # Tighter left margin
271
- right=0.95 # Tighter right margin
272
- )
273
 
274
  return fig
275
 
276
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
277
- """Create a bar plot of the most important k-mers."""
278
  plt.rcParams.update({'font.size': 10})
279
  fig = plt.figure(figsize=(10, 5))
280
 
281
- # Sort by absolute importance
282
  indices = np.argsort(np.abs(shap_values))[-top_k:]
283
  values = shap_values[indices]
284
  features = [kmers[i] for i in indices]
285
 
286
- # negative -> blue, positive -> red
287
  colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
288
-
289
  plt.barh(range(len(values)), values, color=colors)
290
  plt.yticks(range(len(values)), features)
291
  plt.xlabel('SHAP Value (impact on model output)')
@@ -295,9 +240,6 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
295
  return fig
296
 
297
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
298
- """
299
- Simple histogram of SHAP values in the subregion.
300
- """
301
  fig, ax = plt.subplots(figsize=(6, 4))
302
  ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
303
  ax.axvline(0, color='red', linestyle='--', label='0.0')
@@ -309,7 +251,6 @@ def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
309
  return fig
310
 
311
  def compute_gc_content(sequence):
312
- """Compute %GC in the sequence (A, C, G, T)."""
313
  if not sequence:
314
  return 0
315
  gc_count = sequence.count('G') + sequence.count('C')
@@ -319,78 +260,72 @@ def compute_gc_content(sequence):
319
  # 7. SEQUENCE ANALYSIS FUNCTIONS
320
  ###############################################################################
321
 
 
 
 
 
 
 
 
 
 
322
  def analyze_sequence(file_path, top_k=10, fasta_text="", window_size=500):
323
  """
324
  Analyze a virus sequence from a FASTA file or text input.
325
  Returns (results_text, kmer_plot, heatmap_plot, state_dict, header)
326
  """
327
  try:
328
- # Load model and k-mer info
329
- model = VirusClassifier(256) # 4^4 = 256 k-mers for k=4
330
- model.load_state_dict(torch.load("model.pt"))
331
- model.eval()
332
- kmers = [''.join(p) for p in product("ACGT", repeat=4)]
333
-
334
- # Process input (file takes precedence over text)
335
  if file_path:
336
  with open(file_path, 'r') as f:
337
  fasta_text = f.read()
338
 
339
  if not fasta_text.strip():
340
  return ("Error: No sequence provided", None, None, {}, "")
341
-
342
- # Parse FASTA
343
  sequences = parse_fasta(fasta_text)
344
  if not sequences:
345
  return ("Error: No valid FASTA sequences found", None, None, {}, "")
346
-
347
- header, sequence = sequences[0] # Take first sequence
348
 
349
- # Convert to k-mer frequencies
350
- x = sequence_to_kmer_vector(sequence)
351
- x_tensor = torch.tensor(x).float().unsqueeze(0)
 
352
 
353
- # Get model prediction
354
  with torch.no_grad():
355
  output = model(x_tensor)
356
  probs = torch.softmax(output, dim=1)
357
- # Using index 1 for probability of human
358
  pred_human = probs[0, 1].item()
359
-
360
- # Calculate SHAP values
361
- shap_values, prob = calculate_shap_values(model, x_tensor)
362
 
363
- # Find most extreme regions
364
- shap_means = compute_positionwise_scores(sequence, shap_values)
 
 
 
 
365
  start_max, end_max, avg_max = find_extreme_subregion(shap_means, window_size, mode="max")
366
  start_min, end_min, avg_min = find_extreme_subregion(shap_means, window_size, mode="min")
367
 
368
- # Format results text
369
- classification = "Human" if pred_human > 0.5 else "Non-human"
370
  results = (
371
  f"Classification: {classification} "
372
  f"(probability of human = {pred_human:.3f})\n\n"
373
  f"Sequence length: {len(sequence):,} bases\n"
374
  f"Overall GC content: {compute_gc_content(sequence):.1f}%\n\n"
375
- f"Most human-like {window_size}bp region:\n"
376
  f"Position {start_max:,} to {end_max:,}\n"
377
  f"Average SHAP: {avg_max:.4f}\n"
378
  f"GC content: {compute_gc_content(sequence[start_max:end_max]):.1f}%\n\n"
379
- f"Least human-like {window_size}bp region:\n"
380
  f"Position {start_min:,} to {end_min:,}\n"
381
  f"Average SHAP: {avg_min:.4f}\n"
382
  f"GC content: {compute_gc_content(sequence[start_min:end_min]):.1f}%"
383
  )
384
 
385
- # Create k-mer importance plot
386
- kmer_fig = create_importance_bar_plot(shap_values, kmers, top_k)
387
  kmer_img = fig_to_image(kmer_fig)
388
 
389
- # Create genome-wide heatmap
390
  heatmap_fig = plot_linear_heatmap(shap_means)
391
  heatmap_img = fig_to_image(heatmap_fig)
392
 
393
- # Store data for subregion analysis
394
  state = {
395
  "seq": sequence,
396
  "shap_means": shap_means
@@ -399,21 +334,19 @@ def analyze_sequence(file_path, top_k=10, fasta_text="", window_size=500):
399
  return results, kmer_img, heatmap_img, state, header
400
 
401
  except Exception as e:
402
- error_msg = f"Error analyzing sequence: {str(e)}"
403
- return (error_msg, None, None, {}, "")
 
 
 
404
 
405
  def analyze_subregion(state, header, region_start, region_end):
406
- """
407
- Takes stored data from step 1 and a user-chosen region.
408
- Returns a subregion heatmap, histogram, and some stats (GC, average SHAP).
409
- """
410
  if not state or "seq" not in state or "shap_means" not in state:
411
  return ("No sequence data found. Please run Step 1 first.", None, None)
412
 
413
  seq = state["seq"]
414
  shap_means = state["shap_means"]
415
 
416
- # Validate bounds
417
  region_start = int(region_start)
418
  region_end = int(region_end)
419
 
@@ -422,19 +355,15 @@ def analyze_subregion(state, header, region_start, region_end):
422
  if region_end <= region_start:
423
  return ("Invalid region range. End must be > Start.", None, None)
424
 
425
- # Subsequence
426
  region_seq = seq[region_start:region_end]
427
  region_shap = shap_means[region_start:region_end]
428
 
429
- # Some stats
430
  gc_percent = compute_gc_content(region_seq)
431
  avg_shap = float(np.mean(region_shap))
432
 
433
- # Fraction pushing toward human vs. non-human
434
  positive_fraction = np.mean(region_shap > 0)
435
  negative_fraction = np.mean(region_shap < 0)
436
 
437
- # Simple logic-based interpretation
438
  if avg_shap > 0.05:
439
  region_classification = "Likely pushing toward human"
440
  elif avg_shap < -0.05:
@@ -452,7 +381,6 @@ def analyze_subregion(state, header, region_start, region_end):
452
  f"Subregion interpretation: {region_classification}\n"
453
  )
454
 
455
- # Plot region as small heatmap
456
  heatmap_fig = plot_linear_heatmap(
457
  shap_means,
458
  title="Subregion SHAP",
@@ -461,72 +389,45 @@ def analyze_subregion(state, header, region_start, region_end):
461
  )
462
  heatmap_img = fig_to_image(heatmap_fig)
463
 
464
- # Plot histogram of SHAP in region
465
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
466
  hist_img = fig_to_image(hist_fig)
467
 
468
  return (region_info, heatmap_img, hist_img)
469
 
470
  ###############################################################################
471
- # 8. COMPARISON ANALYSIS FUNCTIONS
472
  ###############################################################################
473
 
474
  def normalize_shap_lengths(shap1, shap2, num_points=1000):
475
- """
476
- Normalize two SHAP arrays to the same length using interpolation.
477
- Returns (normalized_shap1, normalized_shap2)
478
- """
479
- # Create x coordinates for both sequences
480
  x1 = np.linspace(0, 1, len(shap1))
481
  x2 = np.linspace(0, 1, len(shap2))
482
 
483
- # Create interpolation functions
484
  f1 = interp1d(x1, shap1, kind='linear')
485
  f2 = interp1d(x2, shap2, kind='linear')
486
 
487
- # Create new x coordinates for interpolation
488
  x_new = np.linspace(0, 1, num_points)
489
 
490
- # Interpolate both sequences to new length
491
  shap1_norm = f1(x_new)
492
  shap2_norm = f2(x_new)
493
 
494
  return shap1_norm, shap2_norm
495
 
496
  def compute_shap_difference(shap1_norm, shap2_norm):
497
- """
498
- Compute the difference between two normalized SHAP arrays.
499
- Positive values indicate seq2 is more "human-like" than seq1.
500
- """
501
  return shap2_norm - shap1_norm
502
 
503
  def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
504
- """
505
- Plot the difference between two sequences' SHAP values.
506
- Red indicates seq2 is more human-like, blue indicates seq1 is more human-like.
507
- """
508
- # Build 2D array for imshow
509
  heatmap_data = shap_diff.reshape(1, -1)
510
-
511
- # Force symmetrical range
512
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
 
513
 
514
- # Create figure with adjusted height ratio
515
  fig, ax = plt.subplots(figsize=(12, 1.8))
516
-
517
- # Create custom colormap
518
- custom_cmap = get_zero_centered_cmap()
519
-
520
- # Plot heatmap
521
  cax = ax.imshow(
522
  heatmap_data,
523
  aspect='auto',
524
- cmap=custom_cmap,
525
  vmin=-extent,
526
- vmax=+extent
527
  )
528
-
529
- # Configure colorbar
530
  cbar = plt.colorbar(
531
  cax,
532
  orientation='horizontal',
@@ -534,74 +435,47 @@ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
534
  aspect=40,
535
  shrink=0.8
536
  )
537
-
538
- # Style the colorbar
539
  cbar.ax.tick_params(labelsize=8)
540
- cbar.set_label(
541
- 'SHAP Difference (Seq2 - Seq1)',
542
- fontsize=9,
543
- labelpad=5
544
- )
545
 
546
- # Configure main plot
547
  ax.set_yticks([])
548
  ax.set_xlabel('Normalized Position (0-100%)', fontsize=10)
549
  ax.set_title(title, pad=10)
550
-
551
- plt.subplots_adjust(
552
- bottom=0.25,
553
- left=0.05,
554
- right=0.95
555
- )
556
 
557
  return fig
558
 
559
  def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
560
- """
561
- Compare two sequences by analyzing their SHAP differences.
562
- Returns comparison text and visualizations.
563
- """
564
- # Process first sequence
565
- results1 = analyze_sequence(file1, fasta_text=fasta1)
566
  if isinstance(results1[0], str) and "Error" in results1[0]:
567
  return (f"Error in sequence 1: {results1[0]}", None, None)
568
 
569
- # Process second sequence
570
- results2 = analyze_sequence(file2, fasta_text=fasta2)
571
  if isinstance(results2[0], str) and "Error" in results2[0]:
572
  return (f"Error in sequence 2: {results2[0]}", None, None)
573
 
574
- # Get SHAP means from state dictionaries
575
  shap1 = results1[3]["shap_means"]
576
  shap2 = results2[3]["shap_means"]
577
 
578
- # Normalize lengths
579
  shap1_norm, shap2_norm = normalize_shap_lengths(shap1, shap2)
580
-
581
- # Compute difference (positive = seq2 more human-like)
582
  shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
583
 
584
- # Calculate statistics
585
  avg_diff = np.mean(shap_diff)
586
  std_diff = np.std(shap_diff)
587
  max_diff = np.max(shap_diff)
588
  min_diff = np.min(shap_diff)
589
 
590
- # Calculate what fraction of positions show substantial differences
591
- threshold = 0.05 # Arbitrary threshold for "substantial" difference
592
  substantial_diffs = np.abs(shap_diff) > threshold
593
  frac_different = np.mean(substantial_diffs)
594
 
595
- # Extract classifications safely
596
  classification1 = results1[0].split('Classification: ')[1].split('\n')[0].strip()
597
  classification2 = results2[0].split('Classification: ')[1].split('\n')[0].strip()
598
 
599
- # Format numbers
600
  len1_formatted = "{:,}".format(len(shap1))
601
  len2_formatted = "{:,}".format(len(shap2))
602
  frac_formatted = "{:.2%}".format(frac_different)
603
 
604
- # Build comparison text
605
  comparison_text = (
606
  "Sequence Comparison Results:\n"
607
  f"Sequence 1: {results1[4]}\n"
@@ -621,21 +495,16 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
621
  "Negative values (blue) indicate regions where Sequence 1 is more 'human-like'"
622
  )
623
 
624
- # Create comparison heatmap
625
  heatmap_fig = plot_comparative_heatmap(shap_diff)
626
  heatmap_img = fig_to_image(heatmap_fig)
627
 
628
- # Create histogram of differences
629
- hist_fig = plot_shap_histogram(
630
- shap_diff,
631
- title="Distribution of SHAP Differences"
632
- )
633
  hist_img = fig_to_image(hist_fig)
634
 
635
  return comparison_text, heatmap_img, hist_img
636
 
637
  ###############################################################################
638
- # 9. BUILD GRADIO INTERFACE
639
  ###############################################################################
640
 
641
  css = """
@@ -666,14 +535,14 @@ with gr.Blocks(css=css) as iface:
666
  placeholder=">sequence_name\nACGTACGT...",
667
  lines=5
668
  )
669
- top_k = gr.Slider(
670
  minimum=5,
671
  maximum=30,
672
  value=10,
673
  step=1,
674
  label="Number of top k-mers to display"
675
  )
676
- win_size = gr.Slider(
677
  minimum=100,
678
  maximum=5000,
679
  value=500,
@@ -694,7 +563,7 @@ with gr.Blocks(css=css) as iface:
694
 
695
  analyze_btn.click(
696
  analyze_sequence,
697
- inputs=[file_input, top_k, text_input, win_size],
698
  outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
699
  )
700
 
@@ -797,22 +666,16 @@ with gr.Blocks(css=css) as iface:
797
  - Statistical summary of differences
798
  """)
799
 
800
- ###############################################################################
801
- # 10. MAIN EXECUTION
802
- ###############################################################################
803
-
804
  if __name__ == "__main__":
805
- # Set up any global configurations if needed
806
  plt.style.use('default')
807
  plt.rcParams['figure.figsize'] = [10, 6]
808
  plt.rcParams['figure.dpi'] = 100
809
  plt.rcParams['font.size'] = 10
810
 
811
- # Launch the interface
812
  iface.launch(
813
- share=False, # Set to True to create a public link
814
- server_name="0.0.0.0", # Listen on all network interfaces
815
- server_port=7860, # Default Gradio port
816
- show_api=False, # Hide API docs
817
- debug=False # Set to True for debugging
818
  )
 
1
  import gradio as gr
2
  import torch
 
3
  import numpy as np
4
  from itertools import product
5
  import torch.nn as nn
 
71
 
72
  total_kmers = len(sequence) - k + 1
73
  if total_kmers > 0:
74
+ vec /= total_kmers
75
 
76
  return vec
77
 
 
86
  """
87
  model.eval()
88
  with torch.no_grad():
 
89
  baseline_output = model(x_tensor)
90
  baseline_probs = torch.softmax(baseline_output, dim=1)
91
  baseline_prob = baseline_probs[0, 1].item() # Probability of 'human'
92
 
 
93
  shap_values = []
94
  x_zeroed = x_tensor.clone()
95
  for i in range(x_tensor.shape[1]):
 
97
  x_zeroed[0, i] = 0.0
98
  output = model(x_zeroed)
99
  probs = torch.softmax(output, dim=1)
100
+ prob = probs[0, 1].item()
101
  impact = baseline_prob - prob
102
  shap_values.append(impact)
103
+ x_zeroed[0, i] = original_val
104
  return np.array(shap_values), baseline_prob
105
 
106
  ###############################################################################
 
108
  ###############################################################################
109
 
110
  def compute_positionwise_scores(sequence, shap_values, k=4):
 
 
 
 
111
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
112
  kmer_dict = {km: i for i, km in enumerate(kmers)}
113
 
 
132
  ###############################################################################
133
 
134
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
 
 
 
 
 
135
  n = len(shap_means)
136
  if n == 0:
137
  return (0, 0, 0.0)
138
  if window_size >= n:
 
139
  avg_val = float(np.mean(shap_means))
140
  return (0, n, avg_val)
141
 
 
142
  csum = np.zeros(n + 1, dtype=np.float32)
143
  csum[1:] = np.cumsum(shap_means)
144
 
 
165
  ###############################################################################
166
 
167
  def fig_to_image(fig):
 
168
  buf = io.BytesIO()
169
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
170
  buf.seek(0)
 
173
  return img
174
 
175
  def get_zero_centered_cmap():
 
 
 
 
 
 
176
  colors = [
177
+ (0.0, 'blue'),
178
+ (0.5, 'white'),
179
+ (1.0, 'red')
180
  ]
181
+ return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
 
182
 
183
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
 
 
 
 
 
 
184
  if start is not None and end is not None:
185
  local_shap = shap_means[start:end]
186
  subtitle = f" (positions {start}-{end})"
 
191
  if len(local_shap) == 0:
192
  local_shap = np.array([0.0])
193
 
 
194
  heatmap_data = local_shap.reshape(1, -1)
 
 
195
  min_val = np.min(local_shap)
196
  max_val = np.max(local_shap)
197
  extent = max(abs(min_val), abs(max_val))
198
+ cmap = get_zero_centered_cmap()
199
 
200
+ fig, ax = plt.subplots(figsize=(12, 1.8))
 
 
 
 
 
 
201
  cax = ax.imshow(
202
  heatmap_data,
203
  aspect='auto',
204
+ cmap=cmap,
205
  vmin=-extent,
206
+ vmax=extent
207
  )
 
 
208
  cbar = plt.colorbar(
209
  cax,
210
  orientation='horizontal',
211
+ pad=0.25,
212
+ aspect=40,
213
+ shrink=0.8
 
 
 
 
 
 
 
 
214
  )
215
+ cbar.ax.tick_params(labelsize=8)
216
+ cbar.set_label('SHAP Contribution', fontsize=9, labelpad=5)
217
 
 
218
  ax.set_yticks([])
219
  ax.set_xlabel('Position in Sequence', fontsize=10)
220
  ax.set_title(f"{title}{subtitle}", pad=10)
221
+ plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
 
 
 
 
 
 
222
 
223
  return fig
224
 
225
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
 
226
  plt.rcParams.update({'font.size': 10})
227
  fig = plt.figure(figsize=(10, 5))
228
 
 
229
  indices = np.argsort(np.abs(shap_values))[-top_k:]
230
  values = shap_values[indices]
231
  features = [kmers[i] for i in indices]
232
 
 
233
  colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
 
234
  plt.barh(range(len(values)), values, color=colors)
235
  plt.yticks(range(len(values)), features)
236
  plt.xlabel('SHAP Value (impact on model output)')
 
240
  return fig
241
 
242
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
 
 
 
243
  fig, ax = plt.subplots(figsize=(6, 4))
244
  ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
245
  ax.axvline(0, color='red', linestyle='--', label='0.0')
 
251
  return fig
252
 
253
  def compute_gc_content(sequence):
 
254
  if not sequence:
255
  return 0
256
  gc_count = sequence.count('G') + sequence.count('C')
 
260
  # 7. SEQUENCE ANALYSIS FUNCTIONS
261
  ###############################################################################
262
 
263
+ # Set up device and load the model once globally
264
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
265
+ model = VirusClassifier(256)
266
+ model.load_state_dict(torch.load("model.pt", map_location=device))
267
+ model.to(device)
268
+ model.eval()
269
+
270
+ KMERS_4 = [''.join(p) for p in product("ACGT", repeat=4)]
271
+
272
  def analyze_sequence(file_path, top_k=10, fasta_text="", window_size=500):
273
  """
274
  Analyze a virus sequence from a FASTA file or text input.
275
  Returns (results_text, kmer_plot, heatmap_plot, state_dict, header)
276
  """
277
  try:
 
 
 
 
 
 
 
278
  if file_path:
279
  with open(file_path, 'r') as f:
280
  fasta_text = f.read()
281
 
282
  if not fasta_text.strip():
283
  return ("Error: No sequence provided", None, None, {}, "")
284
+
 
285
  sequences = parse_fasta(fasta_text)
286
  if not sequences:
287
  return ("Error: No valid FASTA sequences found", None, None, {}, "")
 
 
288
 
289
+ header, sequence = sequences[0]
290
+
291
+ x = sequence_to_kmer_vector(sequence, k=4)
292
+ x_tensor = torch.tensor(x).float().unsqueeze(0).to(device)
293
 
 
294
  with torch.no_grad():
295
  output = model(x_tensor)
296
  probs = torch.softmax(output, dim=1)
 
297
  pred_human = probs[0, 1].item()
 
 
 
298
 
299
+ classification = "Human" if pred_human > 0.5 else "Non-human"
300
+
301
+ shap_values, baseline_prob = calculate_shap_values(model, x_tensor)
302
+
303
+ shap_means = compute_positionwise_scores(sequence, shap_values, k=4)
304
+
305
  start_max, end_max, avg_max = find_extreme_subregion(shap_means, window_size, mode="max")
306
  start_min, end_min, avg_min = find_extreme_subregion(shap_means, window_size, mode="min")
307
 
 
 
308
  results = (
309
  f"Classification: {classification} "
310
  f"(probability of human = {pred_human:.3f})\n\n"
311
  f"Sequence length: {len(sequence):,} bases\n"
312
  f"Overall GC content: {compute_gc_content(sequence):.1f}%\n\n"
313
+ f"Most human-like {window_size} bp region:\n"
314
  f"Position {start_max:,} to {end_max:,}\n"
315
  f"Average SHAP: {avg_max:.4f}\n"
316
  f"GC content: {compute_gc_content(sequence[start_max:end_max]):.1f}%\n\n"
317
+ f"Least human-like {window_size} bp region:\n"
318
  f"Position {start_min:,} to {end_min:,}\n"
319
  f"Average SHAP: {avg_min:.4f}\n"
320
  f"GC content: {compute_gc_content(sequence[start_min:end_min]):.1f}%"
321
  )
322
 
323
+ kmer_fig = create_importance_bar_plot(shap_values, KMERS_4, top_k=top_k)
 
324
  kmer_img = fig_to_image(kmer_fig)
325
 
 
326
  heatmap_fig = plot_linear_heatmap(shap_means)
327
  heatmap_img = fig_to_image(heatmap_fig)
328
 
 
329
  state = {
330
  "seq": sequence,
331
  "shap_means": shap_means
 
334
  return results, kmer_img, heatmap_img, state, header
335
 
336
  except Exception as e:
337
+ return (f"Error analyzing sequence: {str(e)}", None, None, {}, "")
338
+
339
+ ###############################################################################
340
+ # 8. SUBREGION ANALYSIS FUNCTION
341
+ ###############################################################################
342
 
343
  def analyze_subregion(state, header, region_start, region_end):
 
 
 
 
344
  if not state or "seq" not in state or "shap_means" not in state:
345
  return ("No sequence data found. Please run Step 1 first.", None, None)
346
 
347
  seq = state["seq"]
348
  shap_means = state["shap_means"]
349
 
 
350
  region_start = int(region_start)
351
  region_end = int(region_end)
352
 
 
355
  if region_end <= region_start:
356
  return ("Invalid region range. End must be > Start.", None, None)
357
 
 
358
  region_seq = seq[region_start:region_end]
359
  region_shap = shap_means[region_start:region_end]
360
 
 
361
  gc_percent = compute_gc_content(region_seq)
362
  avg_shap = float(np.mean(region_shap))
363
 
 
364
  positive_fraction = np.mean(region_shap > 0)
365
  negative_fraction = np.mean(region_shap < 0)
366
 
 
367
  if avg_shap > 0.05:
368
  region_classification = "Likely pushing toward human"
369
  elif avg_shap < -0.05:
 
381
  f"Subregion interpretation: {region_classification}\n"
382
  )
383
 
 
384
  heatmap_fig = plot_linear_heatmap(
385
  shap_means,
386
  title="Subregion SHAP",
 
389
  )
390
  heatmap_img = fig_to_image(heatmap_fig)
391
 
 
392
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
393
  hist_img = fig_to_image(hist_fig)
394
 
395
  return (region_info, heatmap_img, hist_img)
396
 
397
  ###############################################################################
398
+ # 9. COMPARISON ANALYSIS FUNCTIONS
399
  ###############################################################################
400
 
401
  def normalize_shap_lengths(shap1, shap2, num_points=1000):
 
 
 
 
 
402
  x1 = np.linspace(0, 1, len(shap1))
403
  x2 = np.linspace(0, 1, len(shap2))
404
 
 
405
  f1 = interp1d(x1, shap1, kind='linear')
406
  f2 = interp1d(x2, shap2, kind='linear')
407
 
 
408
  x_new = np.linspace(0, 1, num_points)
409
 
 
410
  shap1_norm = f1(x_new)
411
  shap2_norm = f2(x_new)
412
 
413
  return shap1_norm, shap2_norm
414
 
415
  def compute_shap_difference(shap1_norm, shap2_norm):
 
 
 
 
416
  return shap2_norm - shap1_norm
417
 
418
  def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
 
 
 
 
 
419
  heatmap_data = shap_diff.reshape(1, -1)
 
 
420
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
421
+ cmap = get_zero_centered_cmap()
422
 
 
423
  fig, ax = plt.subplots(figsize=(12, 1.8))
 
 
 
 
 
424
  cax = ax.imshow(
425
  heatmap_data,
426
  aspect='auto',
427
+ cmap=cmap,
428
  vmin=-extent,
429
+ vmax=extent
430
  )
 
 
431
  cbar = plt.colorbar(
432
  cax,
433
  orientation='horizontal',
 
435
  aspect=40,
436
  shrink=0.8
437
  )
 
 
438
  cbar.ax.tick_params(labelsize=8)
439
+ cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
 
 
 
 
440
 
 
441
  ax.set_yticks([])
442
  ax.set_xlabel('Normalized Position (0-100%)', fontsize=10)
443
  ax.set_title(title, pad=10)
444
+ plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
 
 
 
 
 
445
 
446
  return fig
447
 
448
  def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
449
+ results1 = analyze_sequence(file1, top_k=10, fasta_text=fasta1, window_size=500)
 
 
 
 
 
450
  if isinstance(results1[0], str) and "Error" in results1[0]:
451
  return (f"Error in sequence 1: {results1[0]}", None, None)
452
 
453
+ results2 = analyze_sequence(file2, top_k=10, fasta_text=fasta2, window_size=500)
 
454
  if isinstance(results2[0], str) and "Error" in results2[0]:
455
  return (f"Error in sequence 2: {results2[0]}", None, None)
456
 
 
457
  shap1 = results1[3]["shap_means"]
458
  shap2 = results2[3]["shap_means"]
459
 
 
460
  shap1_norm, shap2_norm = normalize_shap_lengths(shap1, shap2)
 
 
461
  shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
462
 
 
463
  avg_diff = np.mean(shap_diff)
464
  std_diff = np.std(shap_diff)
465
  max_diff = np.max(shap_diff)
466
  min_diff = np.min(shap_diff)
467
 
468
+ threshold = 0.05
 
469
  substantial_diffs = np.abs(shap_diff) > threshold
470
  frac_different = np.mean(substantial_diffs)
471
 
 
472
  classification1 = results1[0].split('Classification: ')[1].split('\n')[0].strip()
473
  classification2 = results2[0].split('Classification: ')[1].split('\n')[0].strip()
474
 
 
475
  len1_formatted = "{:,}".format(len(shap1))
476
  len2_formatted = "{:,}".format(len(shap2))
477
  frac_formatted = "{:.2%}".format(frac_different)
478
 
 
479
  comparison_text = (
480
  "Sequence Comparison Results:\n"
481
  f"Sequence 1: {results1[4]}\n"
 
495
  "Negative values (blue) indicate regions where Sequence 1 is more 'human-like'"
496
  )
497
 
 
498
  heatmap_fig = plot_comparative_heatmap(shap_diff)
499
  heatmap_img = fig_to_image(heatmap_fig)
500
 
501
+ hist_fig = plot_shap_histogram(shap_diff, title="Distribution of SHAP Differences")
 
 
 
 
502
  hist_img = fig_to_image(hist_fig)
503
 
504
  return comparison_text, heatmap_img, hist_img
505
 
506
  ###############################################################################
507
+ # 10. BUILD GRADIO INTERFACE
508
  ###############################################################################
509
 
510
  css = """
 
535
  placeholder=">sequence_name\nACGTACGT...",
536
  lines=5
537
  )
538
+ top_k_slider = gr.Slider(
539
  minimum=5,
540
  maximum=30,
541
  value=10,
542
  step=1,
543
  label="Number of top k-mers to display"
544
  )
545
+ win_size_slider = gr.Slider(
546
  minimum=100,
547
  maximum=5000,
548
  value=500,
 
563
 
564
  analyze_btn.click(
565
  analyze_sequence,
566
+ inputs=[file_input, top_k_slider, text_input, win_size_slider],
567
  outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
568
  )
569
 
 
666
  - Statistical summary of differences
667
  """)
668
 
 
 
 
 
669
  if __name__ == "__main__":
 
670
  plt.style.use('default')
671
  plt.rcParams['figure.figsize'] = [10, 6]
672
  plt.rcParams['figure.dpi'] = 100
673
  plt.rcParams['font.size'] = 10
674
 
 
675
  iface.launch(
676
+ share=False,
677
+ server_name="0.0.0.0",
678
+ server_port=7860,
679
+ show_api=False,
680
+ debug=False
681
  )