import gradio as gr import torch import joblib import numpy as np import torch.nn as nn import matplotlib.pyplot as plt import io from PIL import Image from itertools import product # --------------- Model Definition --------------- class VirusClassifier(nn.Module): def __init__(self, input_shape: int): super(VirusClassifier, self).__init__() self.network = nn.Sequential( nn.Linear(input_shape, 64), nn.GELU(), nn.BatchNorm1d(64), nn.Dropout(0.3), nn.Linear(64, 32), nn.GELU(), nn.BatchNorm1d(32), nn.Dropout(0.3), nn.Linear(32, 32), nn.GELU(), nn.Linear(32, 2) ) def forward(self, x): return self.network(x) def get_gradient_importance(self, x, class_index=1): """ Calculate gradient-based importance for each input feature. By default, we compute the gradient wrt the 'human' class (index=1). This method is akin to a raw gradient or 'saliency' approach. """ x = x.clone().detach().requires_grad_(True) output = self.network(x) probs = torch.softmax(output, dim=1) # Probability of the specified class target_prob = probs[..., class_index] # Zero existing gradients if any if x.grad is not None: x.grad.zero_() # Backprop on that probability target_prob.backward() # Raw gradient is now in x.grad importance = x.grad.detach() # Optional: Multiply by input to get a more "integrated gradients"-like measure # importance = importance * x.detach() return importance, float(target_prob) # --------------- Utility Functions --------------- def parse_fasta(text: str): """ Parse a FASTA string and return a list of (header, sequence) pairs. """ sequences = [] current_header = None current_sequence = [] for line in text.split('\n'): line = line.strip() if not line: continue if line.startswith('>'): if current_header: sequences.append((current_header, ''.join(current_sequence))) current_header = line[1:] current_sequence = [] else: current_sequence.append(line.upper()) if current_header: sequences.append((current_header, ''.join(current_sequence))) return sequences def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: """ Convert a nucleotide sequence into a k-mer frequency vector. Defaults to k=4. """ # Generate all possible k-mers kmers = [''.join(p) for p in product("ACGT", repeat=k)] kmer_dict = {km: i for i, km in enumerate(kmers)} vec = np.zeros(len(kmers), dtype=np.float32) for i in range(len(sequence) - k + 1): kmer = sequence[i:i+k] if kmer in kmer_dict: vec[kmer_dict[kmer]] += 1 total_kmers = len(sequence) - k + 1 if total_kmers > 0: vec = vec / total_kmers return vec def compute_sequence_stats(sequence: str): """ Compute various statistics for a given sequence: - Length - GC content (%) - A/C/G/T counts """ length = len(sequence) if length == 0: return { 'length': 0, 'gc_content': 0, 'counts': {'A': 0, 'C': 0, 'G': 0, 'T': 0} } counts = { 'A': sequence.count('A'), 'C': sequence.count('C'), 'G': sequence.count('G'), 'T': sequence.count('T') } gc_content = (counts['G'] + counts['C']) / length * 100.0 return { 'length': length, 'gc_content': gc_content, 'counts': counts } # --------------- Visualization Functions --------------- def plot_shap_like_bars(kmers, importance_values, top_k=10): """ Create a bar chart that mimics a SHAP summary plot: - k-mers on y-axis - importance magnitude on x-axis - color indicating positive (push towards human) vs negative (push towards non-human) """ abs_importance = np.abs(importance_values) # Sort by absolute importance sorted_indices = np.argsort(abs_importance)[::-1] top_indices = sorted_indices[:top_k] # Prepare data top_kmers = [kmers[i] for i in top_indices] top_importances = importance_values[top_indices] # Create plot fig, ax = plt.subplots(figsize=(8, 6)) colors = ['green' if val > 0 else 'red' for val in top_importances] ax.barh(range(len(top_kmers)), np.abs(top_importances), color=colors) ax.set_yticks(range(len(top_kmers))) ax.set_yticklabels(top_kmers) ax.invert_yaxis() # So that the highest value is at the top ax.set_xlabel("Feature Importance (Gradient Magnitude)") ax.set_title(f"Top-{top_k} SHAP-like Feature Importances") plt.tight_layout() return fig def plot_kmer_distribution(kmer_freq_vector, kmers): """ Plot a histogram of k-mer frequencies for the entire vector. (Optional if you want a quick distribution overview) """ fig, ax = plt.subplots(figsize=(10, 4)) ax.bar(range(len(kmer_freq_vector)), kmer_freq_vector, color='blue', alpha=0.6) ax.set_xlabel("K-mer Index") ax.set_ylabel("Frequency") ax.set_title("K-mer Frequency Distribution") ax.set_xticks([]) plt.tight_layout() return fig def create_step_visualization(important_kmers, human_prob): """ Re-implementation of your step-wise probability plot. Shows how each top k-mer 'pushes' the probability from 0.5 to the final value. """ fig = plt.figure(figsize=(8, 5)) ax = fig.add_subplot(111) # Start from 0.5 current_prob = 0.5 steps = [('Start', current_prob, 0)] for kmer in important_kmers: change = kmer['impact'] * (-1 if kmer['direction'] == 'non-human' else 1) current_prob += change steps.append((kmer['kmer'], current_prob, change)) x_vals = range(len(steps)) y_vals = [s[1] for s in steps] ax.step(x_vals, y_vals, 'b-', where='post', label='Probability', linewidth=2) ax.plot(x_vals, y_vals, 'b.', markersize=10) # Reference line at 0.5 ax.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)') ax.set_ylim(0, 1) ax.set_ylabel('Human Probability') ax.set_title(f'K-mer Contributions (final p={human_prob:.3f})') ax.grid(True, linestyle='--', alpha=0.7) for i, (kmer, prob, change) in enumerate(steps): ax.annotate(kmer, (i, prob), xytext=(0, 10 if i % 2 == 0 else -20), textcoords='offset points', ha='center', rotation=45) if i > 0: change_text = f'{change:+.3f}' color = 'green' if change > 0 else 'red' ax.annotate(change_text, (i, prob), xytext=(0, -20 if i % 2 == 0 else 10), textcoords='offset points', ha='center', color=color) ax.legend() plt.tight_layout() return fig def plot_kmer_freq_and_sigma(important_kmers): """ Plot frequencies vs. sigma from mean for the top k-mers. This reuses logic from the original create_visualization second subplot, but as its own function for clarity. """ fig, ax = plt.subplots(figsize=(8, 5)) # Prepare data kmers = [k['kmer'] for k in important_kmers] frequencies = [k['occurrence'] for k in important_kmers] sigmas = [k['sigma'] for k in important_kmers] colors = ['green' if k['direction'] == 'human' else 'red' for k in important_kmers] x = np.arange(len(kmers)) width = 0.35 # Frequency bars ax.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6) # Create a twin axis for sigma ax2 = ax.twinx() # Sigma bars ax2.bar(x + width/2, sigmas, width, label='σ from mean', color=[c if s > 0 else 'gray' for c, s in zip(colors, sigmas)], alpha=0.3) ax.set_xticks(x) ax.set_xticklabels(kmers, rotation=45) ax.set_ylabel('Frequency (%)') ax2.set_ylabel('Standard Deviations (σ) from Mean') ax.set_title("K-mer Frequencies & Statistical Significance") lines1, labels1 = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax.legend(lines1 + lines2, labels1 + labels2, loc='best') plt.tight_layout() return fig # --------------- Main Prediction Logic --------------- def predict_fasta( file_obj, k_size=4, top_k=10, advanced_analysis=False ): """ Main function to predict classes for each sequence in an uploaded FASTA. Returns: - Combined textual report for all sequences - A list of generated PIL Image plots """ # 1. Read raw text from file or string if file_obj is None: return "Please upload a FASTA file", [] try: if isinstance(file_obj, str): text = file_obj else: text = file_obj.decode('utf-8', errors='replace') except Exception as e: return f"Error reading file: {str(e)}", [] # 2. Parse the FASTA sequences = parse_fasta(text) if not sequences: return "No valid FASTA sequences found!", [] # 3. Load model & scaler try: device = 'cuda' if torch.cuda.is_available() else 'cpu' model = VirusClassifier(input_shape=(4 ** k_size)).to(device) state_dict = torch.load('model.pt', map_location=device) model.load_state_dict(state_dict) model.eval() scaler = joblib.load('scaler.pkl') except Exception as e: return f"Error loading model/scaler: {str(e)}", [] # 4. Prepare k-mer dictionary for reference all_kmers = [''.join(p) for p in product("ACGT", repeat=k_size)] kmer_dict = {km: i for i, km in enumerate(all_kmers)} # 5. Iterate over sequences and build output final_text_report = [] plots = [] for idx, (header, seq) in enumerate(sequences, start=1): seq_stats = compute_sequence_stats(seq) # Convert sequence -> raw freq -> scaled freq raw_kmer_freq = sequence_to_kmer_vector(seq, k=k_size) scaled_kmer_freq = scaler.transform(raw_kmer_freq.reshape(1, -1)) X_tensor = torch.FloatTensor(scaled_kmer_freq).to(device) # Predict with torch.no_grad(): output = model(X_tensor) probs = torch.softmax(output, dim=1) # Determine class pred_class = torch.argmax(probs, dim=1).item() pred_label = 'human' if pred_class == 1 else 'non-human' human_prob = float(probs[0][1]) non_human_prob = float(probs[0][0]) confidence = float(torch.max(probs[0]).item()) # Compute gradient-based importance importance, target_prob = model.get_gradient_importance(X_tensor, class_index=1) importance = importance[0].cpu().numpy() # shape: (num_features,) # Identify top-k features (by absolute gradient) abs_importance = np.abs(importance) sorted_indices = np.argsort(abs_importance)[::-1] top_indices = sorted_indices[:top_k] # Build a list of top k-mers top_kmers_info = [] for i in top_indices: kmer_name = all_kmers[i] imp_val = float(importance[i]) direction = 'human' if imp_val > 0 else 'non-human' freq_perc = float(raw_kmer_freq[i] * 100.0) # in percent sigma = float(scaled_kmer_freq[0][i]) # This is the scaled value (stdev from mean if the scaler is StandardScaler) top_kmers_info.append({ 'kmer': kmer_name, 'impact': abs(imp_val), 'direction': direction, 'occurrence': freq_perc, 'sigma': sigma }) # Text summary for this sequence seq_report = [] seq_report.append(f"=== Sequence {idx} ===") seq_report.append(f"Header: {header}") seq_report.append(f"Length: {seq_stats['length']}") seq_report.append(f"GC Content: {seq_stats['gc_content']:.2f}%") seq_report.append(f"A: {seq_stats['counts']['A']}, C: {seq_stats['counts']['C']}, G: {seq_stats['counts']['G']}, T: {seq_stats['counts']['T']}") seq_report.append(f"Prediction: {pred_label} (Confidence: {confidence:.4f})") seq_report.append(f" Human Probability: {human_prob:.4f}") seq_report.append(f" Non-human Probability: {non_human_prob:.4f}") seq_report.append(f"\nTop-{top_k} Influential k-mers (by gradient magnitude):") for tkm in top_kmers_info: seq_report.append( f" {tkm['kmer']}: pushes towards {tkm['direction']} " f"(impact={tkm['impact']:.4f}), occurrence={tkm['occurrence']:.2f}%, " f"sigma={tkm['sigma']:.2f}" ) final_text_report.append("\n".join(seq_report)) # 6. Generate Plots (for each sequence) if advanced_analysis: # 6A. SHAP-like bar chart fig_shap = plot_shap_like_bars( kmers=all_kmers, importance_values=importance, top_k=top_k ) buf_shap = io.BytesIO() fig_shap.savefig(buf_shap, format='png', bbox_inches='tight', dpi=150) buf_shap.seek(0) plots.append(Image.open(buf_shap)) plt.close(fig_shap) # 6B. k-mer distribution histogram fig_kmer_dist = plot_kmer_distribution(raw_kmer_freq, all_kmers) buf_dist = io.BytesIO() fig_kmer_dist.savefig(buf_dist, format='png', bbox_inches='tight', dpi=150) buf_dist.seek(0) plots.append(Image.open(buf_dist)) plt.close(fig_kmer_dist) # 6C. Original step visualization for top k k-mers # Sort by actual 'impact' to preserve that step logic # (largest absolute impact first) top_kmers_info_step = sorted(top_kmers_info, key=lambda x: x['impact'], reverse=True) fig_step = create_step_visualization(top_kmers_info_step, human_prob) buf_step = io.BytesIO() fig_step.savefig(buf_step, format='png', bbox_inches='tight', dpi=150) buf_step.seek(0) plots.append(Image.open(buf_step)) plt.close(fig_step) # 6D. Frequency vs. sigma bar chart fig_freq_sigma = plot_kmer_freq_and_sigma(top_kmers_info_step) buf_freq_sigma = io.BytesIO() fig_freq_sigma.savefig(buf_freq_sigma, format='png', bbox_inches='tight', dpi=150) buf_freq_sigma.seek(0) plots.append(Image.open(buf_freq_sigma)) plt.close(fig_freq_sigma) # Combine all text results combined_text = "\n\n".join(final_text_report) return combined_text, plots # --------------- Gradio Interface --------------- def run_prediction( file_obj, k_size, top_k, advanced_analysis ): """ Wrapper for Gradio to handle the outputs in (text, List[Image]) form. """ text_output, pil_images = predict_fasta( file_obj=file_obj, k_size=k_size, top_k=top_k, advanced_analysis=advanced_analysis ) return text_output, pil_images with gr.Blocks() as demo: gr.Markdown("# Virus Host Classifier (Improved!)") gr.Markdown( "Upload a FASTA file and configure k-mer size, number of top features, " "and whether to run advanced analysis (plots of SHAP-like bars & k-mer distribution)." ) with gr.Row(): with gr.Column(): fasta_file = gr.File(label="Upload FASTA file", type="binary") kmer_slider = gr.Slider(minimum=2, maximum=6, value=4, step=1, label="K-mer Size") topk_slider = gr.Slider(minimum=5, maximum=20, value=10, step=1, label="Top-k Features") advanced_check = gr.Checkbox(value=False, label="Advanced Analysis") predict_button = gr.Button("Predict") with gr.Column(): results_text = gr.Textbox( label="Results", lines=20, placeholder="Prediction results will appear here..." ) # We can display multiple images in a Gallery or as separate outputs. plots_gallery = gr.Gallery(label="Analysis Plots").style(grid=[2], height="auto") predict_button.click( fn=run_prediction, inputs=[fasta_file, kmer_slider, topk_slider, advanced_check], outputs=[results_text, plots_gallery] ) if __name__ == "__main__": demo.launch(share=True)