Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,13 +2,15 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
import joblib
|
4 |
import numpy as np
|
|
|
5 |
import torch.nn as nn
|
6 |
import matplotlib.pyplot as plt
|
7 |
import io
|
8 |
from PIL import Image
|
9 |
-
from itertools import product
|
10 |
|
11 |
-
|
|
|
|
|
12 |
|
13 |
class VirusClassifier(nn.Module):
|
14 |
def __init__(self, input_shape: int):
|
@@ -29,46 +31,20 @@ class VirusClassifier(nn.Module):
|
|
29 |
|
30 |
def forward(self, x):
|
31 |
return self.network(x)
|
32 |
-
|
33 |
-
def get_gradient_importance(self, x, class_index=1):
|
34 |
-
"""
|
35 |
-
Calculate gradient-based importance for each input feature.
|
36 |
-
By default, we compute the gradient wrt the 'human' class (index=1).
|
37 |
-
This method is akin to a raw gradient or 'saliency' approach.
|
38 |
-
"""
|
39 |
-
x = x.clone().detach().requires_grad_(True)
|
40 |
-
output = self.network(x)
|
41 |
-
probs = torch.softmax(output, dim=1)
|
42 |
-
|
43 |
-
# Probability of the specified class
|
44 |
-
target_prob = probs[..., class_index]
|
45 |
-
|
46 |
-
# Zero existing gradients if any
|
47 |
-
if x.grad is not None:
|
48 |
-
x.grad.zero_()
|
49 |
-
|
50 |
-
# Backprop on that probability
|
51 |
-
target_prob.backward()
|
52 |
-
|
53 |
-
# Raw gradient is now in x.grad
|
54 |
-
importance = x.grad.detach()
|
55 |
-
|
56 |
-
# Optional: Multiply by input to get a more "integrated gradients"-like measure
|
57 |
-
# importance = importance * x.detach()
|
58 |
-
|
59 |
-
return importance, float(target_prob)
|
60 |
|
61 |
-
|
|
|
|
|
62 |
|
63 |
-
def parse_fasta(text
|
64 |
"""
|
65 |
-
|
66 |
"""
|
67 |
sequences = []
|
68 |
current_header = None
|
69 |
current_sequence = []
|
70 |
|
71 |
-
for line in text.split('\n'):
|
72 |
line = line.strip()
|
73 |
if not line:
|
74 |
continue
|
@@ -85,10 +61,8 @@ def parse_fasta(text: str):
|
|
85 |
|
86 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
87 |
"""
|
88 |
-
Convert a
|
89 |
-
Defaults to k=4.
|
90 |
"""
|
91 |
-
# Generate all possible k-mers
|
92 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
93 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
94 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
@@ -104,385 +78,355 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
104 |
|
105 |
return vec
|
106 |
|
107 |
-
def
|
108 |
-
"""
|
109 |
-
Compute various statistics for a given sequence:
|
110 |
-
- Length
|
111 |
-
- GC content (%)
|
112 |
-
- A/C/G/T counts
|
113 |
"""
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
}
|
121 |
-
|
122 |
-
counts = {
|
123 |
-
'A': sequence.count('A'),
|
124 |
-
'C': sequence.count('C'),
|
125 |
-
'G': sequence.count('G'),
|
126 |
-
'T': sequence.count('T')
|
127 |
-
}
|
128 |
-
gc_content = (counts['G'] + counts['C']) / length * 100.0
|
129 |
-
|
130 |
-
return {
|
131 |
-
'length': length,
|
132 |
-
'gc_content': gc_content,
|
133 |
-
'counts': counts
|
134 |
-
}
|
135 |
-
|
136 |
-
# --------------- Visualization Functions ---------------
|
137 |
-
|
138 |
-
def plot_shap_like_bars(kmers, importance_values, top_k=10):
|
139 |
"""
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
#
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
-
|
155 |
-
fig, ax = plt.subplots(figsize=(8, 6))
|
156 |
-
colors = ['green' if val > 0 else 'red' for val in top_importances]
|
157 |
-
ax.barh(range(len(top_kmers)), np.abs(top_importances), color=colors)
|
158 |
-
ax.set_yticks(range(len(top_kmers)))
|
159 |
-
ax.set_yticklabels(top_kmers)
|
160 |
-
ax.invert_yaxis() # So that the highest value is at the top
|
161 |
-
ax.set_xlabel("Feature Importance (Gradient Magnitude)")
|
162 |
-
ax.set_title(f"Top-{top_k} SHAP-like Feature Importances")
|
163 |
-
plt.tight_layout()
|
164 |
-
return fig
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
(Optional if you want a quick distribution overview)
|
170 |
-
"""
|
171 |
-
fig, ax = plt.subplots(figsize=(10, 4))
|
172 |
-
ax.bar(range(len(kmer_freq_vector)), kmer_freq_vector, color='blue', alpha=0.6)
|
173 |
-
ax.set_xlabel("K-mer Index")
|
174 |
-
ax.set_ylabel("Frequency")
|
175 |
-
ax.set_title("K-mer Frequency Distribution")
|
176 |
-
ax.set_xticks([])
|
177 |
-
plt.tight_layout()
|
178 |
-
return fig
|
179 |
|
180 |
-
def
|
181 |
"""
|
182 |
-
|
183 |
-
|
184 |
"""
|
185 |
-
fig = plt.figure(figsize=(
|
186 |
-
|
|
|
|
|
187 |
|
188 |
-
#
|
|
|
189 |
current_prob = 0.5
|
190 |
steps = [('Start', current_prob, 0)]
|
191 |
|
192 |
-
for
|
193 |
-
change =
|
194 |
current_prob += change
|
195 |
-
steps.append((
|
196 |
-
|
197 |
-
x_vals = range(len(steps))
|
198 |
-
y_vals = [s[1] for s in steps]
|
199 |
-
|
200 |
-
ax.step(x_vals, y_vals, 'b-', where='post', label='Probability', linewidth=2)
|
201 |
-
ax.plot(x_vals, y_vals, 'b.', markersize=10)
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
for i, (kmer, prob, change) in enumerate(steps):
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
217 |
|
|
|
218 |
if i > 0:
|
219 |
change_text = f'{change:+.3f}'
|
220 |
color = 'green' if change > 0 else 'red'
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
def plot_kmer_freq_and_sigma(important_kmers):
|
233 |
-
"""
|
234 |
-
Plot frequencies vs. sigma from mean for the top k-mers.
|
235 |
-
This reuses logic from the original create_visualization second subplot,
|
236 |
-
but as its own function for clarity.
|
237 |
-
"""
|
238 |
-
fig, ax = plt.subplots(figsize=(8, 5))
|
239 |
|
240 |
# Prepare data
|
241 |
kmers = [k['kmer'] for k in important_kmers]
|
242 |
frequencies = [k['occurrence'] for k in important_kmers]
|
243 |
sigmas = [k['sigma'] for k in important_kmers]
|
244 |
-
colors = ['green' if k['direction'] == 'human' else 'red' for k in important_kmers]
|
245 |
|
|
|
|
|
|
|
|
|
246 |
x = np.arange(len(kmers))
|
247 |
width = 0.35
|
248 |
|
249 |
-
|
250 |
-
ax.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
|
251 |
|
252 |
-
#
|
253 |
-
|
254 |
-
#
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
ax2.set_ylabel('
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
267 |
|
268 |
plt.tight_layout()
|
269 |
return fig
|
270 |
|
271 |
-
|
272 |
-
|
273 |
-
def predict_fasta(
|
274 |
-
file_obj,
|
275 |
-
k_size=4,
|
276 |
-
top_k=10,
|
277 |
-
advanced_analysis=False
|
278 |
-
):
|
279 |
"""
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
"""
|
285 |
-
#
|
286 |
-
|
287 |
-
|
288 |
|
289 |
-
|
290 |
-
|
291 |
-
text = file_obj
|
292 |
-
else:
|
293 |
-
text = file_obj.decode('utf-8', errors='replace')
|
294 |
-
except Exception as e:
|
295 |
-
return f"Error reading file: {str(e)}", []
|
296 |
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
try:
|
304 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
305 |
-
model = VirusClassifier(input_shape=(4 ** k_size)).to(device)
|
306 |
state_dict = torch.load('model.pt', map_location=device)
|
307 |
model.load_state_dict(state_dict)
|
308 |
-
model.eval()
|
309 |
-
|
310 |
scaler = joblib.load('scaler.pkl')
|
311 |
except Exception as e:
|
312 |
-
return f"Error loading model
|
313 |
-
|
314 |
-
#
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
non_human_prob = float(probs[0][0])
|
340 |
-
confidence = float(torch.max(probs[0]).item())
|
341 |
-
|
342 |
-
# Compute gradient-based importance
|
343 |
-
importance, target_prob = model.get_gradient_importance(X_tensor, class_index=1)
|
344 |
-
importance = importance[0].cpu().numpy() # shape: (num_features,)
|
345 |
-
|
346 |
-
# Identify top-k features (by absolute gradient)
|
347 |
-
abs_importance = np.abs(importance)
|
348 |
-
sorted_indices = np.argsort(abs_importance)[::-1]
|
349 |
-
top_indices = sorted_indices[:top_k]
|
350 |
-
|
351 |
-
# Build a list of top k-mers
|
352 |
-
top_kmers_info = []
|
353 |
-
for i in top_indices:
|
354 |
-
kmer_name = all_kmers[i]
|
355 |
-
imp_val = float(importance[i])
|
356 |
-
direction = 'human' if imp_val > 0 else 'non-human'
|
357 |
-
freq_perc = float(raw_kmer_freq[i] * 100.0) # in percent
|
358 |
-
sigma = float(scaled_kmer_freq[0][i]) # This is the scaled value (stdev from mean if the scaler is StandardScaler)
|
359 |
-
|
360 |
-
top_kmers_info.append({
|
361 |
-
'kmer': kmer_name,
|
362 |
-
'impact': abs(imp_val),
|
363 |
-
'direction': direction,
|
364 |
-
'occurrence': freq_perc,
|
365 |
-
'sigma': sigma
|
366 |
-
})
|
367 |
-
|
368 |
-
# Text summary for this sequence
|
369 |
-
seq_report = []
|
370 |
-
seq_report.append(f"=== Sequence {idx} ===")
|
371 |
-
seq_report.append(f"Header: {header}")
|
372 |
-
seq_report.append(f"Length: {seq_stats['length']}")
|
373 |
-
seq_report.append(f"GC Content: {seq_stats['gc_content']:.2f}%")
|
374 |
-
seq_report.append(f"A: {seq_stats['counts']['A']}, C: {seq_stats['counts']['C']}, G: {seq_stats['counts']['G']}, T: {seq_stats['counts']['T']}")
|
375 |
-
seq_report.append(f"Prediction: {pred_label} (Confidence: {confidence:.4f})")
|
376 |
-
seq_report.append(f" Human Probability: {human_prob:.4f}")
|
377 |
-
seq_report.append(f" Non-human Probability: {non_human_prob:.4f}")
|
378 |
-
seq_report.append(f"\nTop-{top_k} Influential k-mers (by gradient magnitude):")
|
379 |
-
for tkm in top_kmers_info:
|
380 |
-
seq_report.append(
|
381 |
-
f" {tkm['kmer']}: pushes towards {tkm['direction']} "
|
382 |
-
f"(impact={tkm['impact']:.4f}), occurrence={tkm['occurrence']:.2f}%, "
|
383 |
-
f"sigma={tkm['sigma']:.2f}"
|
384 |
-
)
|
385 |
-
|
386 |
-
final_text_report.append("\n".join(seq_report))
|
387 |
-
|
388 |
-
# 6. Generate Plots (for each sequence)
|
389 |
-
if advanced_analysis:
|
390 |
-
# 6A. SHAP-like bar chart
|
391 |
-
fig_shap = plot_shap_like_bars(
|
392 |
-
kmers=all_kmers,
|
393 |
-
importance_values=importance,
|
394 |
-
top_k=top_k
|
395 |
-
)
|
396 |
-
buf_shap = io.BytesIO()
|
397 |
-
fig_shap.savefig(buf_shap, format='png', bbox_inches='tight', dpi=150)
|
398 |
-
buf_shap.seek(0)
|
399 |
-
plots.append(Image.open(buf_shap))
|
400 |
-
plt.close(fig_shap)
|
401 |
-
|
402 |
-
# 6B. k-mer distribution histogram
|
403 |
-
fig_kmer_dist = plot_kmer_distribution(raw_kmer_freq, all_kmers)
|
404 |
-
buf_dist = io.BytesIO()
|
405 |
-
fig_kmer_dist.savefig(buf_dist, format='png', bbox_inches='tight', dpi=150)
|
406 |
-
buf_dist.seek(0)
|
407 |
-
plots.append(Image.open(buf_dist))
|
408 |
-
plt.close(fig_kmer_dist)
|
409 |
-
|
410 |
-
# 6C. Original step visualization for top k k-mers
|
411 |
-
# Sort by actual 'impact' to preserve that step logic
|
412 |
-
# (largest absolute impact first)
|
413 |
-
top_kmers_info_step = sorted(top_kmers_info, key=lambda x: x['impact'], reverse=True)
|
414 |
-
fig_step = create_step_visualization(top_kmers_info_step, human_prob)
|
415 |
-
buf_step = io.BytesIO()
|
416 |
-
fig_step.savefig(buf_step, format='png', bbox_inches='tight', dpi=150)
|
417 |
-
buf_step.seek(0)
|
418 |
-
plots.append(Image.open(buf_step))
|
419 |
-
plt.close(fig_step)
|
420 |
-
|
421 |
-
# 6D. Frequency vs. sigma bar chart
|
422 |
-
fig_freq_sigma = plot_kmer_freq_and_sigma(top_kmers_info_step)
|
423 |
-
buf_freq_sigma = io.BytesIO()
|
424 |
-
fig_freq_sigma.savefig(buf_freq_sigma, format='png', bbox_inches='tight', dpi=150)
|
425 |
-
buf_freq_sigma.seek(0)
|
426 |
-
plots.append(Image.open(buf_freq_sigma))
|
427 |
-
plt.close(fig_freq_sigma)
|
428 |
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
):
|
441 |
-
"""
|
442 |
-
Wrapper for Gradio to handle the outputs in (text, List[Image]) form.
|
443 |
-
"""
|
444 |
-
text_output, pil_images = predict_fasta(
|
445 |
-
file_obj=file_obj,
|
446 |
-
k_size=k_size,
|
447 |
-
top_k=top_k,
|
448 |
-
advanced_analysis=advanced_analysis
|
449 |
-
)
|
450 |
-
|
451 |
-
|
452 |
-
return text_output, pil_images
|
453 |
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
)
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
|
486 |
if __name__ == "__main__":
|
487 |
-
|
488 |
-
|
|
|
2 |
import torch
|
3 |
import joblib
|
4 |
import numpy as np
|
5 |
+
from itertools import product
|
6 |
import torch.nn as nn
|
7 |
import matplotlib.pyplot as plt
|
8 |
import io
|
9 |
from PIL import Image
|
|
|
10 |
|
11 |
+
##############################################################################
|
12 |
+
# MODEL DEFINITION
|
13 |
+
##############################################################################
|
14 |
|
15 |
class VirusClassifier(nn.Module):
|
16 |
def __init__(self, input_shape: int):
|
|
|
31 |
|
32 |
def forward(self, x):
|
33 |
return self.network(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
##############################################################################
|
36 |
+
# UTILITIES
|
37 |
+
##############################################################################
|
38 |
|
39 |
+
def parse_fasta(text):
|
40 |
"""
|
41 |
+
Parses FASTA formatted text into a list of (header, sequence).
|
42 |
"""
|
43 |
sequences = []
|
44 |
current_header = None
|
45 |
current_sequence = []
|
46 |
|
47 |
+
for line in text.strip().split('\n'):
|
48 |
line = line.strip()
|
49 |
if not line:
|
50 |
continue
|
|
|
61 |
|
62 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
63 |
"""
|
64 |
+
Convert a sequence to a k-mer frequency vector of size len(ACGT^k).
|
|
|
65 |
"""
|
|
|
66 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
67 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
68 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
|
|
78 |
|
79 |
return vec
|
80 |
|
81 |
+
def ablation_importance(model, x_tensor):
|
|
|
|
|
|
|
|
|
|
|
82 |
"""
|
83 |
+
Calculates a simple ablation-based importance measure for each feature:
|
84 |
+
1. Compute baseline human probability p_base.
|
85 |
+
2. For each feature i, set x[i] = 0, re-run inference, compute new p, and
|
86 |
+
measure delta = p_base - p.
|
87 |
+
3. Return array of deltas (positive means that removing that feature
|
88 |
+
*decreases* the probability => that feature was pushing it higher).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
"""
|
90 |
+
model.eval()
|
91 |
+
with torch.no_grad():
|
92 |
+
# Baseline probability
|
93 |
+
output = model(x_tensor)
|
94 |
+
probs = torch.softmax(output, dim=1)
|
95 |
+
p_base = probs[0, 1].item()
|
96 |
+
|
97 |
+
# Store the delta importances
|
98 |
+
importances = np.zeros(x_tensor.shape[1], dtype=np.float32)
|
99 |
+
|
100 |
+
# For efficiency, we do ablation one feature at a time
|
101 |
+
for i in range(x_tensor.shape[1]):
|
102 |
+
x_copy = x_tensor.clone()
|
103 |
+
x_copy[0, i] = 0.0 # Ablate this feature
|
104 |
+
with torch.no_grad():
|
105 |
+
output_ablation = model(x_copy)
|
106 |
+
probs_ablation = torch.softmax(output_ablation, dim=1)
|
107 |
+
p_ablation = probs_ablation[0, 1].item()
|
108 |
+
# Delta
|
109 |
+
importances[i] = p_base - p_ablation
|
110 |
|
111 |
+
return importances, p_base
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
+
##############################################################################
|
114 |
+
# PLOTTING
|
115 |
+
##############################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
+
def create_step_and_frequency_plot(important_kmers, human_prob, title):
|
118 |
"""
|
119 |
+
Creates a combined step plot (showing how each k-mer modifies the probability)
|
120 |
+
and a frequency vs. sigma bar chart.
|
121 |
"""
|
122 |
+
fig = plt.figure(figsize=(15, 10))
|
123 |
+
|
124 |
+
# Create grid for subplots
|
125 |
+
gs = plt.GridSpec(2, 1, height_ratios=[1.5, 1], hspace=0.3)
|
126 |
|
127 |
+
# 1. Probability Step Plot
|
128 |
+
ax1 = plt.subplot(gs[0])
|
129 |
current_prob = 0.5
|
130 |
steps = [('Start', current_prob, 0)]
|
131 |
|
132 |
+
for kmer_info in important_kmers:
|
133 |
+
change = kmer_info['impact'] # positive => pushes up, negative => pushes down
|
134 |
current_prob += change
|
135 |
+
steps.append((kmer_info['kmer'], current_prob, change))
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
+
x = range(len(steps))
|
138 |
+
y = [step[1] for step in steps]
|
139 |
+
|
140 |
+
# Plot steps
|
141 |
+
ax1.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
|
142 |
+
ax1.plot(x, y, 'b.', markersize=10)
|
143 |
+
|
144 |
+
# Add reference line
|
145 |
+
ax1.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
|
146 |
+
|
147 |
+
# Customize plot
|
148 |
+
ax1.grid(True, linestyle='--', alpha=0.7)
|
149 |
+
ax1.set_ylim(0, 1)
|
150 |
+
ax1.set_ylabel('Human Probability')
|
151 |
+
ax1.set_title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
|
152 |
+
|
153 |
+
# Add labels for each point
|
154 |
for i, (kmer, prob, change) in enumerate(steps):
|
155 |
+
# Add k-mer label
|
156 |
+
ax1.annotate(kmer,
|
157 |
+
(i, prob),
|
158 |
+
xytext=(0, 10 if i % 2 == 0 else -20),
|
159 |
+
textcoords='offset points',
|
160 |
+
ha='center',
|
161 |
+
rotation=45)
|
162 |
|
163 |
+
# Add change value
|
164 |
if i > 0:
|
165 |
change_text = f'{change:+.3f}'
|
166 |
color = 'green' if change > 0 else 'red'
|
167 |
+
ax1.annotate(change_text,
|
168 |
+
(i, prob),
|
169 |
+
xytext=(0, -20 if i % 2 == 0 else 10),
|
170 |
+
textcoords='offset points',
|
171 |
+
ha='center',
|
172 |
+
color=color)
|
173 |
+
|
174 |
+
ax1.legend()
|
175 |
+
|
176 |
+
# 2. K-mer Frequency and Sigma Plot
|
177 |
+
ax2 = plt.subplot(gs[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
# Prepare data
|
180 |
kmers = [k['kmer'] for k in important_kmers]
|
181 |
frequencies = [k['occurrence'] for k in important_kmers]
|
182 |
sigmas = [k['sigma'] for k in important_kmers]
|
|
|
183 |
|
184 |
+
# Color the bars: if impact>0 => green, else red
|
185 |
+
colors = ['g' if k['impact'] > 0 else 'r' for k in important_kmers]
|
186 |
+
|
187 |
+
# Create bar plot for frequencies
|
188 |
x = np.arange(len(kmers))
|
189 |
width = 0.35
|
190 |
|
191 |
+
ax2.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
|
|
|
192 |
|
193 |
+
# Twin axis for sigma
|
194 |
+
ax2_twin = ax2.twinx()
|
195 |
+
# To highlight positive or negative sigma, pick color accordingly
|
196 |
+
sigma_colors = []
|
197 |
+
for s, c in zip(sigmas, colors):
|
198 |
+
if s >= 0:
|
199 |
+
sigma_colors.append('blue') # above average
|
200 |
+
else:
|
201 |
+
sigma_colors.append('gray') # below average
|
202 |
+
|
203 |
+
ax2_twin.bar(x + width/2, sigmas, width, label='σ from Mean', color=sigma_colors, alpha=0.3)
|
204 |
|
205 |
+
# Customize plot
|
206 |
+
ax2.set_xticks(x)
|
207 |
+
ax2.set_xticklabels(kmers, rotation=45)
|
208 |
+
ax2.set_ylabel('Frequency (%)')
|
209 |
+
ax2_twin.set_ylabel('Standard Deviations (σ) from Mean')
|
210 |
+
ax2.set_title('K-mer Frequencies and Statistical Significance')
|
211 |
+
|
212 |
+
# Add legends
|
213 |
+
lines1, labels1 = ax2.get_legend_handles_labels()
|
214 |
+
lines2, labels2 = ax2_twin.get_legend_handles_labels()
|
215 |
+
ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
|
216 |
|
217 |
plt.tight_layout()
|
218 |
return fig
|
219 |
|
220 |
+
def create_shap_like_bar_plot(impact_values, kmer_list, top_k):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
"""
|
222 |
+
Creates a horizontal bar plot showing the top_k features by absolute impact.
|
223 |
+
impact_values: array of float (length=256).
|
224 |
+
kmer_list: list of all k=4 kmers in order.
|
225 |
+
top_k: integer, how many top features to display.
|
226 |
"""
|
227 |
+
# Sort by absolute impact
|
228 |
+
indices_sorted = np.argsort(np.abs(impact_values))[::-1]
|
229 |
+
top_indices = indices_sorted[:top_k]
|
230 |
|
231 |
+
top_impacts = impact_values[top_indices]
|
232 |
+
top_kmers = [kmer_list[i] for i in top_indices]
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
+
fig = plt.figure(figsize=(8, 6))
|
235 |
+
plt.barh(range(len(top_impacts)), top_impacts, color=['green' if i > 0 else 'red' for i in top_impacts])
|
236 |
+
plt.yticks(range(len(top_impacts)), top_kmers)
|
237 |
+
plt.xlabel("Impact on Human Probability (Ablation)")
|
238 |
+
plt.title(f"Top {top_k} K-mers by Absolute Impact")
|
239 |
+
plt.gca().invert_yaxis() # Highest at top
|
240 |
+
plt.tight_layout()
|
241 |
+
return fig
|
242 |
+
|
243 |
+
def create_global_bar_plot(impact_values, kmer_list):
|
244 |
+
"""
|
245 |
+
Creates a bar plot for ALL features (256) to see the global distribution.
|
246 |
+
"""
|
247 |
+
fig = plt.figure(figsize=(12, 6))
|
248 |
+
indices_sorted = np.argsort(np.abs(impact_values))[::-1]
|
249 |
+
sorted_impacts = impact_values[indices_sorted]
|
250 |
+
sorted_kmers = [kmer_list[i] for i in indices_sorted]
|
251 |
|
252 |
+
plt.bar(range(len(sorted_impacts)), sorted_impacts,
|
253 |
+
color=['green' if i > 0 else 'red' for i in sorted_impacts])
|
254 |
+
plt.title("Global Impact of All 256 K-mers (Ablation Method)")
|
255 |
+
plt.xlabel("K-mer (sorted by |impact|)")
|
256 |
+
plt.ylabel("Impact on Human Probability")
|
257 |
+
# Optionally, we can skip labeling all 256 on x-axis.
|
258 |
+
# But we can show only the top/bottom or none for clarity.
|
259 |
+
plt.tight_layout()
|
260 |
+
return fig
|
261 |
+
|
262 |
+
##############################################################################
|
263 |
+
# MAIN PREDICTION FUNCTION
|
264 |
+
##############################################################################
|
265 |
+
|
266 |
+
def predict(file_obj, top_kmers=10, advanced_plots=False, fasta_text=""):
|
267 |
+
"""
|
268 |
+
Main prediction function called by Gradio.
|
269 |
+
- file_obj: optional uploaded FASTA file
|
270 |
+
- top_kmers: number of top k-mers to display in the main SHAP-like plot
|
271 |
+
- advanced_plots: bool, whether to return global bar plots
|
272 |
+
- fasta_text: optional direct-pasted FASTA text
|
273 |
+
"""
|
274 |
+
# Priority: If user pasted text, use that; otherwise use uploaded file.
|
275 |
+
if fasta_text.strip():
|
276 |
+
text = fasta_text.strip()
|
277 |
+
else:
|
278 |
+
if file_obj is None:
|
279 |
+
return "No FASTA input provided", None, None, None
|
280 |
+
try:
|
281 |
+
if isinstance(file_obj, str):
|
282 |
+
text = file_obj
|
283 |
+
else:
|
284 |
+
text = file_obj.decode('utf-8')
|
285 |
+
except Exception as e:
|
286 |
+
return f"Error reading file: {str(e)}", None, None, None
|
287 |
+
|
288 |
+
# Parse FASTA
|
289 |
+
sequences = parse_fasta(text)
|
290 |
+
if len(sequences) == 0:
|
291 |
+
return "No valid FASTA sequences found", None, None, None
|
292 |
+
header, seq = sequences[0]
|
293 |
+
|
294 |
+
# Load model + scaler
|
295 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
296 |
+
model = VirusClassifier(256).to(device)
|
297 |
try:
|
|
|
|
|
298 |
state_dict = torch.load('model.pt', map_location=device)
|
299 |
model.load_state_dict(state_dict)
|
|
|
|
|
300 |
scaler = joblib.load('scaler.pkl')
|
301 |
except Exception as e:
|
302 |
+
return f"Error loading model or scaler: {str(e)}", None, None, None
|
303 |
+
|
304 |
+
# Prepare the vector
|
305 |
+
raw_freq_vector = sequence_to_kmer_vector(seq, k=4)
|
306 |
+
scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
307 |
+
X_tensor = torch.FloatTensor(scaled_vector).to(device)
|
308 |
+
|
309 |
+
# Compute ablation-based importances
|
310 |
+
importances, p_base = ablation_importance(model, X_tensor)
|
311 |
+
# p_base is baseline human probability
|
312 |
+
|
313 |
+
# We also want frequency in % and sigma from mean
|
314 |
+
# If your scaler is e.g. StandardScaler, then "scaled_vector[0][i]" is
|
315 |
+
# how many std devs from the mean that feature is.
|
316 |
+
# We'll gather info in a list of dicts for each k-mer.
|
317 |
+
kmers_4 = [''.join(p) for p in product("ACGT", repeat=4)]
|
318 |
+
kmer_dict = {km: i for i, km in enumerate(kmers_4)}
|
319 |
+
|
320 |
+
# We'll sort by absolute impact to get the top 10 by default.
|
321 |
+
abs_sorted_idx = np.argsort(np.abs(importances))[::-1]
|
322 |
+
# But for the final step/frequency plot we only show top_kmers
|
323 |
+
top_indices = abs_sorted_idx[:top_kmers]
|
324 |
+
|
325 |
+
# Build a list of the top k-mers
|
326 |
+
important_kmers = []
|
327 |
+
for idx in top_indices:
|
328 |
+
# "impact" is how much that feature changed the probability
|
329 |
+
impact = importances[idx]
|
330 |
+
# raw frequency => raw_freq_vector[idx] * 100 for %
|
331 |
+
freq_pct = float(raw_freq_vector[idx] * 100.0)
|
332 |
+
# sigma => scaled_vector[0][idx]
|
333 |
+
sigma_val = float(scaled_vector[0][idx])
|
334 |
+
|
335 |
+
important_kmers.append({
|
336 |
+
'kmer': kmers_4[idx],
|
337 |
+
'impact': impact,
|
338 |
+
'occurrence': freq_pct,
|
339 |
+
'sigma': sigma_val
|
340 |
+
})
|
341 |
|
342 |
+
# For text output
|
343 |
+
# We decide final class based on model's direct output
|
344 |
+
with torch.no_grad():
|
345 |
+
output = model(X_tensor)
|
346 |
+
probs = torch.softmax(output, dim=1)
|
347 |
+
pred_class = 1 if probs[0,1] > probs[0,0] else 0
|
348 |
+
pred_label = 'human' if pred_class == 1 else 'non-human'
|
349 |
+
human_prob = probs[0,1].item()
|
350 |
+
nonhuman_prob = probs[0,0].item()
|
351 |
+
confidence = max(human_prob, nonhuman_prob)
|
352 |
+
|
353 |
+
results_text = (f"Sequence: {header}\n"
|
354 |
+
f"Prediction: {pred_label}\n"
|
355 |
+
f"Confidence: {confidence:.4f}\n"
|
356 |
+
f"Human probability: {human_prob:.4f}\n"
|
357 |
+
f"Non-human probability: {nonhuman_prob:.4f}\n"
|
358 |
+
f"Most influential k-mers (by ablation impact):\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
+
for kmer_info in important_kmers:
|
361 |
+
# sign => if impact>0 => removing it lowers p(human), so it was pushing p(human) up
|
362 |
+
direction = "UP (toward human)" if kmer_info['impact'] > 0 else "DOWN (toward non-human)"
|
363 |
+
results_text += (
|
364 |
+
f" {kmer_info['kmer']}: {direction}, "
|
365 |
+
f"Impact={kmer_info['impact']:.4f}, "
|
366 |
+
f"Occ={kmer_info['occurrence']:.2f}% of seq, "
|
367 |
+
f"{abs(kmer_info['sigma']):.2f}σ "
|
368 |
+
+ ("above" if kmer_info['sigma']>0 else "below")
|
369 |
+
+ " mean\n"
|
370 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
+
# PLOT 1: A SHAP-like bar plot for the top K features
|
373 |
+
shap_fig = create_shap_like_bar_plot(importances, kmers_4, top_kmers)
|
374 |
+
|
375 |
+
# PLOT 2: Step + frequency plot for the top K features
|
376 |
+
step_fig = create_step_and_frequency_plot(important_kmers, human_prob, header)
|
377 |
+
|
378 |
+
# PLOT 3 (optional advanced): global bar plot of all 256 features
|
379 |
+
global_fig = None
|
380 |
+
if advanced_plots:
|
381 |
+
global_fig = create_global_bar_plot(importances, kmers_4)
|
382 |
+
|
383 |
+
# Convert figures to PIL Images
|
384 |
+
def fig_to_image(fig):
|
385 |
+
buf = io.BytesIO()
|
386 |
+
fig.savefig(buf, format='png', bbox_inches='tight', dpi=200)
|
387 |
+
buf.seek(0)
|
388 |
+
im = Image.open(buf)
|
389 |
+
plt.close(fig)
|
390 |
+
return im
|
391 |
+
|
392 |
+
shap_img = fig_to_image(shap_fig)
|
393 |
+
step_img = fig_to_image(step_fig)
|
394 |
+
if global_fig is not None:
|
395 |
+
global_img = fig_to_image(global_fig)
|
396 |
+
else:
|
397 |
+
global_img = None
|
398 |
+
|
399 |
+
return results_text, shap_img, step_img, global_img
|
400 |
+
|
401 |
+
##############################################################################
|
402 |
+
# GRADIO INTERFACE
|
403 |
+
##############################################################################
|
404 |
+
|
405 |
+
title_text = "Virus Host Classifier"
|
406 |
+
description_text = """
|
407 |
+
Upload or paste a FASTA sequence to predict if it's likely **human** or **non-human** origin.
|
408 |
+
- **k=4** k-mers are used as features.
|
409 |
+
- We display ablation-based feature importance for interpretability.
|
410 |
+
- Advanced plots can be toggled to see the global distribution of all 256 k-mer impacts.
|
411 |
+
"""
|
412 |
+
|
413 |
+
iface = gr.Interface(
|
414 |
+
fn=predict,
|
415 |
+
inputs=[
|
416 |
+
gr.File(label="Upload FASTA file", type="binary", optional=True),
|
417 |
+
gr.Slider(label="Number of top k-mers to show", minimum=1, maximum=50, value=10, step=1),
|
418 |
+
gr.Checkbox(label="Show advanced (global) plots?", value=False),
|
419 |
+
gr.Textbox(label="Or paste FASTA text here", lines=5, placeholder=">header\nACGTACGT...")
|
420 |
+
],
|
421 |
+
outputs=[
|
422 |
+
gr.Textbox(label="Results", lines=10),
|
423 |
+
gr.Image(label="SHAP-like Top-k K-mer Bar Plot"),
|
424 |
+
gr.Image(label="Step & Frequency Plot (Top-k)"),
|
425 |
+
gr.Image(label="Global 256-K-mer Plot (advanced)", optional=True)
|
426 |
+
],
|
427 |
+
title=title_text,
|
428 |
+
description=description_text
|
429 |
+
)
|
430 |
|
431 |
if __name__ == "__main__":
|
432 |
+
iface.launch(share=True)
|
|