hiyata commited on
Commit
77621ec
·
verified ·
1 Parent(s): f5ea8d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -230
app.py CHANGED
@@ -8,6 +8,7 @@ import matplotlib.pyplot as plt
8
  import matplotlib.colors as mcolors
9
  import io
10
  from PIL import Image
 
11
 
12
  ###############################################################################
13
  # 1. MODEL DEFINITION
@@ -38,15 +39,12 @@ class VirusClassifier(nn.Module):
38
  ###############################################################################
39
 
40
  def parse_fasta(text):
41
- """Parse FASTA formatted text into a list of (header, sequence)."""
42
  sequences = []
43
  current_header = None
44
  current_sequence = []
45
-
46
  for line in text.strip().split('\n'):
47
  line = line.strip()
48
- if not line:
49
- continue
50
  if line.startswith('>'):
51
  if current_header:
52
  sequences.append((current_header, ''.join(current_sequence)))
@@ -59,20 +57,16 @@ def parse_fasta(text):
59
  return sequences
60
 
61
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
62
- """Convert a sequence to a k-mer frequency vector for classification."""
63
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
64
  kmer_dict = {km: i for i, km in enumerate(kmers)}
65
  vec = np.zeros(len(kmers), dtype=np.float32)
66
-
67
  for i in range(len(sequence) - k + 1):
68
  kmer = sequence[i:i+k]
69
  if kmer in kmer_dict:
70
  vec[kmer_dict[kmer]] += 1
71
-
72
  total_kmers = len(sequence) - k + 1
73
  if total_kmers > 0:
74
- vec = vec / total_kmers
75
-
76
  return vec
77
 
78
  ###############################################################################
@@ -80,18 +74,11 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
80
  ###############################################################################
81
 
82
  def calculate_shap_values(model, x_tensor):
83
- """
84
- Calculate SHAP values using a simple ablation approach.
85
- Returns shap_values, prob_human
86
- """
87
  model.eval()
88
  with torch.no_grad():
89
- # Baseline
90
  baseline_output = model(x_tensor)
91
  baseline_probs = torch.softmax(baseline_output, dim=1)
92
- baseline_prob = baseline_probs[0, 1].item() # Probability of 'human' class
93
-
94
- # Zeroing each feature to measure impact
95
  shap_values = []
96
  x_zeroed = x_tensor.clone()
97
  for i in range(x_tensor.shape[1]):
@@ -100,9 +87,8 @@ def calculate_shap_values(model, x_tensor):
100
  output = model(x_zeroed)
101
  probs = torch.softmax(output, dim=1)
102
  prob = probs[0, 1].item()
103
- impact = baseline_prob - prob
104
- shap_values.append(impact)
105
- x_zeroed[0, i] = original_val # restore
106
  return np.array(shap_values), baseline_prob
107
 
108
  ###############################################################################
@@ -110,27 +96,19 @@ def calculate_shap_values(model, x_tensor):
110
  ###############################################################################
111
 
112
  def compute_positionwise_scores(sequence, shap_values, k=4):
113
- """
114
- Returns an array of per-base SHAP contributions by averaging
115
- the k-mer SHAP values of all k-mers covering that base.
116
- """
117
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
118
  kmer_dict = {km: i for i, km in enumerate(kmers)}
119
-
120
  seq_len = len(sequence)
121
  shap_sums = np.zeros(seq_len, dtype=np.float32)
122
  coverage = np.zeros(seq_len, dtype=np.float32)
123
-
124
  for i in range(seq_len - k + 1):
125
  kmer = sequence[i:i+k]
126
  if kmer in kmer_dict:
127
  val = shap_values[kmer_dict[kmer]]
128
- shap_sums[i : i + k] += val
129
- coverage[i : i + k] += 1
130
-
131
  with np.errstate(divide='ignore', invalid='ignore'):
132
  shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0)
133
-
134
  return shap_means
135
 
136
  ###############################################################################
@@ -138,39 +116,22 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
138
  ###############################################################################
139
 
140
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
141
- """
142
- Finds the subregion of length `window_size` that has the maximum
143
- (mode="max") or minimum (mode="min") average SHAP.
144
- Returns (best_start, best_end, best_avg).
145
- """
146
  n = len(shap_means)
147
- if n == 0:
148
- return (0, 0, 0.0)
149
  if window_size >= n:
150
- # entire sequence
151
- avg_val = float(np.mean(shap_means))
152
- return (0, n, avg_val)
153
-
154
- # We'll build csum of length n+1
155
  csum = np.zeros(n + 1, dtype=np.float32)
156
  csum[1:] = np.cumsum(shap_means)
157
-
158
  best_start = 0
159
  best_sum = csum[window_size] - csum[0]
160
  best_avg = best_sum / window_size
161
-
162
  for start in range(1, n - window_size + 1):
163
  wsum = csum[start + window_size] - csum[start]
164
  wavg = wsum / window_size
165
- if mode == "max":
166
- if wavg > best_avg:
167
- best_avg = wavg
168
- best_start = start
169
- else: # mode == "min"
170
- if wavg < best_avg:
171
- best_avg = wavg
172
- best_start = start
173
-
174
  return (best_start, best_start + window_size, float(best_avg))
175
 
176
  ###############################################################################
@@ -178,7 +139,6 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
178
  ###############################################################################
179
 
180
  def fig_to_image(fig):
181
- """Convert a Matplotlib figure to a PIL Image for Gradio."""
182
  buf = io.BytesIO()
183
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
184
  buf.seek(0)
@@ -187,104 +147,41 @@ def fig_to_image(fig):
187
  return img
188
 
189
  def get_zero_centered_cmap():
190
- """
191
- Creates a custom diverging colormap that is:
192
- - Blue for negative
193
- - White for zero
194
- - Red for positive
195
- """
196
- colors = [
197
- (0.0, 'blue'), # negative
198
- (0.5, 'white'), # zero
199
- (1.0, 'red') # positive
200
- ]
201
- cmap = mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
202
- return cmap
203
 
204
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
205
- """
206
- Plots a 1D heatmap of per-base SHAP contributions with a custom colormap:
207
- - Negative = blue
208
- - 0 = white
209
- - Positive = red
210
- """
211
  if start is not None and end is not None:
212
  local_shap = shap_means[start:end]
213
  subtitle = f" (positions {start}-{end})"
214
  else:
215
  local_shap = shap_means
216
  subtitle = ""
217
-
218
  if len(local_shap) == 0:
219
  local_shap = np.array([0.0])
220
-
221
- # Build 2D array for imshow
222
  heatmap_data = local_shap.reshape(1, -1)
223
-
224
- # Force symmetrical range
225
  min_val = np.min(local_shap)
226
  max_val = np.max(local_shap)
227
  extent = max(abs(min_val), abs(max_val))
228
-
229
- # Create custom colormap
230
- custom_cmap = get_zero_centered_cmap()
231
-
232
- # Create figure with adjusted height ratio
233
- fig, ax = plt.subplots(figsize=(12, 1.8)) # Reduced height
234
-
235
- # Plot heatmap
236
- cax = ax.imshow(
237
- heatmap_data,
238
- aspect='auto',
239
- cmap=custom_cmap,
240
- vmin=-extent,
241
- vmax=+extent
242
- )
243
-
244
- # Configure colorbar with more subtle positioning
245
- cbar = plt.colorbar(
246
- cax,
247
- orientation='horizontal',
248
- pad=0.25, # Reduced padding
249
- aspect=40, # Make colorbar thinner
250
- shrink=0.8 # Make colorbar shorter than plot width
251
- )
252
-
253
- # Style the colorbar
254
- cbar.ax.tick_params(labelsize=8) # Smaller tick labels
255
- cbar.set_label(
256
- 'SHAP Contribution',
257
- fontsize=9,
258
- labelpad=5
259
- )
260
-
261
- # Configure main plot
262
  ax.set_yticks([])
263
  ax.set_xlabel('Position in Sequence', fontsize=10)
264
  ax.set_title(f"{title}{subtitle}", pad=10)
265
-
266
- # Fine-tune layout
267
- plt.subplots_adjust(
268
- bottom=0.25, # Reduced bottom margin
269
- left=0.05, # Tighter left margin
270
- right=0.95 # Tighter right margin
271
- )
272
-
273
  return fig
274
 
275
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
276
- """Create a bar plot of the most important k-mers."""
277
  plt.rcParams.update({'font.size': 10})
278
  fig = plt.figure(figsize=(10, 5))
279
-
280
- # Sort by absolute importance
281
  indices = np.argsort(np.abs(shap_values))[-top_k:]
282
  values = shap_values[indices]
283
  features = [kmers[i] for i in indices]
284
-
285
- # negative -> blue, positive -> red
286
  colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
287
-
288
  plt.barh(range(len(values)), values, color=colors)
289
  plt.yticks(range(len(values)), features)
290
  plt.xlabel('SHAP Value (impact on model output)')
@@ -294,9 +191,6 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
294
  return fig
295
 
296
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
297
- """
298
- Simple histogram of SHAP values in the subregion.
299
- """
300
  fig, ax = plt.subplots(figsize=(6, 4))
301
  ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
302
  ax.axvline(0, color='red', linestyle='--', label='0.0')
@@ -308,9 +202,7 @@ def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
308
  return fig
309
 
310
  def compute_gc_content(sequence):
311
- """Compute %GC in the sequence (A, C, G, T)."""
312
- if not sequence:
313
- return 0
314
  gc_count = sequence.count('G') + sequence.count('C')
315
  return (gc_count / len(sequence)) * 100.0
316
 
@@ -319,11 +211,6 @@ def compute_gc_content(sequence):
319
  ###############################################################################
320
 
321
  def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
322
- """
323
- Analyzes the entire genome, returning classification, full-genome heatmap,
324
- top k-mer bar plot, and identifies subregions with strongest positive/negative push.
325
- """
326
- # Handle input
327
  if fasta_text.strip():
328
  text = fasta_text.strip()
329
  elif file_obj is not None:
@@ -335,46 +222,33 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
335
  else:
336
  return ("Please provide a FASTA sequence.", None, None, None, None)
337
 
338
- # Parse FASTA
339
  sequences = parse_fasta(text)
340
  if not sequences:
341
  return ("No valid FASTA sequences found.", None, None, None, None)
342
-
343
  header, seq = sequences[0]
344
 
345
- # Load model and scaler
346
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
347
  try:
348
- # Use weights_only=True for safer loading
349
  state_dict = torch.load('model.pt', map_location=device, weights_only=True)
350
  model = VirusClassifier(256).to(device)
351
  model.load_state_dict(state_dict)
352
-
353
  scaler = joblib.load('scaler.pkl')
354
  except Exception as e:
355
  return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
356
 
357
- # Vectorize + scale
358
  freq_vector = sequence_to_kmer_vector(seq)
359
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
360
  x_tensor = torch.FloatTensor(scaled_vector).to(device)
361
 
362
- # SHAP + classification
363
  shap_values, prob_human = calculate_shap_values(model, x_tensor)
364
  prob_nonhuman = 1.0 - prob_human
365
-
366
  classification = "Human" if prob_human > 0.5 else "Non-human"
367
  confidence = max(prob_human, prob_nonhuman)
368
 
369
- # Per-base SHAP
370
  shap_means = compute_positionwise_scores(seq, shap_values, k=4)
 
 
371
 
372
- # Find the most "human-pushing" region
373
- (max_start, max_end, max_avg) = find_extreme_subregion(shap_means, window_size, mode="max")
374
- # Find the most "non-human–pushing" region
375
- (min_start, min_end, min_avg) = find_extreme_subregion(shap_means, window_size, mode="min")
376
-
377
- # Build results text
378
  results_text = (
379
  f"Sequence: {header}\n"
380
  f"Length: {len(seq):,} bases\n"
@@ -388,20 +262,14 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
388
  f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
389
  )
390
 
391
- # K-mer importance plot
392
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
393
  bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
394
  bar_img = fig_to_image(bar_fig)
395
 
396
- # Full-genome SHAP heatmap
397
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
398
  heatmap_img = fig_to_image(heatmap_fig)
399
 
400
- # Store data for subregion analysis
401
- state_dict_out = {
402
- "seq": seq,
403
- "shap_means": shap_means
404
- }
405
 
406
  return (results_text, bar_img, heatmap_img, state_dict_out, header)
407
 
@@ -410,45 +278,28 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
410
  ###############################################################################
411
 
412
  def analyze_subregion(state, header, region_start, region_end):
413
- """
414
- Takes stored data from step 1 and a user-chosen region.
415
- Returns a subregion heatmap, histogram, and some stats (GC, average SHAP).
416
- """
417
  if not state or "seq" not in state or "shap_means" not in state:
418
  return ("No sequence data found. Please run Step 1 first.", None, None)
419
-
420
  seq = state["seq"]
421
  shap_means = state["shap_means"]
422
-
423
- # Validate bounds
424
  region_start = int(region_start)
425
  region_end = int(region_end)
426
-
427
  region_start = max(0, min(region_start, len(seq)))
428
  region_end = max(0, min(region_end, len(seq)))
429
  if region_end <= region_start:
430
  return ("Invalid region range. End must be > Start.", None, None)
