hiyata commited on
Commit
6be7ede
Β·
verified Β·
1 Parent(s): 910c6c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +453 -421
app.py CHANGED
@@ -6,13 +6,34 @@ from itertools import product
6
  import torch.nn as nn
7
  import matplotlib.pyplot as plt
8
  import matplotlib.colors as mcolors
9
- import io
10
  from PIL import Image
 
 
 
 
 
 
 
11
 
12
  ###############################################################################
13
- # 1. MODEL DEFINITION
14
  ###############################################################################
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class VirusClassifier(nn.Module):
17
  def __init__(self, input_shape: int):
18
  super(VirusClassifier, self).__init__()
@@ -29,16 +50,16 @@ class VirusClassifier(nn.Module):
29
  nn.GELU(),
30
  nn.Linear(32, 2)
31
  )
32
-
33
  def forward(self, x):
34
  return self.network(x)
35
 
36
  ###############################################################################
37
- # 2. FASTA PARSING & K-MER FEATURE ENGINEERING
38
  ###############################################################################
39
 
40
- def parse_fasta(text):
41
- """Parse FASTA formatted text into a list of (header, sequence)."""
42
  sequences = []
43
  current_header = None
44
  current_sequence = []
@@ -53,67 +74,68 @@ def parse_fasta(text):
53
  current_header = line[1:]
54
  current_sequence = []
55
  else:
56
- current_sequence.append(line.upper())
 
 
 
57
  if current_header:
58
  sequences.append((current_header, ''.join(current_sequence)))
59
  return sequences
60
 
61
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
62
- """Convert a sequence to a k-mer frequency vector for classification."""
63
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
64
  kmer_dict = {km: i for i, km in enumerate(kmers)}
65
  vec = np.zeros(len(kmers), dtype=np.float32)
66
 
 
67
  for i in range(len(sequence) - k + 1):
68
  kmer = sequence[i:i+k]
69
- if kmer in kmer_dict:
70
  vec[kmer_dict[kmer]] += 1
71
-
 
72
  total_kmers = len(sequence) - k + 1
73
  if total_kmers > 0:
74
  vec = vec / total_kmers
75
-
76
  return vec
77
 
 
 
 
 
 
 
 
78
  ###############################################################################
79
- # 3. SHAP-VALUE (ABLATION) CALCULATION
80
  ###############################################################################
81
 
82
- def calculate_shap_values(model, x_tensor):
83
- """
84
- Calculate SHAP values using a simple ablation approach.
85
- Returns shap_values, prob_human
86
- """
87
  model.eval()
88
  with torch.no_grad():
89
- # Baseline
90
  baseline_output = model(x_tensor)
91
  baseline_probs = torch.softmax(baseline_output, dim=1)
92
- baseline_prob = baseline_probs[0, 1].item() # Probability of 'human' class
93
 
94
- # Zeroing each feature to measure impact
95
  shap_values = []
96
  x_zeroed = x_tensor.clone()
 
 
97
  for i in range(x_tensor.shape[1]):
98
- original_val = x_zeroed[0, i].item()
99
  x_zeroed[0, i] = 0.0
100
  output = model(x_zeroed)
101
  probs = torch.softmax(output, dim=1)
102
- prob = probs[0, 1].item()
103
- impact = baseline_prob - prob
104
  shap_values.append(impact)
105
- x_zeroed[0, i] = original_val # restore
 
106
  return np.array(shap_values), baseline_prob
107
 
108
- ###############################################################################
109
- # 4. PER-BASE SHAP AGGREGATION
110
- ###############################################################################
111
-
112
- def compute_positionwise_scores(sequence, shap_values, k=4):
113
- """
114
- Returns an array of per-base SHAP contributions by averaging
115
- the k-mer SHAP values of all k-mers covering that base.
116
- """
117
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
118
  kmer_dict = {km: i for i, km in enumerate(kmers)}
119
 
@@ -121,447 +143,457 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
121
  shap_sums = np.zeros(seq_len, dtype=np.float32)
122
  coverage = np.zeros(seq_len, dtype=np.float32)
123
 
 
124
  for i in range(seq_len - k + 1):
125
  kmer = sequence[i:i+k]
126
  if kmer in kmer_dict:
127
- val = shap_values[kmer_dict[kmer]]
128
- shap_sums[i : i + k] += val
129
- coverage[i : i + k] += 1
130
-
131
  with np.errstate(divide='ignore', invalid='ignore'):
132
  shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0)
133
-
134
  return shap_means
135
 
