hiyata commited on
Commit
de0719b
·
verified ·
1 Parent(s): 1d54b05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +421 -468
app.py CHANGED
@@ -6,34 +6,13 @@ from itertools import product
6
  import torch.nn as nn
7
  import matplotlib.pyplot as plt
8
  import matplotlib.colors as mcolors
9
- import seaborn as sns
10
- from PIL import Image
11
  import io
12
- import pandas as pd
13
- from typing import Tuple, List, Dict, Any
14
- from dataclasses import dataclass
15
- import plotly.graph_objects as go
16
- import plotly.express as px
17
- from plotly.subplots import make_subplots
18
 
19
  ###############################################################################
20
- # 1. DATA STRUCTURES & MODEL
21
  ###############################################################################
22
 
23
- @dataclass
24
- class SequenceAnalysis:
25
- """Container for sequence analysis results"""
26
- header: str
27
- sequence: str
28
- length: int
29
- gc_content: float
30
- classification: str
31
- human_prob: float
32
- nonhuman_prob: float
33
- shap_values: np.ndarray
34
- shap_means: np.ndarray
35
- extreme_regions: Dict[str, Dict[str, Any]]
36
-
37
  class VirusClassifier(nn.Module):
38
  def __init__(self, input_shape: int):
39
  super(VirusClassifier, self).__init__()
@@ -50,16 +29,16 @@ class VirusClassifier(nn.Module):
50
  nn.GELU(),
51
  nn.Linear(32, 2)
52
  )
53
-
54
  def forward(self, x):
55
  return self.network(x)
56
 
57
  ###############################################################################
58
- # 2. SEQUENCE PROCESSING
59
  ###############################################################################
60
 
61
- def parse_fasta(text: str) -> List[Tuple[str, str]]:
62
- """Parse FASTA formatted text with improved robustness"""
63
  sequences = []
64
  current_header = None
65
  current_sequence = []
@@ -74,68 +53,67 @@ def parse_fasta(text: str) -> List[Tuple[str, str]]:
74
  current_header = line[1:]
75
  current_sequence = []
76
  else:
77
- # Filter out non-ACGT characters and convert to uppercase
78
- filtered_line = ''.join(c for c in line.upper() if c in 'ACGT')
79
- current_sequence.append(filtered_line)
80
-
81
  if current_header:
82
  sequences.append((current_header, ''.join(current_sequence)))
83
  return sequences
84
 
85
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
86
- """Convert sequence to k-mer frequency vector with optimizations"""
87
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
88
  kmer_dict = {km: i for i, km in enumerate(kmers)}
89
  vec = np.zeros(len(kmers), dtype=np.float32)
90
 
91
- # Use sliding window for efficiency
92
  for i in range(len(sequence) - k + 1):
93
  kmer = sequence[i:i+k]
94
- if kmer in kmer_dict: # Handle non-ACGT kmers
95
  vec[kmer_dict[kmer]] += 1
96
-
97
- # Normalize
98
  total_kmers = len(sequence) - k + 1
99
  if total_kmers > 0:
100
  vec = vec / total_kmers
101
-
102
- return vec
103
 
104
- def compute_gc_content(sequence: str) -> float:
105
- """Compute GC content percentage"""
106
- if not sequence:
107
- return 0.0
108
- gc_count = sum(1 for base in sequence if base in 'GC')
109
- return (gc_count / len(sequence)) * 100.0
110
 
111
  ###############################################################################
112
- # 3. SHAP & ANALYSIS
113
  ###############################################################################
114
 
115
- def calculate_shap_values(model: nn.Module, x_tensor: torch.Tensor) -> Tuple[np.ndarray, float]:
116
- """Calculate SHAP values using ablation with improved efficiency"""
 
 
 
117
  model.eval()
118
  with torch.no_grad():
 
119
  baseline_output = model(x_tensor)
120
  baseline_probs = torch.softmax(baseline_output, dim=1)
121
- baseline_prob = baseline_probs[0, 1].item()
122
 
 
123
  shap_values = []
124
  x_zeroed = x_tensor.clone()
125
-
126
- # Vectorized computation where possible
127
  for i in range(x_tensor.shape[1]):
 
128
  x_zeroed[0, i] = 0.0
129
  output = model(x_zeroed)
130
  probs = torch.softmax(output, dim=1)
131
- impact = baseline_prob - probs[0, 1].item()
 
132
  shap_values.append(impact)
133
- x_zeroed[0, i] = x_tensor[0, i]
134
-
135
  return np.array(shap_values), baseline_prob
136
 
137
- def compute_positionwise_scores(sequence: str, shap_values: np.ndarray, k: int = 4) -> np.ndarray:
138
- """Compute per-base SHAP scores with optimized memory usage"""
 
 
 
 
 
 
 
139
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
140
  kmer_dict = {km: i for i, km in enumerate(kmers)}
141
 
@@ -143,472 +121,447 @@ def compute_positionwise_scores(sequence: str, shap_values: np.ndarray, k: int =
143
  shap_sums = np.zeros(seq_len, dtype=np.float32)
144
  coverage = np.zeros(seq_len, dtype=np.float32)
145
 
146
- # Vectorized operations where possible
147
  for i in range(seq_len - k + 1):
148
  kmer = sequence[i:i+k]
149
  if kmer in kmer_dict:
150
- idx = kmer_dict[kmer]
151
- shap_sums[i:i+k] += shap_values[idx]
152
- coverage[i:i+k] += 1
153
-
154
  with np.errstate(divide='ignore', invalid='ignore'):
155
  shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0)
