Spaces:
Running
Running
import gradio as gr | |
import torch | |
import joblib | |
import numpy as np | |
from itertools import product | |
import torch.nn as nn | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
############################################################################### | |
# Model Definition | |
############################################################################### | |
class VirusClassifier(nn.Module): | |
def __init__(self, input_shape: int): | |
super(VirusClassifier, self).__init__() | |
self.network = nn.Sequential( | |
nn.Linear(input_shape, 64), | |
nn.GELU(), | |
nn.BatchNorm1d(64), | |
nn.Dropout(0.3), | |
nn.Linear(64, 32), | |
nn.GELU(), | |
nn.BatchNorm1d(32), | |
nn.Dropout(0.3), | |
nn.Linear(32, 32), | |
nn.GELU(), | |
nn.Linear(32, 2) | |
) | |
def forward(self, x): | |
return self.network(x) | |
def get_feature_importance(self, x): | |
""" | |
Calculate gradient-based feature importance, specifically for the | |
'human' class (index=1) by computing gradient of that probability wrt x. | |
""" | |
x.requires_grad_(True) | |
output = self.network(x) | |
probs = torch.softmax(output, dim=1) | |
# Probability of 'human' class (index=1) | |
human_prob = probs[..., 1] | |
if x.grad is not None: | |
x.grad.zero_() | |
human_prob.backward() | |
importance = x.grad # shape: (batch_size, n_features) | |
return importance, float(human_prob) | |
############################################################################### | |
# Utility Functions | |
############################################################################### | |
def parse_fasta(text): | |
"""Parses text input in FASTA format into a list of (header, sequence).""" | |
sequences = [] | |
current_header = None | |
current_sequence = [] | |
for line in text.split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('>'): | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
current_header = line[1:] | |
current_sequence = [] | |
else: | |
current_sequence.append(line.upper()) | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
return sequences | |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: | |
"""Convert a single nucleotide sequence to a k-mer frequency vector.""" | |
kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
kmer_dict = {km: i for i, km in enumerate(kmers)} | |
vec = np.zeros(len(kmers), dtype=np.float32) | |
for i in range(len(sequence) - k + 1): | |
kmer = sequence[i:i+k] | |
if kmer in kmer_dict: | |
vec[kmer_dict[kmer]] += 1 | |
total_kmers = len(sequence) - k + 1 | |
if total_kmers > 0: | |
vec = vec / total_kmers # normalize frequencies | |
return vec | |
############################################################################### | |
# Visualization | |
############################################################################### | |
def create_shap_waterfall_plot(important_kmers, all_kmer_importance, human_prob, title): | |
""" | |
Create a SHAP-like waterfall plot: | |
- Start at baseline = 0.5 | |
- Add a bar for "Other" which is the combined effect of all less-important k-mers | |
- Then apply each of the top k-mers in descending order of absolute importance | |
- Show final predicted human probability as the endpoint | |
""" | |
# 1) Sort 'important_kmers' by absolute impact descending | |
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True) | |
# 2) Compute the total effect of "other" k-mers | |
# We have 256 total features. We selected top 10. Sum the rest. | |
top_ids = set([km['idx'] for km in sorted_kmers]) | |
other_contributions = [] | |
for i, val in enumerate(all_kmer_importance): | |
if i not in top_ids: | |
other_contributions.append(val) | |
# sum up those "other" contributions | |
other_sum = np.sum(other_contributions) | |
# The "impact" for "other" will be the absolute value, direction depends on sign | |
other_impact = float(abs(other_sum)) | |
other_direction = "human" if other_sum > 0 else "non-human" | |
# 3) Build a list of all bars: first "other", then each top k-mer | |
# Each bar needs: name, raw_contribution_value | |
# We'll store (label, contribution). The sign indicates direction. | |
bars = [] | |
bars.append(("Other", other_sum)) # lumps the leftover k-mers | |
for km in sorted_kmers: | |
# We re-inject the sign on the raw gradient | |
# (We stored only the absolute in "impact," so let's create a signed value) | |
signed_val = km['impact'] if km['direction'] == 'human' else -km['impact'] | |
bars.append((km['kmer'], signed_val)) | |
# 4) Waterfall plot data: | |
# We'll accumulate partial sums from baseline=0.5 | |
baseline = 0.5 | |
running_val = baseline | |
x_labels = [] | |
y_vals = [] | |
bar_colors = [] | |
# We'll use green for positive contributions (pushing toward 'human'), | |
# red for negative contributions (pushing away from 'human') | |
for (label, contrib) in bars: | |
x_labels.append(label) | |
# new value after adding this contribution | |
new_val = running_val + (0.05 * contrib) | |
# ^ scaled by 0.05 for better display. Adjust as desired. | |
y_vals.append((running_val, new_val)) | |
running_val = new_val | |
if contrib >= 0: | |
bar_colors.append("green") | |
else: | |
bar_colors.append("red") | |
final_prob = running_val | |
# Final point is the model's predicted probability (not always exact, but this is a shap-like idea). | |
# If we want to forcibly ensure final_prob = human_prob, we could do: | |
# correction = human_prob - running_val | |
# running_val += correction | |
# but for now let's keep the "waterfall" purely additive from the gradient. | |
# Let's plot: | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
# We'll create the bars manually | |
x_positions = np.arange(len(x_labels)) | |
last_end = baseline | |
for i, ((start_val, end_val), color) in enumerate(zip(y_vals, bar_colors)): | |
# The bar's height is the difference | |
height = end_val - start_val | |
ax.bar(i, height, bottom=start_val, color=color, edgecolor='black', alpha=0.7) | |
ax.text(i, (start_val + end_val) / 2, f"{height:+.3f}", ha='center', va='center', color='white', fontsize=8) | |
ax.axhline(y=baseline, color='black', linestyle='--', linewidth=1) | |
ax.set_xticks(x_positions) | |
ax.set_xticklabels(x_labels, rotation=45, ha='right') | |
ax.set_ylim(0, 1) | |
ax.set_ylabel("Running Probability (Human)") | |
ax.set_title(f"SHAP-like Waterfall — Final Probability: {final_prob:.3f} (Model Probability: {human_prob:.3f})") | |
plt.tight_layout() | |
return fig | |
def create_frequency_sigma_plot(important_kmers, title): | |
"""Creates a bar plot of the top k-mers (by importance) showing frequency (%) and σ from mean.""" | |
# Sort by absolute impact | |
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True) | |
kmers = [k["kmer"] for k in sorted_kmers] | |
frequencies = [k["occurrence"] for k in sorted_kmers] # in % | |
sigmas = [k["sigma"] for k in sorted_kmers] | |
directions = [k["direction"] for k in sorted_kmers] | |
x = np.arange(len(kmers)) | |
width = 0.4 | |
fig, ax_bar = plt.subplots(figsize=(10, 6)) | |
# Bar for frequency | |
bars_freq = ax_bar.bar( | |
x - width/2, frequencies, width, alpha=0.7, | |
color=["green" if d=="human" else "red" for d in directions], | |
label="Frequency (%)" | |
) | |
ax_bar.set_ylabel("Frequency (%)") | |
ax_bar.set_ylim(0, max(frequencies) * 1.2 if frequencies else 1) | |
# Twin axis for σ | |
ax_bar_twin = ax_bar.twinx() | |
bars_sigma = ax_bar_twin.bar( | |
x + width/2, sigmas, width, alpha=0.5, color="gray", label="σ from Mean" | |
) | |
ax_bar_twin.set_ylabel("Standard Deviations (σ)") | |
ax_bar.set_title(f"Frequency & σ from Mean for Top k-mers — {title}") | |
ax_bar.set_xticks(x) | |
ax_bar.set_xticklabels(kmers, rotation=45, ha='right') | |
# Combined legend | |
lines1, labels1 = ax_bar.get_legend_handles_labels() | |
lines2, labels2 = ax_bar_twin.get_legend_handles_labels() | |
ax_bar.legend(lines1 + lines2, labels1 + labels2, loc="upper right") | |
plt.tight_layout() | |
return fig | |
def create_importance_bar_plot(important_kmers, title): | |
""" | |
Create a simple bar chart showing the absolute gradient magnitude | |
for the top k-mers, sorted descending. | |
""" | |
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True) | |
kmers = [k['kmer'] for k in sorted_kmers] | |
impacts = [k['impact'] for k in sorted_kmers] | |
directions = [k["direction"] for k in sorted_kmers] | |
x = np.arange(len(kmers)) | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
bar_colors = ["green" if d=="human" else "red" for d in directions] | |
ax.bar(x, impacts, color=bar_colors, alpha=0.7) | |
ax.set_xticks(x) | |
ax.set_xticklabels(kmers, rotation=45, ha='right') | |
ax.set_title(f"Absolute Feature Importance (Top k-mers) — {title}") | |
ax.set_ylabel("Gradient Magnitude") | |
ax.grid(axis="y", alpha=0.3) | |
plt.tight_layout() | |
return fig | |
############################################################################### | |
# Prediction Function | |
############################################################################### | |
def predict(file_obj): | |
""" | |
Main function for Gradio: | |
1. Reads the uploaded FASTA file or text. | |
2. Loads the model and scaler. | |
3. Generates predictions, probabilities, and top k-mers. | |
4. Returns multiple outputs: | |
- A textual summary (Markdown). | |
- Waterfall plot. | |
- Frequency & sigma plot. | |
- Absolute importance bar plot. | |
""" | |
# 0. Basic file read | |
if file_obj is None: | |
return ( | |
"Please upload a FASTA file.", | |
None, | |
None, | |
None | |
) | |
try: | |
# If user provided raw text, use that | |
if isinstance(file_obj, str): | |
text = file_obj | |
else: | |
# If user uploaded a file, decode it | |
text = file_obj.decode('utf-8') | |
except Exception as e: | |
return ( | |
f"Error reading file: {str(e)}", | |
None, | |
None, | |
None | |
) | |
# 1. Parse FASTA | |
sequences = parse_fasta(text) | |
if len(sequences) == 0: | |
return ( | |
"No valid FASTA sequences found. Please check your input.", | |
None, | |
None, | |
None | |
) | |
# We’ll just classify the first sequence for demonstration | |
header, seq = sequences[0] | |
# 2. Create k-mer vector & load model | |
k = 4 | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Prepare raw freq vector & scale | |
raw_freq_vector = sequence_to_kmer_vector(seq, k=k) | |
# Load model & scaler | |
model = VirusClassifier(input_shape=4**k).to(device) | |
state_dict = torch.load('model.pt', map_location=device) | |
model.load_state_dict(state_dict) | |
scaler = joblib.load('scaler.pkl') | |
model.eval() | |
scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1)) | |
X_tensor = torch.FloatTensor(scaled_vector).to(device) | |
# 3. Inference | |
with torch.no_grad(): | |
logits = model(X_tensor) | |
probs = torch.softmax(logits, dim=1) | |
human_prob = float(probs[0][1]) | |
non_human_prob = float(probs[0][0]) | |
pred_class = 1 if human_prob >= non_human_prob else 0 | |
pred_label = "human" if pred_class == 1 else "non-human" | |
confidence = float(max(probs[0])) | |
# 4. Feature importance | |
importance, hum_prob_grad = model.get_feature_importance(X_tensor) | |
# shape: [1, 256] | |
kmer_importances = importance[0].cpu().numpy() | |
# We’ll store them as a dictionary: index -> (k-mer, importance) | |
# Build up a dict for k-mer strings | |
kmers_list = [''.join(p) for p in product("ACGT", repeat=k)] | |
kmer_dict = {km: i for i, km in enumerate(kmers_list)} | |
# 5. Get the top 10 k-mers by absolute importance | |
abs_importance = np.abs(kmer_importances) | |
top_k = 10 | |
top_idxs = np.argsort(abs_importance)[-top_k:][::-1] # descending | |
important_kmers = [] | |
for idx in top_idxs: | |
# Find the k-mer by index | |
kmer_str = kmers_list[idx] | |
# direction | |
direction = "human" if kmer_importances[idx] > 0 else "non-human" | |
# frequency in % from raw_freq_vector | |
freq_percent = float(raw_freq_vector[idx] * 100) | |
# sigma from scaled vector | |
sigma_val = float(scaled_vector[0][idx]) | |
important_kmers.append({ | |
'kmer': kmer_str, | |
'idx': idx, | |
'impact': float(abs_importance[idx]), | |
'direction': direction, | |
'occurrence': freq_percent, | |
'sigma': sigma_val | |
}) | |
# 6. Text Summary | |
summary_text = ( | |
f"**Sequence Header**: {header}\n\n" | |
f"**Predicted Label**: {pred_label}\n" | |
f"**Confidence**: {confidence:.4f}\n\n" | |
f"**Human Probability**: {human_prob:.4f}\n" | |
f"**Non-human Probability**: {non_human_prob:.4f}\n\n" | |
"### Most Influential k-mers:\n" | |
) | |
for km in important_kmers: | |
direction_text = f"(pushes toward {km['direction']})" | |
freq_text = f"{km['occurrence']:.2f}%" | |
sigma_text = f"{abs(km['sigma']):.2f}σ " + ("above" if km['sigma']>0 else "below") + " mean" | |
summary_text += ( | |
f"- **{km['kmer']}**: impact={km['impact']:.4f}, {direction_text}, " | |
f"occurrence={freq_text}, ({sigma_text})\n" | |
) | |
# 7. Plots | |
# a) SHAP-like Waterfall Plot | |
fig_waterfall = create_shap_waterfall_plot( | |
important_kmers, | |
kmer_importances, | |
human_prob, | |
f"{header}" | |
) | |
buf1 = io.BytesIO() | |
fig_waterfall.savefig(buf1, format='png', bbox_inches='tight', dpi=120) | |
buf1.seek(0) | |
waterfall_img = Image.open(buf1) | |
plt.close(fig_waterfall) | |
# b) Frequency & σ Plot (top 10 k-mers) | |
fig_freq_sigma = create_frequency_sigma_plot( | |
important_kmers, | |
f"{header}" | |
) | |
buf2 = io.BytesIO() | |
fig_freq_sigma.savefig(buf2, format='png', bbox_inches='tight', dpi=120) | |
buf2.seek(0) | |
freq_sigma_img = Image.open(buf2) | |
plt.close(fig_freq_sigma) | |
# c) Absolute Importance Bar Plot | |
fig_imp = create_importance_bar_plot( | |
important_kmers, | |
f"{header}" | |
) | |
buf3 = io.BytesIO() | |
fig_imp.savefig(buf3, format='png', bbox_inches='tight', dpi=120) | |
buf3.seek(0) | |
importance_img = Image.open(buf3) | |
plt.close(fig_imp) | |
return summary_text, waterfall_img, freq_sigma_img, importance_img | |
except Exception as e: | |
return ( | |
f"Error during prediction or visualization: {str(e)}", | |
None, | |
None, | |
None | |
) | |
############################################################################### | |
# Gradio Interface | |
############################################################################### | |
with gr.Blocks(title="Advanced Virus Host Classifier") as demo: | |
gr.Markdown( | |
""" | |
# Advanced Virus Host Classifier | |
**Upload a FASTA file** containing a single nucleotide sequence. | |
The model will predict whether this sequence is **human** or **non-human**, | |
provide a confidence score, and highlight the most influential k-mers | |
(using a SHAP-like waterfall plot) along with two additional plots. | |
""" | |
) | |
with gr.Row(): | |
file_in = gr.File(label="Upload FASTA", type="binary") | |
btn = gr.Button("Run Prediction") | |
# We will create multiple tabs for our outputs | |
with gr.Tabs(): | |
with gr.Tab("Prediction Results"): | |
md_out = gr.Markdown() | |
with gr.Tab("SHAP-like Waterfall Plot"): | |
water_out = gr.Image() | |
with gr.Tab("Frequency & σ Plot"): | |
freq_out = gr.Image() | |
with gr.Tab("Importance Bar Plot"): | |
imp_out = gr.Image() | |
# Link the button | |
btn.click( | |
fn=predict, | |
inputs=[file_in], | |
outputs=[md_out, water_out, freq_out, imp_out] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |