Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import io
|
9 |
from PIL import Image
|
|
|
10 |
|
11 |
###############################################################################
|
12 |
# Model Definition
|
@@ -30,30 +31,16 @@ class VirusClassifier(nn.Module):
|
|
30 |
|
31 |
def forward(self, x):
|
32 |
return self.network(x)
|
33 |
-
|
34 |
-
def get_feature_importance(self, x):
|
35 |
-
"""
|
36 |
-
Calculate gradient-based feature importance, specifically for the
|
37 |
-
'human' class (index=1) by computing gradient of that probability wrt x.
|
38 |
-
"""
|
39 |
-
x.requires_grad_(True)
|
40 |
-
output = self.network(x)
|
41 |
-
probs = torch.softmax(output, dim=1)
|
42 |
-
|
43 |
-
# Probability of 'human' class (index=1)
|
44 |
-
human_prob = probs[..., 1]
|
45 |
-
if x.grad is not None:
|
46 |
-
x.grad.zero_()
|
47 |
-
human_prob.backward()
|
48 |
-
importance = x.grad # shape: (batch_size, n_features)
|
49 |
-
|
50 |
-
return importance, float(human_prob)
|
51 |
|
52 |
###############################################################################
|
53 |
# Utility Functions
|
54 |
###############################################################################
|
55 |
def parse_fasta(text):
|
56 |
-
"""
|
|
|
|
|
|
|
57 |
sequences = []
|
58 |
current_header = None
|
59 |
current_sequence = []
|
@@ -74,7 +61,10 @@ def parse_fasta(text):
|
|
74 |
return sequences
|
75 |
|
76 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
77 |
-
"""
|
|
|
|
|
|
|
78 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
79 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
80 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
@@ -92,377 +82,375 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
92 |
|
93 |
|
94 |
###############################################################################
|
95 |
-
# Visualization
|
96 |
###############################################################################
|
97 |
-
def
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
"""
|
105 |
-
|
106 |
-
|
107 |
-
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
|
108 |
-
|
109 |
-
# 2) Compute the total effect of "other" k-mers
|
110 |
-
# We have 256 total features. We selected top 10. Sum the rest.
|
111 |
-
top_ids = set([km['idx'] for km in sorted_kmers])
|
112 |
-
other_contributions = []
|
113 |
-
for i, val in enumerate(all_kmer_importance):
|
114 |
-
if i not in top_ids:
|
115 |
-
other_contributions.append(val)
|
116 |
-
# sum up those "other" contributions
|
117 |
-
other_sum = np.sum(other_contributions)
|
118 |
-
# The "impact" for "other" will be the absolute value, direction depends on sign
|
119 |
-
other_impact = float(abs(other_sum))
|
120 |
-
other_direction = "human" if other_sum > 0 else "non-human"
|
121 |
-
|
122 |
-
# 3) Build a list of all bars: first "other", then each top k-mer
|
123 |
-
# Each bar needs: name, raw_contribution_value
|
124 |
-
# We'll store (label, contribution). The sign indicates direction.
|
125 |
-
bars = []
|
126 |
-
bars.append(("Other", other_sum)) # lumps the leftover k-mers
|
127 |
-
|
128 |
-
for km in sorted_kmers:
|
129 |
-
# We re-inject the sign on the raw gradient
|
130 |
-
# (We stored only the absolute in "impact," so let's create a signed value)
|
131 |
-
signed_val = km['impact'] if km['direction'] == 'human' else -km['impact']
|
132 |
-
bars.append((km['kmer'], signed_val))
|
133 |
-
|
134 |
-
# 4) Waterfall plot data:
|
135 |
-
# We'll accumulate partial sums from baseline=0.5
|
136 |
-
baseline = 0.5
|
137 |
-
running_val = baseline
|
138 |
-
x_labels = []
|
139 |
-
y_vals = []
|
140 |
-
bar_colors = []
|
141 |
-
|
142 |
-
# We'll use green for positive contributions (pushing toward 'human'),
|
143 |
-
# red for negative contributions (pushing away from 'human')
|
144 |
-
for (label, contrib) in bars:
|
145 |
-
x_labels.append(label)
|
146 |
-
# new value after adding this contribution
|
147 |
-
new_val = running_val + (0.05 * contrib)
|
148 |
-
# ^ scaled by 0.05 for better display. Adjust as desired.
|
149 |
-
|
150 |
-
y_vals.append((running_val, new_val))
|
151 |
-
running_val = new_val
|
152 |
-
if contrib >= 0:
|
153 |
-
bar_colors.append("green")
|
154 |
-
else:
|
155 |
-
bar_colors.append("red")
|
156 |
-
|
157 |
-
final_prob = running_val
|
158 |
-
# Final point is the model's predicted probability (not always exact, but this is a shap-like idea).
|
159 |
-
# If we want to forcibly ensure final_prob = human_prob, we could do:
|
160 |
-
# correction = human_prob - running_val
|
161 |
-
# running_val += correction
|
162 |
-
# but for now let's keep the "waterfall" purely additive from the gradient.
|
163 |
-
|
164 |
-
# Let's plot:
|
165 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
-
def create_frequency_sigma_plot(important_kmers, title):
|
188 |
-
"""Creates a bar plot of the top k-mers (by importance) showing frequency (%) and σ from mean."""
|
189 |
-
# Sort by absolute impact
|
190 |
-
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
|
191 |
-
kmers = [k["kmer"] for k in sorted_kmers]
|
192 |
-
frequencies = [k["occurrence"] for k in sorted_kmers] # in %
|
193 |
-
sigmas = [k["sigma"] for k in sorted_kmers]
|
194 |
-
directions = [k["direction"] for k in sorted_kmers]
|
195 |
-
|
196 |
x = np.arange(len(kmers))
|
197 |
width = 0.4
|
198 |
|
199 |
-
fig,
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
color=["green" if d=="human" else "red" for d in directions],
|
205 |
-
label="Frequency (%)"
|
206 |
-
)
|
207 |
-
ax_bar.set_ylabel("Frequency (%)")
|
208 |
-
ax_bar.set_ylim(0, max(frequencies) * 1.2 if frequencies else 1)
|
209 |
-
|
210 |
-
# Twin axis for σ
|
211 |
-
ax_bar_twin = ax_bar.twinx()
|
212 |
-
bars_sigma = ax_bar_twin.bar(
|
213 |
-
x + width/2, sigmas, width, alpha=0.5, color="gray", label="σ from Mean"
|
214 |
-
)
|
215 |
-
ax_bar_twin.set_ylabel("Standard Deviations (σ)")
|
216 |
-
|
217 |
-
ax_bar.set_title(f"Frequency & σ from Mean for Top k-mers — {title}")
|
218 |
-
ax_bar.set_xticks(x)
|
219 |
-
ax_bar.set_xticklabels(kmers, rotation=45, ha='right')
|
220 |
|
221 |
-
#
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
|
226 |
-
plt.tight_layout()
|
227 |
-
return fig
|
228 |
-
|
229 |
-
def create_importance_bar_plot(important_kmers, title):
|
230 |
-
"""
|
231 |
-
Create a simple bar chart showing the absolute gradient magnitude
|
232 |
-
for the top k-mers, sorted descending.
|
233 |
-
"""
|
234 |
-
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
|
235 |
-
kmers = [k['kmer'] for k in sorted_kmers]
|
236 |
-
impacts = [k['impact'] for k in sorted_kmers]
|
237 |
-
directions = [k["direction"] for k in sorted_kmers]
|
238 |
-
|
239 |
-
x = np.arange(len(kmers))
|
240 |
-
|
241 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
242 |
-
bar_colors = ["green" if d=="human" else "red" for d in directions]
|
243 |
-
|
244 |
-
ax.bar(x, impacts, color=bar_colors, alpha=0.7)
|
245 |
ax.set_xticks(x)
|
246 |
ax.set_xticklabels(kmers, rotation=45, ha='right')
|
247 |
-
ax.set_title(f"
|
248 |
-
|
249 |
-
|
|
|
|
|
|
|
250 |
|
251 |
plt.tight_layout()
|
252 |
return fig
|
253 |
|
254 |
|
255 |
###############################################################################
|
256 |
-
#
|
257 |
###############################################################################
|
258 |
-
def
|
259 |
"""
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
- Frequency & sigma plot.
|
268 |
-
- Absolute importance bar plot.
|
269 |
"""
|
270 |
-
#
|
271 |
-
if file_obj
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
# If user provided raw text, use that
|
281 |
-
if isinstance(file_obj, str):
|
282 |
-
text = file_obj
|
283 |
-
else:
|
284 |
-
# If user uploaded a file, decode it
|
285 |
-
text = file_obj.decode('utf-8')
|
286 |
-
except Exception as e:
|
287 |
-
return (
|
288 |
-
f"Error reading file: {str(e)}",
|
289 |
-
None,
|
290 |
-
None,
|
291 |
-
None
|
292 |
-
)
|
293 |
-
|
294 |
-
# 1. Parse FASTA
|
295 |
sequences = parse_fasta(text)
|
296 |
if len(sequences) == 0:
|
297 |
-
return
|
298 |
-
"No valid FASTA sequences found. Please check your input.",
|
299 |
-
None,
|
300 |
-
None,
|
301 |
-
None
|
302 |
-
)
|
303 |
-
# We’ll just classify the first sequence for demonstration
|
304 |
-
header, seq = sequences[0]
|
305 |
|
306 |
-
#
|
307 |
k = 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
try:
|
309 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
310 |
-
|
311 |
-
# Prepare raw freq vector & scale
|
312 |
-
raw_freq_vector = sequence_to_kmer_vector(seq, k=k)
|
313 |
-
|
314 |
-
# Load model & scaler
|
315 |
model = VirusClassifier(input_shape=4**k).to(device)
|
316 |
-
state_dict = torch.load(
|
317 |
model.load_state_dict(state_dict)
|
318 |
-
scaler = joblib.load('scaler.pkl')
|
319 |
model.eval()
|
320 |
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
})
|
366 |
-
|
367 |
-
# 6. Text Summary
|
368 |
-
summary_text = (
|
369 |
-
f"**Sequence Header**: {header}\n\n"
|
370 |
-
f"**Predicted Label**: {pred_label}\n"
|
371 |
-
f"**Confidence**: {confidence:.4f}\n\n"
|
372 |
-
f"**Human Probability**: {human_prob:.4f}\n"
|
373 |
-
f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
|
374 |
-
"### Most Influential k-mers:\n"
|
375 |
-
)
|
376 |
-
for km in important_kmers:
|
377 |
-
direction_text = f"(pushes toward {km['direction']})"
|
378 |
-
freq_text = f"{km['occurrence']:.2f}%"
|
379 |
-
sigma_text = f"{abs(km['sigma']):.2f}σ " + ("above" if km['sigma']>0 else "below") + " mean"
|
380 |
-
summary_text += (
|
381 |
-
f"- **{km['kmer']}**: impact={km['impact']:.4f}, {direction_text}, "
|
382 |
-
f"occurrence={freq_text}, ({sigma_text})\n"
|
383 |
-
)
|
384 |
-
|
385 |
-
# 7. Plots
|
386 |
-
# a) SHAP-like Waterfall Plot
|
387 |
-
fig_waterfall = create_shap_waterfall_plot(
|
388 |
-
important_kmers,
|
389 |
-
kmer_importances,
|
390 |
-
human_prob,
|
391 |
-
f"{header}"
|
392 |
-
)
|
393 |
-
buf1 = io.BytesIO()
|
394 |
-
fig_waterfall.savefig(buf1, format='png', bbox_inches='tight', dpi=120)
|
395 |
-
buf1.seek(0)
|
396 |
-
waterfall_img = Image.open(buf1)
|
397 |
-
plt.close(fig_waterfall)
|
398 |
-
|
399 |
-
# b) Frequency & σ Plot (top 10 k-mers)
|
400 |
-
fig_freq_sigma = create_frequency_sigma_plot(
|
401 |
-
important_kmers,
|
402 |
-
f"{header}"
|
403 |
-
)
|
404 |
-
buf2 = io.BytesIO()
|
405 |
-
fig_freq_sigma.savefig(buf2, format='png', bbox_inches='tight', dpi=120)
|
406 |
-
buf2.seek(0)
|
407 |
-
freq_sigma_img = Image.open(buf2)
|
408 |
-
plt.close(fig_freq_sigma)
|
409 |
-
|
410 |
-
# c) Absolute Importance Bar Plot
|
411 |
-
fig_imp = create_importance_bar_plot(
|
412 |
-
important_kmers,
|
413 |
-
f"{header}"
|
414 |
-
)
|
415 |
-
buf3 = io.BytesIO()
|
416 |
-
fig_imp.savefig(buf3, format='png', bbox_inches='tight', dpi=120)
|
417 |
-
buf3.seek(0)
|
418 |
-
importance_img = Image.open(buf3)
|
419 |
-
plt.close(fig_imp)
|
420 |
|
421 |
-
return summary_text, waterfall_img, freq_sigma_img, importance_img
|
422 |
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
429 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
|
431 |
|
432 |
###############################################################################
|
433 |
# Gradio Interface
|
434 |
###############################################################################
|
435 |
-
with gr.Blocks(title="
|
|
|
|
|
436 |
gr.Markdown(
|
437 |
"""
|
438 |
-
# Advanced Virus Host Classifier
|
439 |
-
**Upload a FASTA file**
|
440 |
-
|
441 |
-
|
442 |
-
|
|
|
443 |
"""
|
444 |
)
|
445 |
-
|
446 |
-
with gr.Row():
|
447 |
-
file_in = gr.File(label="Upload FASTA", type="binary")
|
448 |
-
btn = gr.Button("Run Prediction")
|
449 |
|
450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
with gr.Tabs():
|
452 |
-
with gr.Tab("
|
453 |
md_out = gr.Markdown()
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
)
|
467 |
|
468 |
if __name__ == "__main__":
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import io
|
9 |
from PIL import Image
|
10 |
+
import shap # Requires: pip install shap
|
11 |
|
12 |
###############################################################################
|
13 |
# Model Definition
|
|
|
31 |
|
32 |
def forward(self, x):
|
33 |
return self.network(x)
|
34 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
###############################################################################
|
37 |
# Utility Functions
|
38 |
###############################################################################
|
39 |
def parse_fasta(text):
|
40 |
+
"""
|
41 |
+
Parses text input in FASTA format into a list of (header, sequence).
|
42 |
+
Handles multiple sequences if present.
|
43 |
+
"""
|
44 |
sequences = []
|
45 |
current_header = None
|
46 |
current_sequence = []
|
|
|
61 |
return sequences
|
62 |
|
63 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
64 |
+
"""
|
65 |
+
Convert a single nucleotide sequence to a k-mer frequency vector
|
66 |
+
of length 4^k (e.g., for k=4, length=256).
|
67 |
+
"""
|
68 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
69 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
70 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
|
|
82 |
|
83 |
|
84 |
###############################################################################
|
85 |
+
# Visualization Helpers
|
86 |
###############################################################################
|
87 |
+
def create_freq_sigma_plot(
|
88 |
+
single_shap_values: np.ndarray,
|
89 |
+
raw_freq_vector: np.ndarray,
|
90 |
+
scaled_vector: np.ndarray,
|
91 |
+
kmer_list,
|
92 |
+
title: str
|
93 |
+
):
|
94 |
"""
|
95 |
+
Creates a bar plot showing top-10 k-mers (by absolute SHAP value),
|
96 |
+
with frequency (%) and sigma from mean on a twin-axis.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
+
single_shap_values: shape=(256,) shap values for this sample
|
99 |
+
raw_freq_vector: shape=(256,) original frequencies for this sample
|
100 |
+
scaled_vector: shape=(256,) scaled (Z-score) values for this sample
|
101 |
+
kmer_list: list of all k-mers (length=256)
|
102 |
+
"""
|
103 |
+
abs_vals = np.abs(single_shap_values)
|
104 |
+
top_k = 10
|
105 |
+
top_indices = np.argsort(abs_vals)[-top_k:][::-1] # top 10 by absolute shap
|
106 |
+
top_data = []
|
107 |
+
for idx in top_indices:
|
108 |
+
top_data.append({
|
109 |
+
"kmer": kmer_list[idx],
|
110 |
+
"shap": single_shap_values[idx],
|
111 |
+
"abs_shap": abs_vals[idx],
|
112 |
+
"frequency": raw_freq_vector[idx] * 100.0, # percentage
|
113 |
+
"sigma": scaled_vector[idx]
|
114 |
+
})
|
115 |
+
|
116 |
+
# Sort top_data by abs_shap descending
|
117 |
+
top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
|
118 |
+
|
119 |
+
kmers = [d["kmer"] for d in top_data]
|
120 |
+
freqs = [d["frequency"] for d in top_data]
|
121 |
+
sigmas = [d["sigma"] for d in top_data]
|
122 |
+
# color by sign (positive=green, negative=red)
|
123 |
+
colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
x = np.arange(len(kmers))
|
126 |
width = 0.4
|
127 |
|
128 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
129 |
+
# Frequency
|
130 |
+
ax.bar(x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)")
|
131 |
+
ax.set_ylabel("Frequency (%)", color='black')
|
132 |
+
ax.set_ylim(0, max(freqs)*1.2 if len(freqs) else 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
# Twin axis for sigma
|
135 |
+
ax2 = ax.twinx()
|
136 |
+
ax2.bar(x + width/2, sigmas, width, color="gray", alpha=0.5, label="σ from Mean")
|
137 |
+
ax2.set_ylabel("Standard Deviations (σ)", color='black')
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
ax.set_xticks(x)
|
140 |
ax.set_xticklabels(kmers, rotation=45, ha='right')
|
141 |
+
ax.set_title(f"Top-10 K-mers (Frequency & σ)\n{title}")
|
142 |
+
|
143 |
+
# Combine legends
|
144 |
+
lines1, labels1 = ax.get_legend_handles_labels()
|
145 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
146 |
+
ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
|
147 |
|
148 |
plt.tight_layout()
|
149 |
return fig
|
150 |
|
151 |
|
152 |
###############################################################################
|
153 |
+
# Main Inference & SHAP Logic
|
154 |
###############################################################################
|
155 |
+
def run_classification_and_shap(file_obj):
|
156 |
"""
|
157 |
+
Reads one or more FASTA sequences from file_obj or text.
|
158 |
+
Returns:
|
159 |
+
- Table of results (list of dicts) for each sequence
|
160 |
+
- shap_values object (SHAP values for the entire batch)
|
161 |
+
- array/batch of scaled vectors (for use in the waterfall selection)
|
162 |
+
- list of k-mers (for indexing)
|
163 |
+
- possibly the model or other context
|
|
|
|
|
164 |
"""
|
165 |
+
# 1. Basic read
|
166 |
+
if isinstance(file_obj, str):
|
167 |
+
text = file_obj
|
168 |
+
else:
|
169 |
+
try:
|
170 |
+
text = file_obj.decode("utf-8")
|
171 |
+
except Exception as e:
|
172 |
+
return None, None, f"Error reading file: {str(e)}"
|
173 |
+
|
174 |
+
# 2. Parse FASTA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
sequences = parse_fasta(text)
|
176 |
if len(sequences) == 0:
|
177 |
+
return None, None, "No valid FASTA sequences found!"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
# 3. Convert each sequence to k-mer vector
|
180 |
k = 4
|
181 |
+
all_raw_vectors = []
|
182 |
+
headers = []
|
183 |
+
seqs = []
|
184 |
+
for (hdr, seq) in sequences:
|
185 |
+
raw_vec = sequence_to_kmer_vector(seq, k=k)
|
186 |
+
all_raw_vectors.append(raw_vec)
|
187 |
+
headers.append(hdr)
|
188 |
+
seqs.append(seq)
|
189 |
+
|
190 |
+
all_raw_vectors = np.stack(all_raw_vectors, axis=0) # shape=(num_seqs, 256)
|
191 |
+
|
192 |
+
# 4. Load model & scaler
|
193 |
try:
|
194 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
195 |
model = VirusClassifier(input_shape=4**k).to(device)
|
196 |
+
state_dict = torch.load("model.pt", map_location=device)
|
197 |
model.load_state_dict(state_dict)
|
|
|
198 |
model.eval()
|
199 |
|
200 |
+
scaler = joblib.load("scaler.pkl")
|
201 |
+
except Exception as e:
|
202 |
+
return None, None, f"Error loading model or scaler: {str(e)}"
|
203 |
+
|
204 |
+
# 5. Scale data
|
205 |
+
scaled_data = scaler.transform(all_raw_vectors) # shape=(num_seqs, 256)
|
206 |
+
|
207 |
+
# 6. Predictions
|
208 |
+
X_tensor = torch.FloatTensor(scaled_data).to(device)
|
209 |
+
with torch.no_grad():
|
210 |
+
logits = model(X_tensor)
|
211 |
+
probs = torch.softmax(logits, dim=1).cpu().numpy()
|
212 |
+
preds = np.argmax(probs, axis=1) # 0 or 1
|
213 |
+
|
214 |
+
results_table = []
|
215 |
+
for i, (hdr, seq) in enumerate(zip(headers, seqs)):
|
216 |
+
results_table.append({
|
217 |
+
"header": hdr,
|
218 |
+
"sequence": seq[:50] + ("..." if len(seq)>50 else ""), # truncated
|
219 |
+
"pred_label": "human" if preds[i] == 1 else "non-human",
|
220 |
+
"human_prob": float(probs[i][1]),
|
221 |
+
"non_human_prob": float(probs[i][0]),
|
222 |
+
"confidence": float(max(probs[i]))
|
223 |
+
})
|
224 |
+
|
225 |
+
# 7. SHAP Explainer
|
226 |
+
# We'll pick a background subset if there are many sequences
|
227 |
+
# (For performance, we might limit to e.g. 50 samples max)
|
228 |
+
if scaled_data.shape[0] > 50:
|
229 |
+
background_data = scaled_data[:50]
|
230 |
+
else:
|
231 |
+
background_data = scaled_data
|
232 |
+
|
233 |
+
# Use the "new" unified shap.Explainer approach
|
234 |
+
# We pass in a function that does the forward pass. Or pass the model directly.
|
235 |
+
# For PyTorch models, shap can do a direct 'model' approach with a mask.
|
236 |
+
# We'll do a simple "use shap.Explainer" with data=background_data
|
237 |
+
explainer = shap.Explainer(model, background_data)
|
238 |
+
shap_values = explainer(scaled_data) # shape=(num_samples, num_features)
|
239 |
+
|
240 |
+
# k-mer list
|
241 |
+
kmer_list = [''.join(p) for p in product("ACGT", repeat=k)]
|
242 |
+
|
243 |
+
return (results_table, shap_values, scaled_data, kmer_list, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
|
|
245 |
|
246 |
+
###############################################################################
|
247 |
+
# Gradio Callback Functions
|
248 |
+
###############################################################################
|
249 |
+
def main_predict(file_obj):
|
250 |
+
"""
|
251 |
+
This function is triggered by the 'Run' button in Gradio.
|
252 |
+
It returns a markdown of all sequences/predictions and stores
|
253 |
+
data needed for the subsequent SHAP visualizations.
|
254 |
+
"""
|
255 |
+
results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
|
256 |
+
if err:
|
257 |
+
return (err, None, None, None, None)
|
258 |
+
|
259 |
+
if results is None or shap_vals is None:
|
260 |
+
return ("An unknown error occurred.", None, None, None, None)
|
261 |
+
|
262 |
+
# Build a summary for all sequences
|
263 |
+
md = "# Classification Results\n\n"
|
264 |
+
md += "| # | Header | Pred Label | Confidence | Human Prob | Non-human Prob |\n"
|
265 |
+
md += "|---|--------|------------|------------|------------|----------------|\n"
|
266 |
+
for i, row in enumerate(results):
|
267 |
+
md += (
|
268 |
+
f"| {i} | {row['header']} | {row['pred_label']} | "
|
269 |
+
f"{row['confidence']:.4f} | {row['human_prob']:.4f} | {row['non_human_prob']:.4f} |\n"
|
270 |
)
|
271 |
+
md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots."
|
272 |
+
|
273 |
+
# Return the string, and also the shap values plus data needed
|
274 |
+
# We'll store these to SessionState via Gradio's "State" or we can
|
275 |
+
# pass them out as hidden fields.
|
276 |
+
return (md, shap_vals, scaled_data, kmer_list, results)
|
277 |
+
|
278 |
+
|
279 |
+
def update_waterfall_plot(selected_index, shap_values_obj):
|
280 |
+
"""
|
281 |
+
Build a waterfall plot for the user-selected sample.
|
282 |
+
"""
|
283 |
+
if shap_values_obj is None:
|
284 |
+
return None
|
285 |
+
|
286 |
+
try:
|
287 |
+
selected_index = int(selected_index)
|
288 |
+
except:
|
289 |
+
selected_index = 0
|
290 |
+
|
291 |
+
# We'll create the figure by calling shap.plots.waterfall
|
292 |
+
# Convert shap_values_obj to the new shap interface
|
293 |
+
# shap_values_obj is a shap._explanation.Explanation typically
|
294 |
+
|
295 |
+
# We can create a figure with shap.plots.waterfall and capture it as an image
|
296 |
+
shap_plots_fig = plt.figure(figsize=(8, 5))
|
297 |
+
shap.plots.waterfall(shap_values_obj[selected_index], max_display=14,
|
298 |
+
show=False) # show=False so it doesn't pop in the notebook
|
299 |
+
buf = io.BytesIO()
|
300 |
+
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
301 |
+
buf.seek(0)
|
302 |
+
wf_img = Image.open(buf)
|
303 |
+
plt.close(shap_plots_fig)
|
304 |
+
|
305 |
+
return wf_img
|
306 |
+
|
307 |
+
|
308 |
+
def update_beeswarm_plot(shap_values_obj):
|
309 |
+
"""
|
310 |
+
Build a beeswarm plot across all samples.
|
311 |
+
"""
|
312 |
+
if shap_values_obj is None:
|
313 |
+
return None
|
314 |
+
|
315 |
+
beeswarm_fig = plt.figure(figsize=(8, 5))
|
316 |
+
shap.plots.beeswarm(shap_values_obj, show=False)
|
317 |
+
buf = io.BytesIO()
|
318 |
+
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
319 |
+
buf.seek(0)
|
320 |
+
bs_img = Image.open(buf)
|
321 |
+
plt.close(beeswarm_fig)
|
322 |
+
|
323 |
+
return bs_img
|
324 |
+
|
325 |
+
|
326 |
+
def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
|
327 |
+
"""
|
328 |
+
Create the frequency & sigma bar chart for the selected sequence's top-10 k-mers.
|
329 |
+
We'll need to also compute the raw_freq_vector from the original unscaled data.
|
330 |
+
"""
|
331 |
+
if shap_values_obj is None or scaled_data is None or kmer_list is None:
|
332 |
+
return None
|
333 |
+
|
334 |
+
try:
|
335 |
+
selected_index = int(selected_index)
|
336 |
+
except:
|
337 |
+
selected_index = 0
|
338 |
+
|
339 |
+
# We must re-generate the raw freq vector from the original input file
|
340 |
+
# or store it from earlier. Let's just re-run parse for that single sequence:
|
341 |
+
# But simpler is: run_classification_and_shap was storing all_raw_vectors...
|
342 |
+
# Let's do a quick approach: run_classification_and_shap already computed it
|
343 |
+
# but we didn't store it. We'll re-run the parse logic to get the raw freq again.
|
344 |
+
|
345 |
+
# For memory / speed reasons, better is to store it.
|
346 |
+
# For simplicity, let's parse again quickly:
|
347 |
+
if isinstance(file_obj, str):
|
348 |
+
text = file_obj
|
349 |
+
else:
|
350 |
+
text = file_obj.decode('utf-8')
|
351 |
+
sequences = parse_fasta(text)
|
352 |
+
# the selected_index might be out of range, so let's clamp it
|
353 |
+
if selected_index >= len(sequences):
|
354 |
+
selected_index = 0
|
355 |
+
seq = sequences[selected_index][1] # get the sequence
|
356 |
+
raw_vec = sequence_to_kmer_vector(seq, k=4)
|
357 |
+
|
358 |
+
single_shap_values = shap_values_obj.values[selected_index]
|
359 |
+
freq_sigma_fig = create_freq_sigma_plot(
|
360 |
+
single_shap_values,
|
361 |
+
raw_freq_vector=raw_vec,
|
362 |
+
scaled_vector=scaled_data[selected_index],
|
363 |
+
kmer_list=kmer_list,
|
364 |
+
title=f"Sample #{selected_index} — {sequences[selected_index][0]}"
|
365 |
+
)
|
366 |
+
buf = io.BytesIO()
|
367 |
+
freq_sigma_fig.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
368 |
+
buf.seek(0)
|
369 |
+
fs_img = Image.open(buf)
|
370 |
+
plt.close(freq_sigma_fig)
|
371 |
+
|
372 |
+
return fs_img
|
373 |
|
374 |
|
375 |
###############################################################################
|
376 |
# Gradio Interface
|
377 |
###############################################################################
|
378 |
+
with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
|
379 |
+
shap.initjs() # load shap JS for interactive plots in some contexts (optional)
|
380 |
+
|
381 |
gr.Markdown(
|
382 |
"""
|
383 |
+
# **Advanced Virus Host Classifier with SHAP**
|
384 |
+
**Upload a FASTA file** with one or more nucleotide sequences.
|
385 |
+
This app will:
|
386 |
+
1. Predict each sequence's **host** (human vs. non-human).
|
387 |
+
2. Provide **SHAP** explanations (waterfall & beeswarm).
|
388 |
+
3. Let you explore **frequency & σ** for top-10 k-mers for a chosen sequence.
|
389 |
"""
|
390 |
)
|
|
|
|
|
|
|
|
|
391 |
|
392 |
+
with gr.Row():
|
393 |
+
file_input = gr.File(label="Upload FASTA", type="binary")
|
394 |
+
run_btn = gr.Button("Run Classification")
|
395 |
+
|
396 |
+
# Store intermediate results in "States" for usage in subsequent tabs
|
397 |
+
shap_values_state = gr.State()
|
398 |
+
scaled_data_state = gr.State()
|
399 |
+
kmer_list_state = gr.State()
|
400 |
+
results_state = gr.State()
|
401 |
+
# We'll also store the "raw input" so we can reconstruct freq data for each sample
|
402 |
+
file_data_state = gr.State()
|
403 |
+
|
404 |
+
# TABS for outputs
|
405 |
with gr.Tabs():
|
406 |
+
with gr.Tab("Results Table"):
|
407 |
md_out = gr.Markdown()
|
408 |
+
|
409 |
+
with gr.Tab("SHAP Waterfall"):
|
410 |
+
# We'll let user pick the sequence index from a dropdown or slider
|
411 |
+
with gr.Row():
|
412 |
+
seq_index_dropdown = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
|
413 |
+
update_wf_btn = gr.Button("Update Waterfall")
|
414 |
+
|
415 |
+
wf_plot = gr.Image(label="SHAP Waterfall Plot")
|
416 |
+
|
417 |
+
with gr.Tab("SHAP Beeswarm"):
|
418 |
+
bs_plot = gr.Image(label="Global Beeswarm Plot", height=500)
|
419 |
+
|
420 |
+
with gr.Tab("Top-10 Frequency & Sigma"):
|
421 |
+
with gr.Row():
|
422 |
+
seq_index_dropdown2 = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
|
423 |
+
update_fs_btn = gr.Button("Update Frequency Chart")
|
424 |
+
fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
|
425 |
+
|
426 |
+
# --- Button Logic ---
|
427 |
+
run_btn.click(
|
428 |
+
fn=main_predict,
|
429 |
+
inputs=[file_input],
|
430 |
+
outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state]
|
431 |
+
)
|
432 |
+
run_btn.click( # Also store the raw file data for later freq plots
|
433 |
+
fn=lambda x: x,
|
434 |
+
inputs=file_input,
|
435 |
+
outputs=file_data_state
|
436 |
+
)
|
437 |
+
|
438 |
+
update_wf_btn.click(
|
439 |
+
fn=update_waterfall_plot,
|
440 |
+
inputs=[seq_index_dropdown, shap_values_state],
|
441 |
+
outputs=[wf_plot]
|
442 |
+
)
|
443 |
+
update_fs_btn.click(
|
444 |
+
fn=update_freq_plot,
|
445 |
+
inputs=[seq_index_dropdown2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state],
|
446 |
+
outputs=[fs_plot]
|
447 |
+
)
|
448 |
+
|
449 |
+
# We can auto-generate the beeswarm right after classification as well
|
450 |
+
run_btn.click(
|
451 |
+
fn=update_beeswarm_plot,
|
452 |
+
inputs=[shap_values_state],
|
453 |
+
outputs=[bs_plot]
|
454 |
)
|
455 |
|
456 |
if __name__ == "__main__":
|