hiyata commited on
Commit
d192dd4
·
verified ·
1 Parent(s): 0e7de0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +334 -346
app.py CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
7
  import matplotlib.pyplot as plt
8
  import io
9
  from PIL import Image
 
10
 
11
  ###############################################################################
12
  # Model Definition
@@ -30,30 +31,16 @@ class VirusClassifier(nn.Module):
30
 
31
  def forward(self, x):
32
  return self.network(x)
33
-
34
- def get_feature_importance(self, x):
35
- """
36
- Calculate gradient-based feature importance, specifically for the
37
- 'human' class (index=1) by computing gradient of that probability wrt x.
38
- """
39
- x.requires_grad_(True)
40
- output = self.network(x)
41
- probs = torch.softmax(output, dim=1)
42
-
43
- # Probability of 'human' class (index=1)
44
- human_prob = probs[..., 1]
45
- if x.grad is not None:
46
- x.grad.zero_()
47
- human_prob.backward()
48
- importance = x.grad # shape: (batch_size, n_features)
49
-
50
- return importance, float(human_prob)
51
 
52
  ###############################################################################
53
  # Utility Functions
54
  ###############################################################################
55
  def parse_fasta(text):
56
- """Parses text input in FASTA format into a list of (header, sequence)."""
 
 
 
57
  sequences = []
58
  current_header = None
59
  current_sequence = []
@@ -74,7 +61,10 @@ def parse_fasta(text):
74
  return sequences
75
 
76
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
77
- """Convert a single nucleotide sequence to a k-mer frequency vector."""
 
 
 
78
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
79
  kmer_dict = {km: i for i, km in enumerate(kmers)}
80
  vec = np.zeros(len(kmers), dtype=np.float32)
@@ -92,377 +82,375 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
92
 
93
 
94
  ###############################################################################
95
- # Visualization
96
  ###############################################################################
97
- def create_shap_waterfall_plot(important_kmers, all_kmer_importance, human_prob, title):
98
- """
99
- Create a SHAP-like waterfall plot:
100
- - Start at baseline = 0.5
101
- - Add a bar for "Other" which is the combined effect of all less-important k-mers
102
- - Then apply each of the top k-mers in descending order of absolute importance
103
- - Show final predicted human probability as the endpoint
104
  """
105
-
106
- # 1) Sort 'important_kmers' by absolute impact descending
107
- sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
108
-
109
- # 2) Compute the total effect of "other" k-mers
110
- # We have 256 total features. We selected top 10. Sum the rest.
111
- top_ids = set([km['idx'] for km in sorted_kmers])
112
- other_contributions = []
113
- for i, val in enumerate(all_kmer_importance):
114
- if i not in top_ids:
115
- other_contributions.append(val)
116
- # sum up those "other" contributions
117
- other_sum = np.sum(other_contributions)
118
- # The "impact" for "other" will be the absolute value, direction depends on sign
119
- other_impact = float(abs(other_sum))
120
- other_direction = "human" if other_sum > 0 else "non-human"
121
-
122
- # 3) Build a list of all bars: first "other", then each top k-mer
123
- # Each bar needs: name, raw_contribution_value
124
- # We'll store (label, contribution). The sign indicates direction.
125
- bars = []
126
- bars.append(("Other", other_sum)) # lumps the leftover k-mers
127
-
128
- for km in sorted_kmers:
129
- # We re-inject the sign on the raw gradient
130
- # (We stored only the absolute in "impact," so let's create a signed value)
131
- signed_val = km['impact'] if km['direction'] == 'human' else -km['impact']
132
- bars.append((km['kmer'], signed_val))
133
-
134
- # 4) Waterfall plot data:
135
- # We'll accumulate partial sums from baseline=0.5
136
- baseline = 0.5
137
- running_val = baseline
138
- x_labels = []
139
- y_vals = []
140
- bar_colors = []
141
-
142
- # We'll use green for positive contributions (pushing toward 'human'),
143
- # red for negative contributions (pushing away from 'human')
144
- for (label, contrib) in bars:
145
- x_labels.append(label)
146
- # new value after adding this contribution
147
- new_val = running_val + (0.05 * contrib)
148
- # ^ scaled by 0.05 for better display. Adjust as desired.
149
-
150
- y_vals.append((running_val, new_val))
151
- running_val = new_val
152
- if contrib >= 0:
153
- bar_colors.append("green")
154
- else:
155
- bar_colors.append("red")
156
-
157
- final_prob = running_val
158
- # Final point is the model's predicted probability (not always exact, but this is a shap-like idea).
159
- # If we want to forcibly ensure final_prob = human_prob, we could do:
160
- # correction = human_prob - running_val
161
- # running_val += correction
162
- # but for now let's keep the "waterfall" purely additive from the gradient.
163
-
164
- # Let's plot:
165
- fig, ax = plt.subplots(figsize=(10, 6))
166
 