156
-
157
  return shap_means
158
 
159
- def find_extreme_regions(shap_means: np.ndarray, window_size: int = 500) -> Dict[str, Dict[str, Any]]:
160
- """Find regions with extreme SHAP values using efficient sliding window"""
161
- if len(shap_means) < window_size:
162
- window_size = len(shap_means)
163
-
164
- # Compute cumulative sum for efficient sliding window
165
- cumsum = np.cumsum(np.pad(shap_means, (0, 1)))
166
-
167
- # Sliding window calculation
168
- window_avgs = (cumsum[window_size:] - cumsum[:-window_size]) / window_size
169
-
170
- max_idx = np.argmax(window_avgs)
171
- min_idx = np.argmin(window_avgs)
172
-
173
- return {
174
- "human": {
175
- "start": max_idx,
176
- "end": max_idx + window_size,
177
- "avg_shap": float(window_avgs[max_idx])
178
- },
179
- "nonhuman": {
180
- "start": min_idx,
181
- "end": min_idx + window_size,
182
- "avg_shap": float(window_avgs[min_idx])
183
- }
184
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  ###############################################################################
187
- # 4. VISUALIZATION
188
  ###############################################################################
189
 
190
- def create_genome_overview_plot(analysis: SequenceAnalysis) -> go.Figure:
191
- """Create an interactive genome overview using Plotly"""
192
- fig = make_subplots(
193
- rows=2, cols=1,
194
- subplot_titles=("SHAP Values Along Genome", "GC Content"),
195
- row_heights=[0.7, 0.3],
196
- vertical_spacing=0.1
197
- )
198
-
199
- # SHAP trace
200
- fig.add_trace(
201
- go.Scatter(
202
- x=list(range(len(analysis.shap_means))),
203
- y=analysis.shap_means,
204
- name="SHAP",
205
- line=dict(color='rgba(31, 119, 180, 0.8)'),
206
- hovertemplate="Position: %{x}<br>SHAP: %{y:.4f}<extra></extra>"
207
- ),
208
- row=1, col=1
209
- )
210
-
211
- # Highlight extreme regions
212
- for region_type, region in analysis.extreme_regions.items():
213
- color = 'rgba(255, 0, 0, 0.2)' if region_type == 'human' else 'rgba(0, 0, 255, 0.2)'
214
- fig.add_vrect(
215
- x0=region['start'],
216
- x1=region['end'],
217
- fillcolor=color,
218
- opacity=0.5,
219
- layer="below",
220
- line_width=0,
221
- row=1, col=1
222
- )
223
-
224
- # Calculate rolling GC content
225
- window = 100
226
- gc_content = np.array([
227
- compute_gc_content(analysis.sequence[i:i+window])
228
- for i in range(0, len(analysis.sequence) - window + 1, window)
229
- ])
230
-
231
- # GC content trace
232
- fig.add_trace(
233
- go.Scatter(
234
- x=np.arange(len(gc_content)) * window,
235
- y=gc_content,
236
- name="GC%",
237
- line=dict(color='rgba(44, 160, 44, 0.8)'),
238
- hovertemplate="Position: %{x}<br>GC%: %{y:.1f}%<extra></extra>"
239
- ),
240
- row=2, col=1
 
 
 
 
 
 
 
 
 
 
 
 
241
  )
242
 
243
- # Update layout
244
- fig.update_layout(
245
- height=800,
246
- title=dict(
247
- text=f"Genome Analysis Overview<br><sub>{analysis.header}</sub>",
248
- x=0.5
249
- ),
250
- showlegend=False,
251
- plot_bgcolor='white'
252
- )
253
 
254
- # Update axes
255
- fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
256
- fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
257
 
258
  return fig
259
 
260
- def create_kmer_importance_plot(analysis: SequenceAnalysis, top_k: int = 10) -> go.Figure:
261
- """Create interactive k-mer importance plot using Plotly"""
262
- # Get top k-mers by absolute SHAP value
263
- kmers = [''.join(p) for p in product("ACGT", repeat=4)]
264
- indices = np.argsort(np.abs(analysis.shap_values))[-top_k:]
265
-
266
- # Create DataFrame for plotting
267
- df = pd.DataFrame({
268
- 'k-mer': [kmers[i] for i in indices],
269
- 'SHAP': analysis.shap_values[indices]
270
- })
271
-
272
- # Create plot
273
- fig = px.bar(
274
- df,
275
- x='SHAP',
276
- y='k-mer',
277
- orientation='h',
278
- color='SHAP',
279
- color_continuous_scale='RdBu',
280
- title=f'Top {top_k} Most Influential k-mers'
281
- )
282
-
283
- # Update layout
284
- fig.update_layout(
285
- height=400,
286
- plot_bgcolor='white',
287
- yaxis_title='',
288
- xaxis_title='SHAP Value',
289
- coloraxis_showscale=False
290
- )
291
-
292
  return fig
293
 
294
- def create_shap_distribution_plot(analysis: SequenceAnalysis) -> go.Figure:
295
- """Create SHAP distribution plot using Plotly"""
296
- fig = go.Figure()
297
-
298
- # Add histogram
299
- fig.add_trace(go.Histogram(
300
- x=analysis.shap_means,
301
- nbinsx=50,
302
- name='SHAP Values',
303
- marker_color='rgba(31, 119, 180, 0.6)'
304
- ))
305
-
306
- # Add vertical line at x=0
307
- fig.add_vline(
308
- x=0,
309
- line_dash="dash",
310
- line_color="red",
311
- annotation_text="Neutral",
312
- annotation_position="top"
313
- )
314
-
315
- # Update layout
316
- fig.update_layout(
317
- title='Distribution of SHAP Values',
318
- xaxis_title='SHAP Value',
319
- yaxis_title='Count',
320
- plot_bgcolor='white',
321
- height=400
322
- )
323
-
324
  return fig
325
 
 
 
 
 
 
 
 
326
  ###############################################################################
327
- # 5. MAIN ANALYSIS
328
  ###############################################################################
329
 
330
- def analyze_sequence(
331
- file_obj: str = None,
332
- fasta_text: str = "",
333
- window_size: int = 500,
334
- model_path: str = 'model.pt',
335
- scaler_path: str = 'scaler.pkl'
336
- ) -> SequenceAnalysis:
337
- """Main sequence analysis function"""
338
  # Handle input
339
  if fasta_text.strip():
340
  text = fasta_text.strip()
341
  elif file_obj is not None:
342
- with open(file_obj, 'r') as f:
343
- text = f.read()
 
 
 
344
  else:
345
- raise ValueError("No input provided")
346
-
347
  # Parse FASTA
348
  sequences = parse_fasta(text)
349
  if not sequences:
350
- raise ValueError("No valid FASTA sequences found")
351
 
352
  header, seq = sequences[0]
353
-
354
  # Load model and scaler
355
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
356
- state_dict = torch.load(model_path, map_location=device)
357
- model = VirusClassifier(256).to(device)
358
- model.load_state_dict(state_dict)
359
-
360
- scaler = joblib.load(scaler_path)
361
-
362
- # Process sequence
 
 
 
 
363
  freq_vector = sequence_to_kmer_vector(seq)
364
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
365
  x_tensor = torch.FloatTensor(scaled_vector).to(device)
366
-
367
- # Get SHAP values and classification
368
  shap_values, prob_human = calculate_shap_values(model, x_tensor)
369
  prob_nonhuman = 1.0 - prob_human
370
 
371
- # Get per-base SHAP scores
372
- shap_means = compute_positionwise_scores(seq, shap_values)
373
-
374
- # Find extreme regions
375
- extreme_regions = find_extreme_regions(shap_means, window_size)
376
-
377
- # Create analysis object
378
- return SequenceAnalysis(
379
- header=header,
380
- sequence=seq,
381
- length=len(seq),
382
- gc_content=compute_gc_content(seq),
383
- classification="Human" if prob_human > 0.5 else "Non-human",
384
- human_prob=prob_human,
385
- nonhuman_prob=prob_nonhuman,
386
- shap_values=shap_values,
387
- shap_means=shap_means,
388
- extreme_regions=extreme_regions
 
 
 
 
 
389
  )
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  ###############################################################################
392
- # 6. GRADIO INTERFACE
393
  ###############################################################################
394
 
395
- def create_interface():
396
- """Create enhanced Gradio interface with improved layout and interactivity"""
397
-
398
- def process_sequence(
399
- file_obj: str,
400
- fasta_text: str,
401
- window_size: int,
402
- top_kmers: int
403
- ) -> Tuple[str, List[go.Figure]]:
404
- """Process sequence and return formatted results and plots"""
405
- try:
406
- # Run analysis
407
- analysis = analyze_sequence(
408
- file_obj=file_obj,
409
- fasta_text=fasta_text,
410
- window_size=window_size
411
- )
412
-
413
- # Format results text
414
- results = f"""
415
- ### Sequence Analysis Results
416
-
417
- **Basic Information**
418
- - Sequence: {analysis.header}
419
- - Length: {analysis.length:,} bases
420
- - GC Content: {analysis.gc_content:.1f}%
421
-
422
- **Classification**
423
- - Prediction: {analysis.classification}
424
- - Human Probability: {analysis.human_prob:.3f}
425
- - Non-human Probability: {analysis.nonhuman_prob:.3f}
426
-
427
- **Extreme Regions (window size: {window_size}bp)**
428
- Most Human-like Region:
429
- - Position: {analysis.extreme_regions['human']['start']:,} - {analysis.extreme_regions['human']['end']:,}
430
- - Average SHAP: {analysis.extreme_regions['human']['avg_shap']:.4f}
431
-
432
- Most Non-human-like Region:
433
- - Position: {analysis.extreme_regions['nonhuman']['start']:,} - {analysis.extreme_regions['nonhuman']['end']:,}
434
- - Average SHAP: {analysis.extreme_regions['nonhuman']['avg_shap']:.4f}
435
- """
436
-
437
- # Create plots
438
- genome_plot = create_genome_overview_plot(analysis)
439
- kmer_plot = create_kmer_importance_plot(analysis, top_kmers)
440
- dist_plot = create_shap_distribution_plot(analysis)
441
-
442
- return results, [genome_plot, kmer_plot, dist_plot], analysis
443
-
444
- except Exception as e:
445
- return f"Error: {str(e)}", [], None
446
-
447
- # Create theme and styling
448
- theme = gr.themes.Soft(
449
- primary_hue="blue",
450
- secondary_hue="gray",
451
- ).set(
452
- body_text_color="gray-dark",
453
- background_fill_primary="*gray-50",
454
- block_shadow="*shadow-sm",
455
- block_background_fill="white",
456
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
- # Build interface
459
- with gr.Blocks(theme=theme, css="""
460
- .container { margin: 0 auto; max-width: 1200px; padding: 20px; }
461
- .results { margin-top: 20px; }
462
- .plot-container { margin-top: 10px; }
463
- """) as interface:
464
  gr.Markdown("""
465
- # 🧬 Enhanced Virus Host Classifier
466
-
467
- This tool analyzes viral sequences to predict their host (human vs. non-human) and provides detailed visualizations
468
- of the features influencing this classification. Upload or paste a FASTA sequence to begin.
469
-
470
- *Using advanced SHAP analysis and interactive visualizations for interpretable results.*
471
  """)
 
 
 
 
472
 
473
- # Input section
474
- with gr.Tab("Sequence Analysis"):
475
- with gr.Row():
476
- with gr.Column(scale=1):
477
- file_input = gr.File(
478
- label="Upload FASTA File",
479
- file_types=[".fasta", ".fa", ".txt"],
480
- type="filepath"
481
- )
482
-
483
- text_input = gr.Textbox(
484
- label="Or Paste FASTA Sequence",
485
- placeholder=">sequence_name\nACGTACGT...",
486
- lines=5
487
- )
488
-
489
- with gr.Row():
490
- window_size = gr.Slider(
491
- minimum=100,
492
- maximum=5000,
493
- value=500,
494
- step=100,
495
- label="Window Size for Region Analysis"
496
- )
497
-
498
- top_kmers = gr.Slider(
499
- minimum=5,
500
- maximum=30,
501
- value=10,
502
- step=1,
503
- label="Number of Top k-mers to Display"
504
- )
505
-
506
- analyze_btn = gr.Button(
507
- "🔍 Analyze Sequence",
508
- variant="primary"
509
- )
510
-
511
- # Results section
512
- with gr.Column(scale=2):
513
- results_text = gr.Markdown(
514
- label="Analysis Results"
515
- )
516
-
517
- # Plots
518
- genome_plot = gr.Plot(
519
- label="Genome Overview"
520
- )
521
-
522
- with gr.Row():
523
- kmer_plot = gr.Plot(
524
- label="k-mer Importance"
525
- )
526
- dist_plot = gr.Plot(
527
- label="SHAP Distribution"
528
- )
529
-
530
- # Help tab
531
- with gr.Tab("Help & Information"):
532
- gr.Markdown("""
533
- ### 📖 How to Use This Tool
534
-
535
- 1. **Input Your Sequence**
536
- - Upload a FASTA file or paste your sequence in FASTA format
537
- - The sequence should contain only ACGT bases (non-standard bases will be filtered)
538
-
539
- 2. **Adjust Parameters**
540
- - Window Size: Controls the length of regions analyzed for extreme patterns
541
- - Top k-mers: Number of most influential sequence patterns to display
542
-
543
- 3. **Interpret Results**
544
- - Classification: Predicted host (human vs. non-human)
545
- - Genome Overview: Interactive plot showing SHAP values and GC content
546
- - k-mer Importance: Most influential sequence patterns
547
- - SHAP Distribution: Overall distribution of feature importance
548
-
549
- ### 🎨 Visualization Guide
550
-
551
- - **SHAP Values**:
552
- - Positive (red) = pushing toward human classification
553
- - Negative (blue) = pushing toward non-human classification
554
- - Zero (white) = neutral impact
555
-
556
- - **Extreme Regions**:
557
- - Highlighted in the genome overview plot
558
- - Red regions = most human-like
559
- - Blue regions = most non-human-like
560
-
561
- ### 🔬 Technical Details
562
-
563
- - The classifier uses k-mer frequencies (k=4) as features
564
- - SHAP values are calculated using an ablation-based approach
565
- - GC content is calculated using a sliding window
566
- """)
567
-
568
- # Connect components
569
- sequence_state = gr.State()
570
-
571
- def process_and_update(file_obj, fasta_text, window_size, top_kmers):
572
- """Wrapper to handle plot outputs correctly"""
573
- results, plots, analysis = process_sequence(file_obj, fasta_text, window_size, top_kmers)
574
- if plots:
575
- return [
576
- results,
577
- plots[0], # genome plot
578
- plots[1], # kmer plot
579
- plots[2], # distribution plot
580
- analysis
581
- ]
582
- return [results, None, None, None, None]
583
-
584
- analyze_btn.click(
585
- process_and_update,
586
- inputs=[
587
- file_input,
588
- text_input,
589
- window_size,
590
- top_kmers
591
- ],
592
- outputs=[
593
- results_text,
594
- genome_plot,
595
- kmer_plot,
596
- dist_plot,
597
- sequence_state
598
- ]
599
  )
 
 
 
600
 
601
- return interface
602
-
603
- ###############################################################################
604
- # 7. MAIN ENTRY POINT
605
- ###############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
 
607
  if __name__ == "__main__":
608
- iface = create_interface()
609
- iface.launch(
610
- share=True,
611
- server_name="0.0.0.0",
612
- show_error=True
613
- )
614
- #
 
6
  import torch.nn as nn
7
  import matplotlib.pyplot as plt
8
  import matplotlib.colors as mcolors
 
 
9
  import io
10
+ from PIL import Image
 
 
 
 
 
11
 
12
  ###############################################################################
13
+ # 1. MODEL DEFINITION
14
  ###############################################################################
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class VirusClassifier(nn.Module):
17
  def __init__(self, input_shape: int):
18
  super(VirusClassifier, self).__init__()
 
29
  nn.GELU(),
30
  nn.Linear(32, 2)
31
  )
32
+
33
  def forward(self, x):
34
  return self.network(x)
35
 
36
  ###############################################################################
37
+ # 2. FASTA PARSING & K-MER FEATURE ENGINEERING
38
  ###############################################################################
39
 
40
+ def parse_fasta(text):
41
+ """Parse FASTA formatted text into a list of (header, sequence)."""
42
  sequences = []
43
  current_header = None
44
  current_sequence = []
 
53
  current_header = line[1:]
54
  current_sequence = []
55
  else:
56
+ current_sequence.append(line.upper())
 
 
 
57
  if current_header:
58
  sequences.append((current_header, ''.join(current_sequence)))
59
  return sequences
60
 
61
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
62
+ """Convert a sequence to a k-mer frequency vector for classification."""
63
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
64
  kmer_dict = {km: i for i, km in enumerate(kmers)}
65
  vec = np.zeros(len(kmers), dtype=np.float32)
66
 
 
67
  for i in range(len(sequence) - k + 1):
68
  kmer = sequence[i:i+k]
69
+ if kmer in kmer_dict:
70
  vec[kmer_dict[kmer]] += 1
71
+
 
72
  total_kmers = len(sequence) - k + 1
73
  if total_kmers > 0:
74
  vec = vec / total_kmers
 
 
75
 
76
+ return vec
 
 
 
 
 
77
 
78
  ###############################################################################
79
+ # 3. SHAP-VALUE (ABLATION) CALCULATION
80
  ###############################################################################
81
 
82
+ def calculate_shap_values(model, x_tensor):
83
+ """
84
+ Calculate SHAP values using a simple ablation approach.
85
+ Returns shap_values, prob_human
86
+ """
87
  model.eval()
88
  with torch.no_grad():
89
+ # Baseline
90
  baseline_output = model(x_tensor)
91
  baseline_probs = torch.softmax(baseline_output, dim=1)
92
+ baseline_prob = baseline_probs[0, 1].item() # Probability of 'human' class
93
 
94
+ # Zeroing each feature to measure impact
95
  shap_values = []
96
  x_zeroed = x_tensor.clone()
 
 
97
  for i in range(x_tensor.shape[1]):
98
+ original_val = x_zeroed[0, i].item()
99
  x_zeroed[0, i] = 0.0
100
  output = model(x_zeroed)
101
  probs = torch.softmax(output, dim=1)
102
+ prob = probs[0, 1].item()
103
+ impact = baseline_prob - prob
104
  shap_values.append(impact)
105
+ x_zeroed[0, i] = original_val # restore
 
106
  return np.array(shap_values), baseline_prob
107
 
108
+ ###############################################################################
109
+ # 4. PER-BASE SHAP AGGREGATION
110
+ ###############################################################################
111
+
112
+ def compute_positionwise_scores(sequence, shap_values, k=4):
113
+ """
114
+ Returns an array of per-base SHAP contributions by averaging
115
+ the k-mer SHAP values of all k-mers covering that base.
116
+ """
117
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
118
  kmer_dict = {km: i for i, km in enumerate(kmers)}
119
 
 
121
  shap_sums = np.zeros(seq_len, dtype=np.float32)
122
  coverage = np.zeros(seq_len, dtype=np.float32)
123
 
 
124
  for i in range(seq_len - k + 1):
125
  kmer = sequence[i:i+k]
126
  if kmer in kmer_dict:
127
+ val = shap_values[kmer_dict[kmer]]
128
+ shap_sums[i : i + k] += val
129
+ coverage[i : i + k] += 1
130
+
131
  with np.errstate(divide='ignore', invalid='ignore'):
132
  shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0)