431
-
432
- # Subsequence
433
  region_seq = seq[region_start:region_end]
434
  region_shap = shap_means[region_start:region_end]
435
-
436
- # Some stats
437
  gc_percent = compute_gc_content(region_seq)
438
  avg_shap = float(np.mean(region_shap))
439
-
440
- # Fraction pushing toward human vs. non-human
441
  positive_fraction = np.mean(region_shap > 0)
442
  negative_fraction = np.mean(region_shap < 0)
443
-
444
- # Simple logic-based interpretation
445
  if avg_shap > 0.05:
446
  region_classification = "Likely pushing toward human"
447
  elif avg_shap < -0.05:
448
  region_classification = "Likely pushing toward non-human"
449
  else:
450
  region_classification = "Near neutral (no strong push)"
451
-
452
  region_info = (
453
  f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
454
  f"Region length: {len(region_seq)} bases\n"
@@ -458,25 +309,100 @@ def analyze_subregion(state, header, region_start, region_end):
458
  f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
459
  f"Subregion interpretation: {region_classification}\n"
460
  )
461
-
462
- # Plot region as small heatmap
463
- heatmap_fig = plot_linear_heatmap(
464
- shap_means,
465
- title="Subregion SHAP",
466
- start=region_start,
467
- end=region_end
468
- )
469
  heatmap_img = fig_to_image(heatmap_fig)
470
-
471
- # Plot histogram of SHAP in region
472
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
473
  hist_img = fig_to_image(hist_fig)
474
-
475
  return (region_info, heatmap_img, hist_img)
476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
  ###############################################################################
479
- # 9. BUILD GRADIO INTERFACE
480
  ###############################################################################
481
 
482
  css = """
@@ -497,75 +423,72 @@ with gr.Blocks(css=css) as iface:
497
  with gr.Tab("1) Full-Sequence Analysis"):
498
  with gr.Row():
499
  with gr.Column(scale=1):
500
- file_input = gr.File(
501
- label="Upload FASTA file",
502
- file_types=[".fasta", ".fa", ".txt"],
503
- type="filepath"
504
- )
505
- text_input = gr.Textbox(
506
- label="Or paste FASTA sequence",
507
- placeholder=">sequence_name\nACGTACGT...",
508
- lines=5
509
- )
510
- top_k = gr.Slider(
511
- minimum=5,
512
- maximum=30,
513
- value=10,
514
- step=1,
515
- label="Number of top k-mers to display"
516
- )
517
- win_size = gr.Slider(
518
- minimum=100,
519
- maximum=5000,
520
- value=500,
521
- step=100,
522
- label="Window size for 'most pushing' subregions"
523
- )
524
  analyze_btn = gr.Button("Analyze Sequence", variant="primary")
525
-
526
  with gr.Column(scale=2):
527
- results_box = gr.Textbox(
528
- label="Classification Results", lines=12, interactive=False
529
- )
530
  kmer_img = gr.Image(label="Top k-mer SHAP")
531
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
532
-
533
  seq_state = gr.State()
534
  header_state = gr.State()
535
-
536
- # analyze_sequence(...) returns 5 items
537
  analyze_btn.click(
538
  analyze_sequence,
539
  inputs=[file_input, top_k, text_input, win_size],
540
  outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
541
  )
542
-
543
  with gr.Tab("2) Subregion Exploration"):
544
  gr.Markdown("""
545
  **Subregion Analysis**
546
- Select start/end positions to view local SHAP signals, distribution, and GC content.
547
- The heatmap also uses the same Blue-White-Red scale.
548
  """)
549
  with gr.Row():
550
  region_start = gr.Number(label="Region Start", value=0)
551
  region_end = gr.Number(label="Region End", value=500)
552
  region_btn = gr.Button("Analyze Subregion")
553
-
554
- subregion_info = gr.Textbox(
555
- label="Subregion Analysis",
556
- lines=7,
557
- interactive=False
558
- )
559
  with gr.Row():
560
  subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
561
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
562
-
563
  region_btn.click(
564
  analyze_subregion,
565
  inputs=[seq_state, header_state, region_start, region_end],
566
  outputs=[subregion_info, subregion_img, subregion_hist_img]
567
  )
568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  gr.Markdown("""
570
  ### Interface Features
571
  - **Overall Classification** (human vs non-human) using k-mer frequencies.
@@ -578,7 +501,11 @@ with gr.Blocks(css=css) as iface:
578
  - GC content
579
  - Fraction of positions pushing human vs. non-human
580
  - Simple logic-based classification
 
 
 
 
581
  """)
582
 
583
  if __name__ == "__main__":
584
- iface.launch()
 
8
  import matplotlib.colors as mcolors
9
  import io
10
  from PIL import Image
11
+ from scipy.interpolate import interp1d
12
 
13
  ###############################################################################
14
  # 1. MODEL DEFINITION
 
39
  ###############################################################################
40
 
41
  def parse_fasta(text):
 
42
  sequences = []
43
  current_header = None
44
  current_sequence = []
 
45
  for line in text.strip().split('\n'):
46
  line = line.strip()
47
+ if not line: continue
 
48
  if line.startswith('>'):
49
  if current_header:
50
  sequences.append((current_header, ''.join(current_sequence)))
 
57
  return sequences
58
 
59
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
 
60
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
61
  kmer_dict = {km: i for i, km in enumerate(kmers)}
62
  vec = np.zeros(len(kmers), dtype=np.float32)
 
63
  for i in range(len(sequence) - k + 1):
64
  kmer = sequence[i:i+k]
65
  if kmer in kmer_dict:
66
  vec[kmer_dict[kmer]] += 1
 
67
  total_kmers = len(sequence) - k + 1
68
  if total_kmers > 0:
69
+ vec /= total_kmers
 
70
  return vec
71
 
72
  ###############################################################################
 
74
  ###############################################################################
75
 
76
  def calculate_shap_values(model, x_tensor):
 
 
 
 
77
  model.eval()
78
  with torch.no_grad():
 
79
  baseline_output = model(x_tensor)
80
  baseline_probs = torch.softmax(baseline_output, dim=1)
81
+ baseline_prob = baseline_probs[0, 1].item() # Prob of 'human'
 
 
82
  shap_values = []
83
  x_zeroed = x_tensor.clone()
84
  for i in range(x_tensor.shape[1]):
 
87
  output = model(x_zeroed)
88
  probs = torch.softmax(output, dim=1)
89
  prob = probs[0, 1].item()
90
+ shap_values.append(baseline_prob - prob)
91
+ x_zeroed[0, i] = original_val
 
92
  return np.array(shap_values), baseline_prob
93
 
94
  ###############################################################################
 
96
  ###############################################################################
97
 
98
  def compute_positionwise_scores(sequence, shap_values, k=4):
 
 
 
 
99
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
100
  kmer_dict = {km: i for i, km in enumerate(kmers)}
 
101
  seq_len = len(sequence)
102
  shap_sums = np.zeros(seq_len, dtype=np.float32)
103
  coverage = np.zeros(seq_len, dtype=np.float32)
 
104
  for i in range(seq_len - k + 1):
105
  kmer = sequence[i:i+k]
106
  if kmer in kmer_dict:
107
  val = shap_values[kmer_dict[kmer]]
108
+ shap_sums[i:i+k] += val
109
+ coverage[i:i+k] += 1
 
110
  with np.errstate(divide='ignore', invalid='ignore'):
111
  shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0)
 
112
  return shap_means
113
 
114
  ###############################################################################
 
116
  ###############################################################################
117
 
118
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
 
 
 
 
 
119
  n = len(shap_means)
120
+ if n == 0: return (0, 0, 0.0)
 
121
  if window_size >= n:
122
+ return (0, n, float(np.mean(shap_means)))
 
 
 
 
123
  csum = np.zeros(n + 1, dtype=np.float32)
124
  csum[1:] = np.cumsum(shap_means)
 
125
  best_start = 0
126
  best_sum = csum[window_size] - csum[0]
127
  best_avg = best_sum / window_size
 
128
  for start in range(1, n - window_size + 1):
129
  wsum = csum[start + window_size] - csum[start]
130
  wavg = wsum / window_size
131
+ if mode == "max" and wavg > best_avg:
132
+ best_avg = wavg; best_start = start
133
+ elif mode == "min" and wavg < best_avg:
134
+ best_avg = wavg; best_start = start
 
 
 
 
 
135
  return (best_start, best_start + window_size, float(best_avg))
136
 
137
  ###############################################################################
 
139
  ###############################################################################
140
 
141
  def fig_to_image(fig):
 
142
  buf = io.BytesIO()
143
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
144
  buf.seek(0)
 
147
  return img
148
 
149
  def get_zero_centered_cmap():
150
+ colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
151
+ return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
 
 
 
 
 
 
154
  if start is not None and end is not None:
155
  local_shap = shap_means[start:end]
156
  subtitle = f" (positions {start}-{end})"
157
  else:
158
  local_shap = shap_means
159
  subtitle = ""
 
160
  if len(local_shap) == 0:
161
  local_shap = np.array([0.0])
 
 
162
  heatmap_data = local_shap.reshape(1, -1)
 
 
163
  min_val = np.min(local_shap)
164
  max_val = np.max(local_shap)
165
  extent = max(abs(min_val), abs(max_val))
166
+ cmap = get_zero_centered_cmap()
167
+ fig, ax = plt.subplots(figsize=(12, 1.8))
168
+ cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
169
+ cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
170
+ cbar.ax.tick_params(labelsize=8)
171
+ cbar.set_label('SHAP Contribution', fontsize=9, labelpad=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  ax.set_yticks([])
173
  ax.set_xlabel('Position in Sequence', fontsize=10)
174
  ax.set_title(f"{title}{subtitle}", pad=10)
175
+ plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
 
 
 
 
 
 
 
176
  return fig
177
 
178
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
 
179
  plt.rcParams.update({'font.size': 10})
180
  fig = plt.figure(figsize=(10, 5))
 
 
181
  indices = np.argsort(np.abs(shap_values))[-top_k:]
182
  values = shap_values[indices]
183
  features = [kmers[i] for i in indices]
 
 
184
  colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
 
185
  plt.barh(range(len(values)), values, color=colors)
186
  plt.yticks(range(len(values)), features)
187
  plt.xlabel('SHAP Value (impact on model output)')
 
191
  return fig
192
 
193
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
 
 
 
194
  fig, ax = plt.subplots(figsize=(6, 4))
195
  ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
196
  ax.axvline(0, color='red', linestyle='--', label='0.0')
 
202
  return fig
203
 
204
  def compute_gc_content(sequence):
205
+ if not sequence: return 0
 
 
206
  gc_count = sequence.count('G') + sequence.count('C')
207
  return (gc_count / len(sequence)) * 100.0
208
 
 
211
  ###############################################################################
212
 
213
  def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
 
 
 
 
 
214
  if fasta_text.strip():
215
  text = fasta_text.strip()
216
  elif file_obj is not None:
 
222
  else:
223
  return ("Please provide a FASTA sequence.", None, None, None, None)
224
 
 
225
  sequences = parse_fasta(text)
226
  if not sequences:
227
  return ("No valid FASTA sequences found.", None, None, None, None)
 
228
  header, seq = sequences[0]
229
 
 
230
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
231
  try:
 
232
  state_dict = torch.load('model.pt', map_location=device, weights_only=True)
233
  model = VirusClassifier(256).to(device)
234
  model.load_state_dict(state_dict)
 
235
  scaler = joblib.load('scaler.pkl')
236
  except Exception as e:
237
  return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
238
 
 
239
  freq_vector = sequence_to_kmer_vector(seq)
240
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
241
  x_tensor = torch.FloatTensor(scaled_vector).to(device)
242
 
 
243
  shap_values, prob_human = calculate_shap_values(model, x_tensor)
244
  prob_nonhuman = 1.0 - prob_human
 
245
  classification = "Human" if prob_human > 0.5 else "Non-human"
246
  confidence = max(prob_human, prob_nonhuman)
247
 
 
248
  shap_means = compute_positionwise_scores(seq, shap_values, k=4)
249
+ max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max")
250
+ min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min")
251
 
 
 
 
 
 
 
252
  results_text = (
253
  f"Sequence: {header}\n"
254
  f"Length: {len(seq):,} bases\n"
 
262
  f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
263
  )
264
 
 
265
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
266
  bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
267
  bar_img = fig_to_image(bar_fig)
268
 
 
269
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
270
  heatmap_img = fig_to_image(heatmap_fig)
271
 
272
+ state_dict_out = {"seq": seq, "shap_means": shap_means}
 
 
 
 
273
 
274
  return (results_text, bar_img, heatmap_img, state_dict_out, header)
275
 
 
278
  ###############################################################################
279
 
280
  def analyze_subregion(state, header, region_start, region_end):
 
 
 
 
281
  if not state or "seq" not in state or "shap_means" not in state:
282
  return ("No sequence data found. Please run Step 1 first.", None, None)
 
283
  seq = state["seq"]
284
  shap_means = state["shap_means"]
 
 
285
  region_start = int(region_start)
286
  region_end = int(region_end)
 
287
  region_start = max(0, min(region_start, len(seq)))
288
  region_end = max(0, min(region_end, len(seq)))
289
  if region_end <= region_start:
290
  return ("Invalid region range. End must be > Start.", None, None)
 
 
291
  region_seq = seq[region_start:region_end]
292
  region_shap = shap_means[region_start:region_end]
 
 
293
  gc_percent = compute_gc_content(region_seq)
294
  avg_shap = float(np.mean(region_shap))
 
 
295
  positive_fraction = np.mean(region_shap > 0)
296
  negative_fraction = np.mean(region_shap < 0)
 
 
297
  if avg_shap > 0.05:
298
  region_classification = "Likely pushing toward human"
299
  elif avg_shap < -0.05:
300
  region_classification = "Likely pushing toward non-human"
301
  else:
302
  region_classification = "Near neutral (no strong push)"
 
303
  region_info = (
304
  f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
305
  f"Region length: {len(region_seq)} bases\n"
 
309
  f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
310
  f"Subregion interpretation: {region_classification}\n"
311
  )
312
+ heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end)
 
 
 
 
 
 
 
313
  heatmap_img = fig_to_image(heatmap_fig)
 
 
314
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
315
  hist_img = fig_to_image(hist_fig)
 
316
  return (region_info, heatmap_img, hist_img)
317
 
318
+ ###############################################################################
319
+ # 9. COMPARISON ANALYSIS FUNCTIONS
320
+ ###############################################################################
321
+
322
+ def normalize_shap_lengths(shap1, shap2, num_points=1000):
323
+ x1 = np.linspace(0, 1, len(shap1))
324
+ x2 = np.linspace(0, 1, len(shap2))
325
+ f1 = interp1d(x1, shap1, kind='linear')
326
+ f2 = interp1d(x2, shap2, kind='linear')
327
+ x_new = np.linspace(0, 1, num_points)
328
+ shap1_norm = f1(x_new)
329
+ shap2_norm = f2(x_new)
330
+ return shap1_norm, shap2_norm
331
+
332
+ def compute_shap_difference(shap1_norm, shap2_norm):
333
+ return shap2_norm - shap1_norm
334
+
335
+ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
336
+ heatmap_data = shap_diff.reshape(1, -1)
337
+ extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
338
+ cmap = get_zero_centered_cmap()
339
+ fig, ax = plt.subplots(figsize=(12, 1.8))
340
+ cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
341
+ cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
342
+ cbar.ax.tick_params(labelsize=8)
343
+ cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
344
+ ax.set_yticks([])
345
+ ax.set_xlabel('Normalized Position (0-100%)', fontsize=10)
346
+ ax.set_title(title, pad=10)
347
+ plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
348
+ return fig
349
+
350
+ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
351
+ # Analyze first sequence
352
+ res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500)
353
+ if isinstance(res1[0], str) and "Error" in res1[0]:
354
+ return (f"Error in sequence 1: {res1[0]}", None, None)
355
+ # Analyze second sequence
356
+ res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500)
357
+ if isinstance(res2[0], str) and "Error" in res2[0]:
358
+ return (f"Error in sequence 2: {res2[0]}", None, None)
359
+
360
+ shap1 = res1[3]["shap_means"]
361
+ shap2 = res2[3]["shap_means"]
362
+ shap1_norm, shap2_norm = normalize_shap_lengths(shap1, shap2)
363
+ shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
364
+
365
+ avg_diff = np.mean(shap_diff)
366
+ std_diff = np.std(shap_diff)
367
+ max_diff = np.max(shap_diff)
368
+ min_diff = np.min(shap_diff)
369
+ threshold = 0.05
370
+ substantial_diffs = np.abs(shap_diff) > threshold
371
+ frac_different = np.mean(substantial_diffs)
372
+
373
+ classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip()
374
+ classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip()
375
+ len1_formatted = "{:,}".format(len(shap1))
376
+ len2_formatted = "{:,}".format(len(shap2))
377
+ frac_formatted = "{:.2%}".format(frac_different)
378
+
379
+ comparison_text = (
380
+ "Sequence Comparison Results:\n"
381
+ f"Sequence 1: {res1[4]}\n"
382
+ f"Length: {len1_formatted} bases\n"
383
+ f"Classification: {classification1}\n\n"
384
+ f"Sequence 2: {res2[4]}\n"
385
+ f"Length: {len2_formatted} bases\n"
386
+ f"Classification: {classification2}\n\n"
387
+ "Comparison Statistics:\n"
388
+ f"Average SHAP difference: {avg_diff:.4f}\n"
389
+ f"Standard deviation: {std_diff:.4f}\n"
390
+ f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
391
+ f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
392
+ f"Fraction of positions with substantial differences: {frac_formatted}\n\n"
393
+ "Interpretation:\n"
394
+ "Positive values (red) indicate regions where Sequence 2 is more 'human-like'\n"
395
+ "Negative values (blue) indicate regions where Sequence 1 is more 'human-like'"
396
+ )
397
+
398
+ heatmap_fig = plot_comparative_heatmap(shap_diff)
399
+ heatmap_img = fig_to_image(heatmap_fig)
400
+ hist_fig = plot_shap_histogram(shap_diff, title="Distribution of SHAP Differences")
401
+ hist_img = fig_to_image(hist_fig)
402
+ return comparison_text, heatmap_img, hist_img
403
 
404
  ###############################################################################
405
+ # 10. BUILD GRADIO INTERFACE
406
  ###############################################################################
407
 
408
  css = """
 
423
  with gr.Tab("1) Full-Sequence Analysis"):
424
  with gr.Row():
425
  with gr.Column(scale=1):
426
+ file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
427
+ text_input = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5)
428
+ top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
429
+ win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Window size for 'most pushing' subregions")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  analyze_btn = gr.Button("Analyze Sequence", variant="primary")
 
431
  with gr.Column(scale=2):
432
+ results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
 
 
433
  kmer_img = gr.Image(label="Top k-mer SHAP")
434
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
 
435
  seq_state = gr.State()
436
  header_state = gr.State()
 
 
437
  analyze_btn.click(
438
  analyze_sequence,
439
  inputs=[file_input, top_k, text_input, win_size],
440
  outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
441
  )
442
+
443
  with gr.Tab("2) Subregion Exploration"):
444
  gr.Markdown("""
445
  **Subregion Analysis**
446
+ Select start/end positions to view local SHAP signals, distribution, GC content, etc.
447
+ The heatmap also uses the same Blue-White-Red scale.
448
  """)
449
  with gr.Row():
450
  region_start = gr.Number(label="Region Start", value=0)
451
  region_end = gr.Number(label="Region End", value=500)
452
  region_btn = gr.Button("Analyze Subregion")
453
+ subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False)
 
 
 
 
 