167
- # We'll create the bars manually
168
- x_positions = np.arange(len(x_labels))
169
- last_end = baseline
170
-
171
- for i, ((start_val, end_val), color) in enumerate(zip(y_vals, bar_colors)):
172
- # The bar's height is the difference
173
- height = end_val - start_val
174
- ax.bar(i, height, bottom=start_val, color=color, edgecolor='black', alpha=0.7)
175
- ax.text(i, (start_val + end_val) / 2, f"{height:+.3f}", ha='center', va='center', color='white', fontsize=8)
176
-
177
- ax.axhline(y=baseline, color='black', linestyle='--', linewidth=1)
178
- ax.set_xticks(x_positions)
179
- ax.set_xticklabels(x_labels, rotation=45, ha='right')
180
- ax.set_ylim(0, 1)
181
- ax.set_ylabel("Running Probability (Human)")
182
- ax.set_title(f"SHAP-like Waterfall — Final Probability: {final_prob:.3f} (Model Probability: {human_prob:.3f})")
183
-
184
- plt.tight_layout()
185
- return fig
 
 
 
 
 
 
 
186
 
187
- def create_frequency_sigma_plot(important_kmers, title):
188
- """Creates a bar plot of the top k-mers (by importance) showing frequency (%) and σ from mean."""
189
- # Sort by absolute impact
190
- sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
191
- kmers = [k["kmer"] for k in sorted_kmers]
192
- frequencies = [k["occurrence"] for k in sorted_kmers] # in %
193
- sigmas = [k["sigma"] for k in sorted_kmers]
194
- directions = [k["direction"] for k in sorted_kmers]
195
-
196
  x = np.arange(len(kmers))
197
  width = 0.4
198
 
199
- fig, ax_bar = plt.subplots(figsize=(10, 6))
200
-
201
- # Bar for frequency
202
- bars_freq = ax_bar.bar(
203
- x - width/2, frequencies, width, alpha=0.7,
204
- color=["green" if d=="human" else "red" for d in directions],
205
- label="Frequency (%)"
206
- )
207
- ax_bar.set_ylabel("Frequency (%)")
208
- ax_bar.set_ylim(0, max(frequencies) * 1.2 if frequencies else 1)
209
-
210
- # Twin axis for σ
211
- ax_bar_twin = ax_bar.twinx()
212
- bars_sigma = ax_bar_twin.bar(
213
- x + width/2, sigmas, width, alpha=0.5, color="gray", label="σ from Mean"
214
- )
215
- ax_bar_twin.set_ylabel("Standard Deviations (σ)")
216
-
217
- ax_bar.set_title(f"Frequency & σ from Mean for Top k-mers — {title}")
218
- ax_bar.set_xticks(x)
219
- ax_bar.set_xticklabels(kmers, rotation=45, ha='right')
220
 
221
- # Combined legend
222
- lines1, labels1 = ax_bar.get_legend_handles_labels()
223
- lines2, labels2 = ax_bar_twin.get_legend_handles_labels()
224
- ax_bar.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
225
 
226
- plt.tight_layout()
227
- return fig
228
-
229
- def create_importance_bar_plot(important_kmers, title):
230
- """
231
- Create a simple bar chart showing the absolute gradient magnitude
232
- for the top k-mers, sorted descending.
233
- """
234
- sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
235
- kmers = [k['kmer'] for k in sorted_kmers]
236
- impacts = [k['impact'] for k in sorted_kmers]
237
- directions = [k["direction"] for k in sorted_kmers]
238
-
239
- x = np.arange(len(kmers))
240
-
241
- fig, ax = plt.subplots(figsize=(10, 6))
242
- bar_colors = ["green" if d=="human" else "red" for d in directions]
243
-
244
- ax.bar(x, impacts, color=bar_colors, alpha=0.7)
245
  ax.set_xticks(x)
