Spaces:
Running
Running
import gradio as gr | |
import torch | |
import joblib | |
import numpy as np | |
from itertools import product | |
import torch.nn as nn | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
############################################################################### | |
# Model Definition | |
############################################################################### | |
class VirusClassifier(nn.Module): | |
def __init__(self, input_shape: int): | |
super(VirusClassifier, self).__init__() | |
self.network = nn.Sequential( | |
nn.Linear(input_shape, 64), | |
nn.GELU(), | |
nn.BatchNorm1d(64), | |
nn.Dropout(0.3), | |
nn.Linear(64, 32), | |
nn.GELU(), | |
nn.BatchNorm1d(32), | |
nn.Dropout(0.3), | |
nn.Linear(32, 32), | |
nn.GELU(), | |
nn.Linear(32, 2) | |
) | |
def forward(self, x): | |
return self.network(x) | |
def get_feature_importance(self, x): | |
""" | |
Calculate gradient-based feature importance. | |
We'll compute the gradient of the 'human' probability w.r.t. the input vector. | |
""" | |
x.requires_grad_(True) | |
output = self.network(x) | |
probs = torch.softmax(output, dim=1) | |
# Gradient wrt 'human' class probability (index=1) | |
human_prob = probs[..., 1] | |
if x.grad is not None: | |
x.grad.zero_() | |
human_prob.backward() | |
importance = x.grad # shape: (batch_size, n_features) | |
return importance, float(human_prob) | |
############################################################################### | |
# Utility Functions | |
############################################################################### | |
def parse_fasta(text): | |
"""Parses text input in FASTA format into a list of (header, sequence).""" | |
sequences = [] | |
current_header = None | |
current_sequence = [] | |
for line in text.split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('>'): | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
current_header = line[1:] | |
current_sequence = [] | |
else: | |
current_sequence.append(line.upper()) | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
return sequences | |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: | |
"""Convert a single nucleotide sequence to a k-mer frequency vector.""" | |
kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
kmer_dict = {km: i for i, km in enumerate(kmers)} | |
vec = np.zeros(len(kmers), dtype=np.float32) | |
for i in range(len(sequence) - k + 1): | |
kmer = sequence[i:i+k] | |
if kmer in kmer_dict: | |
vec[kmer_dict[kmer]] += 1 | |
total_kmers = len(sequence) - k + 1 | |
if total_kmers > 0: | |
vec = vec / total_kmers # normalize frequencies | |
return vec | |
############################################################################### | |
# Visualization | |
############################################################################### | |
def create_visualization(important_kmers, human_prob, title): | |
""" | |
Create a multi-panel figure showing: | |
1) A waterfall-like plot for how each top k-mer shifts the probability from 0.5 | |
(the baseline) to the final 'human' probability. | |
2) A side-by-side bar plot for frequency (%) and σ from mean for each important k-mer. | |
""" | |
# Figure & GridSpec Layout | |
fig = plt.figure(figsize=(14, 10)) | |
gs = plt.GridSpec(2, 2, width_ratios=[1.2, 1], height_ratios=[1.2, 1], hspace=0.35, wspace=0.3) | |
# ------------------------------------------------------------------------- | |
# 1. Waterfall-like Plot (top-left subplot) | |
# ------------------------------------------------------------------------- | |
ax_waterfall = plt.subplot(gs[0, 0]) | |
# Start from baseline prob=0.5 | |
baseline = 0.5 | |
current_prob = baseline | |
steps = [("Baseline", current_prob, 0.0)] | |
# Build up the step changes | |
for kmer in important_kmers: | |
direction_multiplier = 1 if kmer["direction"] == "human" else -1 | |
change = kmer["impact"] * 0.05 * direction_multiplier | |
# ^ scale changes so that the sum doesn't overshadow the final probability. | |
current_prob += change | |
steps.append((kmer["kmer"], current_prob, change)) | |
# X-values for step plot | |
x_vals = range(len(steps)) | |
y_vals = [s[1] for s in steps] | |
ax_waterfall.step(x_vals, y_vals, where='post', color='blue', linewidth=2, label='Probability') | |
ax_waterfall.plot(x_vals, y_vals, 'b.', markersize=8) | |
# Reference lines | |
ax_waterfall.axhline(y=baseline, color='gray', linestyle='--', label='Baseline=0.5') | |
# Annotate each step | |
for i, (kmer, prob, change) in enumerate(steps): | |
if i == 0: # baseline | |
ax_waterfall.annotate(kmer, (i, prob), textcoords="offset points", xytext=(0, -15), ha='center', color='black') | |
continue | |
color = "green" if change > 0 else "red" | |
ax_waterfall.annotate( | |
f"{kmer}\n({change:+.3f})", | |
(i, prob), | |
textcoords="offset points", | |
xytext=(0, -15), | |
ha='center', | |
color=color, | |
fontsize=9 | |
) | |
ax_waterfall.set_ylim(0, 1) | |
ax_waterfall.set_xlabel("k-mer Step") | |
ax_waterfall.set_ylabel("Running Probability (Human)") | |
ax_waterfall.set_title(f"K-mer Waterfall Plot — Final Probability: {human_prob:.3f}") | |
ax_waterfall.grid(alpha=0.3) | |
ax_waterfall.legend() | |
# ------------------------------------------------------------------------- | |
# 2. Frequency & σ from Mean (top-right subplot) | |
# ------------------------------------------------------------------------- | |
ax_bar = plt.subplot(gs[0, 1]) | |
kmers = [k["kmer"] for k in important_kmers] | |
frequencies = [k["occurrence"] for k in important_kmers] # in % | |
sigmas = [k["sigma"] for k in important_kmers] | |
directions = [k["direction"] for k in important_kmers] | |
# X-locations | |
x = np.arange(len(kmers)) | |
width = 0.4 | |
# We will create twin axes: one for frequency, one for σ | |
bars1 = ax_bar.bar(x - width/2, frequencies, width, label='Frequency (%)', | |
alpha=0.7, color=['green' if d=='human' else 'red' for d in directions]) | |
ax_bar.set_ylabel("Frequency (%)") | |
ax_bar.set_ylim(0, max(frequencies) * 1.2 if frequencies else 1) | |
ax_bar.set_title("Frequency vs. σ from Mean") | |
# Twin axis for σ | |
ax_bar_twin = ax_bar.twinx() | |
bars2 = ax_bar_twin.bar(x + width/2, sigmas, width, label='σ from Mean', | |
alpha=0.5, color='gray') | |
ax_bar_twin.set_ylabel("Standard Deviations (σ)") | |
ax_bar.set_xticks(x) | |
ax_bar.set_xticklabels(kmers, rotation=45, ha='right', fontsize=9) | |
# Combine legends | |
lines1, labels1 = ax_bar.get_legend_handles_labels() | |
lines2, labels2 = ax_bar_twin.get_legend_handles_labels() | |
ax_bar.legend(lines1 + lines2, labels1 + labels2, loc='upper right') | |
# ------------------------------------------------------------------------- | |
# 3. Top Feature Importances (Bottom, spanning both columns) | |
# ------------------------------------------------------------------------- | |
ax_imp = plt.subplot(gs[1, :]) | |
# Sort by absolute impact | |
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True) | |
top_kmer_labels = [k['kmer'] for k in sorted_kmers] | |
top_kmer_impacts = [k['impact'] for k in sorted_kmers] | |
top_kmer_dirs = [k['direction'] for k in sorted_kmers] | |
x_imp = np.arange(len(top_kmer_impacts)) | |
bar_colors = ['green' if d == 'human' else 'red' for d in top_kmer_dirs] | |
ax_imp.bar(x_imp, top_kmer_impacts, color=bar_colors, alpha=0.7) | |
ax_imp.set_xticks(x_imp) | |
ax_imp.set_xticklabels(top_kmer_labels, rotation=45, ha='right', fontsize=9) | |
ax_imp.set_title("Absolute Feature Importance (Top k-mers)") | |
ax_imp.set_ylabel("Importance (gradient magnitude)") | |
ax_imp.grid(alpha=0.3, axis='y') | |
plt.suptitle(title, fontsize=14, y=1.02) | |
plt.tight_layout() | |
return fig | |
############################################################################### | |
# Prediction Function | |
############################################################################### | |
def predict(file_obj): | |
""" | |
Main function that Gradio will call: | |
1. Reads the uploaded FASTA file (or text). | |
2. Loads the model and scaler. | |
3. Generates predictions, probabilities, and top k-mers. | |
4. Creates a summary text and a matplotlib figure for visualization. | |
""" | |
if file_obj is None: | |
return "Please upload a FASTA file.", None | |
# Read text from file | |
try: | |
if isinstance(file_obj, str): | |
text = file_obj | |
else: | |
text = file_obj.decode('utf-8') | |
except Exception as e: | |
return f"Error reading file: {str(e)}", None | |
# Build k-mer dictionary | |
k = 4 | |
kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
kmer_dict = {km: i for i, km in enumerate(kmers)} | |
# Load model & scaler | |
try: | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = VirusClassifier(256).to(device) | |
state_dict = torch.load('model.pt', map_location=device) | |
model.load_state_dict(state_dict) | |
scaler = joblib.load('scaler.pkl') | |
model.eval() | |
except Exception as e: | |
return f"Error loading model or scaler: {str(e)}", None | |
results_text = "" | |
plot_image = None | |
try: | |
# Parse FASTA | |
sequences = parse_fasta(text) | |
if len(sequences) == 0: | |
return "No valid FASTA sequences found. Please check your input.", None | |
header, seq = sequences[0] # For simplicity, we'll only classify the first sequence | |
# Transform sequence to scaled k-mer vector | |
raw_freq_vector = sequence_to_kmer_vector(seq) | |
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1)) | |
X_tensor = torch.FloatTensor(kmer_vector).to(device) | |
# Inference | |
with torch.no_grad(): | |
output = model(X_tensor) | |
probs = torch.softmax(output, dim=1) | |
# Feature Importance | |
importance, hum_prob_grad = model.get_feature_importance(X_tensor) | |
kmer_importance = importance[0].cpu().numpy() # shape: (256,) | |
# Top k-mers by absolute importance | |
top_k = 10 | |
top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1] # largest -> smallest | |
important_kmers = [] | |
for idx in top_indices: | |
# find corresponding k-mer by index | |
for kmer_str, i_ in kmer_dict.items(): | |
if i_ == idx: | |
kmer_name = kmer_str | |
break | |
imp_val = float(abs(kmer_importance[idx])) | |
direction = 'human' if kmer_importance[idx] > 0 else 'non-human' | |
freq = float(raw_freq_vector[idx] * 100) # frequency in % | |
sigma = float(kmer_vector[0][idx]) # scaled value (Z-score if standard scaler) | |
important_kmers.append({ | |
'kmer': kmer_name, | |
'impact': imp_val, | |
'direction': direction, | |
'occurrence': freq, | |
'sigma': sigma | |
}) | |
pred_class = 1 if probs[0][1] > probs[0][0] else 0 | |
pred_label = 'human' if pred_class == 1 else 'non-human' | |
human_prob = float(probs[0][1]) | |
non_human_prob = float(probs[0][0]) | |
conf = float(max(probs[0])) # confidence in the predicted class | |
# Generate text results | |
results_text = ( | |
f"**Sequence Header**: {header}\n\n" | |
f"**Predicted Label**: {pred_label}\n" | |
f"**Confidence**: {conf:.4f}\n\n" | |
f"**Human Probability**: {human_prob:.4f}\n" | |
f"**Non-human Probability**: {non_human_prob:.4f}\n\n" | |
"### Most Influential k-mers:\n" | |
) | |
for k in important_kmers: | |
direction_text = f"pushes toward {k['direction']}" | |
occurrence_text = f"{k['occurrence']:.2f}% of sequence" | |
sigma_text = f"{abs(k['sigma']):.2f}σ " + ("above" if k['sigma'] > 0 else "below") + " mean" | |
results_text += ( | |
f"- **{k['kmer']}**: " | |
f"impact = {k['impact']:.4f}, {direction_text}, " | |
f"occurrence = {occurrence_text}, " | |
f"({sigma_text})\n" | |
) | |
# Create figure | |
fig = create_visualization(important_kmers, human_prob, f"{header}") | |
# Convert figure to image | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150) | |
buf.seek(0) | |
plot_image = Image.open(buf) | |
plt.close(fig) | |
except Exception as e: | |
return f"Error during prediction or visualization: {str(e)}", None | |
return results_text, plot_image | |
############################################################################### | |
# Gradio Interface | |
############################################################################### | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.File(label="Upload FASTA file", type="binary"), | |
outputs=[ | |
gr.Markdown(label="Prediction Results"), | |
gr.Image(label="K-mer Analysis Visualization") | |
], | |
title="Virus Host Classifier", | |
description=( | |
"Upload a FASTA file containing a single nucleotide sequence. " | |
"This model will predict whether the virus host is **human** or **non-human**, " | |
"provide a confidence score, and highlight the most influential k-mers in the classification." | |
), | |
allow_flagging="never", | |
) | |
if __name__ == "__main__": | |
iface.launch(server_name="0.0.0.0", server_port=7860, share=True) | |