Spaces:
Running
Running
import gradio as gr | |
import torch | |
import joblib | |
import numpy as np | |
import torch.nn as nn | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
from itertools import product | |
# --------------- Model Definition --------------- | |
class VirusClassifier(nn.Module): | |
def __init__(self, input_shape: int): | |
super(VirusClassifier, self).__init__() | |
self.network = nn.Sequential( | |
nn.Linear(input_shape, 64), | |
nn.GELU(), | |
nn.BatchNorm1d(64), | |
nn.Dropout(0.3), | |
nn.Linear(64, 32), | |
nn.GELU(), | |
nn.BatchNorm1d(32), | |
nn.Dropout(0.3), | |
nn.Linear(32, 32), | |
nn.GELU(), | |
nn.Linear(32, 2) | |
) | |
def forward(self, x): | |
return self.network(x) | |
def get_gradient_importance(self, x, class_index=1): | |
""" | |
Calculate gradient-based importance for each input feature. | |
By default, we compute the gradient wrt the 'human' class (index=1). | |
This method is akin to a raw gradient or 'saliency' approach. | |
""" | |
x = x.clone().detach().requires_grad_(True) | |
output = self.network(x) | |
probs = torch.softmax(output, dim=1) | |
# Probability of the specified class | |
target_prob = probs[..., class_index] | |
# Zero existing gradients if any | |
if x.grad is not None: | |
x.grad.zero_() | |
# Backprop on that probability | |
target_prob.backward() | |
# Raw gradient is now in x.grad | |
importance = x.grad.detach() | |
# Optional: Multiply by input to get a more "integrated gradients"-like measure | |
# importance = importance * x.detach() | |
return importance, float(target_prob) | |
# --------------- Utility Functions --------------- | |
def parse_fasta(text: str): | |
""" | |
Parse a FASTA string and return a list of (header, sequence) pairs. | |
""" | |
sequences = [] | |
current_header = None | |
current_sequence = [] | |
for line in text.split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('>'): | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
current_header = line[1:] | |
current_sequence = [] | |
else: | |
current_sequence.append(line.upper()) | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
return sequences | |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: | |
""" | |
Convert a nucleotide sequence into a k-mer frequency vector. | |
Defaults to k=4. | |
""" | |
# Generate all possible k-mers | |
kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
kmer_dict = {km: i for i, km in enumerate(kmers)} | |
vec = np.zeros(len(kmers), dtype=np.float32) | |
for i in range(len(sequence) - k + 1): | |
kmer = sequence[i:i+k] | |
if kmer in kmer_dict: | |
vec[kmer_dict[kmer]] += 1 | |
total_kmers = len(sequence) - k + 1 | |
if total_kmers > 0: | |
vec = vec / total_kmers | |
return vec | |
def compute_sequence_stats(sequence: str): | |
""" | |
Compute various statistics for a given sequence: | |
- Length | |
- GC content (%) | |
- A/C/G/T counts | |
""" | |
length = len(sequence) | |
if length == 0: | |
return { | |
'length': 0, | |
'gc_content': 0, | |
'counts': {'A': 0, 'C': 0, 'G': 0, 'T': 0} | |
} | |
counts = { | |
'A': sequence.count('A'), | |
'C': sequence.count('C'), | |
'G': sequence.count('G'), | |
'T': sequence.count('T') | |
} | |
gc_content = (counts['G'] + counts['C']) / length * 100.0 | |
return { | |
'length': length, | |
'gc_content': gc_content, | |
'counts': counts | |
} | |
# --------------- Visualization Functions --------------- | |
def plot_shap_like_bars(kmers, importance_values, top_k=10): | |
""" | |
Create a bar chart that mimics a SHAP summary plot: | |
- k-mers on y-axis | |
- importance magnitude on x-axis | |
- color indicating positive (push towards human) vs negative (push towards non-human) | |
""" | |
abs_importance = np.abs(importance_values) | |
# Sort by absolute importance | |
sorted_indices = np.argsort(abs_importance)[::-1] | |
top_indices = sorted_indices[:top_k] | |
# Prepare data | |
top_kmers = [kmers[i] for i in top_indices] | |
top_importances = importance_values[top_indices] | |
# Create plot | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
colors = ['green' if val > 0 else 'red' for val in top_importances] | |
ax.barh(range(len(top_kmers)), np.abs(top_importances), color=colors) | |
ax.set_yticks(range(len(top_kmers))) | |
ax.set_yticklabels(top_kmers) | |
ax.invert_yaxis() # So that the highest value is at the top | |
ax.set_xlabel("Feature Importance (Gradient Magnitude)") | |
ax.set_title(f"Top-{top_k} SHAP-like Feature Importances") | |
plt.tight_layout() | |
return fig | |
def plot_kmer_distribution(kmer_freq_vector, kmers): | |
""" | |
Plot a histogram of k-mer frequencies for the entire vector. | |
(Optional if you want a quick distribution overview) | |
""" | |
fig, ax = plt.subplots(figsize=(10, 4)) | |
ax.bar(range(len(kmer_freq_vector)), kmer_freq_vector, color='blue', alpha=0.6) | |
ax.set_xlabel("K-mer Index") | |
ax.set_ylabel("Frequency") | |
ax.set_title("K-mer Frequency Distribution") | |
ax.set_xticks([]) | |
plt.tight_layout() | |
return fig | |
def create_step_visualization(important_kmers, human_prob): | |
""" | |
Re-implementation of your step-wise probability plot. | |
Shows how each top k-mer 'pushes' the probability from 0.5 to the final value. | |
""" | |
fig = plt.figure(figsize=(8, 5)) | |
ax = fig.add_subplot(111) | |
# Start from 0.5 | |
current_prob = 0.5 | |
steps = [('Start', current_prob, 0)] | |
for kmer in important_kmers: | |
change = kmer['impact'] * (-1 if kmer['direction'] == 'non-human' else 1) | |
current_prob += change | |
steps.append((kmer['kmer'], current_prob, change)) | |
x_vals = range(len(steps)) | |
y_vals = [s[1] for s in steps] | |
ax.step(x_vals, y_vals, 'b-', where='post', label='Probability', linewidth=2) | |
ax.plot(x_vals, y_vals, 'b.', markersize=10) | |
# Reference line at 0.5 | |
ax.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)') | |
ax.set_ylim(0, 1) | |
ax.set_ylabel('Human Probability') | |
ax.set_title(f'K-mer Contributions (final p={human_prob:.3f})') | |
ax.grid(True, linestyle='--', alpha=0.7) | |
for i, (kmer, prob, change) in enumerate(steps): | |
ax.annotate(kmer, | |
(i, prob), | |
xytext=(0, 10 if i % 2 == 0 else -20), | |
textcoords='offset points', | |
ha='center', | |
rotation=45) | |
if i > 0: | |
change_text = f'{change:+.3f}' | |
color = 'green' if change > 0 else 'red' | |
ax.annotate(change_text, | |
(i, prob), | |
xytext=(0, -20 if i % 2 == 0 else 10), | |
textcoords='offset points', | |
ha='center', | |
color=color) | |
ax.legend() | |
plt.tight_layout() | |
return fig | |
def plot_kmer_freq_and_sigma(important_kmers): | |
""" | |
Plot frequencies vs. sigma from mean for the top k-mers. | |
This reuses logic from the original create_visualization second subplot, | |
but as its own function for clarity. | |
""" | |
fig, ax = plt.subplots(figsize=(8, 5)) | |
# Prepare data | |
kmers = [k['kmer'] for k in important_kmers] | |
frequencies = [k['occurrence'] for k in important_kmers] | |
sigmas = [k['sigma'] for k in important_kmers] | |
colors = ['green' if k['direction'] == 'human' else 'red' for k in important_kmers] | |
x = np.arange(len(kmers)) | |
width = 0.35 | |
# Frequency bars | |
ax.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6) | |
# Create a twin axis for sigma | |
ax2 = ax.twinx() | |
# Sigma bars | |
ax2.bar(x + width/2, sigmas, width, label='σ from mean', | |
color=[c if s > 0 else 'gray' for c, s in zip(colors, sigmas)], alpha=0.3) | |
ax.set_xticks(x) | |
ax.set_xticklabels(kmers, rotation=45) | |
ax.set_ylabel('Frequency (%)') | |
ax2.set_ylabel('Standard Deviations (σ) from Mean') | |
ax.set_title("K-mer Frequencies & Statistical Significance") | |
lines1, labels1 = ax.get_legend_handles_labels() | |
lines2, labels2 = ax2.get_legend_handles_labels() | |
ax.legend(lines1 + lines2, labels1 + labels2, loc='best') | |
plt.tight_layout() | |
return fig | |
# --------------- Main Prediction Logic --------------- | |
def predict_fasta( | |
file_obj, | |
k_size=4, | |
top_k=10, | |
advanced_analysis=False | |
): | |
""" | |
Main function to predict classes for each sequence in an uploaded FASTA. | |
Returns: | |
- Combined textual report for all sequences | |
- A list of generated PIL Image plots | |
""" | |
# 1. Read raw text from file or string | |
if file_obj is None: | |
return "Please upload a FASTA file", [] | |
try: | |
if isinstance(file_obj, str): | |
text = file_obj | |
else: | |
text = file_obj.decode('utf-8', errors='replace') | |
except Exception as e: | |
return f"Error reading file: {str(e)}", [] | |
# 2. Parse the FASTA | |
sequences = parse_fasta(text) | |
if not sequences: | |
return "No valid FASTA sequences found!", [] | |
# 3. Load model & scaler | |
try: | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = VirusClassifier(input_shape=(4 ** k_size)).to(device) | |
state_dict = torch.load('model.pt', map_location=device) | |
model.load_state_dict(state_dict) | |
model.eval() | |
scaler = joblib.load('scaler.pkl') | |
except Exception as e: | |
return f"Error loading model/scaler: {str(e)}", [] | |
# 4. Prepare k-mer dictionary for reference | |
all_kmers = [''.join(p) for p in product("ACGT", repeat=k_size)] | |
kmer_dict = {km: i for i, km in enumerate(all_kmers)} | |
# 5. Iterate over sequences and build output | |
final_text_report = [] | |
plots = [] | |
for idx, (header, seq) in enumerate(sequences, start=1): | |
seq_stats = compute_sequence_stats(seq) | |
# Convert sequence -> raw freq -> scaled freq | |
raw_kmer_freq = sequence_to_kmer_vector(seq, k=k_size) | |
scaled_kmer_freq = scaler.transform(raw_kmer_freq.reshape(1, -1)) | |
X_tensor = torch.FloatTensor(scaled_kmer_freq).to(device) | |
# Predict | |
with torch.no_grad(): | |
output = model(X_tensor) | |
probs = torch.softmax(output, dim=1) | |
# Determine class | |
pred_class = torch.argmax(probs, dim=1).item() | |
pred_label = 'human' if pred_class == 1 else 'non-human' | |
human_prob = float(probs[0][1]) | |
non_human_prob = float(probs[0][0]) | |
confidence = float(torch.max(probs[0]).item()) | |
# Compute gradient-based importance | |
importance, target_prob = model.get_gradient_importance(X_tensor, class_index=1) | |
importance = importance[0].cpu().numpy() # shape: (num_features,) | |
# Identify top-k features (by absolute gradient) | |
abs_importance = np.abs(importance) | |
sorted_indices = np.argsort(abs_importance)[::-1] | |
top_indices = sorted_indices[:top_k] | |
# Build a list of top k-mers | |
top_kmers_info = [] | |
for i in top_indices: | |
kmer_name = all_kmers[i] | |
imp_val = float(importance[i]) | |
direction = 'human' if imp_val > 0 else 'non-human' | |
freq_perc = float(raw_kmer_freq[i] * 100.0) # in percent | |
sigma = float(scaled_kmer_freq[0][i]) # This is the scaled value (stdev from mean if the scaler is StandardScaler) | |
top_kmers_info.append({ | |
'kmer': kmer_name, | |
'impact': abs(imp_val), | |
'direction': direction, | |
'occurrence': freq_perc, | |
'sigma': sigma | |
}) | |
# Text summary for this sequence | |
seq_report = [] | |
seq_report.append(f"=== Sequence {idx} ===") | |
seq_report.append(f"Header: {header}") | |
seq_report.append(f"Length: {seq_stats['length']}") | |
seq_report.append(f"GC Content: {seq_stats['gc_content']:.2f}%") | |
seq_report.append(f"A: {seq_stats['counts']['A']}, C: {seq_stats['counts']['C']}, G: {seq_stats['counts']['G']}, T: {seq_stats['counts']['T']}") | |
seq_report.append(f"Prediction: {pred_label} (Confidence: {confidence:.4f})") | |
seq_report.append(f" Human Probability: {human_prob:.4f}") | |
seq_report.append(f" Non-human Probability: {non_human_prob:.4f}") | |
seq_report.append(f"\nTop-{top_k} Influential k-mers (by gradient magnitude):") | |
for tkm in top_kmers_info: | |
seq_report.append( | |
f" {tkm['kmer']}: pushes towards {tkm['direction']} " | |
f"(impact={tkm['impact']:.4f}), occurrence={tkm['occurrence']:.2f}%, " | |
f"sigma={tkm['sigma']:.2f}" | |
) | |
final_text_report.append("\n".join(seq_report)) | |
# 6. Generate Plots (for each sequence) | |
if advanced_analysis: | |
# 6A. SHAP-like bar chart | |
fig_shap = plot_shap_like_bars( | |
kmers=all_kmers, | |
importance_values=importance, | |
top_k=top_k | |
) | |
buf_shap = io.BytesIO() | |
fig_shap.savefig(buf_shap, format='png', bbox_inches='tight', dpi=150) | |
buf_shap.seek(0) | |
plots.append(Image.open(buf_shap)) | |
plt.close(fig_shap) | |
# 6B. k-mer distribution histogram | |
fig_kmer_dist = plot_kmer_distribution(raw_kmer_freq, all_kmers) | |
buf_dist = io.BytesIO() | |
fig_kmer_dist.savefig(buf_dist, format='png', bbox_inches='tight', dpi=150) | |
buf_dist.seek(0) | |
plots.append(Image.open(buf_dist)) | |
plt.close(fig_kmer_dist) | |
# 6C. Original step visualization for top k k-mers | |
# Sort by actual 'impact' to preserve that step logic | |
# (largest absolute impact first) | |
top_kmers_info_step = sorted(top_kmers_info, key=lambda x: x['impact'], reverse=True) | |
fig_step = create_step_visualization(top_kmers_info_step, human_prob) | |
buf_step = io.BytesIO() | |
fig_step.savefig(buf_step, format='png', bbox_inches='tight', dpi=150) | |
buf_step.seek(0) | |
plots.append(Image.open(buf_step)) | |
plt.close(fig_step) | |
# 6D. Frequency vs. sigma bar chart | |
fig_freq_sigma = plot_kmer_freq_and_sigma(top_kmers_info_step) | |
buf_freq_sigma = io.BytesIO() | |
fig_freq_sigma.savefig(buf_freq_sigma, format='png', bbox_inches='tight', dpi=150) | |
buf_freq_sigma.seek(0) | |
plots.append(Image.open(buf_freq_sigma)) | |
plt.close(fig_freq_sigma) | |
# Combine all text results | |
combined_text = "\n\n".join(final_text_report) | |
return combined_text, plots | |
# --------------- Gradio Interface --------------- | |
def run_prediction( | |
file_obj, | |
k_size, | |
top_k, | |
advanced_analysis | |
): | |
""" | |
Wrapper for Gradio to handle the outputs in (text, List[Image]) form. | |
""" | |
text_output, pil_images = predict_fasta( | |
file_obj=file_obj, | |
k_size=k_size, | |
top_k=top_k, | |
advanced_analysis=advanced_analysis | |
) | |
return text_output, pil_images | |
with gr.Blocks() as demo: | |
gr.Markdown("# Virus Host Classifier (Improved!)") | |
gr.Markdown( | |
"Upload a FASTA file and configure k-mer size, number of top features, " | |
"and whether to run advanced analysis (plots of SHAP-like bars & k-mer distribution)." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
fasta_file = gr.File(label="Upload FASTA file", type="binary") | |
kmer_slider = gr.Slider(minimum=2, maximum=6, value=4, step=1, label="K-mer Size") | |
topk_slider = gr.Slider(minimum=5, maximum=20, value=10, step=1, label="Top-k Features") | |
advanced_check = gr.Checkbox(value=False, label="Advanced Analysis") | |
predict_button = gr.Button("Predict") | |
with gr.Column(): | |
results_text = gr.Textbox( | |
label="Results", | |
lines=20, | |
placeholder="Prediction results will appear here..." | |
) | |
# We can display multiple images in a Gallery or as separate outputs. | |
plots_gallery = gr.Gallery(label="Analysis Plots").style(grid=[2], height="auto") | |
predict_button.click( | |
fn=run_prediction, | |
inputs=[fasta_file, kmer_slider, topk_slider, advanced_check], | |
outputs=[results_text, plots_gallery] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |