Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -32,6 +32,28 @@ class VirusClassifier(nn.Module):
|
|
32 |
def forward(self, x):
|
33 |
return self.network(x)
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
###############################################################################
|
37 |
# Utility Functions
|
@@ -65,6 +87,7 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
65 |
Convert a single nucleotide sequence to a k-mer frequency vector
|
66 |
of length 4^k (e.g., for k=4, length=256).
|
67 |
"""
|
|
|
68 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
69 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
70 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
@@ -122,6 +145,7 @@ def create_freq_sigma_plot(
|
|
122 |
# color by sign (positive=green, negative=red)
|
123 |
colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
|
124 |
|
|
|
125 |
x = np.arange(len(kmers))
|
126 |
width = 0.4
|
127 |
|
@@ -129,7 +153,8 @@ def create_freq_sigma_plot(
|
|
129 |
# Frequency
|
130 |
ax.bar(x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)")
|
131 |
ax.set_ylabel("Frequency (%)", color='black')
|
132 |
-
|
|
|
133 |
|
134 |
# Twin axis for sigma
|
135 |
ax2 = ax.twinx()
|
@@ -160,7 +185,7 @@ def run_classification_and_shap(file_obj):
|
|
160 |
- shap_values object (SHAP values for the entire batch)
|
161 |
- array/batch of scaled vectors (for use in the waterfall selection)
|
162 |
- list of k-mers (for indexing)
|
163 |
-
-
|
164 |
"""
|
165 |
# 1. Basic read
|
166 |
if isinstance(file_obj, str):
|
@@ -192,12 +217,15 @@ def run_classification_and_shap(file_obj):
|
|
192 |
# 4. Load model & scaler
|
193 |
try:
|
194 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
195 |
model = VirusClassifier(input_shape=4**k).to(device)
|
196 |
-
|
|
|
197 |
model.load_state_dict(state_dict)
|
198 |
model.eval()
|
199 |
|
200 |
scaler = joblib.load("scaler.pkl")
|
|
|
201 |
except Exception as e:
|
202 |
return None, None, f"Error loading model or scaler: {str(e)}"
|
203 |
|
@@ -224,20 +252,18 @@ def run_classification_and_shap(file_obj):
|
|
224 |
|
225 |
# 7. SHAP Explainer
|
226 |
# We'll pick a background subset if there are many sequences
|
227 |
-
# (For performance, we might limit to e.g. 50 samples max)
|
228 |
if scaled_data.shape[0] > 50:
|
229 |
background_data = scaled_data[:50]
|
230 |
else:
|
231 |
background_data = scaled_data
|
232 |
|
233 |
-
#
|
234 |
-
|
235 |
-
|
236 |
-
# We'll do a simple "use shap.Explainer" with data=background_data
|
237 |
-
explainer = shap.Explainer(model, background_data)
|
238 |
shap_values = explainer(scaled_data) # shape=(num_samples, num_features)
|
239 |
|
240 |
# k-mer list
|
|
|
241 |
kmer_list = [''.join(p) for p in product("ACGT", repeat=k)]
|
242 |
|
243 |
return (results_table, shap_values, scaled_data, kmer_list, None)
|
@@ -249,8 +275,8 @@ def run_classification_and_shap(file_obj):
|
|
249 |
def main_predict(file_obj):
|
250 |
"""
|
251 |
This function is triggered by the 'Run' button in Gradio.
|
252 |
-
It returns a markdown of all sequences/predictions and
|
253 |
-
data needed for
|
254 |
"""
|
255 |
results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
|
256 |
if err:
|
@@ -270,32 +296,26 @@ def main_predict(file_obj):
|
|
270 |
)
|
271 |
md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots."
|
272 |
|
273 |
-
# Return the string, and also the shap values plus data needed
|
274 |
-
# We'll store these to SessionState via Gradio's "State" or we can
|
275 |
-
# pass them out as hidden fields.
|
276 |
return (md, shap_vals, scaled_data, kmer_list, results)
|
277 |
|
278 |
-
|
279 |
def update_waterfall_plot(selected_index, shap_values_obj):
|
280 |
"""
|
281 |
-
Build a waterfall plot for the user-selected sample.
|
282 |
"""
|
283 |
if shap_values_obj is None:
|
284 |
return None
|
285 |
|
|
|
|
|
|
|
286 |
try:
|
287 |
selected_index = int(selected_index)
|
288 |
except:
|
289 |
selected_index = 0
|
290 |
|
291 |
-
#
|
292 |
-
# Convert shap_values_obj to the new shap interface
|
293 |
-
# shap_values_obj is a shap._explanation.Explanation typically
|
294 |
-
|
295 |
-
# We can create a figure with shap.plots.waterfall and capture it as an image
|
296 |
shap_plots_fig = plt.figure(figsize=(8, 5))
|
297 |
-
shap.plots.waterfall(shap_values_obj[selected_index], max_display=14,
|
298 |
-
show=False) # show=False so it doesn't pop in the notebook
|
299 |
buf = io.BytesIO()
|
300 |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
301 |
buf.seek(0)
|
@@ -304,7 +324,6 @@ def update_waterfall_plot(selected_index, shap_values_obj):
|
|
304 |
|
305 |
return wf_img
|
306 |
|
307 |
-
|
308 |
def update_beeswarm_plot(shap_values_obj):
|
309 |
"""
|
310 |
Build a beeswarm plot across all samples.
|
@@ -312,6 +331,9 @@ def update_beeswarm_plot(shap_values_obj):
|
|
312 |
if shap_values_obj is None:
|
313 |
return None
|
314 |
|
|
|
|
|
|
|
315 |
beeswarm_fig = plt.figure(figsize=(8, 5))
|
316 |
shap.plots.beeswarm(shap_values_obj, show=False)
|
317 |
buf = io.BytesIO()
|
@@ -322,11 +344,10 @@ def update_beeswarm_plot(shap_values_obj):
|
|
322 |
|
323 |
return bs_img
|
324 |
|
325 |
-
|
326 |
def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
|
327 |
"""
|
328 |
Create the frequency & sigma bar chart for the selected sequence's top-10 k-mers.
|
329 |
-
We
|
330 |
"""
|
331 |
if shap_values_obj is None or scaled_data is None or kmer_list is None:
|
332 |
return None
|
@@ -336,23 +357,17 @@ def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, fi
|
|
336 |
except:
|
337 |
selected_index = 0
|
338 |
|
339 |
-
#
|
340 |
-
# or store it from earlier. Let's just re-run parse for that single sequence:
|
341 |
-
# But simpler is: run_classification_and_shap was storing all_raw_vectors...
|
342 |
-
# Let's do a quick approach: run_classification_and_shap already computed it
|
343 |
-
# but we didn't store it. We'll re-run the parse logic to get the raw freq again.
|
344 |
-
|
345 |
-
# For memory / speed reasons, better is to store it.
|
346 |
-
# For simplicity, let's parse again quickly:
|
347 |
if isinstance(file_obj, str):
|
348 |
text = file_obj
|
349 |
else:
|
350 |
text = file_obj.decode('utf-8')
|
|
|
351 |
sequences = parse_fasta(text)
|
352 |
-
# the selected_index might be out of range, so let's clamp it
|
353 |
if selected_index >= len(sequences):
|
354 |
selected_index = 0
|
355 |
-
|
|
|
356 |
raw_vec = sequence_to_kmer_vector(seq, k=4)
|
357 |
|
358 |
single_shap_values = shap_values_obj.values[selected_index]
|
@@ -376,11 +391,11 @@ def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, fi
|
|
376 |
# Gradio Interface
|
377 |
###############################################################################
|
378 |
with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
|
379 |
-
shap.initjs() # load shap JS for interactive
|
380 |
|
381 |
gr.Markdown(
|
382 |
"""
|
383 |
-
# **
|
384 |
**Upload a FASTA file** with one or more nucleotide sequences.
|
385 |
This app will:
|
386 |
1. Predict each sequence's **host** (human vs. non-human).
|
@@ -407,7 +422,7 @@ with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
|
|
407 |
md_out = gr.Markdown()
|
408 |
|
409 |
with gr.Tab("SHAP Waterfall"):
|
410 |
-
# We'll let user pick the sequence index from a dropdown or
|
411 |
with gr.Row():
|
412 |
seq_index_dropdown = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
|
413 |
update_wf_btn = gr.Button("Update Waterfall")
|
@@ -424,34 +439,39 @@ with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
|
|
424 |
fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
|
425 |
|
426 |
# --- Button Logic ---
|
|
|
427 |
run_btn.click(
|
428 |
fn=main_predict,
|
429 |
inputs=[file_input],
|
430 |
outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state]
|
431 |
)
|
432 |
-
|
|
|
433 |
fn=lambda x: x,
|
434 |
inputs=file_input,
|
435 |
outputs=file_data_state
|
436 |
)
|
437 |
|
|
|
438 |
update_wf_btn.click(
|
439 |
fn=update_waterfall_plot,
|
440 |
inputs=[seq_index_dropdown, shap_values_state],
|
441 |
outputs=[wf_plot]
|
442 |
)
|
443 |
-
update_fs_btn.click(
|
444 |
-
fn=update_freq_plot,
|
445 |
-
inputs=[seq_index_dropdown2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state],
|
446 |
-
outputs=[fs_plot]
|
447 |
-
)
|
448 |
|
449 |
-
#
|
450 |
run_btn.click(
|
451 |
fn=update_beeswarm_plot,
|
452 |
inputs=[shap_values_state],
|
453 |
outputs=[bs_plot]
|
454 |
)
|
455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
if __name__ == "__main__":
|
457 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
|
|
32 |
def forward(self, x):
|
33 |
return self.network(x)
|
34 |
|
35 |
+
###############################################################################
|
36 |
+
# Torch Model Wrapper for SHAP
|
37 |
+
###############################################################################
|
38 |
+
class TorchModelWrapper:
|
39 |
+
"""
|
40 |
+
A simple callable that takes a PyTorch model and device,
|
41 |
+
and allows SHAP to pass in numpy arrays, which we convert to torch tensors.
|
42 |
+
"""
|
43 |
+
def __init__(self, model: nn.Module, device='cpu'):
|
44 |
+
self.model = model
|
45 |
+
self.device = device
|
46 |
+
|
47 |
+
def __call__(self, x_np: np.ndarray):
|
48 |
+
"""
|
49 |
+
x_np: shape=(batch_size, num_features) as a numpy array
|
50 |
+
Returns: numpy array of shape=(batch_size, num_outputs)
|
51 |
+
"""
|
52 |
+
x_torch = torch.from_numpy(x_np).float().to(self.device)
|
53 |
+
with torch.no_grad():
|
54 |
+
out = self.model(x_torch).cpu().numpy()
|
55 |
+
return out
|
56 |
+
|
57 |
|
58 |
###############################################################################
|
59 |
# Utility Functions
|
|
|
87 |
Convert a single nucleotide sequence to a k-mer frequency vector
|
88 |
of length 4^k (e.g., for k=4, length=256).
|
89 |
"""
|
90 |
+
from itertools import product
|
91 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
92 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
93 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
|
|
145 |
# color by sign (positive=green, negative=red)
|
146 |
colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
|
147 |
|
148 |
+
import matplotlib.pyplot as plt
|
149 |
x = np.arange(len(kmers))
|
150 |
width = 0.4
|
151 |
|
|
|
153 |
# Frequency
|
154 |
ax.bar(x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)")
|
155 |
ax.set_ylabel("Frequency (%)", color='black')
|
156 |
+
if freqs:
|
157 |
+
ax.set_ylim(0, max(freqs)*1.2)
|
158 |
|
159 |
# Twin axis for sigma
|
160 |
ax2 = ax.twinx()
|
|
|
185 |
- shap_values object (SHAP values for the entire batch)
|
186 |
- array/batch of scaled vectors (for use in the waterfall selection)
|
187 |
- list of k-mers (for indexing)
|
188 |
+
- error message or None
|
189 |
"""
|
190 |
# 1. Basic read
|
191 |
if isinstance(file_obj, str):
|
|
|
217 |
# 4. Load model & scaler
|
218 |
try:
|
219 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
220 |
+
|
221 |
model = VirusClassifier(input_shape=4**k).to(device)
|
222 |
+
# Set weights_only=True to suppress the future pickle warning
|
223 |
+
state_dict = torch.load("model.pt", map_location=device, weights_only=True)
|
224 |
model.load_state_dict(state_dict)
|
225 |
model.eval()
|
226 |
|
227 |
scaler = joblib.load("scaler.pkl")
|
228 |
+
|
229 |
except Exception as e:
|
230 |
return None, None, f"Error loading model or scaler: {str(e)}"
|
231 |
|
|
|
252 |
|
253 |
# 7. SHAP Explainer
|
254 |
# We'll pick a background subset if there are many sequences
|
|
|
255 |
if scaled_data.shape[0] > 50:
|
256 |
background_data = scaled_data[:50]
|
257 |
else:
|
258 |
background_data = scaled_data
|
259 |
|
260 |
+
# Wrap the model so it can handle numpy -> tensor
|
261 |
+
wrapped_model = TorchModelWrapper(model, device)
|
262 |
+
explainer = shap.Explainer(wrapped_model, background_data)
|
|
|
|
|
263 |
shap_values = explainer(scaled_data) # shape=(num_samples, num_features)
|
264 |
|
265 |
# k-mer list
|
266 |
+
from itertools import product
|
267 |
kmer_list = [''.join(p) for p in product("ACGT", repeat=k)]
|
268 |
|
269 |
return (results_table, shap_values, scaled_data, kmer_list, None)
|
|
|
275 |
def main_predict(file_obj):
|
276 |
"""
|
277 |
This function is triggered by the 'Run' button in Gradio.
|
278 |
+
It returns a markdown of all sequences/predictions and
|
279 |
+
the shap values plus data needed for subsequent plots.
|
280 |
"""
|
281 |
results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
|
282 |
if err:
|
|
|
296 |
)
|
297 |
md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots."
|
298 |
|
|
|
|
|
|
|
299 |
return (md, shap_vals, scaled_data, kmer_list, results)
|
300 |
|
|
|
301 |
def update_waterfall_plot(selected_index, shap_values_obj):
|
302 |
"""
|
303 |
+
Build a waterfall plot for the user-selected sample using shap.plots.waterfall.
|
304 |
"""
|
305 |
if shap_values_obj is None:
|
306 |
return None
|
307 |
|
308 |
+
import matplotlib.pyplot as plt
|
309 |
+
import shap
|
310 |
+
|
311 |
try:
|
312 |
selected_index = int(selected_index)
|
313 |
except:
|
314 |
selected_index = 0
|
315 |
|
316 |
+
# Create the figure by calling shap.plots.waterfall
|
|
|
|
|
|
|
|
|
317 |
shap_plots_fig = plt.figure(figsize=(8, 5))
|
318 |
+
shap.plots.waterfall(shap_values_obj[selected_index], max_display=14, show=False)
|
|
|
319 |
buf = io.BytesIO()
|
320 |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
321 |
buf.seek(0)
|
|
|
324 |
|
325 |
return wf_img
|
326 |
|
|
|
327 |
def update_beeswarm_plot(shap_values_obj):
|
328 |
"""
|
329 |
Build a beeswarm plot across all samples.
|
|
|
331 |
if shap_values_obj is None:
|
332 |
return None
|
333 |
|
334 |
+
import matplotlib.pyplot as plt
|
335 |
+
import shap
|
336 |
+
|
337 |
beeswarm_fig = plt.figure(figsize=(8, 5))
|
338 |
shap.plots.beeswarm(shap_values_obj, show=False)
|
339 |
buf = io.BytesIO()
|
|
|
344 |
|
345 |
return bs_img
|
346 |
|
|
|
347 |
def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
|
348 |
"""
|
349 |
Create the frequency & sigma bar chart for the selected sequence's top-10 k-mers.
|
350 |
+
We must re-parse the raw freq vector for that sequence, or store it from earlier.
|
351 |
"""
|
352 |
if shap_values_obj is None or scaled_data is None or kmer_list is None:
|
353 |
return None
|
|
|
357 |
except:
|
358 |
selected_index = 0
|
359 |
|
360 |
+
# Re-parse the FASTA to get the corresponding sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
if isinstance(file_obj, str):
|
362 |
text = file_obj
|
363 |
else:
|
364 |
text = file_obj.decode('utf-8')
|
365 |
+
|
366 |
sequences = parse_fasta(text)
|
|
|
367 |
if selected_index >= len(sequences):
|
368 |
selected_index = 0
|
369 |
+
|
370 |
+
seq = sequences[selected_index][1]
|
371 |
raw_vec = sequence_to_kmer_vector(seq, k=4)
|
372 |
|
373 |
single_shap_values = shap_values_obj.values[selected_index]
|
|
|
391 |
# Gradio Interface
|
392 |
###############################################################################
|
393 |
with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
|
394 |
+
shap.initjs() # load shap JS if needed for interactive HTML (optional)
|
395 |
|
396 |
gr.Markdown(
|
397 |
"""
|
398 |
+
# **Virus Host Classifier with SHAP**
|
399 |
**Upload a FASTA file** with one or more nucleotide sequences.
|
400 |
This app will:
|
401 |
1. Predict each sequence's **host** (human vs. non-human).
|
|
|
422 |
md_out = gr.Markdown()
|
423 |
|
424 |
with gr.Tab("SHAP Waterfall"):
|
425 |
+
# We'll let user pick the sequence index from a dropdown or input
|
426 |
with gr.Row():
|
427 |
seq_index_dropdown = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
|
428 |
update_wf_btn = gr.Button("Update Waterfall")
|
|
|
439 |
fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
|
440 |
|
441 |
# --- Button Logic ---
|
442 |
+
# 1) The main classification run
|
443 |
run_btn.click(
|
444 |
fn=main_predict,
|
445 |
inputs=[file_input],
|
446 |
outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state]
|
447 |
)
|
448 |
+
# Also store raw file data for subsequent freq usage
|
449 |
+
run_btn.click(
|
450 |
fn=lambda x: x,
|
451 |
inputs=file_input,
|
452 |
outputs=file_data_state
|
453 |
)
|
454 |
|
455 |
+
# 2) Waterfall update
|
456 |
update_wf_btn.click(
|
457 |
fn=update_waterfall_plot,
|
458 |
inputs=[seq_index_dropdown, shap_values_state],
|
459 |
outputs=[wf_plot]
|
460 |
)
|
|
|
|
|
|
|
|
|
|
|
461 |
|
462 |
+
# 3) Beeswarm update
|
463 |
run_btn.click(
|
464 |
fn=update_beeswarm_plot,
|
465 |
inputs=[shap_values_state],
|
466 |
outputs=[bs_plot]
|
467 |
)
|
468 |
|
469 |
+
# 4) Frequency top-10 update
|
470 |
+
update_fs_btn.click(
|
471 |
+
fn=update_freq_plot,
|
472 |
+
inputs=[seq_index_dropdown2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state],
|
473 |
+
outputs=[fs_plot]
|
474 |
+
)
|
475 |
+
|
476 |
if __name__ == "__main__":
|
477 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|