133
+
134
  return shap_means
135
 
136
+ ###############################################################################
137
+ # 5. FIND EXTREME SHAP REGIONS
138
+ ###############################################################################
139
+
140
+ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
141
+ """
142
+ Finds the subregion of length `window_size` that has the maximum
143
+ (mode="max") or minimum (mode="min") average SHAP.
144
+ Returns (best_start, best_end, best_avg).
145
+ """
146
+ n = len(shap_means)
147
+ if n == 0:
148
+ return (0, 0, 0.0)
149
+ if window_size >= n:
150
+ # entire sequence
151
+ avg_val = float(np.mean(shap_means))
152
+ return (0, n, avg_val)
153
+
154
+ # We'll build csum of length n+1
155
+ csum = np.zeros(n + 1, dtype=np.float32)
156
+ csum[1:] = np.cumsum(shap_means)
157
+
158
+ best_start = 0
159
+ best_sum = csum[window_size] - csum[0]
160
+ best_avg = best_sum / window_size
161
+
162
+ for start in range(1, n - window_size + 1):
163
+ wsum = csum[start + window_size] - csum[start]
164
+ wavg = wsum / window_size
165
+ if mode == "max":
166
+ if wavg > best_avg:
167
+ best_avg = wavg
168
+ best_start = start
169
+ else: # mode == "min"
170
+ if wavg < best_avg:
171
+ best_avg = wavg
172
+ best_start = start
173
+
174
+ return (best_start, best_start + window_size, float(best_avg))
175
 
