HostClassifier / app.py
hiyata's picture
Update app.py
0681a74 verified
raw
history blame
9.36 kB
import gradio as gr
import torch
import joblib
import numpy as np
from itertools import product
import torch.nn as nn
import matplotlib.pyplot as plt
import io
from PIL import Image
class VirusClassifier(nn.Module):
def __init__(self, input_shape: int):
super(VirusClassifier, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_shape, 64),
nn.GELU(),
nn.BatchNorm1d(64),
nn.Dropout(0.3),
nn.Linear(64, 32),
nn.GELU(),
nn.BatchNorm1d(32),
nn.Dropout(0.3),
nn.Linear(32, 32),
nn.GELU(),
nn.Linear(32, 2)
)
def forward(self, x):
return self.network(x)
def parse_fasta(text):
"""
Parses FASTA formatted text into a list of (header, sequence).
"""
sequences = []
current_header = None
current_sequence = []
for line in text.strip().split('\n'):
line = line.strip()
if not line:
continue
if line.startswith('>'):
if current_header:
sequences.append((current_header, ''.join(current_sequence)))
current_header = line[1:]
current_sequence = []
else:
current_sequence.append(line.upper())
if current_header:
sequences.append((current_header, ''.join(current_sequence)))
return sequences
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
"""
Convert a sequence to a k-mer frequency vector.
"""
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
kmer_dict = {km: i for i, km in enumerate(kmers)}
vec = np.zeros(len(kmers), dtype=np.float32)
for i in range(len(sequence) - k + 1):
kmer = sequence[i:i+k]
if kmer in kmer_dict:
vec[kmer_dict[kmer]] += 1
total_kmers = len(sequence) - k + 1
if total_kmers > 0:
vec = vec / total_kmers
return vec
def calculate_shap_values(model, x_tensor):
"""
Calculate SHAP-like values using a simple ablation approach.
"""
model.eval()
with torch.no_grad():
baseline_output = model(x_tensor)
baseline_prob = torch.softmax(baseline_output, dim=1)[0, 1].item()
shap_values = []
for i in range(x_tensor.shape[1]):
perturbed_input = x_tensor.clone()
perturbed_input[0, i] = 0 # Ablate feature
output = model(perturbed_input)
prob = torch.softmax(output, dim=1)[0, 1].item()
shap_values.append(baseline_prob - prob)
return np.array(shap_values), baseline_prob
def create_importance_plot(shap_values, kmers, top_k=10):
"""
Create horizontal bar plot of feature importance.
"""
plt.style.use('seaborn')
fig = plt.figure(figsize=(10, 8))
# Sort by absolute importance
indices = np.argsort(np.abs(shap_values))[-top_k:]
values = shap_values[indices]
features = [kmers[i] for i in indices]
colors = ['#2ecc71' if v > 0 else '#e74c3c' for v in values]
plt.barh(range(len(values)), values, color=colors)
plt.yticks(range(len(values)), features)
plt.xlabel('Impact on Prediction (SHAP value)')
plt.title(f'Top {top_k} Most Influential k-mers')
plt.gca().invert_yaxis()
return fig
def create_contribution_plot(important_kmers, final_prob):
"""
Create waterfall plot showing cumulative feature contributions.
"""
plt.style.use('seaborn')
fig = plt.figure(figsize=(12, 6))
base_prob = 0.5
cumulative = [base_prob]
labels = ['Base']
for kmer_info in important_kmers:
cumulative.append(cumulative[-1] + kmer_info['impact'])
labels.append(kmer_info['kmer'])
plt.plot(range(len(cumulative)), cumulative, 'b-o', linewidth=2)
plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
plt.xticks(range(len(labels)), labels, rotation=45)
plt.ylim(0, 1)
plt.grid(True, alpha=0.3)
plt.title('Cumulative Feature Contributions')
plt.ylabel('Probability of Human Origin')
return fig
def predict(file_obj, top_kmers=10, fasta_text=""):
"""
Main prediction function for the Gradio interface.
"""
# Handle input
if fasta_text.strip():
text = fasta_text.strip()
elif file_obj is not None:
try:
# File input will be a filepath since we specified type="filepath"
with open(file_obj, 'r') as f:
text = f.read()
except Exception as e:
return f"Error reading file: {str(e)}\nPlease ensure you're uploading a valid FASTA text file.", None, None
else:
return "Please provide a FASTA sequence either by file upload or text input.", None, None
# Parse FASTA
sequences = parse_fasta(text)
if not sequences:
return "No valid FASTA sequences found in input.", None, None
header, seq = sequences[0]
# Process sequence
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
model = VirusClassifier(256).to(device)
model.load_state_dict(torch.load('model.pt', map_location=device))
scaler = joblib.load('scaler.pkl')
except Exception as e:
return f"Error loading model: {str(e)}", None, None
# Generate features
freq_vector = sequence_to_kmer_vector(seq)
scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
x_tensor = torch.FloatTensor(scaled_vector).to(device)
# Calculate SHAP values and predictions
shap_values, human_prob = calculate_shap_values(model, x_tensor)
# Generate k-mer information
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
important_indices = np.argsort(np.abs(shap_values))[-top_kmers:]
important_kmers = []
for idx in important_indices:
important_kmers.append({
'kmer': kmers[idx],
'impact': shap_values[idx],
'frequency': freq_vector[idx] * 100,
'significance': scaled_vector[0][idx]
})
# Format results text
results = [
f"Sequence: {header}",
f"Prediction: {'Human' if human_prob > 0.5 else 'Non-human'} Origin",
f"Confidence: {max(human_prob, 1-human_prob):.3f}",
f"Human Probability: {human_prob:.3f}",
"\nTop Contributing k-mers:",
]
for kmer in important_kmers:
direction = "β†’ Human" if kmer['impact'] > 0 else "β†’ Non-human"
results.append(
f"β€’ {kmer['kmer']}: {direction} "
f"(impact: {kmer['impact']:.3f}, "
f"freq: {kmer['frequency']:.2f}%)"
)
# Generate plots
shap_plot = create_importance_plot(shap_values, kmers, top_kmers)
contribution_plot = create_contribution_plot(important_kmers, human_prob)
# Convert plots to images
def fig_to_image(fig):
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
buf.seek(0)
img = Image.open(buf)
plt.close(fig)
return img
return "\n".join(results), fig_to_image(shap_plot), fig_to_image(contribution_plot)
# Create Gradio interface
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.interpretation-container {
margin-top: 20px;
padding: 15px;
border-radius: 8px;
background-color: #f8f9fa;
}
"""
with gr.Blocks(css=css) as iface:
gr.Markdown("""
# Virus Host Classifier
This tool predicts whether a viral sequence is likely of human or non-human origin using k-mer frequency analysis.
### Instructions
1. Upload a FASTA file or paste your sequence in FASTA format
2. Adjust the number of top k-mers to display (default: 10)
3. View the prediction results and feature importance visualizations
""")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="Upload FASTA file",
file_types=[".fasta", ".fa", ".txt"],
type="filepath" # Changed to filepath which is one of the valid options
)
text_input = gr.Textbox(
label="Or paste FASTA sequence",
placeholder=">sequence_name\nACGTACGT...",
lines=5
)
top_k = gr.Slider(
minimum=5,
maximum=20,
value=10,
step=1,
label="Number of top k-mers to display"
)
submit_btn = gr.Button("Analyze Sequence", variant="primary")
with gr.Column(scale=2):
results = gr.Textbox(label="Analysis Results", lines=10)
shap_plot = gr.Image(label="Feature Importance Plot")
contribution_plot = gr.Image(label="Cumulative Contribution Plot")
submit_btn.click(
predict,
inputs=[file_input, top_k, text_input],
outputs=[results, shap_plot, contribution_plot]
)
gr.Markdown("""
### About
- Uses 4-mer frequencies as sequence features
- Employs SHAP-like values for feature importance interpretation
- Visualizes cumulative feature contributions to the final prediction
""")
if __name__ == "__main__":
iface.launch()