Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,10 +4,13 @@ import joblib
|
|
4 |
import numpy as np
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
|
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import io
|
9 |
from PIL import Image
|
10 |
-
import shap
|
|
|
11 |
|
12 |
###############################################################################
|
13 |
# Model Definition
|
@@ -32,13 +35,15 @@ class VirusClassifier(nn.Module):
|
|
32 |
def forward(self, x):
|
33 |
return self.network(x)
|
34 |
|
|
|
35 |
###############################################################################
|
36 |
# Torch Model Wrapper for SHAP
|
37 |
###############################################################################
|
38 |
class TorchModelWrapper:
|
39 |
"""
|
40 |
A simple callable that takes a PyTorch model and device,
|
41 |
-
|
|
|
42 |
"""
|
43 |
def __init__(self, model: nn.Module, device='cpu'):
|
44 |
self.model = model
|
@@ -87,7 +92,6 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
87 |
Convert a single nucleotide sequence to a k-mer frequency vector
|
88 |
of length 4^k (e.g., for k=4, length=256).
|
89 |
"""
|
90 |
-
from itertools import product
|
91 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
92 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
93 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
@@ -118,47 +122,54 @@ def create_freq_sigma_plot(
|
|
118 |
Creates a bar plot showing top-10 k-mers (by absolute SHAP value),
|
119 |
with frequency (%) and sigma from mean on a twin-axis.
|
120 |
|
121 |
-
single_shap_values: shape=(256,)
|
122 |
-
raw_freq_vector:
|
123 |
-
scaled_vector:
|
124 |
-
kmer_list:
|
125 |
"""
|
126 |
-
|
|
|
127 |
top_k = 10
|
128 |
-
top_indices = np.argsort(abs_vals)[-top_k:][::-1] #
|
|
|
129 |
top_data = []
|
130 |
for idx in top_indices:
|
|
|
131 |
top_data.append({
|
132 |
-
"kmer": kmer_list[
|
133 |
-
"shap": single_shap_values[
|
134 |
-
"abs_shap": abs_vals[
|
135 |
-
"frequency": raw_freq_vector[
|
136 |
-
"sigma": scaled_vector[
|
137 |
})
|
138 |
|
139 |
# Sort top_data by abs_shap descending
|
140 |
top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
|
141 |
|
|
|
142 |
kmers = [d["kmer"] for d in top_data]
|
143 |
freqs = [d["frequency"] for d in top_data]
|
144 |
sigmas = [d["sigma"] for d in top_data]
|
145 |
-
# color by sign (positive=green, negative=red)
|
146 |
colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
|
147 |
|
148 |
-
import matplotlib.pyplot as plt
|
149 |
x = np.arange(len(kmers))
|
150 |
width = 0.4
|
151 |
|
152 |
fig, ax = plt.subplots(figsize=(8, 5))
|
153 |
# Frequency
|
154 |
-
ax.bar(
|
|
|
|
|
155 |
ax.set_ylabel("Frequency (%)", color='black')
|
156 |
-
if freqs:
|
157 |
ax.set_ylim(0, max(freqs)*1.2)
|
158 |
|
159 |
# Twin axis for sigma
|
160 |
ax2 = ax.twinx()
|
161 |
-
ax2.bar(
|
|
|
|
|
162 |
ax2.set_ylabel("Standard Deviations (σ)", color='black')
|
163 |
|
164 |
ax.set_xticks(x)
|
@@ -182,9 +193,9 @@ def run_classification_and_shap(file_obj):
|
|
182 |
Reads one or more FASTA sequences from file_obj or text.
|
183 |
Returns:
|
184 |
- Table of results (list of dicts) for each sequence
|
185 |
-
- shap_values object (SHAP values for the entire batch)
|
186 |
-
- array
|
187 |
-
- list of k-mers
|
188 |
- error message or None
|
189 |
"""
|
190 |
# 1. Basic read
|
@@ -194,12 +205,12 @@ def run_classification_and_shap(file_obj):
|
|
194 |
try:
|
195 |
text = file_obj.decode("utf-8")
|
196 |
except Exception as e:
|
197 |
-
return None, None, f"Error reading file: {str(e)}"
|
198 |
|
199 |
# 2. Parse FASTA
|
200 |
sequences = parse_fasta(text)
|
201 |
if len(sequences) == 0:
|
202 |
-
return None, None, "No valid FASTA sequences found!"
|
203 |
|
204 |
# 3. Convert each sequence to k-mer vector
|
205 |
k = 4
|
@@ -219,15 +230,14 @@ def run_classification_and_shap(file_obj):
|
|
219 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
220 |
|
221 |
model = VirusClassifier(input_shape=4**k).to(device)
|
222 |
-
#
|
223 |
state_dict = torch.load("model.pt", map_location=device, weights_only=True)
|
224 |
model.load_state_dict(state_dict)
|
225 |
model.eval()
|
226 |
|
227 |
scaler = joblib.load("scaler.pkl")
|
228 |
-
|
229 |
except Exception as e:
|
230 |
-
return None, None, f"Error loading model or scaler: {str(e)}"
|
231 |
|
232 |
# 5. Scale data
|
233 |
scaled_data = scaler.transform(all_raw_vectors) # shape=(num_seqs, 256)
|
@@ -236,6 +246,7 @@ def run_classification_and_shap(file_obj):
|
|
236 |
X_tensor = torch.FloatTensor(scaled_data).to(device)
|
237 |
with torch.no_grad():
|
238 |
logits = model(X_tensor)
|
|
|
239 |
probs = torch.softmax(logits, dim=1).cpu().numpy()
|
240 |
preds = np.argmax(probs, axis=1) # 0 or 1
|
241 |
|
@@ -243,29 +254,30 @@ def run_classification_and_shap(file_obj):
|
|
243 |
for i, (hdr, seq) in enumerate(zip(headers, seqs)):
|
244 |
results_table.append({
|
245 |
"header": hdr,
|
246 |
-
"sequence": seq[:50] + ("..." if len(seq)>50 else ""),
|
247 |
"pred_label": "human" if preds[i] == 1 else "non-human",
|
248 |
"human_prob": float(probs[i][1]),
|
249 |
"non_human_prob": float(probs[i][0]),
|
250 |
-
"confidence": float(max(probs[i]))
|
251 |
})
|
252 |
|
253 |
# 7. SHAP Explainer
|
254 |
-
#
|
255 |
if scaled_data.shape[0] > 50:
|
256 |
background_data = scaled_data[:50]
|
257 |
else:
|
258 |
background_data = scaled_data
|
259 |
|
260 |
-
# Wrap the model so it can handle numpy -> tensor
|
261 |
wrapped_model = TorchModelWrapper(model, device)
|
262 |
explainer = shap.Explainer(wrapped_model, background_data)
|
263 |
-
shap_values
|
|
|
|
|
264 |
|
265 |
-
# k-mer list
|
266 |
-
from itertools import product
|
267 |
kmer_list = [''.join(p) for p in product("ACGT", repeat=k)]
|
268 |
|
|
|
269 |
return (results_table, shap_values, scaled_data, kmer_list, None)
|
270 |
|
271 |
|
@@ -274,9 +286,8 @@ def run_classification_and_shap(file_obj):
|
|
274 |
###############################################################################
|
275 |
def main_predict(file_obj):
|
276 |
"""
|
277 |
-
|
278 |
-
|
279 |
-
the shap values plus data needed for subsequent plots.
|
280 |
"""
|
281 |
results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
|
282 |
if err:
|
@@ -294,28 +305,44 @@ def main_predict(file_obj):
|
|
294 |
f"| {i} | {row['header']} | {row['pred_label']} | "
|
295 |
f"{row['confidence']:.4f} | {row['human_prob']:.4f} | {row['non_human_prob']:.4f} |\n"
|
296 |
)
|
297 |
-
md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots."
|
298 |
|
299 |
return (md, shap_vals, scaled_data, kmer_list, results)
|
300 |
|
|
|
301 |
def update_waterfall_plot(selected_index, shap_values_obj):
|
302 |
"""
|
303 |
-
Build a waterfall plot for the user-selected sample
|
|
|
|
|
|
|
304 |
"""
|
305 |
if shap_values_obj is None:
|
306 |
return None
|
307 |
|
308 |
import matplotlib.pyplot as plt
|
309 |
-
import shap
|
310 |
|
311 |
try:
|
312 |
selected_index = int(selected_index)
|
313 |
except:
|
314 |
selected_index = 0
|
315 |
|
316 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
shap_plots_fig = plt.figure(figsize=(8, 5))
|
318 |
-
shap.plots.waterfall(
|
319 |
buf = io.BytesIO()
|
320 |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
321 |
buf.seek(0)
|
@@ -324,18 +351,35 @@ def update_waterfall_plot(selected_index, shap_values_obj):
|
|
324 |
|
325 |
return wf_img
|
326 |
|
|
|
327 |
def update_beeswarm_plot(shap_values_obj):
|
328 |
"""
|
329 |
-
Build a beeswarm plot across all samples.
|
|
|
|
|
330 |
"""
|
331 |
if shap_values_obj is None:
|
332 |
return None
|
333 |
|
334 |
import matplotlib.pyplot as plt
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
beeswarm_fig = plt.figure(figsize=(8, 5))
|
338 |
-
shap.plots.beeswarm(
|
339 |
buf = io.BytesIO()
|
340 |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
341 |
buf.seek(0)
|
@@ -344,14 +388,17 @@ def update_beeswarm_plot(shap_values_obj):
|
|
344 |
|
345 |
return bs_img
|
346 |
|
|
|
347 |
def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
|
348 |
"""
|
349 |
-
Create the frequency &
|
350 |
-
|
351 |
"""
|
352 |
if shap_values_obj is None or scaled_data is None or kmer_list is None:
|
353 |
return None
|
354 |
|
|
|
|
|
355 |
try:
|
356 |
selected_index = int(selected_index)
|
357 |
except:
|
@@ -364,20 +411,23 @@ def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, fi
|
|
364 |
text = file_obj.decode('utf-8')
|
365 |
|
366 |
sequences = parse_fasta(text)
|
|
|
367 |
if selected_index >= len(sequences):
|
368 |
selected_index = 0
|
369 |
|
370 |
seq = sequences[selected_index][1]
|
371 |
-
raw_vec = sequence_to_kmer_vector(seq, k=4)
|
372 |
|
373 |
-
|
|
|
374 |
freq_sigma_fig = create_freq_sigma_plot(
|
375 |
-
single_shap_values,
|
376 |
-
raw_freq_vector=raw_vec,
|
377 |
scaled_vector=scaled_data[selected_index],
|
378 |
kmer_list=kmer_list,
|
379 |
title=f"Sample #{selected_index} — {sequences[selected_index][0]}"
|
380 |
)
|
|
|
381 |
buf = io.BytesIO()
|
382 |
freq_sigma_fig.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
383 |
buf.seek(0)
|
@@ -391,16 +441,19 @@ def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, fi
|
|
391 |
# Gradio Interface
|
392 |
###############################################################################
|
393 |
with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
|
394 |
-
shap.initjs() # load shap JS if needed for
|
395 |
|
396 |
gr.Markdown(
|
397 |
"""
|
398 |
-
# **
|
399 |
-
|
400 |
This app will:
|
401 |
1. Predict each sequence's **host** (human vs. non-human).
|
402 |
-
2. Provide **SHAP** explanations
|
403 |
-
3.
|
|
|
|
|
|
|
404 |
"""
|
405 |
)
|
406 |
|
@@ -408,23 +461,20 @@ with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
|
|
408 |
file_input = gr.File(label="Upload FASTA", type="binary")
|
409 |
run_btn = gr.Button("Run Classification")
|
410 |
|
411 |
-
# Store intermediate results in
|
412 |
shap_values_state = gr.State()
|
413 |
scaled_data_state = gr.State()
|
414 |
kmer_list_state = gr.State()
|
415 |
results_state = gr.State()
|
416 |
-
# We'll also store the "raw input" so we can reconstruct freq data for each sample
|
417 |
file_data_state = gr.State()
|
418 |
|
419 |
-
# TABS for outputs
|
420 |
with gr.Tabs():
|
421 |
with gr.Tab("Results Table"):
|
422 |
md_out = gr.Markdown()
|
423 |
|
424 |
with gr.Tab("SHAP Waterfall"):
|
425 |
-
# We'll let user pick the sequence index from a dropdown or input
|
426 |
with gr.Row():
|
427 |
-
|
428 |
update_wf_btn = gr.Button("Update Waterfall")
|
429 |
|
430 |
wf_plot = gr.Image(label="SHAP Waterfall Plot")
|
@@ -434,44 +484,43 @@ with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
|
|
434 |
|
435 |
with gr.Tab("Top-10 Frequency & Sigma"):
|
436 |
with gr.Row():
|
437 |
-
|
438 |
update_fs_btn = gr.Button("Update Frequency Chart")
|
439 |
fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
|
440 |
|
441 |
-
#
|
442 |
-
# 1) The main classification run
|
443 |
run_btn.click(
|
444 |
fn=main_predict,
|
445 |
inputs=[file_input],
|
446 |
outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state]
|
447 |
)
|
448 |
-
# Also store raw file data for subsequent freq usage
|
449 |
run_btn.click(
|
450 |
fn=lambda x: x,
|
451 |
inputs=file_input,
|
452 |
outputs=file_data_state
|
453 |
)
|
454 |
|
455 |
-
# 2) Waterfall
|
456 |
update_wf_btn.click(
|
457 |
fn=update_waterfall_plot,
|
458 |
-
inputs=[
|
459 |
outputs=[wf_plot]
|
460 |
)
|
461 |
|
462 |
-
# 3) Beeswarm
|
463 |
run_btn.click(
|
464 |
fn=update_beeswarm_plot,
|
465 |
inputs=[shap_values_state],
|
466 |
outputs=[bs_plot]
|
467 |
)
|
468 |
|
469 |
-
# 4) Frequency
|
470 |
update_fs_btn.click(
|
471 |
fn=update_freq_plot,
|
472 |
-
inputs=[
|
473 |
outputs=[fs_plot]
|
474 |
)
|
475 |
|
476 |
if __name__ == "__main__":
|
477 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
|
|
|
4 |
import numpy as np
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
7 |
+
import matplotlib
|
8 |
+
matplotlib.use("Agg") # In case we're running in a no-display environment
|
9 |
import matplotlib.pyplot as plt
|
10 |
import io
|
11 |
from PIL import Image
|
12 |
+
import shap
|
13 |
+
|
14 |
|
15 |
###############################################################################
|
16 |
# Model Definition
|
|
|
35 |
def forward(self, x):
|
36 |
return self.network(x)
|
37 |
|
38 |
+
|
39 |
###############################################################################
|
40 |
# Torch Model Wrapper for SHAP
|
41 |
###############################################################################
|
42 |
class TorchModelWrapper:
|
43 |
"""
|
44 |
A simple callable that takes a PyTorch model and device,
|
45 |
+
allowing SHAP to pass in NumPy arrays. We convert them
|
46 |
+
to torch tensors, run the model, and return NumPy outputs.
|
47 |
"""
|
48 |
def __init__(self, model: nn.Module, device='cpu'):
|
49 |
self.model = model
|
|
|
92 |
Convert a single nucleotide sequence to a k-mer frequency vector
|
93 |
of length 4^k (e.g., for k=4, length=256).
|
94 |
"""
|
|
|
95 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
96 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
97 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
|
|
122 |
Creates a bar plot showing top-10 k-mers (by absolute SHAP value),
|
123 |
with frequency (%) and sigma from mean on a twin-axis.
|
124 |
|
125 |
+
single_shap_values: shape=(256,) SHAP values for the "human" class
|
126 |
+
raw_freq_vector: shape=(256,) original frequencies for this sample
|
127 |
+
scaled_vector: shape=(256,) scaled (Z-score) values for this sample
|
128 |
+
kmer_list: list of length=256 of all k-mers
|
129 |
"""
|
130 |
+
# Identify the top 10 k-mers by absolute shap
|
131 |
+
abs_vals = np.abs(single_shap_values) # shape=(256,)
|
132 |
top_k = 10
|
133 |
+
top_indices = np.argsort(abs_vals)[-top_k:][::-1] # indices of largest -> smallest
|
134 |
+
|
135 |
top_data = []
|
136 |
for idx in top_indices:
|
137 |
+
idx_int = int(idx) # ensure integer
|
138 |
top_data.append({
|
139 |
+
"kmer": kmer_list[idx_int],
|
140 |
+
"shap": single_shap_values[idx_int],
|
141 |
+
"abs_shap": abs_vals[idx_int],
|
142 |
+
"frequency": raw_freq_vector[idx_int] * 100.0, # percentage
|
143 |
+
"sigma": scaled_vector[idx_int]
|
144 |
})
|
145 |
|
146 |
# Sort top_data by abs_shap descending
|
147 |
top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
|
148 |
|
149 |
+
# Prepare for plotting
|
150 |
kmers = [d["kmer"] for d in top_data]
|
151 |
freqs = [d["frequency"] for d in top_data]
|
152 |
sigmas = [d["sigma"] for d in top_data]
|
153 |
+
# color by sign (positive=green => pushes "human", negative=red => pushes "non-human")
|
154 |
colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
|
155 |
|
|
|
156 |
x = np.arange(len(kmers))
|
157 |
width = 0.4
|
158 |
|
159 |
fig, ax = plt.subplots(figsize=(8, 5))
|
160 |
# Frequency
|
161 |
+
ax.bar(
|
162 |
+
x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)"
|
163 |
+
)
|
164 |
ax.set_ylabel("Frequency (%)", color='black')
|
165 |
+
if len(freqs) > 0:
|
166 |
ax.set_ylim(0, max(freqs)*1.2)
|
167 |
|
168 |
# Twin axis for sigma
|
169 |
ax2 = ax.twinx()
|
170 |
+
ax2.bar(
|
171 |
+
x + width/2, sigmas, width, color="gray", alpha=0.5, label="σ from Mean"
|
172 |
+
)
|
173 |
ax2.set_ylabel("Standard Deviations (σ)", color='black')
|
174 |
|
175 |
ax.set_xticks(x)
|
|
|
193 |
Reads one or more FASTA sequences from file_obj or text.
|
194 |
Returns:
|
195 |
- Table of results (list of dicts) for each sequence
|
196 |
+
- shap_values object (SHAP values for the entire batch, shape=(num_samples, 2, num_features))
|
197 |
+
- array of scaled vectors
|
198 |
+
- list of k-mers
|
199 |
- error message or None
|
200 |
"""
|
201 |
# 1. Basic read
|
|
|
205 |
try:
|
206 |
text = file_obj.decode("utf-8")
|
207 |
except Exception as e:
|
208 |
+
return None, None, None, None, f"Error reading file: {str(e)}"
|
209 |
|
210 |
# 2. Parse FASTA
|
211 |
sequences = parse_fasta(text)
|
212 |
if len(sequences) == 0:
|
213 |
+
return None, None, None, None, "No valid FASTA sequences found!"
|
214 |
|
215 |
# 3. Convert each sequence to k-mer vector
|
216 |
k = 4
|
|
|
230 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
231 |
|
232 |
model = VirusClassifier(input_shape=4**k).to(device)
|
233 |
+
# Use weights_only=True to suppress future warnings about untrusted pickles
|
234 |
state_dict = torch.load("model.pt", map_location=device, weights_only=True)
|
235 |
model.load_state_dict(state_dict)
|
236 |
model.eval()
|
237 |
|
238 |
scaler = joblib.load("scaler.pkl")
|
|
|
239 |
except Exception as e:
|
240 |
+
return None, None, None, None, f"Error loading model or scaler: {str(e)}"
|
241 |
|
242 |
# 5. Scale data
|
243 |
scaled_data = scaler.transform(all_raw_vectors) # shape=(num_seqs, 256)
|
|
|
246 |
X_tensor = torch.FloatTensor(scaled_data).to(device)
|
247 |
with torch.no_grad():
|
248 |
logits = model(X_tensor)
|
249 |
+
# shape=(num_seqs, 2)
|
250 |
probs = torch.softmax(logits, dim=1).cpu().numpy()
|
251 |
preds = np.argmax(probs, axis=1) # 0 or 1
|
252 |
|
|
|
254 |
for i, (hdr, seq) in enumerate(zip(headers, seqs)):
|
255 |
results_table.append({
|
256 |
"header": hdr,
|
257 |
+
"sequence": seq[:50] + ("..." if len(seq) > 50 else ""),
|
258 |
"pred_label": "human" if preds[i] == 1 else "non-human",
|
259 |
"human_prob": float(probs[i][1]),
|
260 |
"non_human_prob": float(probs[i][0]),
|
261 |
+
"confidence": float(np.max(probs[i]))
|
262 |
})
|
263 |
|
264 |
# 7. SHAP Explainer
|
265 |
+
# For large data, pick a smaller background subset
|
266 |
if scaled_data.shape[0] > 50:
|
267 |
background_data = scaled_data[:50]
|
268 |
else:
|
269 |
background_data = scaled_data
|
270 |
|
|
|
271 |
wrapped_model = TorchModelWrapper(model, device)
|
272 |
explainer = shap.Explainer(wrapped_model, background_data)
|
273 |
+
# shap_values shape=(num_samples, num_features) if single-output
|
274 |
+
# but here we have 2 outputs => shape=(num_samples, 2, num_features).
|
275 |
+
shap_values = explainer(scaled_data)
|
276 |
|
277 |
+
# Prepare k-mer list
|
|
|
278 |
kmer_list = [''.join(p) for p in product("ACGT", repeat=k)]
|
279 |
|
280 |
+
# Return everything
|
281 |
return (results_table, shap_values, scaled_data, kmer_list, None)
|
282 |
|
283 |
|
|
|
286 |
###############################################################################
|
287 |
def main_predict(file_obj):
|
288 |
"""
|
289 |
+
Triggered by the 'Run Classification' button in Gradio.
|
290 |
+
Returns a markdown table plus states for subsequent plots.
|
|
|
291 |
"""
|
292 |
results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
|
293 |
if err:
|
|
|
305 |
f"| {i} | {row['header']} | {row['pred_label']} | "
|
306 |
f"{row['confidence']:.4f} | {row['human_prob']:.4f} | {row['non_human_prob']:.4f} |\n"
|
307 |
)
|
308 |
+
md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots (class=1/human)."
|
309 |
|
310 |
return (md, shap_vals, scaled_data, kmer_list, results)
|
311 |
|
312 |
+
|
313 |
def update_waterfall_plot(selected_index, shap_values_obj):
|
314 |
"""
|
315 |
+
Build a waterfall plot for the user-selected sample, but ONLY for class=1 (human).
|
316 |
+
shap_values_obj has shape=(num_samples, 2, num_features).
|
317 |
+
We do shap_values_obj[selected_index, 1] => shape=(num_features,)
|
318 |
+
for a single-sample single-class explanation.
|
319 |
"""
|
320 |
if shap_values_obj is None:
|
321 |
return None
|
322 |
|
323 |
import matplotlib.pyplot as plt
|
|
|
324 |
|
325 |
try:
|
326 |
selected_index = int(selected_index)
|
327 |
except:
|
328 |
selected_index = 0
|
329 |
|
330 |
+
# We only visualize class=1 ("human") SHAP values
|
331 |
+
# shap_values_obj.values shape => (num_samples, 2, num_features)
|
332 |
+
single_ex_values = shap_values_obj.values[selected_index, 1, :] # shape=(256,)
|
333 |
+
single_ex_base = shap_values_obj.base_values[selected_index, 1] # scalar
|
334 |
+
single_ex_data = shap_values_obj.data[selected_index] # shape=(256,)
|
335 |
+
|
336 |
+
# Construct a shap.Explanation object for just this one sample & class
|
337 |
+
single_expl = shap.Explanation(
|
338 |
+
values=single_ex_values,
|
339 |
+
base_values=single_ex_base,
|
340 |
+
data=single_ex_data,
|
341 |
+
feature_names=[f"feat_{i}" for i in range(single_ex_values.shape[0])]
|
342 |
+
)
|
343 |
+
|
344 |
shap_plots_fig = plt.figure(figsize=(8, 5))
|
345 |
+
shap.plots.waterfall(single_expl, max_display=14, show=False)
|
346 |
buf = io.BytesIO()
|
347 |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
348 |
buf.seek(0)
|
|
|
351 |
|
352 |
return wf_img
|
353 |
|
354 |
+
|
355 |
def update_beeswarm_plot(shap_values_obj):
|
356 |
"""
|
357 |
+
Build a beeswarm plot across all samples, but only for class=1 (human).
|
358 |
+
We slice shap_values_obj to pick shap_values_obj.values[:, 1, :]
|
359 |
+
=> shape=(num_samples, num_features).
|
360 |
"""
|
361 |
if shap_values_obj is None:
|
362 |
return None
|
363 |
|
364 |
import matplotlib.pyplot as plt
|
365 |
+
|
366 |
+
# For multi-output, shap_values_obj.values shape => (num_samples, 2, num_features)
|
367 |
+
# We'll create a new Explanation object for class=1:
|
368 |
+
class1_vals = shap_values_obj.values[:, 1, :] # shape=(num_samples, num_features)
|
369 |
+
class1_base = shap_values_obj.base_values[:, 1] # shape=(num_samples,)
|
370 |
+
class1_data = shap_values_obj.data # shape=(num_samples, num_features)
|
371 |
+
|
372 |
+
# Some versions of shap store data in a 2D array, which is fine
|
373 |
+
# We'll re-wrap them in a shap.Explanation:
|
374 |
+
class1_expl = shap.Explanation(
|
375 |
+
values=class1_vals,
|
376 |
+
base_values=class1_base,
|
377 |
+
data=class1_data,
|
378 |
+
feature_names=[f"feat_{i}" for i in range(class1_vals.shape[1])]
|
379 |
+
)
|
380 |
|
381 |
beeswarm_fig = plt.figure(figsize=(8, 5))
|
382 |
+
shap.plots.beeswarm(class1_expl, show=False)
|
383 |
buf = io.BytesIO()
|
384 |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
385 |
buf.seek(0)
|
|
|
388 |
|
389 |
return bs_img
|
390 |
|
391 |
+
|
392 |
def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
|
393 |
"""
|
394 |
+
Create the frequency & σ bar chart for the selected sequence's top-10 k-mers (by abs SHAP).
|
395 |
+
Again, we'll use class=1 SHAP values only.
|
396 |
"""
|
397 |
if shap_values_obj is None or scaled_data is None or kmer_list is None:
|
398 |
return None
|
399 |
|
400 |
+
import matplotlib.pyplot as plt
|
401 |
+
|
402 |
try:
|
403 |
selected_index = int(selected_index)
|
404 |
except:
|
|
|
411 |
text = file_obj.decode('utf-8')
|
412 |
|
413 |
sequences = parse_fasta(text)
|
414 |
+
# If out of range, clamp to 0
|
415 |
if selected_index >= len(sequences):
|
416 |
selected_index = 0
|
417 |
|
418 |
seq = sequences[selected_index][1]
|
419 |
+
raw_vec = sequence_to_kmer_vector(seq, k=4) # shape=(256,)
|
420 |
|
421 |
+
# SHAP for class=1 => shape=(num_samples, 2, 256)
|
422 |
+
single_shap_values = shap_values_obj.values[selected_index, 1, :]
|
423 |
freq_sigma_fig = create_freq_sigma_plot(
|
424 |
+
single_shap_values,
|
425 |
+
raw_freq_vector=raw_vec,
|
426 |
scaled_vector=scaled_data[selected_index],
|
427 |
kmer_list=kmer_list,
|
428 |
title=f"Sample #{selected_index} — {sequences[selected_index][0]}"
|
429 |
)
|
430 |
+
|
431 |
buf = io.BytesIO()
|
432 |
freq_sigma_fig.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
433 |
buf.seek(0)
|
|
|
441 |
# Gradio Interface
|
442 |
###############################################################################
|
443 |
with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
|
444 |
+
shap.initjs() # load shap JS if needed for HTML-based plots (optional)
|
445 |
|
446 |
gr.Markdown(
|
447 |
"""
|
448 |
+
# **irus Host Classifier**
|
449 |
+
Upload a FASTA file with one or more nucleotide sequences.
|
450 |
This app will:
|
451 |
1. Predict each sequence's **host** (human vs. non-human).
|
452 |
+
2. Provide **SHAP** explanations focusing on the 'human' class (index=1).
|
453 |
+
3. Display:
|
454 |
+
- A **waterfall** plot per-sequence (top features).
|
455 |
+
- A **beeswarm** plot across all sequences (global summary).
|
456 |
+
- A **frequency & σ** bar chart for the top-10 k-mers of any selected sequence.
|
457 |
"""
|
458 |
)
|
459 |
|
|
|
461 |
file_input = gr.File(label="Upload FASTA", type="binary")
|
462 |
run_btn = gr.Button("Run Classification")
|
463 |
|
464 |
+
# Store intermediate results in Gradio states
|
465 |
shap_values_state = gr.State()
|
466 |
scaled_data_state = gr.State()
|
467 |
kmer_list_state = gr.State()
|
468 |
results_state = gr.State()
|
|
|
469 |
file_data_state = gr.State()
|
470 |
|
|
|
471 |
with gr.Tabs():
|
472 |
with gr.Tab("Results Table"):
|
473 |
md_out = gr.Markdown()
|
474 |
|
475 |
with gr.Tab("SHAP Waterfall"):
|
|
|
476 |
with gr.Row():
|
477 |
+
seq_index_input = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
|
478 |
update_wf_btn = gr.Button("Update Waterfall")
|
479 |
|
480 |
wf_plot = gr.Image(label="SHAP Waterfall Plot")
|
|
|
484 |
|
485 |
with gr.Tab("Top-10 Frequency & Sigma"):
|
486 |
with gr.Row():
|
487 |
+
seq_index_input2 = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
|
488 |
update_fs_btn = gr.Button("Update Frequency Chart")
|
489 |
fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
|
490 |
|
491 |
+
# 1) Main classification
|
|
|
492 |
run_btn.click(
|
493 |
fn=main_predict,
|
494 |
inputs=[file_input],
|
495 |
outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state]
|
496 |
)
|
|
|
497 |
run_btn.click(
|
498 |
fn=lambda x: x,
|
499 |
inputs=file_input,
|
500 |
outputs=file_data_state
|
501 |
)
|
502 |
|
503 |
+
# 2) Update Waterfall
|
504 |
update_wf_btn.click(
|
505 |
fn=update_waterfall_plot,
|
506 |
+
inputs=[seq_index_input, shap_values_state],
|
507 |
outputs=[wf_plot]
|
508 |
)
|
509 |
|
510 |
+
# 3) Update Beeswarm right after classification
|
511 |
run_btn.click(
|
512 |
fn=update_beeswarm_plot,
|
513 |
inputs=[shap_values_state],
|
514 |
outputs=[bs_plot]
|
515 |
)
|
516 |
|
517 |
+
# 4) Update Frequency & σ
|
518 |
update_fs_btn.click(
|
519 |
fn=update_freq_plot,
|
520 |
+
inputs=[seq_index_input2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state],
|
521 |
outputs=[fs_plot]
|
522 |
)
|
523 |
|
524 |
if __name__ == "__main__":
|
525 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
526 |
+
|