Spaces:
Running
Running
import gradio as gr | |
import torch | |
import joblib | |
import numpy as np | |
from itertools import product | |
import torch.nn as nn | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
import shap # Requires: pip install shap | |
############################################################################### | |
# Model Definition | |
############################################################################### | |
class VirusClassifier(nn.Module): | |
def __init__(self, input_shape: int): | |
super(VirusClassifier, self).__init__() | |
self.network = nn.Sequential( | |
nn.Linear(input_shape, 64), | |
nn.GELU(), | |
nn.BatchNorm1d(64), | |
nn.Dropout(0.3), | |
nn.Linear(64, 32), | |
nn.GELU(), | |
nn.BatchNorm1d(32), | |
nn.Dropout(0.3), | |
nn.Linear(32, 32), | |
nn.GELU(), | |
nn.Linear(32, 2) | |
) | |
def forward(self, x): | |
return self.network(x) | |
############################################################################### | |
# Utility Functions | |
############################################################################### | |
def parse_fasta(text): | |
""" | |
Parses text input in FASTA format into a list of (header, sequence). | |
Handles multiple sequences if present. | |
""" | |
sequences = [] | |
current_header = None | |
current_sequence = [] | |
for line in text.split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('>'): | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
current_header = line[1:] | |
current_sequence = [] | |
else: | |
current_sequence.append(line.upper()) | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
return sequences | |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: | |
""" | |
Convert a single nucleotide sequence to a k-mer frequency vector | |
of length 4^k (e.g., for k=4, length=256). | |
""" | |
kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
kmer_dict = {km: i for i, km in enumerate(kmers)} | |
vec = np.zeros(len(kmers), dtype=np.float32) | |
for i in range(len(sequence) - k + 1): | |
kmer = sequence[i:i+k] | |
if kmer in kmer_dict: | |
vec[kmer_dict[kmer]] += 1 | |
total_kmers = len(sequence) - k + 1 | |
if total_kmers > 0: | |
vec = vec / total_kmers # normalize frequencies | |
return vec | |
############################################################################### | |
# Visualization Helpers | |
############################################################################### | |
def create_freq_sigma_plot( | |
single_shap_values: np.ndarray, | |
raw_freq_vector: np.ndarray, | |
scaled_vector: np.ndarray, | |
kmer_list, | |
title: str | |
): | |
""" | |
Creates a bar plot showing top-10 k-mers (by absolute SHAP value), | |
with frequency (%) and sigma from mean on a twin-axis. | |
single_shap_values: shape=(256,) shap values for this sample | |
raw_freq_vector: shape=(256,) original frequencies for this sample | |
scaled_vector: shape=(256,) scaled (Z-score) values for this sample | |
kmer_list: list of all k-mers (length=256) | |
""" | |
abs_vals = np.abs(single_shap_values) | |
top_k = 10 | |
top_indices = np.argsort(abs_vals)[-top_k:][::-1] # top 10 by absolute shap | |
top_data = [] | |
for idx in top_indices: | |
top_data.append({ | |
"kmer": kmer_list[idx], | |
"shap": single_shap_values[idx], | |
"abs_shap": abs_vals[idx], | |
"frequency": raw_freq_vector[idx] * 100.0, # percentage | |
"sigma": scaled_vector[idx] | |
}) | |
# Sort top_data by abs_shap descending | |
top_data.sort(key=lambda x: x["abs_shap"], reverse=True) | |
kmers = [d["kmer"] for d in top_data] | |
freqs = [d["frequency"] for d in top_data] | |
sigmas = [d["sigma"] for d in top_data] | |
# color by sign (positive=green, negative=red) | |
colors = ["green" if d["shap"] >= 0 else "red" for d in top_data] | |
x = np.arange(len(kmers)) | |
width = 0.4 | |
fig, ax = plt.subplots(figsize=(8, 5)) | |
# Frequency | |
ax.bar(x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)") | |
ax.set_ylabel("Frequency (%)", color='black') | |
ax.set_ylim(0, max(freqs)*1.2 if len(freqs) else 1) | |
# Twin axis for sigma | |
ax2 = ax.twinx() | |
ax2.bar(x + width/2, sigmas, width, color="gray", alpha=0.5, label="σ from Mean") | |
ax2.set_ylabel("Standard Deviations (σ)", color='black') | |
ax.set_xticks(x) | |
ax.set_xticklabels(kmers, rotation=45, ha='right') | |
ax.set_title(f"Top-10 K-mers (Frequency & σ)\n{title}") | |
# Combine legends | |
lines1, labels1 = ax.get_legend_handles_labels() | |
lines2, labels2 = ax2.get_legend_handles_labels() | |
ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right') | |
plt.tight_layout() | |
return fig | |
############################################################################### | |
# Main Inference & SHAP Logic | |
############################################################################### | |
def run_classification_and_shap(file_obj): | |
""" | |
Reads one or more FASTA sequences from file_obj or text. | |
Returns: | |
- Table of results (list of dicts) for each sequence | |
- shap_values object (SHAP values for the entire batch) | |
- array/batch of scaled vectors (for use in the waterfall selection) | |
- list of k-mers (for indexing) | |
- possibly the model or other context | |
""" | |
# 1. Basic read | |
if isinstance(file_obj, str): | |
text = file_obj | |
else: | |
try: | |
text = file_obj.decode("utf-8") | |
except Exception as e: | |
return None, None, f"Error reading file: {str(e)}" | |
# 2. Parse FASTA | |
sequences = parse_fasta(text) | |
if len(sequences) == 0: | |
return None, None, "No valid FASTA sequences found!" | |
# 3. Convert each sequence to k-mer vector | |
k = 4 | |
all_raw_vectors = [] | |
headers = [] | |
seqs = [] | |
for (hdr, seq) in sequences: | |
raw_vec = sequence_to_kmer_vector(seq, k=k) | |
all_raw_vectors.append(raw_vec) | |
headers.append(hdr) | |
seqs.append(seq) | |
all_raw_vectors = np.stack(all_raw_vectors, axis=0) # shape=(num_seqs, 256) | |
# 4. Load model & scaler | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = VirusClassifier(input_shape=4**k).to(device) | |
state_dict = torch.load("model.pt", map_location=device) | |
model.load_state_dict(state_dict) | |
model.eval() | |
scaler = joblib.load("scaler.pkl") | |
except Exception as e: | |
return None, None, f"Error loading model or scaler: {str(e)}" | |
# 5. Scale data | |
scaled_data = scaler.transform(all_raw_vectors) # shape=(num_seqs, 256) | |
# 6. Predictions | |
X_tensor = torch.FloatTensor(scaled_data).to(device) | |
with torch.no_grad(): | |
logits = model(X_tensor) | |
probs = torch.softmax(logits, dim=1).cpu().numpy() | |
preds = np.argmax(probs, axis=1) # 0 or 1 | |
results_table = [] | |
for i, (hdr, seq) in enumerate(zip(headers, seqs)): | |
results_table.append({ | |
"header": hdr, | |
"sequence": seq[:50] + ("..." if len(seq)>50 else ""), # truncated | |
"pred_label": "human" if preds[i] == 1 else "non-human", | |
"human_prob": float(probs[i][1]), | |
"non_human_prob": float(probs[i][0]), | |
"confidence": float(max(probs[i])) | |
}) | |
# 7. SHAP Explainer | |
# We'll pick a background subset if there are many sequences | |
# (For performance, we might limit to e.g. 50 samples max) | |
if scaled_data.shape[0] > 50: | |
background_data = scaled_data[:50] | |
else: | |
background_data = scaled_data | |
# Use the "new" unified shap.Explainer approach | |
# We pass in a function that does the forward pass. Or pass the model directly. | |
# For PyTorch models, shap can do a direct 'model' approach with a mask. | |
# We'll do a simple "use shap.Explainer" with data=background_data | |
explainer = shap.Explainer(model, background_data) | |
shap_values = explainer(scaled_data) # shape=(num_samples, num_features) | |
# k-mer list | |
kmer_list = [''.join(p) for p in product("ACGT", repeat=k)] | |
return (results_table, shap_values, scaled_data, kmer_list, None) | |
############################################################################### | |
# Gradio Callback Functions | |
############################################################################### | |
def main_predict(file_obj): | |
""" | |
This function is triggered by the 'Run' button in Gradio. | |
It returns a markdown of all sequences/predictions and stores | |
data needed for the subsequent SHAP visualizations. | |
""" | |
results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj) | |
if err: | |
return (err, None, None, None, None) | |
if results is None or shap_vals is None: | |
return ("An unknown error occurred.", None, None, None, None) | |
# Build a summary for all sequences | |
md = "# Classification Results\n\n" | |
md += "| # | Header | Pred Label | Confidence | Human Prob | Non-human Prob |\n" | |
md += "|---|--------|------------|------------|------------|----------------|\n" | |
for i, row in enumerate(results): | |
md += ( | |
f"| {i} | {row['header']} | {row['pred_label']} | " | |
f"{row['confidence']:.4f} | {row['human_prob']:.4f} | {row['non_human_prob']:.4f} |\n" | |
) | |
md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots." | |
# Return the string, and also the shap values plus data needed | |
# We'll store these to SessionState via Gradio's "State" or we can | |
# pass them out as hidden fields. | |
return (md, shap_vals, scaled_data, kmer_list, results) | |
def update_waterfall_plot(selected_index, shap_values_obj): | |
""" | |
Build a waterfall plot for the user-selected sample. | |
""" | |
if shap_values_obj is None: | |
return None | |
try: | |
selected_index = int(selected_index) | |
except: | |
selected_index = 0 | |
# We'll create the figure by calling shap.plots.waterfall | |
# Convert shap_values_obj to the new shap interface | |
# shap_values_obj is a shap._explanation.Explanation typically | |
# We can create a figure with shap.plots.waterfall and capture it as an image | |
shap_plots_fig = plt.figure(figsize=(8, 5)) | |
shap.plots.waterfall(shap_values_obj[selected_index], max_display=14, | |
show=False) # show=False so it doesn't pop in the notebook | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120) | |
buf.seek(0) | |
wf_img = Image.open(buf) | |
plt.close(shap_plots_fig) | |
return wf_img | |
def update_beeswarm_plot(shap_values_obj): | |
""" | |
Build a beeswarm plot across all samples. | |
""" | |
if shap_values_obj is None: | |
return None | |
beeswarm_fig = plt.figure(figsize=(8, 5)) | |
shap.plots.beeswarm(shap_values_obj, show=False) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120) | |
buf.seek(0) | |
bs_img = Image.open(buf) | |
plt.close(beeswarm_fig) | |
return bs_img | |
def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj): | |
""" | |
Create the frequency & sigma bar chart for the selected sequence's top-10 k-mers. | |
We'll need to also compute the raw_freq_vector from the original unscaled data. | |
""" | |
if shap_values_obj is None or scaled_data is None or kmer_list is None: | |
return None | |
try: | |
selected_index = int(selected_index) | |
except: | |
selected_index = 0 | |
# We must re-generate the raw freq vector from the original input file | |
# or store it from earlier. Let's just re-run parse for that single sequence: | |
# But simpler is: run_classification_and_shap was storing all_raw_vectors... | |
# Let's do a quick approach: run_classification_and_shap already computed it | |
# but we didn't store it. We'll re-run the parse logic to get the raw freq again. | |
# For memory / speed reasons, better is to store it. | |
# For simplicity, let's parse again quickly: | |
if isinstance(file_obj, str): | |
text = file_obj | |
else: | |
text = file_obj.decode('utf-8') | |
sequences = parse_fasta(text) | |
# the selected_index might be out of range, so let's clamp it | |
if selected_index >= len(sequences): | |
selected_index = 0 | |
seq = sequences[selected_index][1] # get the sequence | |
raw_vec = sequence_to_kmer_vector(seq, k=4) | |
single_shap_values = shap_values_obj.values[selected_index] | |
freq_sigma_fig = create_freq_sigma_plot( | |
single_shap_values, | |
raw_freq_vector=raw_vec, | |
scaled_vector=scaled_data[selected_index], | |
kmer_list=kmer_list, | |
title=f"Sample #{selected_index} — {sequences[selected_index][0]}" | |
) | |
buf = io.BytesIO() | |
freq_sigma_fig.savefig(buf, format='png', bbox_inches='tight', dpi=120) | |
buf.seek(0) | |
fs_img = Image.open(buf) | |
plt.close(freq_sigma_fig) | |
return fs_img | |
############################################################################### | |
# Gradio Interface | |
############################################################################### | |
with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo: | |
shap.initjs() # load shap JS for interactive plots in some contexts (optional) | |
gr.Markdown( | |
""" | |
# **Advanced Virus Host Classifier with SHAP** | |
**Upload a FASTA file** with one or more nucleotide sequences. | |
This app will: | |
1. Predict each sequence's **host** (human vs. non-human). | |
2. Provide **SHAP** explanations (waterfall & beeswarm). | |
3. Let you explore **frequency & σ** for top-10 k-mers for a chosen sequence. | |
""" | |
) | |
with gr.Row(): | |
file_input = gr.File(label="Upload FASTA", type="binary") | |
run_btn = gr.Button("Run Classification") | |
# Store intermediate results in "States" for usage in subsequent tabs | |
shap_values_state = gr.State() | |
scaled_data_state = gr.State() | |
kmer_list_state = gr.State() | |
results_state = gr.State() | |
# We'll also store the "raw input" so we can reconstruct freq data for each sample | |
file_data_state = gr.State() | |
# TABS for outputs | |
with gr.Tabs(): | |
with gr.Tab("Results Table"): | |
md_out = gr.Markdown() | |
with gr.Tab("SHAP Waterfall"): | |
# We'll let user pick the sequence index from a dropdown or slider | |
with gr.Row(): | |
seq_index_dropdown = gr.Number(label="Sequence Index (0-based)", value=0, precision=0) | |
update_wf_btn = gr.Button("Update Waterfall") | |
wf_plot = gr.Image(label="SHAP Waterfall Plot") | |
with gr.Tab("SHAP Beeswarm"): | |
bs_plot = gr.Image(label="Global Beeswarm Plot", height=500) | |
with gr.Tab("Top-10 Frequency & Sigma"): | |
with gr.Row(): | |
seq_index_dropdown2 = gr.Number(label="Sequence Index (0-based)", value=0, precision=0) | |
update_fs_btn = gr.Button("Update Frequency Chart") | |
fs_plot = gr.Image(label="Top-10 Frequency & σ Chart") | |
# --- Button Logic --- | |
run_btn.click( | |
fn=main_predict, | |
inputs=[file_input], | |
outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state] | |
) | |
run_btn.click( # Also store the raw file data for later freq plots | |
fn=lambda x: x, | |
inputs=file_input, | |
outputs=file_data_state | |
) | |
update_wf_btn.click( | |
fn=update_waterfall_plot, | |
inputs=[seq_index_dropdown, shap_values_state], | |
outputs=[wf_plot] | |
) | |
update_fs_btn.click( | |
fn=update_freq_plot, | |
inputs=[seq_index_dropdown2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state], | |
outputs=[fs_plot] | |
) | |
# We can auto-generate the beeswarm right after classification as well | |
run_btn.click( | |
fn=update_beeswarm_plot, | |
inputs=[shap_values_state], | |
outputs=[bs_plot] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |