Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,17 +2,12 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
import joblib
|
4 |
import numpy as np
|
5 |
-
import shap
|
6 |
-
import random
|
7 |
from itertools import product
|
8 |
import torch.nn as nn
|
9 |
import matplotlib.pyplot as plt
|
10 |
import io
|
11 |
from PIL import Image
|
12 |
|
13 |
-
###############################################################################
|
14 |
-
# Model Definition
|
15 |
-
###############################################################################
|
16 |
class VirusClassifier(nn.Module):
|
17 |
def __init__(self, input_shape: int):
|
18 |
super(VirusClassifier, self).__init__()
|
@@ -34,28 +29,38 @@ class VirusClassifier(nn.Module):
|
|
34 |
return self.network(x)
|
35 |
|
36 |
def get_feature_importance(self, x):
|
37 |
-
"""
|
38 |
-
Calculate gradient-based feature importance, specifically for the
|
39 |
-
'human' class (index=1) by computing gradient of that probability wrt x.
|
40 |
-
"""
|
41 |
x.requires_grad_(True)
|
42 |
output = self.network(x)
|
43 |
probs = torch.softmax(output, dim=1)
|
44 |
|
45 |
-
#
|
46 |
human_prob = probs[..., 1]
|
47 |
if x.grad is not None:
|
48 |
x.grad.zero_()
|
49 |
human_prob.backward()
|
50 |
-
importance = x.grad
|
51 |
|
52 |
return importance, float(human_prob)
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def parse_fasta(text):
|
58 |
-
"""Parses text input in FASTA format into a list of (header, sequence)."""
|
59 |
sequences = []
|
60 |
current_header = None
|
61 |
current_sequence = []
|
@@ -75,213 +80,97 @@ def parse_fasta(text):
|
|
75 |
sequences.append((current_header, ''.join(current_sequence)))
|
76 |
return sequences
|
77 |
|
78 |
-
def
|
79 |
-
"""
|
80 |
-
|
81 |
-
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
82 |
-
vec = np.zeros(len(kmers), dtype=np.float32)
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
if kmer in kmer_dict:
|
87 |
-
vec[kmer_dict[kmer]] += 1
|
88 |
-
|
89 |
-
total_kmers = len(sequence) - k + 1
|
90 |
-
if total_kmers > 0:
|
91 |
-
vec = vec / total_kmers # normalize frequencies
|
92 |
-
|
93 |
-
return vec
|
94 |
-
|
95 |
-
###############################################################################
|
96 |
-
# Additional Plots
|
97 |
-
###############################################################################
|
98 |
-
def create_probability_bar_plot(prob_human, prob_nonhuman):
|
99 |
-
"""
|
100 |
-
Simple bar plot comparing human vs. non-human probabilities.
|
101 |
-
"""
|
102 |
-
labels = ["Non-human", "Human"]
|
103 |
-
probs = [prob_nonhuman, prob_human]
|
104 |
-
colors = ["red", "green"]
|
105 |
-
|
106 |
-
fig, ax = plt.subplots(figsize=(6, 4))
|
107 |
-
ax.bar(labels, probs, color=colors, alpha=0.7)
|
108 |
-
ax.set_ylim(0, 1)
|
109 |
-
for i, v in enumerate(probs):
|
110 |
-
ax.text(i, v+0.02, f"{v:.3f}", ha='center', color='black', fontsize=11)
|
111 |
-
|
112 |
-
ax.set_title("Predicted Probabilities")
|
113 |
-
ax.set_ylabel("Probability")
|
114 |
-
plt.tight_layout()
|
115 |
-
return fig
|
116 |
-
|
117 |
-
def create_frequency_sigma_plot(important_kmers, title):
|
118 |
-
"""
|
119 |
-
Creates a bar plot of the top k-mers (by importance) showing
|
120 |
-
frequency (%) and σ from mean.
|
121 |
-
"""
|
122 |
-
# Sort by absolute impact
|
123 |
-
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
|
124 |
-
kmers = [k["kmer"] for k in sorted_kmers]
|
125 |
-
frequencies = [k["occurrence"] for k in sorted_kmers] # in %
|
126 |
-
sigmas = [k["sigma"] for k in sorted_kmers]
|
127 |
-
directions = [k["direction"] for k in sorted_kmers]
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
x = np.arange(len(kmers))
|
130 |
-
width = 0.
|
131 |
-
|
132 |
-
fig, ax_bar = plt.subplots(figsize=(10, 5))
|
133 |
-
|
134 |
-
# Bar for frequency
|
135 |
-
bars_freq = ax_bar.bar(
|
136 |
-
x - width/2, frequencies, width, alpha=0.7,
|
137 |
-
color=["green" if d=="human" else "red" for d in directions],
|
138 |
-
label="Frequency (%)"
|
139 |
-
)
|
140 |
-
ax_bar.set_ylabel("Frequency (%)")
|
141 |
-
ax_bar.set_ylim(0, max(frequencies) * 1.2 if len(frequencies) > 0 else 1)
|
142 |
-
|
143 |
-
# Twin axis for σ
|
144 |
-
ax_bar_twin = ax_bar.twinx()
|
145 |
-
bars_sigma = ax_bar_twin.bar(
|
146 |
-
x + width/2, sigmas, width, alpha=0.5, color="gray", label="σ from Mean"
|
147 |
-
)
|
148 |
-
ax_bar_twin.set_ylabel("Standard Deviations (σ)")
|
149 |
-
|
150 |
-
ax_bar.set_title(f"Frequency & σ from Mean for Top k-mers — {title}")
|
151 |
-
ax_bar.set_xticks(x)
|
152 |
-
ax_bar.set_xticklabels(kmers, rotation=45, ha='right')
|
153 |
-
|
154 |
-
# Combined legend
|
155 |
-
lines1, labels1 = ax_bar.get_legend_handles_labels()
|
156 |
-
lines2, labels2 = ax_bar_twin.get_legend_handles_labels()
|
157 |
-
ax_bar.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
|
158 |
-
|
159 |
-
plt.tight_layout()
|
160 |
-
return fig
|
161 |
-
|
162 |
-
def create_importance_bar_plot(important_kmers, title):
|
163 |
-
"""
|
164 |
-
Create a simple bar chart showing the absolute gradient magnitude
|
165 |
-
for the top k-mers, sorted descending.
|
166 |
-
"""
|
167 |
-
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
|
168 |
-
kmers = [k['kmer'] for k in sorted_kmers]
|
169 |
-
impacts = [k['impact'] for k in sorted_kmers]
|
170 |
-
directions = [k["direction"] for k in sorted_kmers]
|
171 |
-
|
172 |
-
x = np.arange(len(kmers))
|
173 |
-
|
174 |
-
fig, ax = plt.subplots(figsize=(10, 5))
|
175 |
-
bar_colors = ["green" if d=="human" else "red" for d in directions]
|
176 |
-
|
177 |
-
ax.bar(x, impacts, color=bar_colors, alpha=0.7, edgecolor='black')
|
178 |
-
ax.set_xticks(x)
|
179 |
-
ax.set_xticklabels(kmers, rotation=45, ha='right')
|
180 |
-
ax.set_title(f"Absolute Feature Importance (Top k-mers) — {title}")
|
181 |
-
ax.set_ylabel("Gradient Magnitude")
|
182 |
-
ax.grid(axis="y", alpha=0.3)
|
183 |
-
|
184 |
-
plt.tight_layout()
|
185 |
-
return fig
|
186 |
-
|
187 |
-
###############################################################################
|
188 |
-
# SHAP Beeswarm
|
189 |
-
###############################################################################
|
190 |
-
def create_shap_beeswarm_plot(
|
191 |
-
model,
|
192 |
-
input_vector: np.ndarray,
|
193 |
-
background_data: np.ndarray,
|
194 |
-
feature_names: list
|
195 |
-
):
|
196 |
-
"""
|
197 |
-
Creates a SHAP beeswarm plot using KernelExplainer for the given model and data.
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
Trained PyTorch model (binary classifier).
|
203 |
-
input_vector : np.ndarray
|
204 |
-
The 1-sample input (or multiple samples) we want SHAP values for.
|
205 |
-
background_data : np.ndarray
|
206 |
-
Background samples for KernelExplainer. Should have shape (N, #features).
|
207 |
-
feature_names : list
|
208 |
-
Names for each feature (k-mers).
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
#
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
data: shape (N, #features)
|
222 |
-
returns: shape (N, 2) for 2-class logits
|
223 |
-
"""
|
224 |
-
with torch.no_grad():
|
225 |
-
x = torch.FloatTensor(data)
|
226 |
-
logits = model(x)
|
227 |
-
return logits.detach().cpu().numpy()
|
228 |
-
|
229 |
-
# Create KernelExplainer
|
230 |
-
explainer = shap.KernelExplainer(
|
231 |
-
model=predict_fn,
|
232 |
-
data=background_data
|
233 |
-
)
|
234 |
-
|
235 |
-
# Compute SHAP values
|
236 |
-
# For a 2-class model, shap_values is a list of length 2 => [class0 array, class1 array]
|
237 |
-
# Each array is shape (N, #features).
|
238 |
-
shap_values = explainer.shap_values(input_vector)
|
239 |
-
|
240 |
-
# We’ll produce a beeswarm for the 'human' class (class index=1).
|
241 |
-
# If we have only 1 sample, the beeswarm won't be too interesting, but let's do it anyway.
|
242 |
-
class_idx = 1 # 'human'
|
243 |
|
244 |
-
# If we only have one sample, place it in an array for shap summary plotting:
|
245 |
-
# We can do shap_values[class_idx].shape => (1, #features) for a single sample
|
246 |
-
# Beeswarm typically expects multiple samples. We'll plot anyway.
|
247 |
-
shap.plots.beeswarm(
|
248 |
-
shap_values[class_idx],
|
249 |
-
feature_names=feature_names,
|
250 |
-
show=False
|
251 |
-
)
|
252 |
-
|
253 |
-
fig = plt.gcf()
|
254 |
-
fig.set_size_inches(8, 6)
|
255 |
-
plt.title("SHAP Beeswarm Plot (Class: Human)")
|
256 |
-
|
257 |
plt.tight_layout()
|
258 |
return fig
|
259 |
|
260 |
-
###############################################################################
|
261 |
-
# Prediction Function
|
262 |
-
###############################################################################
|
263 |
def predict(file_obj):
|
264 |
-
"""
|
265 |
-
Main function for Gradio:
|
266 |
-
1. Reads the uploaded FASTA file or text.
|
267 |
-
2. Loads the model and scaler.
|
268 |
-
3. Generates predictions, probabilities, and top k-mers.
|
269 |
-
4. Creates multiple outputs:
|
270 |
-
- Text summary (Markdown)
|
271 |
-
- Probability Bar Plot
|
272 |
-
- SHAP Beeswarm Plot
|
273 |
-
- Frequency & σ Plot
|
274 |
-
- Absolute Feature Importance Bar Plot
|
275 |
-
"""
|
276 |
-
# 0. Basic file read
|
277 |
if file_obj is None:
|
278 |
-
return
|
279 |
-
"Please upload a FASTA file.",
|
280 |
-
None,
|
281 |
-
None,
|
282 |
-
None,
|
283 |
-
None
|
284 |
-
)
|
285 |
|
286 |
try:
|
287 |
if isinstance(file_obj, str):
|
@@ -289,202 +178,106 @@ def predict(file_obj):
|
|
289 |
else:
|
290 |
text = file_obj.decode('utf-8')
|
291 |
except Exception as e:
|
292 |
-
return (
|
293 |
-
f"Error reading file: {str(e)}",
|
294 |
-
None,
|
295 |
-
None,
|
296 |
-
None,
|
297 |
-
None
|
298 |
-
)
|
299 |
-
|
300 |
-
# 1. Parse FASTA
|
301 |
-
sequences = parse_fasta(text)
|
302 |
-
if len(sequences) == 0:
|
303 |
-
return (
|
304 |
-
"No valid FASTA sequences found. Please check your input.",
|
305 |
-
None,
|
306 |
-
None,
|
307 |
-
None,
|
308 |
-
None
|
309 |
-
)
|
310 |
-
header, seq = sequences[0] # We'll classify only the first sequence
|
311 |
|
312 |
-
# 2. Prepare model, scaler, and input
|
313 |
k = 4
|
314 |
-
|
|
|
|
|
315 |
try:
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
model = VirusClassifier(input_shape=4**k).to(device)
|
320 |
-
state_dict = torch.load("model.pt", map_location=device)
|
321 |
model.load_state_dict(state_dict)
|
322 |
-
scaler = joblib.load(
|
323 |
model.eval()
|
|
|
|
|
324 |
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
with torch.no_grad():
|
330 |
-
|
331 |
-
probs = torch.softmax(
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
#
|
338 |
-
importance, hum_prob_grad = model.get_feature_importance(X_tensor)
|
339 |
-
importances = importance[0].cpu().numpy() # shape: (#features,)
|
340 |
-
abs_importances = np.abs(importances)
|
341 |
-
|
342 |
-
# 5. Gather k-mer strings
|
343 |
-
kmers_list = [''.join(p) for p in product("ACGT", repeat=k)]
|
344 |
-
# top 10 by absolute importance
|
345 |
top_k = 10
|
346 |
-
|
|
|
347 |
important_kmers = []
|
348 |
-
for idx in
|
349 |
-
|
350 |
-
|
351 |
-
|
|
|
|
|
|
|
352 |
important_kmers.append({
|
353 |
-
'kmer':
|
354 |
-
'
|
355 |
-
'impact': abs_importances[idx],
|
356 |
'direction': direction,
|
357 |
-
'occurrence':
|
358 |
-
'sigma':
|
359 |
})
|
360 |
-
|
361 |
-
# 6. Generate text summary
|
362 |
-
text_summary = (
|
363 |
-
f"**Sequence Header**: {header}\n\n"
|
364 |
-
f"**Predicted Label**: {pred_label}\n"
|
365 |
-
f"**Confidence**: {confidence:.4f}\n\n"
|
366 |
-
f"**Human Probability**: {human_prob:.4f}\n"
|
367 |
-
f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
|
368 |
-
"### Most Influential k-mers:\n"
|
369 |
-
)
|
370 |
-
for km in important_kmers:
|
371 |
-
direction_text = f"(pushes toward {km['direction']})"
|
372 |
-
freq_text = f"{km['occurrence']:.2f}%"
|
373 |
-
sigma_text = (
|
374 |
-
f"{abs(km['sigma']):.2f}σ "
|
375 |
-
+ ("above" if km['sigma'] > 0 else "below")
|
376 |
-
+ " mean"
|
377 |
-
)
|
378 |
-
text_summary += (
|
379 |
-
f"- **{km['kmer']}**: impact={km['impact']:.4f}, {direction_text}, "
|
380 |
-
f"occurrence={freq_text}, ({sigma_text})\n"
|
381 |
-
)
|
382 |
-
|
383 |
-
# 7. Probability Bar Plot
|
384 |
-
fig_prob = create_probability_bar_plot(human_prob, non_human_prob)
|
385 |
-
buf_prob = io.BytesIO()
|
386 |
-
fig_prob.savefig(buf_prob, format='png', bbox_inches='tight', dpi=120)
|
387 |
-
buf_prob.seek(0)
|
388 |
-
prob_img = Image.open(buf_prob)
|
389 |
-
plt.close(fig_prob)
|
390 |
-
|
391 |
-
# 8. SHAP Beeswarm Plot
|
392 |
-
# We need some background data for KernelExplainer. Let's create a small random sample
|
393 |
-
# or sample from the scaled_vector itself in a repeated manner. Real usage: choose a valid background set.
|
394 |
-
background_size = 5 # keep small for speed
|
395 |
-
# We'll pick random sequences from normal(0,1) or from scaled_vector repeated
|
396 |
-
background_data = []
|
397 |
-
for _ in range(background_size):
|
398 |
-
# Option A: random small variations around scaled_vector
|
399 |
-
# new_sample = scaled_vector[0] + np.random.normal(0, 0.5, size=scaled_vector.shape[1])
|
400 |
-
# Option B: just clone the same scaled vector multiple times
|
401 |
-
new_sample = scaled_vector[0]
|
402 |
-
background_data.append(new_sample)
|
403 |
-
background_data = np.stack(background_data, axis=0) # shape (5, #features)
|
404 |
-
|
405 |
-
fig_bee = create_shap_beeswarm_plot(
|
406 |
-
model=model,
|
407 |
-
input_vector=scaled_vector, # our single sample
|
408 |
-
background_data=background_data, # background for KernelExplainer
|
409 |
-
feature_names=kmers_list
|
410 |
-
)
|
411 |
-
buf_bee = io.BytesIO()
|
412 |
-
fig_bee.savefig(buf_bee, format='png', bbox_inches='tight', dpi=120)
|
413 |
-
buf_bee.seek(0)
|
414 |
-
bee_img = Image.open(buf_bee)
|
415 |
-
plt.close(fig_bee)
|
416 |
-
|
417 |
-
# 9. Frequency & σ Plot
|
418 |
-
fig_freq = create_frequency_sigma_plot(important_kmers, header)
|
419 |
-
buf_freq = io.BytesIO()
|
420 |
-
fig_freq.savefig(buf_freq, format='png', bbox_inches='tight', dpi=120)
|
421 |
-
buf_freq.seek(0)
|
422 |
-
freq_img = Image.open(buf_freq)
|
423 |
-
plt.close(fig_freq)
|
424 |
-
|
425 |
-
# 10. Absolute Feature Importance Bar Plot
|
426 |
-
fig_imp = create_importance_bar_plot(important_kmers, header)
|
427 |
-
buf_imp = io.BytesIO()
|
428 |
-
fig_imp.savefig(buf_imp, format='png', bbox_inches='tight', dpi=120)
|
429 |
-
buf_imp.seek(0)
|
430 |
-
imp_img = Image.open(buf_imp)
|
431 |
-
plt.close(fig_imp)
|
432 |
-
|
433 |
-
return text_summary, prob_img, bee_img, freq_img, imp_img
|
434 |
-
|
435 |
-
except Exception as e:
|
436 |
-
return (
|
437 |
-
f"Error during prediction or visualization: {str(e)}",
|
438 |
-
None,
|
439 |
-
None,
|
440 |
-
None,
|
441 |
-
None
|
442 |
-
)
|
443 |
-
|
444 |
-
|
445 |
-
###############################################################################
|
446 |
-
# Gradio Interface
|
447 |
-
###############################################################################
|
448 |
-
with gr.Blocks(title="Advanced Virus Host Classifier with SHAP Beeswarm") as demo:
|
449 |
-
gr.Markdown(
|
450 |
-
"""
|
451 |
-
# Advanced Virus Host Classifier (SHAP Beeswarm Edition)
|
452 |
-
|
453 |
-
**Upload a FASTA file** containing a single nucleotide sequence.
|
454 |
-
The model will predict whether this sequence is **human** or **non-human**,
|
455 |
-
provide a confidence score, and highlight the most influential k-mers.
|
456 |
-
We also produce a **SHAP beeswarm** plot for the features.
|
457 |
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
|
468 |
-
|
469 |
-
with gr.Tabs():
|
470 |
-
with gr.Tab("Prediction Results"):
|
471 |
-
md_out = gr.Markdown()
|
472 |
-
with gr.Tab("Probability Plot"):
|
473 |
-
prob_out = gr.Image()
|
474 |
-
with gr.Tab("SHAP Beeswarm Plot"):
|
475 |
-
bee_out = gr.Image()
|
476 |
-
with gr.Tab("Frequency & σ Plot"):
|
477 |
-
freq_out = gr.Image()
|
478 |
-
with gr.Tab("Importance Bar Plot"):
|
479 |
-
imp_out = gr.Image()
|
480 |
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
|
|
487 |
|
488 |
if __name__ == "__main__":
|
489 |
-
|
490 |
-
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
|
|
2 |
import torch
|
3 |
import joblib
|
4 |
import numpy as np
|
|
|
|
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
7 |
import matplotlib.pyplot as plt
|
8 |
import io
|
9 |
from PIL import Image
|
10 |
|
|
|
|
|
|
|
11 |
class VirusClassifier(nn.Module):
|
12 |
def __init__(self, input_shape: int):
|
13 |
super(VirusClassifier, self).__init__()
|
|
|
29 |
return self.network(x)
|
30 |
|
31 |
def get_feature_importance(self, x):
|
32 |
+
"""Calculate feature importance using gradient-based method"""
|
|
|
|
|
|
|
33 |
x.requires_grad_(True)
|
34 |
output = self.network(x)
|
35 |
probs = torch.softmax(output, dim=1)
|
36 |
|
37 |
+
# Get importance for human class (index 1)
|
38 |
human_prob = probs[..., 1]
|
39 |
if x.grad is not None:
|
40 |
x.grad.zero_()
|
41 |
human_prob.backward()
|
42 |
+
importance = x.grad
|
43 |
|
44 |
return importance, float(human_prob)
|
45 |
|
46 |
+
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
47 |
+
"""Convert sequence to k-mer frequency vector"""
|
48 |
+
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
49 |
+
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
50 |
+
vec = np.zeros(len(kmers), dtype=np.float32)
|
51 |
+
|
52 |
+
for i in range(len(sequence) - k + 1):
|
53 |
+
kmer = sequence[i:i+k]
|
54 |
+
if kmer in kmer_dict:
|
55 |
+
vec[kmer_dict[kmer]] += 1
|
56 |
+
|
57 |
+
total_kmers = len(sequence) - k + 1
|
58 |
+
if total_kmers > 0:
|
59 |
+
vec = vec / total_kmers
|
60 |
+
|
61 |
+
return vec
|
62 |
+
|
63 |
def parse_fasta(text):
|
|
|
64 |
sequences = []
|
65 |
current_header = None
|
66 |
current_sequence = []
|
|
|
80 |
sequences.append((current_header, ''.join(current_sequence)))
|
81 |
return sequences
|
82 |
|
83 |
+
def create_visualization(important_kmers, human_prob, title):
|
84 |
+
"""Create a comprehensive visualization of k-mer impacts"""
|
85 |
+
fig = plt.figure(figsize=(15, 10))
|
|
|
|
|
86 |
|
87 |
+
# Create grid for subplots
|
88 |
+
gs = plt.GridSpec(2, 1, height_ratios=[1.5, 1], hspace=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
+
# 1. Probability Step Plot
|
91 |
+
ax1 = plt.subplot(gs[0])
|
92 |
+
current_prob = 0.5
|
93 |
+
steps = [('Start', current_prob, 0)]
|
94 |
+
|
95 |
+
for kmer in important_kmers:
|
96 |
+
change = kmer['impact'] * (-1 if kmer['direction'] == 'non-human' else 1)
|
97 |
+
current_prob += change
|
98 |
+
steps.append((kmer['kmer'], current_prob, change))
|
99 |
+
|
100 |
+
x = range(len(steps))
|
101 |
+
y = [step[1] for step in steps]
|
102 |
+
|
103 |
+
# Plot steps
|
104 |
+
ax1.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
|
105 |
+
ax1.plot(x, y, 'b.', markersize=10)
|
106 |
+
|
107 |
+
# Add reference line
|
108 |
+
ax1.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
|
109 |
+
|
110 |
+
# Customize plot
|
111 |
+
ax1.grid(True, linestyle='--', alpha=0.7)
|
112 |
+
ax1.set_ylim(0, 1)
|
113 |
+
ax1.set_ylabel('Human Probability')
|
114 |
+
ax1.set_title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
|
115 |
+
|
116 |
+
# Add labels for each point
|
117 |
+
for i, (kmer, prob, change) in enumerate(steps):
|
118 |
+
# Add k-mer label
|
119 |
+
ax1.annotate(kmer,
|
120 |
+
(i, prob),
|
121 |
+
xytext=(0, 10 if i % 2 == 0 else -20),
|
122 |
+
textcoords='offset points',
|
123 |
+
ha='center',
|
124 |
+
rotation=45)
|
125 |
+
|
126 |
+
# Add change value
|
127 |
+
if i > 0:
|
128 |
+
change_text = f'{change:+.3f}'
|
129 |
+
color = 'green' if change > 0 else 'red'
|
130 |
+
ax1.annotate(change_text,
|
131 |
+
(i, prob),
|
132 |
+
xytext=(0, -20 if i % 2 == 0 else 10),
|
133 |
+
textcoords='offset points',
|
134 |
+
ha='center',
|
135 |
+
color=color)
|
136 |
+
|
137 |
+
ax1.legend()
|
138 |
+
|
139 |
+
# 2. K-mer Frequency and Sigma Plot
|
140 |
+
ax2 = plt.subplot(gs[1])
|
141 |
+
|
142 |
+
# Prepare data
|
143 |
+
kmers = [k['kmer'] for k in important_kmers]
|
144 |
+
frequencies = [k['occurrence'] for k in important_kmers]
|
145 |
+
sigmas = [k['sigma'] for k in important_kmers]
|
146 |
+
colors = ['g' if k['direction'] == 'human' else 'r' for k in important_kmers]
|
147 |
+
|
148 |
+
# Create bar plot for frequencies
|
149 |
x = np.arange(len(kmers))
|
150 |
+
width = 0.35
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
+
ax2.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
|
153 |
+
ax2_twin = ax2.twinx()
|
154 |
+
ax2_twin.bar(x + width/2, sigmas, width, label='σ from mean', color=[c if s > 0 else 'gray' for c, s in zip(colors, sigmas)], alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
+
# Customize plot
|
157 |
+
ax2.set_xticks(x)
|
158 |
+
ax2.set_xticklabels(kmers, rotation=45)
|
159 |
+
ax2.set_ylabel('Frequency (%)')
|
160 |
+
ax2_twin.set_ylabel('Standard Deviations (σ) from Mean')
|
161 |
+
ax2.set_title('K-mer Frequencies and Statistical Significance')
|
162 |
+
|
163 |
+
# Add legends
|
164 |
+
lines1, labels1 = ax2.get_legend_handles_labels()
|
165 |
+
lines2, labels2 = ax2_twin.get_legend_handles_labels()
|
166 |
+
ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
plt.tight_layout()
|
169 |
return fig
|
170 |
|
|
|
|
|
|
|
171 |
def predict(file_obj):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
if file_obj is None:
|
173 |
+
return "Please upload a FASTA file", None
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
try:
|
176 |
if isinstance(file_obj, str):
|
|
|
178 |
else:
|
179 |
text = file_obj.decode('utf-8')
|
180 |
except Exception as e:
|
181 |
+
return f"Error reading file: {str(e)}", None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
|
|
183 |
k = 4
|
184 |
+
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
185 |
+
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
186 |
+
|
187 |
try:
|
188 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
189 |
+
model = VirusClassifier(256).to(device)
|
190 |
+
state_dict = torch.load('model.pt', map_location=device)
|
|
|
|
|
191 |
model.load_state_dict(state_dict)
|
192 |
+
scaler = joblib.load('scaler.pkl')
|
193 |
model.eval()
|
194 |
+
except Exception as e:
|
195 |
+
return f"Error loading model: {str(e)}", None
|
196 |
|
197 |
+
results_text = ""
|
198 |
+
plot_image = None
|
199 |
+
|
200 |
+
try:
|
201 |
+
sequences = parse_fasta(text)
|
202 |
+
header, seq = sequences[0]
|
203 |
+
|
204 |
+
raw_freq_vector = sequence_to_kmer_vector(seq)
|
205 |
+
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
206 |
+
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
207 |
+
|
208 |
+
# Get model predictions
|
209 |
with torch.no_grad():
|
210 |
+
output = model(X_tensor)
|
211 |
+
probs = torch.softmax(output, dim=1)
|
212 |
+
|
213 |
+
# Get feature importance
|
214 |
+
importance, _ = model.get_feature_importance(X_tensor)
|
215 |
+
kmer_importance = importance[0].cpu().numpy()
|
216 |
+
|
217 |
+
# Get top k-mers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
top_k = 10
|
219 |
+
top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
|
220 |
+
|
221 |
important_kmers = []
|
222 |
+
for idx in top_indices:
|
223 |
+
kmer = list(kmer_dict.keys())[list(kmer_dict.values()).index(idx)]
|
224 |
+
imp = float(abs(kmer_importance[idx]))
|
225 |
+
direction = 'human' if kmer_importance[idx] > 0 else 'non-human'
|
226 |
+
freq = float(raw_freq_vector[idx] * 100) # Convert to percentage
|
227 |
+
sigma = float(kmer_vector[0][idx])
|
228 |
+
|
229 |
important_kmers.append({
|
230 |
+
'kmer': kmer,
|
231 |
+
'impact': imp,
|
|
|
232 |
'direction': direction,
|
233 |
+
'occurrence': freq,
|
234 |
+
'sigma': sigma
|
235 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
+
# Generate text results
|
238 |
+
pred_class = 1 if probs[0][1] > probs[0][0] else 0
|
239 |
+
pred_label = 'human' if pred_class == 1 else 'non-human'
|
240 |
+
human_prob = float(probs[0][1])
|
241 |
+
|
242 |
+
results_text = f"""Sequence: {header}
|
243 |
+
Prediction: {pred_label}
|
244 |
+
Confidence: {float(max(probs[0])):0.4f}
|
245 |
+
Human probability: {human_prob:0.4f}
|
246 |
+
Non-human probability: {float(probs[0][0]):0.4f}
|
247 |
+
Most influential k-mers (ranked by importance):"""
|
248 |
+
|
249 |
+
for kmer in important_kmers:
|
250 |
+
results_text += f"\n {kmer['kmer']}: "
|
251 |
+
results_text += f"pushes toward {kmer['direction']} (impact={kmer['impact']:.4f}), "
|
252 |
+
results_text += f"occurrence={kmer['occurrence']:.2f}% of sequence "
|
253 |
+
results_text += f"(appears {abs(kmer['sigma']):.2f}σ "
|
254 |
+
results_text += "more" if kmer['sigma'] > 0 else "less"
|
255 |
+
results_text += " than average)"
|
256 |
+
|
257 |
+
# Create visualization
|
258 |
+
fig = create_visualization(important_kmers, human_prob, header)
|
259 |
+
|
260 |
+
# Save plot
|
261 |
+
buf = io.BytesIO()
|
262 |
+
fig.savefig(buf, format='png', bbox_inches='tight', dpi=300)
|
263 |
+
buf.seek(0)
|
264 |
+
plot_image = Image.open(buf)
|
265 |
+
plt.close(fig)
|
266 |
+
|
267 |
+
except Exception as e:
|
268 |
+
return f"Error processing sequences: {str(e)}", None
|
269 |
|
270 |
+
return results_text, plot_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
+
iface = gr.Interface(
|
273 |
+
fn=predict,
|
274 |
+
inputs=gr.File(label="Upload FASTA file", type="binary"),
|
275 |
+
outputs=[
|
276 |
+
gr.Textbox(label="Results"),
|
277 |
+
gr.Image(label="K-mer Analysis Visualization")
|
278 |
+
],
|
279 |
+
title="Virus Host Classifier"
|
280 |
+
)
|
281 |
|
282 |
if __name__ == "__main__":
|
283 |
+
iface.launch(share=True)
|
|