246
  ax.set_xticklabels(kmers, rotation=45, ha='right')
247
- ax.set_title(f"Absolute Feature Importance (Top k-mers) {title}")
248
- ax.set_ylabel("Gradient Magnitude")
249
- ax.grid(axis="y", alpha=0.3)
 
 
 
250
 
251
  plt.tight_layout()
252
  return fig
253
 
254
 
255
  ###############################################################################
256
- # Prediction Function
257
  ###############################################################################
258
- def predict(file_obj):
259
  """
260
- Main function for Gradio:
261
- 1. Reads the uploaded FASTA file or text.
262
- 2. Loads the model and scaler.
263
- 3. Generates predictions, probabilities, and top k-mers.
264
- 4. Returns multiple outputs:
265
- - A textual summary (Markdown).
266
- - Waterfall plot.
267
- - Frequency & sigma plot.
268
- - Absolute importance bar plot.
269
  """
270
- # 0. Basic file read
271
- if file_obj is None:
272
- return (
273
- "Please upload a FASTA file.",
274
- None,
275
- None,
276
- None
277
- )
278
-
279
- try:
280
- # If user provided raw text, use that
281
- if isinstance(file_obj, str):
282
- text = file_obj
283
- else:
284
- # If user uploaded a file, decode it
285
- text = file_obj.decode('utf-8')
286
- except Exception as e:
287
- return (
288
- f"Error reading file: {str(e)}",
289
- None,
290
- None,
291
- None
292
- )
293
-
294
- # 1. Parse FASTA
295
  sequences = parse_fasta(text)
296
  if len(sequences) == 0:
297
- return (
298
- "No valid FASTA sequences found. Please check your input.",
299
- None,
300
- None,
301
- None
302
- )
303
- # We’ll just classify the first sequence for demonstration
304
- header, seq = sequences[0]
305
 
306
- # 2. Create k-mer vector & load model
307
  k = 4
 
 
 
 
 
 
 
 
 
 
 
 
308
  try:
309
  device = "cuda" if torch.cuda.is_available() else "cpu"
310
-
311
- # Prepare raw freq vector & scale
312
- raw_freq_vector = sequence_to_kmer_vector(seq, k=k)
313
-
314
- # Load model & scaler
315
  model = VirusClassifier(input_shape=4**k).to(device)
316
- state_dict = torch.load('model.pt', map_location=device)
317
  model.load_state_dict(state_dict)
318
- scaler = joblib.load('scaler.pkl')
319
  model.eval()
320
 
321
- scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
322
- X_tensor = torch.FloatTensor(scaled_vector).to(device)
323
-
324
- # 3. Inference
325
- with torch.no_grad():
326
- logits = model(X_tensor)
327
- probs = torch.softmax(logits, dim=1)
328
- human_prob = float(probs[0][1])
329
- non_human_prob = float(probs[0][0])
330
- pred_class = 1 if human_prob >= non_human_prob else 0
331
- pred_label = "human" if pred_class == 1 else "non-human"
332
- confidence = float(max(probs[0]))
333
-
334
- # 4. Feature importance
335
- importance, hum_prob_grad = model.get_feature_importance(X_tensor)
336
- # shape: [1, 256]
337
- kmer_importances = importance[0].cpu().numpy()
338
-
339
- # We’ll store them as a dictionary: index -> (k-mer, importance)
340
- # Build up a dict for k-mer strings
341
- kmers_list = [''.join(p) for p in product("ACGT", repeat=k)]
342
- kmer_dict = {km: i for i, km in enumerate(kmers_list)}
343
-
344
- # 5. Get the top 10 k-mers by absolute importance
345
- abs_importance = np.abs(kmer_importances)
346
- top_k = 10
347
- top_idxs = np.argsort(abs_importance)[-top_k:][::-1] # descending
348
- important_kmers = []
349
- for idx in top_idxs:
350
- # Find the k-mer by index
351
- kmer_str = kmers_list[idx]
352
- # direction
353
- direction = "human" if kmer_importances[idx] > 0 else "non-human"
354
- # frequency in % from raw_freq_vector
355
- freq_percent = float(raw_freq_vector[idx] * 100)
356
- # sigma from scaled vector
357
- sigma_val = float(scaled_vector[0][idx])
358
- important_kmers.append({
359
- 'kmer': kmer_str,
360
- 'idx': idx,
361
- 'impact': float(abs_importance[idx]),
362
- 'direction': direction,
363
- 'occurrence': freq_percent,
364
- 'sigma': sigma_val
365
- })
366
-
367
- # 6. Text Summary
368
- summary_text = (
369
- f"**Sequence Header**: {header}\n\n"
370
- f"**Predicted Label**: {pred_label}\n"
371
- f"**Confidence**: {confidence:.4f}\n\n"
372
- f"**Human Probability**: {human_prob:.4f}\n"
373
- f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
374
- "### Most Influential k-mers:\n"
375
- )
376
- for km in important_kmers:
377
- direction_text = f"(pushes toward {km['direction']})"
378
- freq_text = f"{km['occurrence']:.2f}%"
379
- sigma_text = f"{abs(km['sigma']):.2f}σ " + ("above" if km['sigma']>0 else "below") + " mean"
380
- summary_text += (
381
- f"- **{km['kmer']}**: impact={km['impact']:.4f}, {direction_text}, "
382
- f"occurrence={freq_text}, ({sigma_text})\n"
383
- )
384
-
385
- # 7. Plots
386
- # a) SHAP-like Waterfall Plot
387
- fig_waterfall = create_shap_waterfall_plot(
388
- important_kmers,
389
- kmer_importances,
390
- human_prob,
391
- f"{header}"
392
- )
393
- buf1 = io.BytesIO()
394
- fig_waterfall.savefig(buf1, format='png', bbox_inches='tight', dpi=120)
395
- buf1.seek(0)
396
- waterfall_img = Image.open(buf1)
397
- plt.close(fig_waterfall)
398
-
399
- # b) Frequency & σ Plot (top 10 k-mers)
400
- fig_freq_sigma = create_frequency_sigma_plot(
401
- important_kmers,
402
- f"{header}"
403
- )
404
- buf2 = io.BytesIO()
405
- fig_freq_sigma.savefig(buf2, format='png', bbox_inches='tight', dpi=120)
406
- buf2.seek(0)
407
- freq_sigma_img = Image.open(buf2)
408
- plt.close(fig_freq_sigma)
409
-
410
- # c) Absolute Importance Bar Plot
411
- fig_imp = create_importance_bar_plot(
412
- important_kmers,
413
- f"{header}"
414
- )
415
- buf3 = io.BytesIO()
416
- fig_imp.savefig(buf3, format='png', bbox_inches='tight', dpi=120)
417
- buf3.seek(0)
418
- importance_img = Image.open(buf3)
419
- plt.close(fig_imp)
420
 
421
- return summary_text, waterfall_img, freq_sigma_img, importance_img
422
 