176
  ###############################################################################
177
+ # 6. PLOTTING / UTILITIES
178
  ###############################################################################
179
 
180
+ def fig_to_image(fig):
181
+ """Convert a Matplotlib figure to a PIL Image for Gradio."""
182
+ buf = io.BytesIO()
183
+ fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
184
+ buf.seek(0)
185
+ img = Image.open(buf)
186
+ plt.close(fig)
187
+ return img
188
+
189
+ def get_zero_centered_cmap():
190
+ """
191
+ Creates a custom diverging colormap that is:
192
+ - Blue for negative
193
+ - White for zero
194
+ - Red for positive
195
+ """
196
+ colors = [
197
+ (0.0, 'blue'), # negative
198
+ (0.5, 'white'), # zero
199
+ (1.0, 'red') # positive
200
+ ]
201
+ cmap = mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
202
+ return cmap
203
+
204
+ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
205
+ """
206
+ Plots a 1D heatmap of per-base SHAP contributions with a custom colormap:
207
+ - Negative = blue
208
+ - 0 = white
209
+ - Positive = red
210
+ We'll force the range to be symmetrical around 0 by using:
211
+ vmin=-extent, vmax=+extent
212
+ so 0 is in the middle.
213
+ """
214
+ if start is not None and end is not None:
215
+ local_shap = shap_means[start:end]
216
+ subtitle = f" (positions {start}-{end})"
217
+ else:
218
+ local_shap = shap_means
219
+ subtitle = ""
220
+
221
+ if len(local_shap) == 0:
222
+ # Edge case: no data to plot
223
+ local_shap = np.array([0.0])
224
+
225
+ # Build 2D array for imshow
226
+ heatmap_data = local_shap.reshape(1, -1)
227
+
228
+ # Force symmetrical range
229
+ min_val = np.min(local_shap)
230
+ max_val = np.max(local_shap)
231
+ extent = max(abs(min_val), abs(max_val))
232
+
233
+ # Create custom colormap
234
+ custom_cmap = get_zero_centered_cmap()
235
+
236
+ fig, ax = plt.subplots(figsize=(12, 2))
237
+ cax = ax.imshow(
238
+ heatmap_data,
239
+ aspect='auto',
240
+ cmap=custom_cmap,
241
+ vmin=-extent,
242
+ vmax=+extent
243
  )
244
 
245
+ # Place colorbar below with plenty of margin
246
+ cbar = plt.colorbar(cax, orientation='horizontal', pad=0.35)
247
+ cbar.set_label('SHAP Contribution (negative=blue, zero=white, positive=red)')
248
+
249
+ ax.set_yticks([])
250
+ ax.set_xlabel('Position in Sequence')
251
+ ax.set_title(f"{title}{subtitle}")
 
 
 
252
 
253
+ # Extra bottom margin so colorbar won't overlap x-axis labels
254
+ plt.subplots_adjust(bottom=0.4)
 
255
 
256
  return fig
257
 
258
+ def create_importance_bar_plot(shap_values, kmers, top_k=10):
259
+ """Create a bar plot of the most important k-mers."""
260
+ plt.rcParams.update({'font.size': 10})
261
+ fig = plt.figure(figsize=(10, 5))
262
+
263
+ # Sort by absolute importance
264
+ indices = np.argsort(np.abs(shap_values))[-top_k:]
265
+ values = shap_values[indices]
266
+ features = [kmers[i] for i in indices]
267
+
268
+ # negative -> blue, positive -> red
269
+ colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
270
+
271
+ plt.barh(range(len(values)), values, color=colors)
272
+ plt.yticks(range(len(values)), features)
273
+ plt.xlabel('SHAP Value (impact on model output)')
274
+ plt.title(f'Top {top_k} Most Influential k-mers')
275
+ plt.gca().invert_yaxis()
276
+ plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  return fig
278
 
279
+ def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
280
+ """
281
+ Simple histogram of SHAP values in the subregion.
282
+ """
283
+ fig, ax = plt.subplots(figsize=(6, 4))
284
+ ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
285
+ ax.axvline(0, color='red', linestyle='--', label='0.0')
286
+ ax.set_xlabel("SHAP Value")
287
+ ax.set_ylabel("Count")
288
+ ax.set_title(title)
289
+ ax.legend()
290
+ plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  return fig
292
 
293
+ def compute_gc_content(sequence):
294
+ """Compute %GC in the sequence (A, C, G, T)."""
295
+ if not sequence:
296
+ return 0
297
+ gc_count = sequence.count('G') + sequence.count('C')
298
+ return (gc_count / len(sequence)) * 100.0
299
+
300
  ###############################################################################
301
+ # 7. MAIN ANALYSIS STEP (Gradio Step 1)
302
  ###############################################################################
303
 
304
+ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
305
+ """
306
+ Analyzes the entire genome, returning classification, full-genome heatmap,
307
+ top k-mer bar plot, and identifies subregions with strongest positive/negative push.
308
+ """
 
 
 
309
  # Handle input
310
  if fasta_text.strip():
311
  text = fasta_text.strip()
312
  elif file_obj is not None:
313
+ try:
314
+ with open(file_obj, 'r') as f:
315
+ text = f.read()
316
+ except Exception as e:
317
+ return (f"Error reading file: {str(e)}", None, None, None, None)
318
  else:
