Spaces:
Running
Running
import gradio as gr | |
import torch | |
import joblib | |
import numpy as np | |
from itertools import product | |
import torch.nn as nn | |
import matplotlib | |
matplotlib.use("Agg") # In case we're running in a no-display environment | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
import shap | |
############################################################################### | |
# Model Definition | |
############################################################################### | |
class VirusClassifier(nn.Module): | |
def __init__(self, input_shape: int): | |
super(VirusClassifier, self).__init__() | |
self.network = nn.Sequential( | |
nn.Linear(input_shape, 64), | |
nn.GELU(), | |
nn.BatchNorm1d(64), | |
nn.Dropout(0.3), | |
nn.Linear(64, 32), | |
nn.GELU(), | |
nn.BatchNorm1d(32), | |
nn.Dropout(0.3), | |
nn.Linear(32, 32), | |
nn.GELU(), | |
nn.Linear(32, 2) | |
) | |
def forward(self, x): | |
return self.network(x) | |
############################################################################### | |
# Torch Model Wrapper for SHAP | |
############################################################################### | |
class TorchModelWrapper: | |
""" | |
A simple callable that takes a PyTorch model and device, | |
allowing SHAP to pass in NumPy arrays. We convert them | |
to torch tensors, run the model, and return NumPy outputs. | |
""" | |
def __init__(self, model: nn.Module, device='cpu'): | |
self.model = model | |
self.device = device | |
def __call__(self, x_np: np.ndarray): | |
""" | |
x_np: shape=(batch_size, num_features) as a numpy array | |
Returns: numpy array of shape=(batch_size, num_outputs) | |
""" | |
x_torch = torch.from_numpy(x_np).float().to(self.device) | |
with torch.no_grad(): | |
out = self.model(x_torch).cpu().numpy() | |
return out | |
############################################################################### | |
# Utility Functions | |
############################################################################### | |
def parse_fasta(text): | |
""" | |
Parses text input in FASTA format into a list of (header, sequence). | |
Handles multiple sequences if present. | |
""" | |
sequences = [] | |
current_header = None | |
current_sequence = [] | |
for line in text.split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('>'): | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
current_header = line[1:] | |
current_sequence = [] | |
else: | |
current_sequence.append(line.upper()) | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
return sequences | |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: | |
""" | |
Convert a single nucleotide sequence to a k-mer frequency vector | |
of length 4^k (e.g., for k=4, length=256). | |
""" | |
kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
kmer_dict = {km: i for i, km in enumerate(kmers)} | |
vec = np.zeros(len(kmers), dtype=np.float32) | |
for i in range(len(sequence) - k + 1): | |
kmer = sequence[i:i+k] | |
if kmer in kmer_dict: | |
vec[kmer_dict[kmer]] += 1 | |
total_kmers = len(sequence) - k + 1 | |
if total_kmers > 0: | |
vec = vec / total_kmers # normalize frequencies | |
return vec | |
############################################################################### | |
# Visualization Helpers | |
############################################################################### | |
def create_freq_sigma_plot( | |
single_shap_values: np.ndarray, | |
raw_freq_vector: np.ndarray, | |
scaled_vector: np.ndarray, | |
kmer_list, | |
title: str | |
): | |
""" | |
Creates a bar plot showing top-10 k-mers (by absolute SHAP value), | |
with frequency (%) and sigma from mean on a twin-axis. | |
single_shap_values: shape=(256,) SHAP values for the "human" class | |
raw_freq_vector: shape=(256,) original frequencies for this sample | |
scaled_vector: shape=(256,) scaled (Z-score) values for this sample | |
kmer_list: list of length=256 of all k-mers | |
""" | |
# Identify the top 10 k-mers by absolute shap | |
abs_vals = np.abs(single_shap_values) # shape=(256,) | |
top_k = 10 | |
top_indices = np.argsort(abs_vals)[-top_k:][::-1] # indices of largest -> smallest | |
top_data = [] | |
for idx in top_indices: | |
idx_int = int(idx) # ensure integer | |
top_data.append({ | |
"kmer": kmer_list[idx_int], | |
"shap": single_shap_values[idx_int], | |
"abs_shap": abs_vals[idx_int], | |
"frequency": raw_freq_vector[idx_int] * 100.0, # percentage | |
"sigma": scaled_vector[idx_int] | |
}) | |
# Sort top_data by abs_shap descending | |
top_data.sort(key=lambda x: x["abs_shap"], reverse=True) | |
# Prepare for plotting | |
kmers = [d["kmer"] for d in top_data] | |
freqs = [d["frequency"] for d in top_data] | |
sigmas = [d["sigma"] for d in top_data] | |
# color by sign (positive=green => pushes "human", negative=red => pushes "non-human") | |
colors = ["green" if d["shap"] >= 0 else "red" for d in top_data] | |
x = np.arange(len(kmers)) | |
width = 0.4 | |
fig, ax = plt.subplots(figsize=(8, 5)) | |
# Frequency | |
ax.bar( | |
x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)" | |
) | |
ax.set_ylabel("Frequency (%)", color='black') | |
if len(freqs) > 0: | |
ax.set_ylim(0, max(freqs)*1.2) | |
# Twin axis for sigma | |
ax2 = ax.twinx() | |
ax2.bar( | |
x + width/2, sigmas, width, color="gray", alpha=0.5, label="σ from Mean" | |
) | |
ax2.set_ylabel("Standard Deviations (σ)", color='black') | |
ax.set_xticks(x) | |
ax.set_xticklabels(kmers, rotation=45, ha='right') | |
ax.set_title(f"Top-10 K-mers (Frequency & σ)\n{title}") | |
# Combine legends | |
lines1, labels1 = ax.get_legend_handles_labels() | |
lines2, labels2 = ax2.get_legend_handles_labels() | |
ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right') | |
plt.tight_layout() | |
return fig | |
############################################################################### | |
# Main Inference & SHAP Logic | |
############################################################################### | |
def run_classification_and_shap(file_obj): | |
""" | |
Reads one or more FASTA sequences from file_obj or text. | |
Returns: | |
- Table of results (list of dicts) for each sequence | |
- shap_values object (SHAP values for the entire batch, shape=(num_samples, 2, num_features)) | |
- array of scaled vectors | |
- list of k-mers | |
- error message or None | |
""" | |
# 1. Basic read | |
if isinstance(file_obj, str): | |
text = file_obj | |
else: | |
try: | |
text = file_obj.decode("utf-8") | |
except Exception as e: | |
return None, None, None, None, f"Error reading file: {str(e)}" | |
# 2. Parse FASTA | |
sequences = parse_fasta(text) | |
if len(sequences) == 0: | |
return None, None, None, None, "No valid FASTA sequences found!" | |
# 3. Convert each sequence to k-mer vector | |
k = 4 | |
all_raw_vectors = [] | |
headers = [] | |
seqs = [] | |
for (hdr, seq) in sequences: | |
raw_vec = sequence_to_kmer_vector(seq, k=k) | |
all_raw_vectors.append(raw_vec) | |
headers.append(hdr) | |
seqs.append(seq) | |
all_raw_vectors = np.stack(all_raw_vectors, axis=0) # shape=(num_seqs, 256) | |
# 4. Load model & scaler | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = VirusClassifier(input_shape=4**k).to(device) | |
# Use weights_only=True to suppress future warnings about untrusted pickles | |
state_dict = torch.load("model.pt", map_location=device, weights_only=True) | |
model.load_state_dict(state_dict) | |
model.eval() | |
scaler = joblib.load("scaler.pkl") | |
except Exception as e: | |
return None, None, None, None, f"Error loading model or scaler: {str(e)}" | |
# 5. Scale data | |
scaled_data = scaler.transform(all_raw_vectors) # shape=(num_seqs, 256) | |
# 6. Predictions | |
X_tensor = torch.FloatTensor(scaled_data).to(device) | |
with torch.no_grad(): | |
logits = model(X_tensor) | |
# shape=(num_seqs, 2) | |
probs = torch.softmax(logits, dim=1).cpu().numpy() | |
preds = np.argmax(probs, axis=1) # 0 or 1 | |
results_table = [] | |
for i, (hdr, seq) in enumerate(zip(headers, seqs)): | |
results_table.append({ | |
"header": hdr, | |
"sequence": seq[:50] + ("..." if len(seq) > 50 else ""), | |
"pred_label": "human" if preds[i] == 1 else "non-human", | |
"human_prob": float(probs[i][1]), | |
"non_human_prob": float(probs[i][0]), | |
"confidence": float(np.max(probs[i])) | |
}) | |
# 7. SHAP Explainer | |
# For large data, pick a smaller background subset | |
if scaled_data.shape[0] > 50: | |
background_data = scaled_data[:50] | |
else: | |
background_data = scaled_data | |
wrapped_model = TorchModelWrapper(model, device) | |
explainer = shap.Explainer(wrapped_model, background_data) | |
# shap_values shape=(num_samples, num_features) if single-output | |
# but here we have 2 outputs => shape=(num_samples, 2, num_features). | |
shap_values = explainer(scaled_data) | |
# Prepare k-mer list | |
kmer_list = [''.join(p) for p in product("ACGT", repeat=k)] | |
# Return everything | |
return (results_table, shap_values, scaled_data, kmer_list, None) | |
############################################################################### | |
# Gradio Callback Functions | |
############################################################################### | |
def main_predict(file_obj): | |
""" | |
Triggered by the 'Run Classification' button in Gradio. | |
Returns a markdown table plus states for subsequent plots. | |
""" | |
results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj) | |
if err: | |
return (err, None, None, None, None) | |
if results is None or shap_vals is None: | |
return ("An unknown error occurred.", None, None, None, None) | |
# Build a summary for all sequences | |
md = "# Classification Results\n\n" | |
md += "| # | Header | Pred Label | Confidence | Human Prob | Non-human Prob |\n" | |
md += "|---|--------|------------|------------|------------|----------------|\n" | |
for i, row in enumerate(results): | |
md += ( | |
f"| {i} | {row['header']} | {row['pred_label']} | " | |
f"{row['confidence']:.4f} | {row['human_prob']:.4f} | {row['non_human_prob']:.4f} |\n" | |
) | |
md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots (class=1/human)." | |
return (md, shap_vals, scaled_data, kmer_list, results) | |
def update_waterfall_plot(selected_index, shap_values_obj): | |
""" | |
Build a waterfall plot for the user-selected sample, but ONLY for class=1 (human). | |
shap_values_obj has shape=(num_samples, 2, num_features). | |
We do shap_values_obj[selected_index, 1] => shape=(num_features,) | |
for a single-sample single-class explanation. | |
""" | |
if shap_values_obj is None: | |
return None | |
import matplotlib.pyplot as plt | |
try: | |
selected_index = int(selected_index) | |
except: | |
selected_index = 0 | |
# We only visualize class=1 ("human") SHAP values | |
# shap_values_obj.values shape => (num_samples, 2, num_features) | |
single_ex_values = shap_values_obj.values[selected_index, 1, :] # shape=(256,) | |
single_ex_base = shap_values_obj.base_values[selected_index, 1] # scalar | |
single_ex_data = shap_values_obj.data[selected_index] # shape=(256,) | |
# Construct a shap.Explanation object for just this one sample & class | |
single_expl = shap.Explanation( | |
values=single_ex_values, | |
base_values=single_ex_base, | |
data=single_ex_data, | |
feature_names=[f"feat_{i}" for i in range(single_ex_values.shape[0])] | |
) | |
shap_plots_fig = plt.figure(figsize=(8, 5)) | |
shap.plots.waterfall(single_expl, max_display=14, show=False) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120) | |
buf.seek(0) | |
wf_img = Image.open(buf) | |
plt.close(shap_plots_fig) | |
return wf_img | |
def update_beeswarm_plot(shap_values_obj): | |
""" | |
Build a beeswarm plot across all samples, but only for class=1 (human). | |
We slice shap_values_obj to pick shap_values_obj.values[:, 1, :] | |
=> shape=(num_samples, num_features). | |
""" | |
if shap_values_obj is None: | |
return None | |
import matplotlib.pyplot as plt | |
# For multi-output, shap_values_obj.values shape => (num_samples, 2, num_features) | |
# We'll create a new Explanation object for class=1: | |
class1_vals = shap_values_obj.values[:, 1, :] # shape=(num_samples, num_features) | |
class1_base = shap_values_obj.base_values[:, 1] # shape=(num_samples,) | |
class1_data = shap_values_obj.data # shape=(num_samples, num_features) | |
# Some versions of shap store data in a 2D array, which is fine | |
# We'll re-wrap them in a shap.Explanation: | |
class1_expl = shap.Explanation( | |
values=class1_vals, | |
base_values=class1_base, | |
data=class1_data, | |
feature_names=[f"feat_{i}" for i in range(class1_vals.shape[1])] | |
) | |
beeswarm_fig = plt.figure(figsize=(8, 5)) | |
shap.plots.beeswarm(class1_expl, show=False) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120) | |
buf.seek(0) | |
bs_img = Image.open(buf) | |
plt.close(beeswarm_fig) | |
return bs_img | |
def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj): | |
""" | |
Create the frequency & σ bar chart for the selected sequence's top-10 k-mers (by abs SHAP). | |
Again, we'll use class=1 SHAP values only. | |
""" | |
if shap_values_obj is None or scaled_data is None or kmer_list is None: | |
return None | |
import matplotlib.pyplot as plt | |
try: | |
selected_index = int(selected_index) | |
except: | |
selected_index = 0 | |
# Re-parse the FASTA to get the corresponding sequence | |
if isinstance(file_obj, str): | |
text = file_obj | |
else: | |
text = file_obj.decode('utf-8') | |
sequences = parse_fasta(text) | |
# If out of range, clamp to 0 | |
if selected_index >= len(sequences): | |
selected_index = 0 | |
seq = sequences[selected_index][1] | |
raw_vec = sequence_to_kmer_vector(seq, k=4) # shape=(256,) | |
# SHAP for class=1 => shape=(num_samples, 2, 256) | |
single_shap_values = shap_values_obj.values[selected_index, 1, :] | |
freq_sigma_fig = create_freq_sigma_plot( | |
single_shap_values, | |
raw_freq_vector=raw_vec, | |
scaled_vector=scaled_data[selected_index], | |
kmer_list=kmer_list, | |
title=f"Sample #{selected_index} — {sequences[selected_index][0]}" | |
) | |
buf = io.BytesIO() | |
freq_sigma_fig.savefig(buf, format='png', bbox_inches='tight', dpi=120) | |
buf.seek(0) | |
fs_img = Image.open(buf) | |
plt.close(freq_sigma_fig) | |
return fs_img | |
############################################################################### | |
# Gradio Interface | |
############################################################################### | |
with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo: | |
shap.initjs() # load shap JS if needed for HTML-based plots (optional) | |
gr.Markdown( | |
""" | |
# **irus Host Classifier** | |
Upload a FASTA file with one or more nucleotide sequences. | |
This app will: | |
1. Predict each sequence's **host** (human vs. non-human). | |
2. Provide **SHAP** explanations focusing on the 'human' class (index=1). | |
3. Display: | |
- A **waterfall** plot per-sequence (top features). | |
- A **beeswarm** plot across all sequences (global summary). | |
- A **frequency & σ** bar chart for the top-10 k-mers of any selected sequence. | |
""" | |
) | |
with gr.Row(): | |
file_input = gr.File(label="Upload FASTA", type="binary") | |
run_btn = gr.Button("Run Classification") | |
# Store intermediate results in Gradio states | |
shap_values_state = gr.State() | |
scaled_data_state = gr.State() | |
kmer_list_state = gr.State() | |
results_state = gr.State() | |
file_data_state = gr.State() | |
with gr.Tabs(): | |
with gr.Tab("Results Table"): | |
md_out = gr.Markdown() | |
with gr.Tab("SHAP Waterfall"): | |
with gr.Row(): | |
seq_index_input = gr.Number(label="Sequence Index (0-based)", value=0, precision=0) | |
update_wf_btn = gr.Button("Update Waterfall") | |
wf_plot = gr.Image(label="SHAP Waterfall Plot") | |
with gr.Tab("SHAP Beeswarm"): | |
bs_plot = gr.Image(label="Global Beeswarm Plot", height=500) | |
with gr.Tab("Top-10 Frequency & Sigma"): | |
with gr.Row(): | |
seq_index_input2 = gr.Number(label="Sequence Index (0-based)", value=0, precision=0) | |
update_fs_btn = gr.Button("Update Frequency Chart") | |
fs_plot = gr.Image(label="Top-10 Frequency & σ Chart") | |
# 1) Main classification | |
run_btn.click( | |
fn=main_predict, | |
inputs=[file_input], | |
outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state] | |
) | |
run_btn.click( | |
fn=lambda x: x, | |
inputs=file_input, | |
outputs=file_data_state | |
) | |
# 2) Update Waterfall | |
update_wf_btn.click( | |
fn=update_waterfall_plot, | |
inputs=[seq_index_input, shap_values_state], | |
outputs=[wf_plot] | |
) | |
# 3) Update Beeswarm right after classification | |
run_btn.click( | |
fn=update_beeswarm_plot, | |
inputs=[shap_values_state], | |
outputs=[bs_plot] | |
) | |
# 4) Update Frequency & σ | |
update_fs_btn.click( | |
fn=update_freq_plot, | |
inputs=[seq_index_input2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state], | |
outputs=[fs_plot] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |