hiyata commited on
Commit
f1d4be6
·
verified ·
1 Parent(s): 8731787

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +315 -371
app.py CHANGED
@@ -2,13 +2,15 @@ import gradio as gr
2
  import torch
3
  import joblib
4
  import numpy as np
 
5
  import torch.nn as nn
6
  import matplotlib.pyplot as plt
7
  import io
8
  from PIL import Image
9
- from itertools import product
10
 
11
- # --------------- Model Definition ---------------
 
 
12
 
13
  class VirusClassifier(nn.Module):
14
  def __init__(self, input_shape: int):
@@ -29,46 +31,20 @@ class VirusClassifier(nn.Module):
29
 
30
  def forward(self, x):
31
  return self.network(x)
32
-
33
- def get_gradient_importance(self, x, class_index=1):
34
- """
35
- Calculate gradient-based importance for each input feature.
36
- By default, we compute the gradient wrt the 'human' class (index=1).
37
- This method is akin to a raw gradient or 'saliency' approach.
38
- """
39
- x = x.clone().detach().requires_grad_(True)
40
- output = self.network(x)
41
- probs = torch.softmax(output, dim=1)
42
-
43
- # Probability of the specified class
44
- target_prob = probs[..., class_index]
45
-
46
- # Zero existing gradients if any
47
- if x.grad is not None:
48
- x.grad.zero_()
49
-
50
- # Backprop on that probability
51
- target_prob.backward()
52
-
53
- # Raw gradient is now in x.grad
54
- importance = x.grad.detach()
55
-
56
- # Optional: Multiply by input to get a more "integrated gradients"-like measure
57
- # importance = importance * x.detach()
58
-
59
- return importance, float(target_prob)
60
 
61
- # --------------- Utility Functions ---------------
 
 
62
 
63
- def parse_fasta(text: str):
64
  """
65
- Parse a FASTA string and return a list of (header, sequence) pairs.
66
  """
67
  sequences = []
68
  current_header = None
69
  current_sequence = []
70
 
71
- for line in text.split('\n'):
72
  line = line.strip()
73
  if not line:
74
  continue
@@ -85,10 +61,8 @@ def parse_fasta(text: str):
85
 
86
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
87
  """
88
- Convert a nucleotide sequence into a k-mer frequency vector.
89
- Defaults to k=4.
90
  """
91
- # Generate all possible k-mers
92
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
93
  kmer_dict = {km: i for i, km in enumerate(kmers)}
94
  vec = np.zeros(len(kmers), dtype=np.float32)
@@ -104,385 +78,355 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
104
 
105
  return vec
106
 
107
- def compute_sequence_stats(sequence: str):
108
- """
109
- Compute various statistics for a given sequence:
110
- - Length
111
- - GC content (%)
112
- - A/C/G/T counts
113
  """
114
- length = len(sequence)
115
- if length == 0:
116
- return {
117
- 'length': 0,
118
- 'gc_content': 0,
119
- 'counts': {'A': 0, 'C': 0, 'G': 0, 'T': 0}
120
- }
121
-
122
- counts = {
123
- 'A': sequence.count('A'),
124
- 'C': sequence.count('C'),
125
- 'G': sequence.count('G'),
126
- 'T': sequence.count('T')
127
- }
128
- gc_content = (counts['G'] + counts['C']) / length * 100.0
129
-
130
- return {
131
- 'length': length,
132
- 'gc_content': gc_content,
133
- 'counts': counts
134
- }
135
-
136
- # --------------- Visualization Functions ---------------
137
-
138
- def plot_shap_like_bars(kmers, importance_values, top_k=10):
139
  """
140
- Create a bar chart that mimics a SHAP summary plot:
141
- - k-mers on y-axis
142
- - importance magnitude on x-axis
143
- - color indicating positive (push towards human) vs negative (push towards non-human)
144
- """
145
- abs_importance = np.abs(importance_values)
146
- # Sort by absolute importance
147
- sorted_indices = np.argsort(abs_importance)[::-1]
148
- top_indices = sorted_indices[:top_k]
149
-
150
- # Prepare data
151
- top_kmers = [kmers[i] for i in top_indices]
152
- top_importances = importance_values[top_indices]
 
 
 
 
 
 
 
153
 
154
- # Create plot
155
- fig, ax = plt.subplots(figsize=(8, 6))
156
- colors = ['green' if val > 0 else 'red' for val in top_importances]
157
- ax.barh(range(len(top_kmers)), np.abs(top_importances), color=colors)
158
- ax.set_yticks(range(len(top_kmers)))
159
- ax.set_yticklabels(top_kmers)
160
- ax.invert_yaxis() # So that the highest value is at the top
161
- ax.set_xlabel("Feature Importance (Gradient Magnitude)")
162
- ax.set_title(f"Top-{top_k} SHAP-like Feature Importances")
163
- plt.tight_layout()
164
- return fig
165
 
166
- def plot_kmer_distribution(kmer_freq_vector, kmers):
167
- """
168
- Plot a histogram of k-mer frequencies for the entire vector.
169
- (Optional if you want a quick distribution overview)
170
- """
171
- fig, ax = plt.subplots(figsize=(10, 4))
172
- ax.bar(range(len(kmer_freq_vector)), kmer_freq_vector, color='blue', alpha=0.6)
173
- ax.set_xlabel("K-mer Index")
174
- ax.set_ylabel("Frequency")
175
- ax.set_title("K-mer Frequency Distribution")
176
- ax.set_xticks([])
177
- plt.tight_layout()
178
- return fig
179
 
180
- def create_step_visualization(important_kmers, human_prob):
181
  """
182
- Re-implementation of your step-wise probability plot.
183
- Shows how each top k-mer 'pushes' the probability from 0.5 to the final value.
184
  """
185
- fig = plt.figure(figsize=(8, 5))
186
- ax = fig.add_subplot(111)
 
 
187
 
188
- # Start from 0.5
 
189
  current_prob = 0.5
190
  steps = [('Start', current_prob, 0)]
191
 
192
- for kmer in important_kmers:
193
- change = kmer['impact'] * (-1 if kmer['direction'] == 'non-human' else 1)
194
  current_prob += change
195
- steps.append((kmer['kmer'], current_prob, change))
196
-
197
- x_vals = range(len(steps))
198
- y_vals = [s[1] for s in steps]
199
-
200
- ax.step(x_vals, y_vals, 'b-', where='post', label='Probability', linewidth=2)
201
- ax.plot(x_vals, y_vals, 'b.', markersize=10)
202
 
203
- # Reference line at 0.5
204
- ax.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
205
- ax.set_ylim(0, 1)
206
- ax.set_ylabel('Human Probability')
207
- ax.set_title(f'K-mer Contributions (final p={human_prob:.3f})')
208
- ax.grid(True, linestyle='--', alpha=0.7)
209
-
 
 
 
 
 
 
 
 
 
 
210
  for i, (kmer, prob, change) in enumerate(steps):
211
- ax.annotate(kmer,
212
- (i, prob),
213
- xytext=(0, 10 if i % 2 == 0 else -20),
214
- textcoords='offset points',
215
- ha='center',
216
- rotation=45)
 
217
 
 
218
  if i > 0:
219
  change_text = f'{change:+.3f}'
220
  color = 'green' if change > 0 else 'red'
221
- ax.annotate(change_text,
222
- (i, prob),
223
- xytext=(0, -20 if i % 2 == 0 else 10),
224
- textcoords='offset points',
225
- ha='center',
226
- color=color)
227
-
228
- ax.legend()
229
- plt.tight_layout()
230
- return fig
231
-
232
- def plot_kmer_freq_and_sigma(important_kmers):
233
- """
234
- Plot frequencies vs. sigma from mean for the top k-mers.
235
- This reuses logic from the original create_visualization second subplot,
236
- but as its own function for clarity.
237
- """
238
- fig, ax = plt.subplots(figsize=(8, 5))
239
 
240
  # Prepare data
241
  kmers = [k['kmer'] for k in important_kmers]
242
  frequencies = [k['occurrence'] for k in important_kmers]
243
  sigmas = [k['sigma'] for k in important_kmers]
244
- colors = ['green' if k['direction'] == 'human' else 'red' for k in important_kmers]
245
 
 
 
 
 
246
  x = np.arange(len(kmers))
247
  width = 0.35
248
 
249
- # Frequency bars
250
- ax.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
251
 
252
- # Create a twin axis for sigma
253
- ax2 = ax.twinx()
254
- # Sigma bars
255
- ax2.bar(x + width/2, sigmas, width, label= from mean',
256
- color=[c if s > 0 else 'gray' for c, s in zip(colors, sigmas)], alpha=0.3)
 
 
 
 
 
 
257
 
258
- ax.set_xticks(x)
259
- ax.set_xticklabels(kmers, rotation=45)
260
- ax.set_ylabel('Frequency (%)')
261
- ax2.set_ylabel('Standard Deviations (σ) from Mean')
262
- ax.set_title("K-mer Frequencies & Statistical Significance")
263
-
264
- lines1, labels1 = ax.get_legend_handles_labels()
265
- lines2, labels2 = ax2.get_legend_handles_labels()
266
- ax.legend(lines1 + lines2, labels1 + labels2, loc='best')
 
 
267
 
268
  plt.tight_layout()
269
  return fig
270
 
271
- # --------------- Main Prediction Logic ---------------
272
-
273
- def predict_fasta(
274
- file_obj,
275
- k_size=4,
276
- top_k=10,
277
- advanced_analysis=False
278
- ):
279
  """
280
- Main function to predict classes for each sequence in an uploaded FASTA.
281
- Returns:
282
- - Combined textual report for all sequences
283
- - A list of generated PIL Image plots
284
  """
285
- # 1. Read raw text from file or string
286
- if file_obj is None:
287
- return "Please upload a FASTA file", []
288
 
289
- try:
290
- if isinstance(file_obj, str):
291
- text = file_obj
292
- else:
293
- text = file_obj.decode('utf-8', errors='replace')
294
- except Exception as e:
295
- return f"Error reading file: {str(e)}", []
296
 
297
- # 2. Parse the FASTA
298
- sequences = parse_fasta(text)
299
- if not sequences:
300
- return "No valid FASTA sequences found!", []
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
- # 3. Load model & scaler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  try:
304
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
305
- model = VirusClassifier(input_shape=(4 ** k_size)).to(device)
306
  state_dict = torch.load('model.pt', map_location=device)
307
  model.load_state_dict(state_dict)
308
- model.eval()
309
-
310
  scaler = joblib.load('scaler.pkl')
311
  except Exception as e:
312
- return f"Error loading model/scaler: {str(e)}", []
313
-
314
- # 4. Prepare k-mer dictionary for reference
315
- all_kmers = [''.join(p) for p in product("ACGT", repeat=k_size)]
316
- kmer_dict = {km: i for i, km in enumerate(all_kmers)}
317
-
318
- # 5. Iterate over sequences and build output
319
- final_text_report = []
320
- plots = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
- for idx, (header, seq) in enumerate(sequences, start=1):
323
- seq_stats = compute_sequence_stats(seq)
324
-
325
- # Convert sequence -> raw freq -> scaled freq
326
- raw_kmer_freq = sequence_to_kmer_vector(seq, k=k_size)
327
- scaled_kmer_freq = scaler.transform(raw_kmer_freq.reshape(1, -1))
328
- X_tensor = torch.FloatTensor(scaled_kmer_freq).to(device)
329
-
330
- # Predict
331
- with torch.no_grad():
332
- output = model(X_tensor)
333
- probs = torch.softmax(output, dim=1)
334
-
335
- # Determine class
336
- pred_class = torch.argmax(probs, dim=1).item()
337
- pred_label = 'human' if pred_class == 1 else 'non-human'
338
- human_prob = float(probs[0][1])
339
- non_human_prob = float(probs[0][0])
340
- confidence = float(torch.max(probs[0]).item())
341
-
342
- # Compute gradient-based importance
343
- importance, target_prob = model.get_gradient_importance(X_tensor, class_index=1)
344
- importance = importance[0].cpu().numpy() # shape: (num_features,)
345
-
346
- # Identify top-k features (by absolute gradient)
347
- abs_importance = np.abs(importance)
348
- sorted_indices = np.argsort(abs_importance)[::-1]
349
- top_indices = sorted_indices[:top_k]
350
-
351
- # Build a list of top k-mers
352
- top_kmers_info = []
353
- for i in top_indices:
354
- kmer_name = all_kmers[i]
355
- imp_val = float(importance[i])
356
- direction = 'human' if imp_val > 0 else 'non-human'
357
- freq_perc = float(raw_kmer_freq[i] * 100.0) # in percent
358
- sigma = float(scaled_kmer_freq[0][i]) # This is the scaled value (stdev from mean if the scaler is StandardScaler)
359
-
360
- top_kmers_info.append({
361
- 'kmer': kmer_name,
362
- 'impact': abs(imp_val),
363
- 'direction': direction,
364
- 'occurrence': freq_perc,
365
- 'sigma': sigma
366
- })
367
-
368
- # Text summary for this sequence
369
- seq_report = []
370
- seq_report.append(f"=== Sequence {idx} ===")
371
- seq_report.append(f"Header: {header}")
372
- seq_report.append(f"Length: {seq_stats['length']}")
373
- seq_report.append(f"GC Content: {seq_stats['gc_content']:.2f}%")
374
- seq_report.append(f"A: {seq_stats['counts']['A']}, C: {seq_stats['counts']['C']}, G: {seq_stats['counts']['G']}, T: {seq_stats['counts']['T']}")
375
- seq_report.append(f"Prediction: {pred_label} (Confidence: {confidence:.4f})")
376
- seq_report.append(f" Human Probability: {human_prob:.4f}")
377
- seq_report.append(f" Non-human Probability: {non_human_prob:.4f}")
378
- seq_report.append(f"\nTop-{top_k} Influential k-mers (by gradient magnitude):")
379
- for tkm in top_kmers_info:
380
- seq_report.append(
381
- f" {tkm['kmer']}: pushes towards {tkm['direction']} "
382
- f"(impact={tkm['impact']:.4f}), occurrence={tkm['occurrence']:.2f}%, "
383
- f"sigma={tkm['sigma']:.2f}"
384
- )
385
-
386
- final_text_report.append("\n".join(seq_report))
387
-
388
- # 6. Generate Plots (for each sequence)
389
- if advanced_analysis:
390
- # 6A. SHAP-like bar chart
391
- fig_shap = plot_shap_like_bars(
392
- kmers=all_kmers,
393
- importance_values=importance,
394
- top_k=top_k
395
- )
396
- buf_shap = io.BytesIO()
397
- fig_shap.savefig(buf_shap, format='png', bbox_inches='tight', dpi=150)
398
- buf_shap.seek(0)
399
- plots.append(Image.open(buf_shap))
400
- plt.close(fig_shap)
401
-
402
- # 6B. k-mer distribution histogram
403
- fig_kmer_dist = plot_kmer_distribution(raw_kmer_freq, all_kmers)
404
- buf_dist = io.BytesIO()
405
- fig_kmer_dist.savefig(buf_dist, format='png', bbox_inches='tight', dpi=150)
406
- buf_dist.seek(0)
407
- plots.append(Image.open(buf_dist))
408
- plt.close(fig_kmer_dist)
409
-
410
- # 6C. Original step visualization for top k k-mers
411
- # Sort by actual 'impact' to preserve that step logic
412
- # (largest absolute impact first)
413
- top_kmers_info_step = sorted(top_kmers_info, key=lambda x: x['impact'], reverse=True)
414
- fig_step = create_step_visualization(top_kmers_info_step, human_prob)
415
- buf_step = io.BytesIO()
416
- fig_step.savefig(buf_step, format='png', bbox_inches='tight', dpi=150)
417
- buf_step.seek(0)
418
- plots.append(Image.open(buf_step))
419
- plt.close(fig_step)
420
-
421
- # 6D. Frequency vs. sigma bar chart
422
- fig_freq_sigma = plot_kmer_freq_and_sigma(top_kmers_info_step)
423
- buf_freq_sigma = io.BytesIO()
424
- fig_freq_sigma.savefig(buf_freq_sigma, format='png', bbox_inches='tight', dpi=150)
425
- buf_freq_sigma.seek(0)
426
- plots.append(Image.open(buf_freq_sigma))
427
- plt.close(fig_freq_sigma)
428
 
429
- # Combine all text results
430
- combined_text = "\n\n".join(final_text_report)
431
- return combined_text, plots
432
-
433
- # --------------- Gradio Interface ---------------
434
-
435
- def run_prediction(
436
- file_obj,
437
- k_size,
438
- top_k,
439
- advanced_analysis
440
- ):
441
- """
442
- Wrapper for Gradio to handle the outputs in (text, List[Image]) form.
443
- """
444
- text_output, pil_images = predict_fasta(
445
- file_obj=file_obj,
446
- k_size=k_size,
447
- top_k=top_k,
448
- advanced_analysis=advanced_analysis
449
- )
450
-
451
-
452
- return text_output, pil_images
453
 
