import gradio as gr import torch import joblib import numpy as np from itertools import product import torch.nn as nn import matplotlib.pyplot as plt import io from PIL import Image class VirusClassifier(nn.Module): def __init__(self, input_shape: int): super(VirusClassifier, self).__init__() self.network = nn.Sequential( nn.Linear(input_shape, 64), nn.GELU(), nn.BatchNorm1d(64), nn.Dropout(0.3), nn.Linear(64, 32), nn.GELU(), nn.BatchNorm1d(32), nn.Dropout(0.3), nn.Linear(32, 32), nn.GELU(), nn.Linear(32, 2) ) def forward(self, x): return self.network(x) def parse_fasta(text): """ Parses FASTA formatted text into a list of (header, sequence). """ sequences = [] current_header = None current_sequence = [] for line in text.strip().split('\n'): line = line.strip() if not line: continue if line.startswith('>'): if current_header: sequences.append((current_header, ''.join(current_sequence))) current_header = line[1:] current_sequence = [] else: current_sequence.append(line.upper()) if current_header: sequences.append((current_header, ''.join(current_sequence))) return sequences def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: """ Convert a sequence to a k-mer frequency vector. """ kmers = [''.join(p) for p in product("ACGT", repeat=k)] kmer_dict = {km: i for i, km in enumerate(kmers)} vec = np.zeros(len(kmers), dtype=np.float32) for i in range(len(sequence) - k + 1): kmer = sequence[i:i+k] if kmer in kmer_dict: vec[kmer_dict[kmer]] += 1 total_kmers = len(sequence) - k + 1 if total_kmers > 0: vec = vec / total_kmers return vec def calculate_shap_values(model, x_tensor): """ Calculate SHAP-like values using a simple ablation approach. """ model.eval() with torch.no_grad(): baseline_output = model(x_tensor) baseline_prob = torch.softmax(baseline_output, dim=1)[0, 1].item() shap_values = [] for i in range(x_tensor.shape[1]): perturbed_input = x_tensor.clone() perturbed_input[0, i] = 0 # Ablate feature output = model(perturbed_input) prob = torch.softmax(output, dim=1)[0, 1].item() shap_values.append(baseline_prob - prob) return np.array(shap_values), baseline_prob def create_importance_plot(shap_values, kmers, top_k=10): """ Create horizontal bar plot of feature importance. """ plt.style.use('seaborn') fig = plt.figure(figsize=(10, 8)) # Sort by absolute importance indices = np.argsort(np.abs(shap_values))[-top_k:] values = shap_values[indices] features = [kmers[i] for i in indices] colors = ['#2ecc71' if v > 0 else '#e74c3c' for v in values] plt.barh(range(len(values)), values, color=colors) plt.yticks(range(len(values)), features) plt.xlabel('Impact on Prediction (SHAP value)') plt.title(f'Top {top_k} Most Influential k-mers') plt.gca().invert_yaxis() return fig def create_contribution_plot(important_kmers, final_prob): """ Create waterfall plot showing cumulative feature contributions. """ plt.style.use('seaborn') fig = plt.figure(figsize=(12, 6)) base_prob = 0.5 cumulative = [base_prob] labels = ['Base'] for kmer_info in important_kmers: cumulative.append(cumulative[-1] + kmer_info['impact']) labels.append(kmer_info['kmer']) plt.plot(range(len(cumulative)), cumulative, 'b-o', linewidth=2) plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5) plt.xticks(range(len(labels)), labels, rotation=45) plt.ylim(0, 1) plt.grid(True, alpha=0.3) plt.title('Cumulative Feature Contributions') plt.ylabel('Probability of Human Origin') return fig def predict(file_obj, top_kmers=10, fasta_text=""): """ Main prediction function for the Gradio interface. """ # Handle input if fasta_text.strip(): text = fasta_text.strip() elif file_obj is not None: try: # File input will be a filepath since we specified type="filepath" with open(file_obj, 'r') as f: text = f.read() except Exception as e: return f"Error reading file: {str(e)}\nPlease ensure you're uploading a valid FASTA text file.", None, None else: return "Please provide a FASTA sequence either by file upload or text input.", None, None # Parse FASTA sequences = parse_fasta(text) if not sequences: return "No valid FASTA sequences found in input.", None, None header, seq = sequences[0] # Process sequence device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') try: model = VirusClassifier(256).to(device) model.load_state_dict(torch.load('model.pt', map_location=device)) scaler = joblib.load('scaler.pkl') except Exception as e: return f"Error loading model: {str(e)}", None, None # Generate features freq_vector = sequence_to_kmer_vector(seq) scaled_vector = scaler.transform(freq_vector.reshape(1, -1)) x_tensor = torch.FloatTensor(scaled_vector).to(device) # Calculate SHAP values and predictions shap_values, human_prob = calculate_shap_values(model, x_tensor) # Generate k-mer information kmers = [''.join(p) for p in product("ACGT", repeat=4)] important_indices = np.argsort(np.abs(shap_values))[-top_kmers:] important_kmers = [] for idx in important_indices: important_kmers.append({ 'kmer': kmers[idx], 'impact': shap_values[idx], 'frequency': freq_vector[idx] * 100, 'significance': scaled_vector[0][idx] }) # Format results text results = [ f"Sequence: {header}", f"Prediction: {'Human' if human_prob > 0.5 else 'Non-human'} Origin", f"Confidence: {max(human_prob, 1-human_prob):.3f}", f"Human Probability: {human_prob:.3f}", "\nTop Contributing k-mers:", ] for kmer in important_kmers: direction = "→ Human" if kmer['impact'] > 0 else "→ Non-human" results.append( f"• {kmer['kmer']}: {direction} " f"(impact: {kmer['impact']:.3f}, " f"freq: {kmer['frequency']:.2f}%)" ) # Generate plots shap_plot = create_importance_plot(shap_values, kmers, top_kmers) contribution_plot = create_contribution_plot(important_kmers, human_prob) # Convert plots to images def fig_to_image(fig): buf = io.BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', dpi=150) buf.seek(0) img = Image.open(buf) plt.close(fig) return img return "\n".join(results), fig_to_image(shap_plot), fig_to_image(contribution_plot) # Create Gradio interface css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } .interpretation-container { margin-top: 20px; padding: 15px; border-radius: 8px; background-color: #f8f9fa; } """ with gr.Blocks(css=css) as iface: gr.Markdown(""" # Virus Host Classifier This tool predicts whether a viral sequence is likely of human or non-human origin using k-mer frequency analysis. ### Instructions 1. Upload a FASTA file or paste your sequence in FASTA format 2. Adjust the number of top k-mers to display (default: 10) 3. View the prediction results and feature importance visualizations """) with gr.Row(): with gr.Column(scale=1): file_input = gr.File( label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath" # Changed to filepath which is one of the valid options ) text_input = gr.Textbox( label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5 ) top_k = gr.Slider( minimum=5, maximum=20, value=10, step=1, label="Number of top k-mers to display" ) submit_btn = gr.Button("Analyze Sequence", variant="primary") with gr.Column(scale=2): results = gr.Textbox(label="Analysis Results", lines=10) shap_plot = gr.Image(label="Feature Importance Plot") contribution_plot = gr.Image(label="Cumulative Contribution Plot") submit_btn.click( predict, inputs=[file_input, top_k, text_input], outputs=[results, shap_plot, contribution_plot] ) gr.Markdown(""" ### About - Uses 4-mer frequencies as sequence features - Employs SHAP-like values for feature importance interpretation - Visualizes cumulative feature contributions to the final prediction """) if __name__ == "__main__": iface.launch()