319
+ return ("Please provide a FASTA sequence.", None, None, None, None)
320
+
321
  # Parse FASTA
322
  sequences = parse_fasta(text)
323
  if not sequences:
324
+ return ("No valid FASTA sequences found.", None, None, None, None)
325
 
326
  header, seq = sequences[0]
327
+
328
  # Load model and scaler
329
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
330
+ try:
331
+ # Use weights_only=True for safer loading
332
+ state_dict = torch.load('model.pt', map_location=device, weights_only=True)
333
+ model = VirusClassifier(256).to(device)
334
+ model.load_state_dict(state_dict)
335
+
336
+ scaler = joblib.load('scaler.pkl')
337
+ except Exception as e:
338
+ return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
339
+
340
+ # Vectorize + scale
341
  freq_vector = sequence_to_kmer_vector(seq)
342
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
343
  x_tensor = torch.FloatTensor(scaled_vector).to(device)
344
+
345
+ # SHAP + classification
346
  shap_values, prob_human = calculate_shap_values(model, x_tensor)
347
  prob_nonhuman = 1.0 - prob_human
348
 
349
+ classification = "Human" if prob_human > 0.5 else "Non-human"
350
+ confidence = max(prob_human, prob_nonhuman)
351
+
352
+ # Per-base SHAP
353
+ shap_means = compute_positionwise_scores(seq, shap_values, k=4)
354
+
355
+ # Find the most "human-pushing" region
356
+ (max_start, max_end, max_avg) = find_extreme_subregion(shap_means, window_size, mode="max")
357
+ # Find the most "non-human–pushing" region
358
+ (min_start, min_end, min_avg) = find_extreme_subregion(shap_means, window_size, mode="min")
359
+
360
+ # Build results text
361
+ results_text = (
362
+ f"Sequence: {header}\n"
363
+ f"Length: {len(seq):,} bases\n"
364
+ f"Classification: {classification}\n"
365
+ f"Confidence: {confidence:.3f}\n"
366
+ f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
367
+ f"---\n"
368
+ f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
369
+ f"Start: {max_start}, End: {max_end}, Avg SHAP: {max_avg:.4f}\n\n"
370
+ f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
371
+ f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
372
  )
373
 
374
+ # K-mer importance plot
375
+ kmers = [''.join(p) for p in product("ACGT", repeat=4)]
376
+ bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
377
+ bar_img = fig_to_image(bar_fig)
378
+
379
+ # Full-genome SHAP heatmap
380
+ heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
381
+ heatmap_img = fig_to_image(heatmap_fig)
382
+
383
+ # Store data for subregion analysis
384
+ state_dict_out = {
385
+ "seq": seq,
386
+ "shap_means": shap_means
387
+ }
388
+
389
+ return (results_text, bar_img, heatmap_img, state_dict_out, header)
390
+
391
  ###############################################################################
392
+ # 8. SUBREGION ANALYSIS (Gradio Step 2)
393
  ###############################################################################
394
 
395
+ def analyze_subregion(state, header, region_start, region_end):
396
+ """
397
+ Takes stored data from step 1 and a user-chosen region.
398
+ Returns a subregion heatmap, histogram, and some stats (GC, average SHAP).
399
+ """
400
+ if not state or "seq" not in state or "shap_means" not in state:
401
+ return ("No sequence data found. Please run Step 1 first.", None, None)
402
+
403
+ seq = state["seq"]
404
+ shap_means = state["shap_means"]
405
+
406
+ # Validate bounds
407
+ region_start = int(region_start)
408
+ region_end = int(region_end)
409
+
410
+ region_start = max(0, min(region_start, len(seq)))
411
+ region_end = max(0, min(region_end, len(seq)))
412
+ if region_end <= region_start:
413
+ return ("Invalid region range. End must be > Start.", None, None)
414
+
415
+ # Subsequence
416
+ region_seq = seq[region_start:region_end]
417
+ region_shap = shap_means[region_start:region_end]
418
+
419
+ # Some stats
420
+ gc_percent = compute_gc_content(region_seq)
421
+ avg_shap = float(np.mean(region_shap))
422
+
423
+ # Fraction pushing toward human vs. non-human
424
+ positive_fraction = np.mean(region_shap > 0)
425
+ negative_fraction = np.mean(region_shap < 0)
426
+
427
+ # Simple logic-based interpretation
428
+ if avg_shap > 0.05:
429
+ region_classification = "Likely pushing toward human"
430
+ elif avg_shap < -0.05:
431
+ region_classification = "Likely pushing toward non-human"
432
+ else:
433
+ region_classification = "Near neutral (no strong push)"
434
+
435
+ region_info = (
436
+ f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
437
+ f"Region length: {len(region_seq)} bases\n"
438
+ f"GC content: {gc_percent:.2f}%\n"
439
+ f"Average SHAP in region: {avg_shap:.4f}\n"
440
+ f"Fraction with SHAP > 0 (toward human): {positive_fraction:.2f}\n"
441
+ f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
442
+ f"Subregion interpretation: {region_classification}\n"
443
+ )
444
+
445
+ # Plot region as small heatmap
446
+ heatmap_fig = plot_linear_heatmap(
447
+ shap_means,
448
+ title="Subregion SHAP",
449
+ start=region_start,
450
+ end=region_end
 
 
 
 
 
451
  )