454
-
455
- with gr.Blocks() as demo:
456
- gr.Markdown("# Virus Host Classifier (Improved!)")
457
- gr.Markdown(
458
- "Upload a FASTA file and configure k-mer size, number of top features, "
459
- "and whether to run advanced analysis (plots of SHAP-like bars & k-mer distribution)."
460
- )
461
-
462
- with gr.Row():
463
- with gr.Column():
464
- fasta_file = gr.File(label="Upload FASTA file", type="binary")
465
- kmer_slider = gr.Slider(minimum=2, maximum=6, value=4, step=1, label="K-mer Size")
466
- topk_slider = gr.Slider(minimum=5, maximum=20, value=10, step=1, label="Top-k Features")
467
- advanced_check = gr.Checkbox(value=False, label="Advanced Analysis")
468
- predict_button = gr.Button("Predict")
469
-
470
- with gr.Column():
471
- results_text = gr.Textbox(
472
- label="Results",
473
- lines=20,
474
- placeholder="Prediction results will appear here..."
475
- )
476
-
477
- # We can display multiple images in a Gallery or as separate outputs.
478
- plots_gallery = gr.Gallery(label="Analysis Plots", columns=2)
479
-
480
- predict_button.click(
481
- fn=run_prediction,
482
- inputs=[fasta_file, kmer_slider, topk_slider, advanced_check],
483
- outputs=[results_text, plots_gallery]
484
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  if __name__ == "__main__":
487
- demo.launch(share=True)
488
-
 
2
  import torch
3
  import joblib
4
  import numpy as np
5
+ from itertools import product
6
  import torch.nn as nn
7
  import matplotlib.pyplot as plt
8
  import io
9
  from PIL import Image
 
10
 
11
+ ##############################################################################
12
+ # MODEL DEFINITION
13
+ ##############################################################################
14
 
15
  class VirusClassifier(nn.Module):
16
  def __init__(self, input_shape: int):
 
31
 
32
  def forward(self, x):
33
  return self.network(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ ##############################################################################
36
+ # UTILITIES
37
+ ##############################################################################
38
 
39
+ def parse_fasta(text):
40
  """
41
+ Parses FASTA formatted text into a list of (header, sequence).
42
  """
43
  sequences = []
44
  current_header = None
45
  current_sequence = []
46
 
47
+ for line in text.strip().split('\n'):
48
  line = line.strip()
49
  if not line:
50
  continue
 
61
 
62
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
63
  """
64
+ Convert a sequence to a k-mer frequency vector of size len(ACGT^k).
 
65
  """
 
66
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
67
  kmer_dict = {km: i for i, km in enumerate(kmers)}
68
  vec = np.zeros(len(kmers), dtype=np.float32)
 
78
 
79
  return vec
80
 
81
+ def ablation_importance(model, x_tensor):
 
 
 
 
 
82
  """
83
+ Calculates a simple ablation-based importance measure for each feature:
84
+ 1. Compute baseline human probability p_base.
85
+ 2. For each feature i, set x[i] = 0, re-run inference, compute new p, and
86
+ measure delta = p_base - p.
87
+ 3. Return array of deltas (positive means that removing that feature
88
+ *decreases* the probability => that feature was pushing it higher).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  """
90
+ model.eval()
91
+ with torch.no_grad():
92
+ # Baseline probability
93
+ output = model(x_tensor)
94
+ probs = torch.softmax(output, dim=1)
95
+ p_base = probs[0, 1].item()
96
+
97
+ # Store the delta importances
98
+ importances = np.zeros(x_tensor.shape[1], dtype=np.float32)
99
+
100
+ # For efficiency, we do ablation one feature at a time
101
+ for i in range(x_tensor.shape[1]):
102
+ x_copy = x_tensor.clone()
103
+ x_copy[0, i] = 0.0 # Ablate this feature
104
+ with torch.no_grad():
105
+ output_ablation = model(x_copy)
106
+ probs_ablation = torch.softmax(output_ablation, dim=1)
107
+ p_ablation = probs_ablation[0, 1].item()
108
+ # Delta
109
+ importances[i] = p_base - p_ablation
110
 
111
+ return importances, p_base
 
 
 
 
 
 
 
 
 
 
112
 
113
+ ##############################################################################
114
+ # PLOTTING
115
+ ##############################################################################
 
 
 
 
 
 
 
 
 
 
116
 
117
+ def create_step_and_frequency_plot(important_kmers, human_prob, title):
118
  """
119
+ Creates a combined step plot (showing how each k-mer modifies the probability)
120
+ and a frequency vs. sigma bar chart.
121
  """
122
+ fig = plt.figure(figsize=(15, 10))
123
+
124
+ # Create grid for subplots
125
+ gs = plt.GridSpec(2, 1, height_ratios=[1.5, 1], hspace=0.3)
126
 
127
+ # 1. Probability Step Plot
128
+ ax1 = plt.subplot(gs[0])
129
  current_prob = 0.5
130
  steps = [('Start', current_prob, 0)]
131
 
132
+ for kmer_info in important_kmers:
133
+ change = kmer_info['impact'] # positive => pushes up, negative => pushes down
134
  current_prob += change
135
+ steps.append((kmer_info['kmer'], current_prob, change))
 
 
 
 
 
 
136
 
137
+ x = range(len(steps))
138
+ y = [step[1] for step in steps]
139
+
140
+ # Plot steps
141
+ ax1.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
142
+ ax1.plot(x, y, 'b.', markersize=10)
143
+
144
+ # Add reference line
145
+ ax1.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
146
+
147
+ # Customize plot
148
+ ax1.grid(True, linestyle='--', alpha=0.7)
149
+ ax1.set_ylim(0, 1)
150
+ ax1.set_ylabel('Human Probability')
151
+ ax1.set_title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
152
+
153
+ # Add labels for each point
154
  for i, (kmer, prob, change) in enumerate(steps):
155
+ # Add k-mer label
156
+ ax1.annotate(kmer,
157
+ (i, prob),
158
+ xytext=(0, 10 if i % 2 == 0 else -20),
159
+ textcoords='offset points',
160
+ ha='center',
161
+ rotation=45)
162
 
163
+ # Add change value
164
  if i > 0:
165
  change_text = f'{change:+.3f}'
166
  color = 'green' if change > 0 else 'red'
167
+ ax1.annotate(change_text,
168
+ (i, prob),
169
+ xytext=(0, -20 if i % 2 == 0 else 10),
170
+ textcoords='offset points',
171
+ ha='center',
172
+ color=color)
173
+
174
+ ax1.legend()
175
+
176
+ # 2. K-mer Frequency and Sigma Plot
177
+ ax2 = plt.subplot(gs[1])
 
 
 
 
 
 
 
178
 
179
  # Prepare data
180
  kmers = [k['kmer'] for k in important_kmers]
181
  frequencies = [k['occurrence'] for k in important_kmers]
182
  sigmas = [k['sigma'] for k in important_kmers]
 
183
 
184
+ # Color the bars: if impact>0 => green, else red
185
+ colors = ['g' if k['impact'] > 0 else 'r' for k in important_kmers]
186
+
187
+ # Create bar plot for frequencies
188
  x = np.arange(len(kmers))
189
  width = 0.35
190
 
191
+ ax2.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
 
192
 
193
+ # Twin axis for sigma
194
+ ax2_twin = ax2.twinx()
195
+ # To highlight positive or negative sigma, pick color accordingly
196
+ sigma_colors = []
197
+ for s, c in zip(sigmas, colors):
198
+ if s >= 0:
199
+ sigma_colors.append('blue') # above average
200
+ else:
201
+ sigma_colors.append('gray') # below average
202
+
203
+ ax2_twin.bar(x + width/2, sigmas, width, label='σ from Mean', color=sigma_colors, alpha=0.3)
204
 
205
+ # Customize plot
206
+ ax2.set_xticks(x)
207
+ ax2.set_xticklabels(kmers, rotation=45)
208
+ ax2.set_ylabel('Frequency (%)')
209
+ ax2_twin.set_ylabel('Standard Deviations (σ) from Mean')
210
+ ax2.set_title('K-mer Frequencies and Statistical Significance')
211
+
212
+ # Add legends
213
+ lines1, labels1 = ax2.get_legend_handles_labels()
214
+ lines2, labels2 = ax2_twin.get_legend_handles_labels()
215
+ ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
216
 
217
  plt.tight_layout()
218
  return fig
219
 
220
+ def create_shap_like_bar_plot(impact_values, kmer_list, top_k):
 
 
 
 
 
 
 
221
  """
222
+ Creates a horizontal bar plot showing the top_k features by absolute impact.
223
+ impact_values: array of float (length=256).
224
+ kmer_list: list of all k=4 kmers in order.
225
+ top_k: integer, how many top features to display.
226
  """
227
+ # Sort by absolute impact
228
+ indices_sorted = np.argsort(np.abs(impact_values))[::-1]
229
+ top_indices = indices_sorted[:top_k]
230
 
231
+ top_impacts = impact_values[top_indices]
232
+ top_kmers = [kmer_list[i] for i in top_indices]
 
 
 
 
 
233
 
234
+ fig = plt.figure(figsize=(8, 6))
235
+ plt.barh(range(len(top_impacts)), top_impacts, color=['green' if i > 0 else 'red' for i in top_impacts])
236
+ plt.yticks(range(len(top_impacts)), top_kmers)
237
+ plt.xlabel("Impact on Human Probability (Ablation)")
238
+ plt.title(f"Top {top_k} K-mers by Absolute Impact")
239
+ plt.gca().invert_yaxis() # Highest at top
240
+ plt.tight_layout()
241
+ return fig
242
+
243
+ def create_global_bar_plot(impact_values, kmer_list):
244
+ """
245
+ Creates a bar plot for ALL features (256) to see the global distribution.
246
+ """
247
+ fig = plt.figure(figsize=(12, 6))
248
+ indices_sorted = np.argsort(np.abs(impact_values))[::-1]
249
+ sorted_impacts = impact_values[indices_sorted]
250
+ sorted_kmers = [kmer_list[i] for i in indices_sorted]
251
 
252
+ plt.bar(range(len(sorted_impacts)), sorted_impacts,
253
+ color=['green' if i > 0 else 'red' for i in sorted_impacts])
254
+ plt.title("Global Impact of All 256 K-mers (Ablation Method)")
255
+ plt.xlabel("K-mer (sorted by |impact|)")
256
+ plt.ylabel("Impact on Human Probability")
257
+ # Optionally, we can skip labeling all 256 on x-axis.
258
+ # But we can show only the top/bottom or none for clarity.
259
+ plt.tight_layout()
260
+ return fig
261
+
262
+ ##############################################################################
263
+ # MAIN PREDICTION FUNCTION
264
+ ##############################################################################
265
+
266
+ def predict(file_obj, top_kmers=10, advanced_plots=False, fasta_text=""):
267
+ """
268
+ Main prediction function called by Gradio.
269
+ - file_obj: optional uploaded FASTA file
270
+ - top_kmers: number of top k-mers to display in the main SHAP-like plot
271
+ - advanced_plots: bool, whether to return global bar plots
272
+ - fasta_text: optional direct-pasted FASTA text
273
+ """
274
+ # Priority: If user pasted text, use that; otherwise use uploaded file.
275
+ if fasta_text.strip():
276
+ text = fasta_text.strip()
277
+ else:
278
+ if file_obj is None:
279
+ return "No FASTA input provided", None, None, None
280
+ try:
281
+ if isinstance(file_obj, str):
282
+ text = file_obj
283
+ else:
284
+ text = file_obj.decode('utf-8')
285
+ except Exception as e:
286
+ return f"Error reading file: {str(e)}", None, None, None
287
+
288
+ # Parse FASTA
289
+ sequences = parse_fasta(text)
290
+ if len(sequences) == 0:
291
+ return "No valid FASTA sequences found", None, None, None
292
+ header, seq = sequences[0]
293
+
294
+ # Load model + scaler
295
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
296
+ model = VirusClassifier(256).to(device)
297
  try:
 
 
298
  state_dict = torch.load('model.pt', map_location=device)
299
  model.load_state_dict(state_dict)
 
 
300
  scaler = joblib.load('scaler.pkl')
301
  except Exception as e:
302
+ return f"Error loading model or scaler: {str(e)}", None, None, None
303
+
304
+ # Prepare the vector
305
+ raw_freq_vector = sequence_to_kmer_vector(seq, k=4)
306
+ scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
307
+ X_tensor = torch.FloatTensor(scaled_vector).to(device)
308
+
309
+ # Compute ablation-based importances
310
+ importances, p_base = ablation_importance(model, X_tensor)
311
+ # p_base is baseline human probability
312
+
313
+ # We also want frequency in % and sigma from mean
314
+ # If your scaler is e.g. StandardScaler, then "scaled_vector[0][i]" is
315
+ # how many std devs from the mean that feature is.
316
+ # We'll gather info in a list of dicts for each k-mer.
317
+ kmers_4 = [''.join(p) for p in product("ACGT", repeat=4)]
318
+ kmer_dict = {km: i for i, km in enumerate(kmers_4)}
319
+
320
+ # We'll sort by absolute impact to get the top 10 by default.
321
+ abs_sorted_idx = np.argsort(np.abs(importances))[::-1]
322
+ # But for the final step/frequency plot we only show top_kmers
323
+ top_indices = abs_sorted_idx[:top_kmers]
324
+
325
+ # Build a list of the top k-mers
326
+ important_kmers = []
327
+ for idx in top_indices:
328
+ # "impact" is how much that feature changed the probability
329
+ impact = importances[idx]
330
+ # raw frequency => raw_freq_vector[idx] * 100 for %
331
+ freq_pct = float(raw_freq_vector[idx] * 100.0)
332
+ # sigma => scaled_vector[0][idx]
333
+ sigma_val = float(scaled_vector[0][idx])
334
+
335
+ important_kmers.append({
336
+ 'kmer': kmers_4[idx],
337
+ 'impact': impact,
338
+ 'occurrence': freq_pct,
339
+ 'sigma': sigma_val
340
+ })
341
 
342
+ # For text output
343
+ # We decide final class based on model's direct output
344
+ with torch.no_grad():
345
+ output = model(X_tensor)
346
+ probs = torch.softmax(output, dim=1)
347
+ pred_class = 1 if probs[0,1] > probs[0,0] else 0
348
+ pred_label = 'human' if pred_class == 1 else 'non-human'
349
+ human_prob = probs[0,1].item()
350
+ nonhuman_prob = probs[0,0].item()
351
+ confidence = max(human_prob, nonhuman_prob)
352
+
353
+ results_text = (f"Sequence: {header}\n"
354
+ f"Prediction: {pred_label}\n"
355
+ f"Confidence: {confidence:.4f}\n"
356
+ f"Human probability: {human_prob:.4f}\n"
357
+ f"Non-human probability: {nonhuman_prob:.4f}\n"
358
+ f"Most influential k-mers (by ablation impact):\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
+ for kmer_info in important_kmers:
361
+ # sign => if impact>0 => removing it lowers p(human), so it was pushing p(human) up
362
+ direction = "UP (toward human)" if kmer_info['impact'] > 0 else "DOWN (toward non-human)"
363
+ results_text += (
364
+ f" {kmer_info['kmer']}: {direction}, "
365
+ f"Impact={kmer_info['impact']:.4f}, "
366
+ f"Occ={kmer_info['occurrence']:.2f}% of seq, "
367
+ f"{abs(kmer_info['sigma']):.2f}σ "
368
+ + ("above" if kmer_info['sigma']>0 else "below")
369
+ + " mean\n"
370
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
+ # PLOT 1: A SHAP-like bar plot for the top K features
373
+ shap_fig = create_shap_like_bar_plot(importances, kmers_4, top_kmers)
374
+
375
+ # PLOT 2: Step + frequency plot for the top K features
376
+ step_fig = create_step_and_frequency_plot(important_kmers, human_prob, header)
377
+
378
+ # PLOT 3 (optional advanced): global bar plot of all 256 features
379
+ global_fig = None
380
+ if advanced_plots:
381
+ global_fig = create_global_bar_plot(importances, kmers_4)
382
+
383
+ # Convert figures to PIL Images
384
+ def fig_to_image(fig):
385
+ buf = io.BytesIO()
386
+ fig.savefig(buf, format='png', bbox_inches='tight', dpi=200)
387
+ buf.seek(0)
388
+ im = Image.open(buf)
389
+ plt.close(fig)
390
+ return im
391
+
392
+ shap_img = fig_to_image(shap_fig)
393
+ step_img = fig_to_image(step_fig)
394
+ if global_fig is not None:
395
+ global_img = fig_to_image(global_fig)
396
+ else:
397
+ global_img = None
398
+
399
+ return results_text, shap_img, step_img, global_img
400
+
401
+ ##############################################################################
402
+ # GRADIO INTERFACE
403
+ ##############################################################################
404
+
405
+ title_text = "Virus Host Classifier"
406
+ description_text = """
407
+ Upload or paste a FASTA sequence to predict if it's likely **human** or **non-human** origin.
408
+ - **k=4** k-mers are used as features.
409
+ - We display ablation-based feature importance for interpretability.
410
+ - Advanced plots can be toggled to see the global distribution of all 256 k-mer impacts.
411
+ """
412
+
413
+ iface = gr.Interface(
414
+ fn=predict,
415
+ inputs=[
416
+ gr.File(label="Upload FASTA file", type="binary", optional=True),
417
+ gr.Slider(label="Number of top k-mers to show", minimum=1, maximum=50, value=10, step=1),
418
+ gr.Checkbox(label="Show advanced (global) plots?", value=False),
419
+ gr.Textbox(label="Or paste FASTA text here", lines=5, placeholder=">header\nACGTACGT...")
420
+ ],
421
+ outputs=[
422
+ gr.Textbox(label="Results", lines=10),
423
+ gr.Image(label="SHAP-like Top-k K-mer Bar Plot"),
424
+ gr.Image(label="Step & Frequency Plot (Top-k)"),
425
+ gr.Image(label="Global 256-K-mer Plot (advanced)", optional=True)
426
+ ],
427
+ title=title_text,
428
+ description=description_text
429
+ )
430
 
431
  if __name__ == "__main__":
432
+ iface.launch(share=True)