136
- ###############################################################################
137
- # 5. FIND EXTREME SHAP REGIONS
138
- ###############################################################################
139
-
140
- def find_extreme_subregion(shap_means, window_size=500, mode="max"):
141
- """
142
- Finds the subregion of length `window_size` that has the maximum
143
- (mode="max") or minimum (mode="min") average SHAP.
144
- Returns (best_start, best_end, best_avg).
145
- """
146
- n = len(shap_means)
147
- if n == 0:
148
- return (0, 0, 0.0)
149
- if window_size >= n:
150
- # entire sequence
151
- avg_val = float(np.mean(shap_means))
152
- return (0, n, avg_val)
153
-
154
- # We'll build csum of length n+1
155
- csum = np.zeros(n + 1, dtype=np.float32)
156
- csum[1:] = np.cumsum(shap_means)
157
-
158
- best_start = 0
159
- best_sum = csum[window_size] - csum[0]
160
- best_avg = best_sum / window_size
161
-
162
- for start in range(1, n - window_size + 1):
163
- wsum = csum[start + window_size] - csum[start]
164
- wavg = wsum / window_size
165
- if mode == "max":
166
- if wavg > best_avg:
167
- best_avg = wavg
168
- best_start = start
169
- else: # mode == "min"
170
- if wavg < best_avg:
171
- best_avg = wavg
172
- best_start = start
173
-
174
- return (best_start, best_start + window_size, float(best_avg))
175
 
176
  ###############################################################################
177
- # 6. PLOTTING / UTILITIES
178
  ###############################################################################
179
 
180
- def fig_to_image(fig):
181
- """Convert a Matplotlib figure to a PIL Image for Gradio."""
182
- buf = io.BytesIO()
183
- fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
184
- buf.seek(0)
185
- img = Image.open(buf)
186
- plt.close(fig)
187
- return img
188
-
189
- def get_zero_centered_cmap():
190
- """
191
- Creates a custom diverging colormap that is:
192
- - Blue for negative
193
- - White for zero
194
- - Red for positive
195
- """
196
- colors = [
197
- (0.0, 'blue'), # negative
198
- (0.5, 'white'), # zero
199
- (1.0, 'red') # positive
200
- ]
201
- cmap = mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
202
- return cmap
203
-
204
- def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
205
- """
206
- Plots a 1D heatmap of per-base SHAP contributions with a custom colormap:
207
- - Negative = blue
208
- - 0 = white
209
- - Positive = red
210
- We'll force the range to be symmetrical around 0 by using:
211
- vmin=-extent, vmax=+extent
212
- so 0 is in the middle.
213
- """
214
- if start is not None and end is not None:
215
- local_shap = shap_means[start:end]
216
- subtitle = f" (positions {start}-{end})"
217
- else:
218
- local_shap = shap_means
219
- subtitle = ""
220
-
221
- if len(local_shap) == 0:
222
- # Edge case: no data to plot
223
- local_shap = np.array([0.0])
224
-
225
- # Build 2D array for imshow
226
- heatmap_data = local_shap.reshape(1, -1)
227
-
228
- # Force symmetrical range
229
- min_val = np.min(local_shap)
230
- max_val = np.max(local_shap)
231
- extent = max(abs(min_val), abs(max_val))
232
-
233
- # Create custom colormap
234
- custom_cmap = get_zero_centered_cmap()
235
-
236
- fig, ax = plt.subplots(figsize=(12, 2))
237
- cax = ax.imshow(
238
- heatmap_data,
239
- aspect='auto',
240
- cmap=custom_cmap,
241
- vmin=-extent,
242
- vmax=+extent
243
  )
244
 
245
- # Place colorbar below with plenty of margin
246
- cbar = plt.colorbar(cax, orientation='horizontal', pad=0.35)
247
- cbar.set_label('SHAP Contribution (negative=blue, zero=white, positive=red)')
248
-
249
- ax.set_yticks([])
250
- ax.set_xlabel('Position in Sequence')
251
- ax.set_title(f"{title}{subtitle}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
- # Extra bottom margin so colorbar won't overlap x-axis labels
254
- plt.subplots_adjust(bottom=0.4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  return fig
257
 
258
- def create_importance_bar_plot(shap_values, kmers, top_k=10):
259
- """Create a bar plot of the most important k-mers."""
260
- plt.rcParams.update({'font.size': 10})
261
- fig = plt.figure(figsize=(10, 5))
262
-
263
- # Sort by absolute importance
264
- indices = np.argsort(np.abs(shap_values))[-top_k:]
265
- values = shap_values[indices]
266
- features = [kmers[i] for i in indices]
267
-
268
- # negative -> blue, positive -> red
269
- colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
270
-
271
- plt.barh(range(len(values)), values, color=colors)
272
- plt.yticks(range(len(values)), features)
273
- plt.xlabel('SHAP Value (impact on model output)')
274
- plt.title(f'Top {top_k} Most Influential k-mers')
275
- plt.gca().invert_yaxis()
276
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  return fig
278
 
279
- def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
280
- """
281
- Simple histogram of SHAP values in the subregion.
282
- """
283
- fig, ax = plt.subplots(figsize=(6, 4))
284
- ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
285
- ax.axvline(0, color='red', linestyle='--', label='0.0')
286
- ax.set_xlabel("SHAP Value")
287
- ax.set_ylabel("Count")
288
- ax.set_title(title)
289
- ax.legend()
290
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  return fig
292
 
293
- def compute_gc_content(sequence):
294
- """Compute %GC in the sequence (A, C, G, T)."""
295
- if not sequence:
296
- return 0
297
- gc_count = sequence.count('G') + sequence.count('C')
298
- return (gc_count / len(sequence)) * 100.0
299
-
300
  ###############################################################################
301
- # 7. MAIN ANALYSIS STEP (Gradio Step 1)
302
  ###############################################################################
303
 
304
- def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
305
- """
306
- Analyzes the entire genome, returning classification, full-genome heatmap,
307
- top k-mer bar plot, and identifies subregions with strongest positive/negative push.
308
- """
 
 
 
309
  # Handle input
310
  if fasta_text.strip():
311
  text = fasta_text.strip()
312
  elif file_obj is not None:
313
- try:
314
- with open(file_obj, 'r') as f:
315
- text = f.read()
316
- except Exception as e:
317
- return (f"Error reading file: {str(e)}", None, None, None, None)
318
  else:
319
- return ("Please provide a FASTA sequence.", None, None, None, None)
320
-
321
  # Parse FASTA
322
  sequences = parse_fasta(text)
323
  if not sequences:
324
- return ("No valid FASTA sequences found.", None, None, None, None)
325
 
326
  header, seq = sequences[0]
327
-
328
  # Load model and scaler
329
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
330
- try:
331
- # Use weights_only=True for safer loading
332
- state_dict = torch.load('model.pt', map_location=device, weights_only=True)
333
- model = VirusClassifier(256).to(device)
334
- model.load_state_dict(state_dict)
335
-
336
- scaler = joblib.load('scaler.pkl')
337
- except Exception as e:
338
- return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
339
-
340
- # Vectorize + scale
341
  freq_vector = sequence_to_kmer_vector(seq)
342
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
343
  x_tensor = torch.FloatTensor(scaled_vector).to(device)
344
-
345
- # SHAP + classification
346
  shap_values, prob_human = calculate_shap_values(model, x_tensor)
347
  prob_nonhuman = 1.0 - prob_human
348
 
349
- classification = "Human" if prob_human > 0.5 else "Non-human"
350
- confidence = max(prob_human, prob_nonhuman)
351
-
352
- # Per-base SHAP
353
- shap_means = compute_positionwise_scores(seq, shap_values, k=4)
354
-
355
- # Find the most "human-pushing" region
356
- (max_start, max_end, max_avg) = find_extreme_subregion(shap_means, window_size, mode="max")
357
- # Find the most "non-human–pushing" region
358
- (min_start, min_end, min_avg) = find_extreme_subregion(shap_means, window_size, mode="min")
359
-
360
- # Build results text
361
- results_text = (
362
- f"Sequence: {header}\n"
363
- f"Length: {len(seq):,} bases\n"
364
- f"Classification: {classification}\n"
365
- f"Confidence: {confidence:.3f}\n"
366
- f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
367
- f"---\n"
368
- f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
369
- f"Start: {max_start}, End: {max_end}, Avg SHAP: {max_avg:.4f}\n\n"
370
- f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
371
- f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
372
  )
373
 
374
- # K-mer importance plot
375
- kmers = [''.join(p) for p in product("ACGT", repeat=4)]
376
- bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
377
- bar_img = fig_to_image(bar_fig)
378
-
379
- # Full-genome SHAP heatmap
380
- heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
381
- heatmap_img = fig_to_image(heatmap_fig)
382
-
383
- # Store data for subregion analysis
384
- state_dict_out = {
385
- "seq": seq,
386
- "shap_means": shap_means
387
- }
388
-
389
- return (results_text, bar_img, heatmap_img, state_dict_out, header)
390
-
391
  ###############################################################################
392
- # 8. SUBREGION ANALYSIS (Gradio Step 2)
393
  ###############################################################################
394
 
395
- def analyze_subregion(state, header, region_start, region_end):
396
- """
397
- Takes stored data from step 1 and a user-chosen region.
398
- Returns a subregion heatmap, histogram, and some stats (GC, average SHAP).
399
- """
400
- if not state or "seq" not in state or "shap_means" not in state:
401
- return ("No sequence data found. Please run Step 1 first.", None, None)
402
-
403
- seq = state["seq"]
404
- shap_means = state["shap_means"]
405
-
406
- # Validate bounds
407
- region_start = int(region_start)
408
- region_end = int(region_end)
409
-
410
- region_start = max(0, min(region_start, len(seq)))
411
- region_end = max(0, min(region_end, len(seq)))
412
- if region_end <= region_start:
413
- return ("Invalid region range. End must be > Start.", None, None)
414
-
415
- # Subsequence
416
- region_seq = seq[region_start:region_end]
417
- region_shap = shap_means[region_start:region_end]
418
-
419
- # Some stats
420
- gc_percent = compute_gc_content(region_seq)
421
- avg_shap = float(np.mean(region_shap))
422
-
423
- # Fraction pushing toward human vs. non-human
424
- positive_fraction = np.mean(region_shap > 0)
425
- negative_fraction = np.mean(region_shap < 0)
426
-
427
- # Simple logic-based interpretation
428
- if avg_shap > 0.05:
429
- region_classification = "Likely pushing toward human"
430
- elif avg_shap < -0.05:
431
- region_classification = "Likely pushing toward non-human"
432
- else:
433
- region_classification = "Near neutral (no strong push)"
434
-
435
- region_info = (
436
- f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
437
- f"Region length: {len(region_seq)} bases\n"
438
- f"GC content: {gc_percent:.2f}%\n"
439
- f"Average SHAP in region: {avg_shap:.4f}\n"
440
- f"Fraction with SHAP > 0 (toward human): {positive_fraction:.2f}\n"
441
- f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
442
- f"Subregion interpretation: {region_classification}\n"
443
- )
444
-
445
- # Plot region as small heatmap
446
- heatmap_fig = plot_linear_heatmap(
447
- shap_means,
448
- title="Subregion SHAP",
449
- start=region_start,
450
- end=region_end
 
 
 
 
 
451
  )
452
- heatmap_img = fig_to_image(heatmap_fig)
453
-
454
- # Plot histogram of SHAP in region
455
- hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
456
- hist_img = fig_to_image(hist_fig)
457
-
458
- return (region_info, heatmap_img, hist_img)
459
-
460
-
461
- ###############################################################################
462
- # 9. BUILD GRADIO INTERFACE
463
- ###############################################################################
464
-
465
- css = """
466
- .gradio-container {
467
- font-family: 'IBM Plex Sans', sans-serif;
468
- }
469
- """
470
-
471
- with gr.Blocks(css=css) as iface:
472
- gr.Markdown("""
473
- # Virus Host Classifier with White-Centered Gradient
474
- **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
475
- **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
476
-
477
- **Color Scale**: Negative SHAP = Blue, Zero = White, Positive = Red.
478
- """)
479
-
480
- with gr.Tab("1) Full-Sequence Analysis"):
481
- with gr.Row():
482
- with gr.Column(scale=1):
483
- file_input = gr.File(
484
- label="Upload FASTA file",
485
- file_types=[".fasta", ".fa", ".txt"],
486
- type="filepath"
487
- )
488
- text_input = gr.Textbox(
489
- label="Or paste FASTA sequence",
490
- placeholder=">sequence_name\nACGTACGT...",
491
- lines=5
492
- )
493
- top_k = gr.Slider(
494
- minimum=5,
495
- maximum=30,
496
- value=10,
497
- step=1,
498
- label="Number of top k-mers to display"
499
- )
500
- win_size = gr.Slider(
501
- minimum=100,
502
- maximum=5000,
503
- value=500,
504
- step=100,
505
- label="Window size for 'most pushing' subregions"
506
- )
507
- analyze_btn = gr.Button("Analyze Sequence", variant="primary")
508
-
509
- with gr.Column(scale=2):
510
- results_box = gr.Textbox(
511
- label="Classification Results", lines=12, interactive=False
512
- )
513
- kmer_img = gr.Image(label="Top k-mer SHAP")
514
- genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
515
-
516
- seq_state = gr.State()
517
- header_state = gr.State()
518
-
519
- # analyze_sequence(...) returns 5 items
520
- analyze_btn.click(
521
- analyze_sequence,
522
- inputs=[file_input, top_k, text_input, win_size],
523
- outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
524
- )
525
 
526
- with gr.Tab("2) Subregion Exploration"):
 
 
 
 
 
527
  gr.Markdown("""
528
- **Subregion Analysis**
529
- Select start/end positions to view local SHAP signals, distribution, and GC content.
530
- The heatmap also uses the same Blue-White-Red scale.
 
 
 
531
  """)
532
- with gr.Row():
533
- region_start = gr.Number(label="Region Start", value=0)
534
- region_end = gr.Number(label="Region End", value=500)
535
- region_btn = gr.Button("Analyze Subregion")
536
 
537
- subregion_info = gr.Textbox(
538
- label="Subregion Analysis",
539
- lines=7,
540
- interactive=False
541
- )
542
- with gr.Row():
543
- subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
544
- subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
- region_btn.click(
547
- analyze_subregion,
548
- inputs=[seq_state, header_state, region_start, region_end],
549
- outputs=[subregion_info, subregion_img, subregion_hist_img]
 
 
 
 
 
 
 
 
 
550
  )
551
-
552
- gr.Markdown("""
553
- ### Interface Features
554
- - **Overall Classification** (human vs non-human) using k-mer frequencies.
555
- - **SHAP Analysis** to see which k-mers push classification toward or away from human.
556
- - **White-Centered SHAP Gradient**:
557
- - Negative (blue), 0 (white), Positive (red), with symmetrical color range around 0.
558
- - **Identify Subregions** with the strongest push for human or non-human.
559
- - **Subregion Exploration**:
560
- - Local SHAP heatmap & histogram
561
- - GC content
562
- - Fraction of positions pushing human vs. non-human
563
- - Simple logic-based classification
564
- """)
565
 
566
  if __name__ == "__main__":
567
- iface.launch()
 
 
 
 
 
 
 
6
  import torch.nn as nn
7
  import matplotlib.pyplot as plt
8
  import matplotlib.colors as mcolors
9
+ import seaborn as sns
10
  from PIL import Image
11
+ import io
12
+ import pandas as pd
13
+ from typing import Tuple, List, Dict, Any
14
+ from dataclasses import dataclass
15
+ import plotly.graph_objects as go
16
+ import plotly.express as px
17
+ from plotly.subplots import make_subplots
18
 
19
  ###############################################################################
20
+ # 1. DATA STRUCTURES & MODEL
21
  ###############################################################################
22
 
23
+ @dataclass
24
+ class SequenceAnalysis:
25
+ """Container for sequence analysis results"""
26
+ header: str
27
+ sequence: str
28
+ length: int
29
+ gc_content: float
30
+ classification: str
31
+ human_prob: float
32
+ nonhuman_prob: float
33
+ shap_values: np.ndarray
34
+ shap_means: np.ndarray
35
+ extreme_regions: Dict[str, Dict[str, Any]]
36
+
37
  class VirusClassifier(nn.Module):
38
  def __init__(self, input_shape: int):
39
  super(VirusClassifier, self).__init__()
 
50
  nn.GELU(),
51
  nn.Linear(32, 2)
52
  )
53
+
54
  def forward(self, x):
55
  return self.network(x)
56
 
57
  ###############################################################################
58
+ # 2. SEQUENCE PROCESSING
59
  ###############################################################################
60
 
61
+ def parse_fasta(text: str) -> List[Tuple[str, str]]:
62
+ """Parse FASTA formatted text with improved robustness"""
63
  sequences = []
64
  current_header = None
65
  current_sequence = []
 
74
  current_header = line[1:]
75
  current_sequence = []
76
  else:
77
+ # Filter out non-ACGT characters and convert to uppercase
78
+ filtered_line = ''.join(c for c in line.upper() if c in 'ACGT')
79
+ current_sequence.append(filtered_line)
80
+
81
  if current_header:
82
  sequences.append((current_header, ''.join(current_sequence)))
83
  return sequences
84
 
85
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
86
+ """Convert sequence to k-mer frequency vector with optimizations"""
87
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
88
  kmer_dict = {km: i for i, km in enumerate(kmers)}
89
  vec = np.zeros(len(kmers), dtype=np.float32)
90
 
91
+ # Use sliding window for efficiency
92
  for i in range(len(sequence) - k + 1):
93
  kmer = sequence[i:i+k]
94
+ if kmer in kmer_dict: # Handle non-ACGT kmers
95
  vec[kmer_dict[kmer]] += 1
96
+
97
+ # Normalize
98
  total_kmers = len(sequence) - k + 1
99
  if total_kmers > 0:
100
  vec = vec / total_kmers
101
+
102
  return vec
103
 
104
+ def compute_gc_content(sequence: str) -> float:
105
+ """Compute GC content percentage"""
106
+ if not sequence:
107
+ return 0.0
108
+ gc_count = sum(1 for base in sequence if base in 'GC')
109
+ return (gc_count / len(sequence)) * 100.0
110
+
111
  ###############################################################################
112
+ # 3. SHAP & ANALYSIS
113
  ###############################################################################
114
 
115
+ def calculate_shap_values(model: nn.Module, x_tensor: torch.Tensor) -> Tuple[np.ndarray, float]:
116
+ """Calculate SHAP values using ablation with improved efficiency"""
 
 
 
117
  model.eval()
118
  with torch.no_grad():
 
119
  baseline_output = model(x_tensor)
120
  baseline_probs = torch.softmax(baseline_output, dim=1)
121
+ baseline_prob = baseline_probs[0, 1].item()
122
 
 
123
  shap_values = []
124
  x_zeroed = x_tensor.clone()
125
+
126
+ # Vectorized computation where possible
127
  for i in range(x_tensor.shape[1]):
 
128
  x_zeroed[0, i] = 0.0
129
  output = model(x_zeroed)
130
  probs = torch.softmax(output, dim=1)
131
+ impact = baseline_prob - probs[0, 1].item()
 
132
  shap_values.append(impact)
133
+ x_zeroed[0, i] = x_tensor[0, i]
134
+
135
  return np.array(shap_values), baseline_prob
136
 
137
+ def compute_positionwise_scores(sequence: str, shap_values: np.ndarray, k: int = 4) -> np.ndarray:
138
+ """Compute per-base SHAP scores with optimized memory usage"""
 
 
 
 
 
 
 
139
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
140
  kmer_dict = {km: i for i, km in enumerate(kmers)}
141
 
 
143
  shap_sums = np.zeros(seq_len, dtype=np.float32)
144
  coverage = np.zeros(seq_len, dtype=np.float32)
145
 
146
+ # Vectorized operations where possible
147
  for i in range(seq_len - k + 1):
148
  kmer = sequence[i:i+k]
149
  if kmer in kmer_dict:
150
+ idx = kmer_dict[kmer]
151
+ shap_sums[i:i+k] += shap_values[idx]
152
+ coverage[i:i+k] += 1
153
+
154
  with np.errstate(divide='ignore', invalid='ignore'):
155
  shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0)
156
+
157
  return shap_means
158
 
159
+ def find_extreme_regions(shap_means: np.ndarray, window_size: int = 500) -> Dict[str, Dict[str, Any]]:
160
+ """Find regions with extreme SHAP values using efficient sliding window"""
161
+ if len(shap_means) < window_size:
162
+ window_size = len(shap_means)
163
+
164
+ # Compute cumulative sum for efficient sliding window
165
+ cumsum = np.cumsum(np.pad(shap_means, (0, 1)))
166
+
167
+ # Sliding window calculation
168
+ window_avgs = (cumsum[window_size:] - cumsum[:-window_size]) / window_size
169
+
170
+ max_idx = np.argmax(window_avgs)
171
+ min_idx = np.argmin(window_avgs)
172
+
173
+ return {
174
+ "human": {
175
+ "start": max_idx,
176
+ "end": max_idx + window_size,
177
+ "avg_shap": float(window_avgs[max_idx])
178
+ },
179
+ "nonhuman": {
180
+ "start": min_idx,
181
+ "end": min_idx + window_size,
182
+ "avg_shap": float(window_avgs[min_idx])
183
+ }
184
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  ###############################################################################
187
+ # 4. VISUALIZATION
188
  ###############################################################################
189
 
190
+ def create_genome_overview_plot(analysis: SequenceAnalysis) -> go.Figure:
191
+ """Create an interactive genome overview using Plotly"""
192
+ fig = make_subplots(
193
+ rows=2, cols=1,
194
+ subplot_titles=("SHAP Values Along Genome", "GC Content"),
195
+ row_heights=[0.7, 0.3],
196
+ vertical_spacing=0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  )
198
 
199
+ # SHAP trace
200
+ fig.add_trace(
201
+ go.Scatter(
202
+ x=list(range(len(analysis.shap_means))),
203
+ y=analysis.shap_means,
204
+ name="SHAP",
205
+ line=dict(color='rgba(31, 119, 180, 0.8)'),
206
+ hovertemplate="Position: %{x}<br>SHAP: %{y:.4f}<extra></extra>"
207
+ ),
208
+ row=1, col=1
209
+ )
210
+
211
+ # Highlight extreme regions
212
+ for region_type, region in analysis.extreme_regions.items():
213
+ color = 'rgba(255, 0, 0, 0.2)' if region_type == 'human' else 'rgba(0, 0, 255, 0.2)'
214
+ fig.add_vrect(
215
+ x0=region['start'],
216
+ x1=region['end'],
217
+ fillcolor=color,
218
+ opacity=0.5,
219
+ layer="below",
220
+ line_width=0,
221
+ row=1, col=1
222
+ )
223
 
224
+ # Calculate rolling GC content
225
+ window = 100
226
+ gc_content = np.array([
227
+ compute_gc_content(analysis.sequence[i:i+window])
228
+ for i in range(0, len(analysis.sequence) - window + 1, window)
229
+ ])
230
+
231
+ # GC content trace
232
+ fig.add_trace(
233
+ go.Scatter(
234
+ x=np.arange(len(gc_content)) * window,
235
+ y=gc_content,
236
+ name="GC%",
237
+ line=dict(color='rgba(44, 160, 44, 0.8)'),
238
+ hovertemplate="Position: %{x}<br>GC%: %{y:.1f}%<extra></extra>"
239
+ ),
240
+ row=2, col=1
241
+ )
242
+
243
+ # Update layout
244
+ fig.update_layout(
245
+ height=800,
246
+ title=dict(
247
+ text=f"Genome Analysis Overview<br><sub>{analysis.header}</sub>",
248
+ x=0.5
249
+ ),
250
+ showlegend=False,
251
+ plot_bgcolor='white'
252
+ )
253
+
254
+ # Update axes
255
+ fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
256
+ fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
257
 
258
  return fig
259
 
260
+ def create_kmer_importance_plot(analysis: SequenceAnalysis, top_k: int = 10) -> go.Figure:
261
+ """Create interactive k-mer importance plot using Plotly"""
262
+ # Get top k-mers by absolute SHAP value
263
+ kmers = [''.join(p) for p in product("ACGT", repeat=4)]
264
+ indices = np.argsort(np.abs(analysis.shap_values))[-top_k:]
265
+
266
+ # Create DataFrame for plotting
267
+ df = pd.DataFrame({
268
+ 'k-mer': [kmers[i] for i in indices],
269
+ 'SHAP': analysis.shap_values[indices]
270
+ })
271
+
272
+ # Create plot
273
+ fig = px.bar(
274
+ df,
275
+ x='SHAP',
276
+ y='k-mer',
277
+ orientation='h',
278
+ color='SHAP',
279
+ color_continuous_scale='RdBu',
280
+ title=f'Top {top_k} Most Influential k-mers'
281
+ )
282
+
283
+ # Update layout
284
+ fig.update_layout(
285
+ height=400,
286
+ plot_bgcolor='white',
287
+ yaxis_title='',
288
+ xaxis_title='SHAP Value',
289
+ coloraxis_showscale=False
290
+ )
291
+
292
  return fig
293
 
294
+ def create_shap_distribution_plot(analysis: SequenceAnalysis) -> go.Figure:
295
+ """Create SHAP distribution plot using Plotly"""
296
+ fig = go.Figure()
297
+
298
+ # Add histogram
299
+ fig.add_trace(go.Histogram(
300
+ x=analysis.shap_means,
301
+ nbinsx=50,
302
+ name='SHAP Values',
303
+ marker_color='rgba(31, 119, 180, 0.6)'
304
+ ))
305
+
306
+ # Add vertical line at x=0
307
+ fig.add_vline(
308
+ x=0,
309
+ line_dash="dash",
310
+ line_color="red",
311
+ annotation_text="Neutral",
312
+ annotation_position="top"
313
+ )
314
+
315
+ # Update layout
316
+ fig.update_layout(
317
+ title='Distribution of SHAP Values',
318
+ xaxis_title='SHAP Value',
319
+ yaxis_title='Count',
320
+ plot_bgcolor='white',
321
+ height=400
322
+ )
323
+
324
  return fig
325
 
 
 
 
 
 
 
 
326
  ###############################################################################
327
+ # 5. MAIN ANALYSIS
328
  ###############################################################################
329
 
330
+ def analyze_sequence(
331
+ file_obj: str = None,
332
+ fasta_text: str = "",
333
+ window_size: int = 500,
334
+ model_path: str = 'model.pt',
335
+ scaler_path: str = 'scaler.pkl'
336
+ ) -> SequenceAnalysis:
337
+ """Main sequence analysis function"""
338
  # Handle input
339
  if fasta_text.strip():
340
  text = fasta_text.strip()
341
  elif file_obj is not None:
342
+ with open(file_obj, 'r') as f:
343
+ text = f.read()
 
 
 
344
  else:
345
+ raise ValueError("No input provided")
346
+
347
  # Parse FASTA
348
  sequences = parse_fasta(text)
349
  if not sequences:
350
+ raise ValueError("No valid FASTA sequences found")
351
 
352
  header, seq = sequences[0]
353
+
354
  # Load model and scaler
355
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
356
+ state_dict = torch.load(model_path, map_location=device)
357
+ model = VirusClassifier(256).to(device)
358
+ model.load_state_dict(state_dict)
359
+
360
+ scaler = joblib.load(scaler_path)
361
+
362
+ # Process sequence
 
 
 
 
363
  freq_vector = sequence_to_kmer_vector(seq)
364
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
365
  x_tensor = torch.FloatTensor(scaled_vector).to(device)
366
+
367
+ # Get SHAP values and classification
368
  shap_values, prob_human = calculate_shap_values(model, x_tensor)
369
  prob_nonhuman = 1.0 - prob_human
370
 
371
+ # Get per-base SHAP scores
372
+ shap_means = compute_positionwise_scores(seq, shap_values)
373
+
374
+ # Find extreme regions
375
+ extreme_regions = find_extreme_regions(shap_means, window_size)
376
+
377
+ # Create analysis object
378
+ return SequenceAnalysis(
379
+ header=header,
380
+ sequence=seq,
381
+ length=len(seq),
382
+ gc_content=compute_gc_content(seq),
383
+ classification="Human" if prob_human > 0.5 else "Non-human",
384
+ human_prob=prob_human,
385
+ nonhuman_prob=prob_nonhuman,
386
+ shap_values=shap_values,
387
+ shap_means=shap_means,
388
+ extreme_regions=extreme_regions
 
 
 
 
 
389
  )
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  ###############################################################################
392
+ # 6. GRADIO INTERFACE
393
  ###############################################################################
394
 
395
+ def create_interface():
396
+ """Create enhanced Gradio interface with improved layout and interactivity"""
397
+
398
+ def process_sequence(
399
+ file_obj: str,
400
+ fasta_text: str,
401
+ window_size: int,
402
+ top_kmers: int
403
+ ) -> Tuple[str, List[go.Figure]]:
404
+ """Process sequence and return formatted results and plots"""
405
+ try:
406
+ # Run analysis
407
+ analysis = analyze_sequence(
408
+ file_obj=file_obj,
409
+ fasta_text=fasta_text,
410
+ window_size=window_size
411
+ )
412
+
413
+ # Format results text
414
+ results = f"""
415
+ ### Sequence Analysis Results
416
+
417
+ **Basic Information**
418
+ - Sequence: {analysis.header}
419
+ - Length: {analysis.length:,} bases
420
+ - GC Content: {analysis.gc_content:.1f}%
421
+
422
+ **Classification**
423
+ - Prediction: {analysis.classification}
424
+ - Human Probability: {analysis.human_prob:.3f}
425
+ - Non-human Probability: {analysis.nonhuman_prob:.3f}
426
+
427
+ **Extreme Regions (window size: {window_size}bp)**
428
+ Most Human-like Region:
429
+ - Position: {analysis.extreme_regions['human']['start']:,} - {analysis.extreme_regions['human']['end']:,}
430
+ - Average SHAP: {analysis.extreme_regions['human']['avg_shap']:.4f}
431
+
432
+ Most Non-human-like Region:
433
+ - Position: {analysis.extreme_regions['nonhuman']['start']:,} - {analysis.extreme_regions['nonhuman']['end']:,}
434
+ - Average SHAP: {analysis.extreme_regions['nonhuman']['avg_shap']:.4f}
435
+ """
436
+
437
+ # Create plots
438
+ genome_plot = create_genome_overview_plot(analysis)
439
+ kmer_plot = create_kmer_importance_plot(analysis, top_kmers)
440
+ dist_plot = create_shap_distribution_plot(analysis)
441
+
442
+ return results, [genome_plot, kmer_plot, dist_plot], analysis
443
+
444
+ except Exception as e:
445
+ return f"Error: {str(e)}", [], None
446
+
447
+ # Create theme and styling
448
+ theme = gr.themes.Soft(
449
+ primary_hue="blue",
450
+ secondary_hue="gray",
451
+ ).set(
452
+ body_text_color="gray-dark",
453
+ background_fill_primary="*gray-50",
454
+ block_shadow="*shadow-sm",
455
+ block_background_fill="white",
456
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
+ # Build interface
459
+ with gr.Blocks(theme=theme, css="""
460
+ .container { margin: 0 auto; max-width: 1200px; padding: 20px; }
461
+ .results { margin-top: 20px; }
462
+ .plot-container { margin-top: 10px; }
463
+ """) as interface:
464
  gr.Markdown("""
465
+ # 🧬 Enhanced Virus Host Classifier
466
+
467
+ This tool analyzes viral sequences to predict their host (human vs. non-human) and provides detailed visualizations
468
+ of the features influencing this classification. Upload or paste a FASTA sequence to begin.
469
+
470
+ *Using advanced SHAP analysis and interactive visualizations for interpretable results.*
471
  """)
 
 
 
 
472
 
473
+ # Input section
474
+ with gr.Tab("Sequence Analysis"):
475
+ with gr.Row():
476
+ with gr.Column(scale=1):
477
+ file_input = gr.File(
478
+ label="Upload FASTA File",
479
+ file_types=[".fasta", ".fa", ".txt"],
480
+ type="filepath"
481
+ )
482
+
483
+ text_input = gr.Textbox(
484
+ label="Or Paste FASTA Sequence",
485
+ placeholder=">sequence_name\nACGTACGT...",
486
+ lines=5
487
+ )
488
+
489
+ with gr.Row():
490
+ window_size = gr.Slider(
491
+ minimum=100,
492
+ maximum=5000,
493
+ value=500,
494
+ step=100,
495
+ label="Window Size for Region Analysis"
496
+ )
497
+
498
+ top_kmers = gr.Slider(
499
+ minimum=5,
500
+ maximum=30,
501
+ value=10,
502
+ step=1,
503
+ label="Number of Top k-mers to Display"
504
+ )
505
+
506
+ analyze_btn = gr.Button(
507
+ "πŸ” Analyze Sequence",
508
+ variant="primary"
509
+ )
510
+
511
+ # Results section
512
+ with gr.Column(scale=2):
513
+ results_text = gr.Markdown(
514
+ label="Analysis Results"
515
+ )
516
+
517
+ # Plots
518
+ genome_plot = gr.Plot(
519
+ label="Genome Overview"
520
+ )
521
+
522
+ with gr.Row():
523
+ kmer_plot = gr.Plot(
524
+ label="k-mer Importance"
525
+ )
526
+ dist_plot = gr.Plot(
527
+ label="SHAP Distribution"
528
+ )
529
+
530
+ # Help tab
531
+ with gr.Tab("Help & Information"):
532
+ gr.Markdown("""
533
+ ### πŸ“– How to Use This Tool
534
+
535
+ 1. **Input Your Sequence**
536
+ - Upload a FASTA file or paste your sequence in FASTA format
537
+ - The sequence should contain only ACGT bases (non-standard bases will be filtered)
538
+
539
+ 2. **Adjust Parameters**
540
+ - Window Size: Controls the length of regions analyzed for extreme patterns
541
+ - Top k-mers: Number of most influential sequence patterns to display
542
+
543
+ 3. **Interpret Results**
544
+ - Classification: Predicted host (human vs. non-human)
545
+ - Genome Overview: Interactive plot showing SHAP values and GC content
546
+ - k-mer Importance: Most influential sequence patterns
547
+ - SHAP Distribution: Overall distribution of feature importance
548
+
549
+ ### 🎨 Visualization Guide
550
+
551
+ - **SHAP Values**:
552
+ - Positive (red) = pushing toward human classification
553
+ - Negative (blue) = pushing toward non-human classification
554
+ - Zero (white) = neutral impact
555
+
556
+ - **Extreme Regions**:
557
+ - Highlighted in the genome overview plot
558
+ - Red regions = most human-like
559
+ - Blue regions = most non-human-like
560
+
561
+ ### πŸ”¬ Technical Details
562
+
563
+ - The classifier uses k-mer frequencies (k=4) as features
564
+ - SHAP values are calculated using an ablation-based approach
565
+ - GC content is calculated using a sliding window
566
+ """)
567
+
568
+ # Connect components
569
+ sequence_state = gr.State()
570
 
571
+ analyze_btn.click(
572
+ process_sequence,
573
+ inputs=[
574
+ file_input,
575
+ text_input,
576
+ window_size,
577
+ top_kmers
578
+ ],
579
+ outputs=[
580
+ results_text,
581
+ [genome_plot, kmer_plot, dist_plot],
582
+ sequence_state
583
+ ]
584
  )
585
+
586
+ return interface
587
+
588
+ ###############################################################################
589
+ # 7. MAIN ENTRY POINT
590
+ ###############################################################################
 
 
 
 
 
 
 
 
591
 
592
  if __name__ == "__main__":
593
+ iface = create_interface()
594
+ iface.launch(
595
+ share=True,
596
+ server_name="0.0.0.0",
597
+ show_error=True
598
+ )
599
+ #