423
- except Exception as e:
424
- return (
425
- f"Error during prediction or visualization: {str(e)}",
426
- None,
427
- None,
428
- None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
 
432
  ###############################################################################
433
  # Gradio Interface
434
  ###############################################################################
435
- with gr.Blocks(title="Advanced Virus Host Classifier") as demo:
 
 
436
  gr.Markdown(
437
  """
438
- # Advanced Virus Host Classifier
439
- **Upload a FASTA file** containing a single nucleotide sequence.
440
- The model will predict whether this sequence is **human** or **non-human**,
441
- provide a confidence score, and highlight the most influential k-mers
442
- (using a SHAP-like waterfall plot) along with two additional plots.
 
443
  """
444
  )
445
-
446
- with gr.Row():
447
- file_in = gr.File(label="Upload FASTA", type="binary")
448
- btn = gr.Button("Run Prediction")
449
 
450
- # We will create multiple tabs for our outputs
 
 
 
 
 
 
 
 
 
 
 
 
451
  with gr.Tabs():
452
- with gr.Tab("Prediction Results"):
453
  md_out = gr.Markdown()
454
- with gr.Tab("SHAP-like Waterfall Plot"):
455
- water_out = gr.Image()
456
- with gr.Tab("Frequency & σ Plot"):
457
- freq_out = gr.Image()
458
- with gr.Tab("Importance Bar Plot"):
459
- imp_out = gr.Image()
460
-
461
- # Link the button
462
- btn.click(
463
- fn=predict,
464
- inputs=[file_in],
465
- outputs=[md_out, water_out, freq_out, imp_out]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  )
467
 
468
  if __name__ == "__main__":
 
7
  import matplotlib.pyplot as plt
8
  import io
9
  from PIL import Image
10
+ import shap # Requires: pip install shap
11
 
12
  ###############################################################################
13
  # Model Definition
 
31
 
32
  def forward(self, x):
33
  return self.network(x)
34
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  ###############################################################################
37
  # Utility Functions
38
  ###############################################################################
39
  def parse_fasta(text):
40
+ """
41
+ Parses text input in FASTA format into a list of (header, sequence).
42
+ Handles multiple sequences if present.
43
+ """
44
  sequences = []
45
  current_header = None
46
  current_sequence = []
 
61
  return sequences
62
 
63
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
64
+ """
65
+ Convert a single nucleotide sequence to a k-mer frequency vector
66
+ of length 4^k (e.g., for k=4, length=256).
67
+ """
68
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
69
  kmer_dict = {km: i for i, km in enumerate(kmers)}
70
  vec = np.zeros(len(kmers), dtype=np.float32)
 
82
 
83
 
84
  ###############################################################################
85
+ # Visualization Helpers
86
  ###############################################################################
87
+ def create_freq_sigma_plot(
88
+ single_shap_values: np.ndarray,
89
+ raw_freq_vector: np.ndarray,
90
+ scaled_vector: np.ndarray,
91
+ kmer_list,
92
+ title: str
93
+ ):
94
  """
95
+ Creates a bar plot showing top-10 k-mers (by absolute SHAP value),
96
+ with frequency (%) and sigma from mean on a twin-axis.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ single_shap_values: shape=(256,) shap values for this sample
99
+ raw_freq_vector: shape=(256,) original frequencies for this sample
100
+ scaled_vector: shape=(256,) scaled (Z-score) values for this sample
101
+ kmer_list: list of all k-mers (length=256)
102
+ """
103
+ abs_vals = np.abs(single_shap_values)
104
+ top_k = 10
105
+ top_indices = np.argsort(abs_vals)[-top_k:][::-1] # top 10 by absolute shap
106
+ top_data = []
107
+ for idx in top_indices:
108
+ top_data.append({
109
+ "kmer": kmer_list[idx],
110
+ "shap": single_shap_values[idx],
111
+ "abs_shap": abs_vals[idx],
112
+ "frequency": raw_freq_vector[idx] * 100.0, # percentage
113
+ "sigma": scaled_vector[idx]
114
+ })
115
+
116
+ # Sort top_data by abs_shap descending
117
+ top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
118
+
119
+ kmers = [d["kmer"] for d in top_data]
120
+ freqs = [d["frequency"] for d in top_data]
121
+ sigmas = [d["sigma"] for d in top_data]
122
+ # color by sign (positive=green, negative=red)
123
+ colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
124
 
 
 
 
 
 
 
 
 
 
125
  x = np.arange(len(kmers))
126
  width = 0.4
127
 
128
+ fig, ax = plt.subplots(figsize=(8, 5))
129
+ # Frequency
130
+ ax.bar(x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)")
131
+ ax.set_ylabel("Frequency (%)", color='black')
132
+ ax.set_ylim(0, max(freqs)*1.2 if len(freqs) else 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # Twin axis for sigma
135
+ ax2 = ax.twinx()
136
+ ax2.bar(x + width/2, sigmas, width, color="gray", alpha=0.5, label="σ from Mean")
137
+ ax2.set_ylabel("Standard Deviations (σ)", color='black')
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  ax.set_xticks(x)
140
  ax.set_xticklabels(kmers, rotation=45, ha='right')
141
+ ax.set_title(f"Top-10 K-mers (Frequency & σ)\n{title}")
142
+
143
+ # Combine legends
144
+ lines1, labels1 = ax.get_legend_handles_labels()
145
+ lines2, labels2 = ax2.get_legend_handles_labels()
146
+ ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
147
 
148
  plt.tight_layout()
149
  return fig
150
 
151
 
152
  ###############################################################################
153
+ # Main Inference & SHAP Logic
154
  ###############################################################################
155
+ def run_classification_and_shap(file_obj):
156
  """
157
+ Reads one or more FASTA sequences from file_obj or text.
158
+ Returns:
159
+ - Table of results (list of dicts) for each sequence
160
+ - shap_values object (SHAP values for the entire batch)
161
+ - array/batch of scaled vectors (for use in the waterfall selection)
162
+ - list of k-mers (for indexing)
163
+ - possibly the model or other context
 
 
164
  """
165
+ # 1. Basic read
166
+ if isinstance(file_obj, str):
167
+ text = file_obj
168
+ else:
169
+ try:
170
+ text = file_obj.decode("utf-8")
171
+ except Exception as e:
172
+ return None, None, f"Error reading file: {str(e)}"
173
+
174
+ # 2. Parse FASTA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  sequences = parse_fasta(text)
176
  if len(sequences) == 0:
177
+ return None, None, "No valid FASTA sequences found!"
 
 
 
 
 
 
 
178
 
179
+ # 3. Convert each sequence to k-mer vector
180
  k = 4
181
+ all_raw_vectors = []
182
+ headers = []
183
+ seqs = []
184
+ for (hdr, seq) in sequences:
185
+ raw_vec = sequence_to_kmer_vector(seq, k=k)
186
+ all_raw_vectors.append(raw_vec)
187
+ headers.append(hdr)
188
+ seqs.append(seq)
189
+
190
+ all_raw_vectors = np.stack(all_raw_vectors, axis=0) # shape=(num_seqs, 256)
191
+
192
+ # 4. Load model & scaler
193
  try:
194
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
195
  model = VirusClassifier(input_shape=4**k).to(device)
196
+ state_dict = torch.load("model.pt", map_location=device)
197
  model.load_state_dict(state_dict)
 
198
  model.eval()
199
 
200
+ scaler = joblib.load("scaler.pkl")
201
+ except Exception as e:
202
+ return None, None, f"Error loading model or scaler: {str(e)}"
203
+
204
+ # 5. Scale data
205
+ scaled_data = scaler.transform(all_raw_vectors) # shape=(num_seqs, 256)
206
+
207
+ # 6. Predictions
208
+ X_tensor = torch.FloatTensor(scaled_data).to(device)
209
+ with torch.no_grad():
210
+ logits = model(X_tensor)
211
+ probs = torch.softmax(logits, dim=1).cpu().numpy()
212
+ preds = np.argmax(probs, axis=1) # 0 or 1
213
+
214
+ results_table = []
215
+ for i, (hdr, seq) in enumerate(zip(headers, seqs)):
216
+ results_table.append({
217
+ "header": hdr,
218
+ "sequence": seq[:50] + ("..." if len(seq)>50 else ""), # truncated
219
+ "pred_label": "human" if preds[i] == 1 else "non-human",
220
+ "human_prob": float(probs[i][1]),
221
+ "non_human_prob": float(probs[i][0]),
222
+ "confidence": float(max(probs[i]))
223
+ })
224
+
225
+ # 7. SHAP Explainer
226
+ # We'll pick a background subset if there are many sequences
227
+ # (For performance, we might limit to e.g. 50 samples max)
228
+ if scaled_data.shape[0] > 50:
229
+ background_data = scaled_data[:50]
230
+ else:
231
+ background_data = scaled_data
232
+
233
+ # Use the "new" unified shap.Explainer approach
234
+ # We pass in a function that does the forward pass. Or pass the model directly.
235
+ # For PyTorch models, shap can do a direct 'model' approach with a mask.
236
+ # We'll do a simple "use shap.Explainer" with data=background_data
237
+ explainer = shap.Explainer(model, background_data)
238
+ shap_values = explainer(scaled_data) # shape=(num_samples, num_features)
239
+
240
+ # k-mer list
241
+ kmer_list = [''.join(p) for p in product("ACGT", repeat=k)]
242
+
243
+ return (results_table, shap_values, scaled_data, kmer_list, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
 
245
 
246
+ ###############################################################################
247
+ # Gradio Callback Functions
248
+ ###############################################################################
249
+ def main_predict(file_obj):
250
+ """
251
+ This function is triggered by the 'Run' button in Gradio.
252
+ It returns a markdown of all sequences/predictions and stores
253
+ data needed for the subsequent SHAP visualizations.
254
+ """
255
+ results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
256
+ if err:
257
+ return (err, None, None, None, None)
258
+
259
+ if results is None or shap_vals is None:
260
+ return ("An unknown error occurred.", None, None, None, None)
261
+
262
+ # Build a summary for all sequences
263
+ md = "# Classification Results\n\n"
264
+ md += "| # | Header | Pred Label | Confidence | Human Prob | Non-human Prob |\n"
265
+ md += "|---|--------|------------|------------|------------|----------------|\n"
266
+ for i, row in enumerate(results):
267
+ md += (
268
+ f"| {i} | {row['header']} | {row['pred_label']} | "
269
+ f"{row['confidence']:.4f} | {row['human_prob']:.4f} | {row['non_human_prob']:.4f} |\n"
270
  )
271
+ md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots."
272
+
273
+ # Return the string, and also the shap values plus data needed
274
+ # We'll store these to SessionState via Gradio's "State" or we can
275
+ # pass them out as hidden fields.
276
+ return (md, shap_vals, scaled_data, kmer_list, results)
277
+
278
+
279
+ def update_waterfall_plot(selected_index, shap_values_obj):
280
+ """
281
+ Build a waterfall plot for the user-selected sample.
282
+ """
283
+ if shap_values_obj is None:
284
+ return None
285
+
286
+ try:
287
+ selected_index = int(selected_index)
288
+ except:
289
+ selected_index = 0
290
+
291
+ # We'll create the figure by calling shap.plots.waterfall
292
+ # Convert shap_values_obj to the new shap interface
293
+ # shap_values_obj is a shap._explanation.Explanation typically
294
+
295
+ # We can create a figure with shap.plots.waterfall and capture it as an image
296
+ shap_plots_fig = plt.figure(figsize=(8, 5))
297
+ shap.plots.waterfall(shap_values_obj[selected_index], max_display=14,
298
+ show=False) # show=False so it doesn't pop in the notebook
299
+ buf = io.BytesIO()
300
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
301
+ buf.seek(0)
302
+ wf_img = Image.open(buf)
303
+ plt.close(shap_plots_fig)
304
+
305
+ return wf_img
306
+
307
+
308
+ def update_beeswarm_plot(shap_values_obj):
309
+ """
310
+ Build a beeswarm plot across all samples.
311
+ """
312
+ if shap_values_obj is None:
313
+ return None
314
+
315
+ beeswarm_fig = plt.figure(figsize=(8, 5))
316
+ shap.plots.beeswarm(shap_values_obj, show=False)
317
+ buf = io.BytesIO()
318
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
319
+ buf.seek(0)
320
+ bs_img = Image.open(buf)
321
+ plt.close(beeswarm_fig)
322
+
323
+ return bs_img
324
+
325
+
326
+ def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
327
+ """
328
+ Create the frequency & sigma bar chart for the selected sequence's top-10 k-mers.
329
+ We'll need to also compute the raw_freq_vector from the original unscaled data.
330
+ """
331
+ if shap_values_obj is None or scaled_data is None or kmer_list is None:
332
+ return None
333
+
334
+ try:
335
+ selected_index = int(selected_index)
336
+ except:
337
+ selected_index = 0
338
+
339
+ # We must re-generate the raw freq vector from the original input file
340
+ # or store it from earlier. Let's just re-run parse for that single sequence:
341
+ # But simpler is: run_classification_and_shap was storing all_raw_vectors...
342
+ # Let's do a quick approach: run_classification_and_shap already computed it
343
+ # but we didn't store it. We'll re-run the parse logic to get the raw freq again.
344
+
345
+ # For memory / speed reasons, better is to store it.
346
+ # For simplicity, let's parse again quickly:
347
+ if isinstance(file_obj, str):
348
+ text = file_obj
349
+ else:
350
+ text = file_obj.decode('utf-8')
351
+ sequences = parse_fasta(text)
352
+ # the selected_index might be out of range, so let's clamp it
353
+ if selected_index >= len(sequences):
354
+ selected_index = 0
355
+ seq = sequences[selected_index][1] # get the sequence
356
+ raw_vec = sequence_to_kmer_vector(seq, k=4)
357
+
358
+ single_shap_values = shap_values_obj.values[selected_index]
359
+ freq_sigma_fig = create_freq_sigma_plot(
360
+ single_shap_values,
361
+ raw_freq_vector=raw_vec,
362
+ scaled_vector=scaled_data[selected_index],
363
+ kmer_list=kmer_list,
364
+ title=f"Sample #{selected_index} — {sequences[selected_index][0]}"
365
+ )
366
+ buf = io.BytesIO()
367
+ freq_sigma_fig.savefig(buf, format='png', bbox_inches='tight', dpi=120)
368
+ buf.seek(0)
369
+ fs_img = Image.open(buf)
370
+ plt.close(freq_sigma_fig)
371
+
372
+ return fs_img
373
 
374
 
375
  ###############################################################################
376
  # Gradio Interface
377
  ###############################################################################
378
+ with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
379
+ shap.initjs() # load shap JS for interactive plots in some contexts (optional)
380
+
381
  gr.Markdown(
382
  """
383
+ # **Advanced Virus Host Classifier with SHAP**
384
+ **Upload a FASTA file** with one or more nucleotide sequences.
385
+ This app will:
386
+ 1. Predict each sequence's **host** (human vs. non-human).
387
+ 2. Provide **SHAP** explanations (waterfall & beeswarm).
388
+ 3. Let you explore **frequency & σ** for top-10 k-mers for a chosen sequence.
389
  """
390
  )
 
 
 
 
391
 
392
+ with gr.Row():
393
+ file_input = gr.File(label="Upload FASTA", type="binary")
394
+ run_btn = gr.Button("Run Classification")
395
+
396
+ # Store intermediate results in "States" for usage in subsequent tabs
397
+ shap_values_state = gr.State()
398
+ scaled_data_state = gr.State()
399
+ kmer_list_state = gr.State()
400
+ results_state = gr.State()
401
+ # We'll also store the "raw input" so we can reconstruct freq data for each sample
402
+ file_data_state = gr.State()
403
+
404
+ # TABS for outputs
405
  with gr.Tabs():
406
+ with gr.Tab("Results Table"):
407
  md_out = gr.Markdown()
408
+
409
+ with gr.Tab("SHAP Waterfall"):
410
+ # We'll let user pick the sequence index from a dropdown or slider
411
+ with gr.Row():
412
+ seq_index_dropdown = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
413
+ update_wf_btn = gr.Button("Update Waterfall")
414
+
415
+ wf_plot = gr.Image(label="SHAP Waterfall Plot")
416
+
417
+ with gr.Tab("SHAP Beeswarm"):
418
+ bs_plot = gr.Image(label="Global Beeswarm Plot", height=500)
419
+
420
+ with gr.Tab("Top-10 Frequency & Sigma"):
421
+ with gr.Row():
422
+ seq_index_dropdown2 = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
423
+ update_fs_btn = gr.Button("Update Frequency Chart")
424
+ fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
425
+
426
+ # --- Button Logic ---
427
+ run_btn.click(
428
+ fn=main_predict,
429
+ inputs=[file_input],
430
+ outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state]
431
+ )
432
+ run_btn.click( # Also store the raw file data for later freq plots
433
+ fn=lambda x: x,
434
+ inputs=file_input,
435
+ outputs=file_data_state
436
+ )
437
+
438
+ update_wf_btn.click(
439
+ fn=update_waterfall_plot,
440
+ inputs=[seq_index_dropdown, shap_values_state],
441
+ outputs=[wf_plot]
442
+ )
443
+ update_fs_btn.click(
444
+ fn=update_freq_plot,
445
+ inputs=[seq_index_dropdown2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state],
446
+ outputs=[fs_plot]
447
+ )
448
+
449
+ # We can auto-generate the beeswarm right after classification as well
450
+ run_btn.click(
451
+ fn=update_beeswarm_plot,
452
+ inputs=[shap_values_state],
453
+ outputs=[bs_plot]
454
  )
455
 
456
  if __name__ == "__main__":