452
+ heatmap_img = fig_to_image(heatmap_fig)
453
+
454
+ # Plot histogram of SHAP in region
455
+ hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
456
+ hist_img = fig_to_image(hist_fig)
457
+
458
+ return (region_info, heatmap_img, hist_img)
459
+
460
+
461
+ ###############################################################################
462
+ # 9. BUILD GRADIO INTERFACE
463
+ ###############################################################################
464
+
465
+ css = """
466
+ .gradio-container {
467
+ font-family: 'IBM Plex Sans', sans-serif;
468
+ }
469
+ """
470
+
471
+ with gr.Blocks(css=css) as iface:
472
+ gr.Markdown("""
473
+ # Virus Host Classifier with White-Centered Gradient
474
+ **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
475
+ **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
476
+
477
+ **Color Scale**: Negative SHAP = Blue, Zero = White, Positive = Red.
478
+ """)
479
+
480
+ with gr.Tab("1) Full-Sequence Analysis"):
481
+ with gr.Row():
482
+ with gr.Column(scale=1):
483
+ file_input = gr.File(
484
+ label="Upload FASTA file",
485
+ file_types=[".fasta", ".fa", ".txt"],
486
+ type="filepath"
487
+ )
488
+ text_input = gr.Textbox(
489
+ label="Or paste FASTA sequence",
490
+ placeholder=">sequence_name\nACGTACGT...",
491
+ lines=5
492
+ )
493
+ top_k = gr.Slider(
494
+ minimum=5,
495
+ maximum=30,
496
+ value=10,
497
+ step=1,
498
+ label="Number of top k-mers to display"
499
+ )
500
+ win_size = gr.Slider(
501
+ minimum=100,
502
+ maximum=5000,
503
+ value=500,
504
+ step=100,
505
+ label="Window size for 'most pushing' subregions"
506
+ )
507
+ analyze_btn = gr.Button("Analyze Sequence", variant="primary")
508
+
509
+ with gr.Column(scale=2):
510
+ results_box = gr.Textbox(
511
+ label="Classification Results", lines=12, interactive=False
512
+ )
513
+ kmer_img = gr.Image(label="Top k-mer SHAP")
514
+ genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
515
+
516
+ seq_state = gr.State()
517
+ header_state = gr.State()
518
+
519
+ # analyze_sequence(...) returns 5 items
520
+ analyze_btn.click(
521
+ analyze_sequence,
522
+ inputs=[file_input, top_k, text_input, win_size],
523
+ outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
524
+ )
525
 
526
+ with gr.Tab("2) Subregion Exploration"):
 
 
 
 
 
527
  gr.Markdown("""
528
+ **Subregion Analysis**
529
+ Select start/end positions to view local SHAP signals, distribution, and GC content.
530
+ The heatmap also uses the same Blue-White-Red scale.
 
 
 
531
  """)
532
+ with gr.Row():
533
+ region_start = gr.Number(label="Region Start", value=0)
534
+ region_end = gr.Number(label="Region End", value=500)
535
+ region_btn = gr.Button("Analyze Subregion")
536
 
537
+ subregion_info = gr.Textbox(
538
+ label="Subregion Analysis",
539
+ lines=7,
540
+ interactive=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  )
542
+ with gr.Row():
543
+ subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
544
+ subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
545
 
546
+ region_btn.click(
547
+ analyze_subregion,
548
+ inputs=[seq_state, header_state, region_start, region_end],
549
+ outputs=[subregion_info, subregion_img, subregion_hist_img]
550
+ )
551
+
552
+ gr.Markdown("""
553
+ ### Interface Features
554
+ - **Overall Classification** (human vs non-human) using k-mer frequencies.
555
+ - **SHAP Analysis** to see which k-mers push classification toward or away from human.
556
+ - **White-Centered SHAP Gradient**:
557
+ - Negative (blue), 0 (white), Positive (red), with symmetrical color range around 0.
558
+ - **Identify Subregions** with the strongest push for human or non-human.
559
+ - **Subregion Exploration**:
560
+ - Local SHAP heatmap & histogram
561
+ - GC content
562
+ - Fraction of positions pushing human vs. non-human
563
+ - Simple logic-based classification
564
+ """)
565
 
566
  if __name__ == "__main__":
567
+ iface.launch()