454
  with gr.Row():
455
  subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
456
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
 
457
  region_btn.click(
458
  analyze_subregion,
459
  inputs=[seq_state, header_state, region_start, region_end],
460
  outputs=[subregion_info, subregion_img, subregion_hist_img]
461
  )
462
 
463
+ with gr.Tab("3) Comparative Analysis"):
464
+ gr.Markdown("""
465
+ **Compare Two Sequences**
466
+ Upload or paste two FASTA sequences to compare their SHAP patterns.
467
+ The sequences will be normalized to the same length for comparison.
468
+
469
+ **Color Scale**:
470
+ - Red: Sequence 2 is more human-like in this region
471
+ - Blue: Sequence 1 is more human-like in this region
472
+ - White: No substantial difference
473
+ """)
474
+ with gr.Row():
475
+ with gr.Column(scale=1):
476
+ file_input1 = gr.File(label="Upload first FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
477
+ text_input1 = gr.Textbox(label="Or paste first FASTA sequence", placeholder=">sequence1\nACGTACGT...", lines=5)
478
+ with gr.Column(scale=1):
479
+ file_input2 = gr.File(label="Upload second FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
480
+ text_input2 = gr.Textbox(label="Or paste second FASTA sequence", placeholder=">sequence2\nACGTACGT...", lines=5)
481
+ compare_btn = gr.Button("Compare Sequences", variant="primary")
482
+ comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
483
+ with gr.Row():
484
+ diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
485
+ diff_hist = gr.Image(label="Distribution of SHAP Differences")
486
+ compare_btn.click(
487
+ analyze_sequence_comparison,
488
+ inputs=[file_input1, file_input2, text_input1, text_input2],
489
+ outputs=[comparison_text, diff_heatmap, diff_hist]
490
+ )
491
+
492
  gr.Markdown("""
493
  ### Interface Features
494
  - **Overall Classification** (human vs non-human) using k-mer frequencies.
 
501
  - GC content
502
  - Fraction of positions pushing human vs. non-human
503
  - Simple logic-based classification
504
+ - **Sequence Comparison**:
505
+ - Compare two sequences to identify regions of difference
506
+ - Normalized comparison to handle different sequence lengths
507
+ - Statistical summary of differences
508
  """)
509
 
510
  if __name__ == "__main__":
511
+ iface.launch()