import gradio as gr import torch import joblib import numpy as np from itertools import product import torch.nn as nn import matplotlib.pyplot as plt import io from PIL import Image class VirusClassifier(nn.Module): def __init__(self, input_shape: int): super(VirusClassifier, self).__init__() self.network = nn.Sequential( nn.Linear(input_shape, 64), nn.GELU(), nn.BatchNorm1d(64), nn.Dropout(0.3), nn.Linear(64, 32), nn.GELU(), nn.BatchNorm1d(32), nn.Dropout(0.3), nn.Linear(32, 32), nn.GELU(), nn.Linear(32, 2) ) def forward(self, x): return self.network(x) def parse_fasta(text): """Parse FASTA formatted text into a list of (header, sequence).""" sequences = [] current_header = None current_sequence = [] for line in text.strip().split('\n'): line = line.strip() if not line: continue if line.startswith('>'): if current_header: sequences.append((current_header, ''.join(current_sequence))) current_header = line[1:] current_sequence = [] else: current_sequence.append(line.upper()) if current_header: sequences.append((current_header, ''.join(current_sequence))) return sequences def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: """Convert a sequence to a k-mer frequency vector.""" kmers = [''.join(p) for p in product("ACGT", repeat=k)] kmer_dict = {km: i for i, km in enumerate(kmers)} vec = np.zeros(len(kmers), dtype=np.float32) for i in range(len(sequence) - k + 1): kmer = sequence[i:i+k] if kmer in kmer_dict: vec[kmer_dict[kmer]] += 1 total_kmers = len(sequence) - k + 1 if total_kmers > 0: vec = vec / total_kmers return vec def calculate_shap_values(model, x_tensor): """ Calculate SHAP values using a simple ablation approach. Returns shap values and model prediction. """ model.eval() with torch.no_grad(): # Get baseline prediction baseline_output = model(x_tensor) baseline_probs = torch.softmax(baseline_output, dim=1) baseline_prob = baseline_probs[0, 1].item() # Probability of human class # Calculate impact of zeroing each feature shap_values = [] x_zeroed = x_tensor.clone() for i in range(x_tensor.shape[1]): x_zeroed[0, i] = 0 output = model(x_zeroed) probs = torch.softmax(output, dim=1) prob = probs[0, 1].item() impact = baseline_prob - prob # How much removing the feature changed the prediction shap_values.append(impact) x_zeroed[0, i] = x_tensor[0, i] # Restore the original value return np.array(shap_values), baseline_prob def create_importance_bar_plot(shap_values, kmers, top_k=10): """Create a bar plot of the most important k-mers.""" plt.rcParams.update({'font.size': 10}) plt.figure(figsize=(10, 6)) # Sort by absolute importance indices = np.argsort(np.abs(shap_values))[-top_k:] values = shap_values[indices] features = [kmers[i] for i in indices] colors = ['#ff9999' if v > 0 else '#99ccff' for v in values] plt.barh(range(len(values)), values, color=colors) plt.yticks(range(len(values)), features) plt.xlabel('SHAP value (impact on model output)') plt.title(f'Top {top_k} Most Influential k-mers') plt.gca().invert_yaxis() # Most important at top return plt.gcf() def visualize_sequence_impacts(sequence, kmers, shap_values, base_prob): """ Create a SHAP-style visualization of sequence impacts. Shows each k-mer's contribution in context. """ k = 4 # k-mer size kmer_dict = {km: i for i, km in enumerate(kmers)} # Find all k-mers and their impacts kmer_impacts = [] for i in range(len(sequence) - k + 1): kmer = sequence[i:i+k] if kmer in kmer_dict: impact = shap_values[kmer_dict[kmer]] kmer_impacts.append((i, kmer, impact)) # Sort by absolute impact kmer_impacts.sort(key=lambda x: abs(x[2]), reverse=True) # Limit display to top 30 k-mers display_kmers = kmer_impacts[:30] # Calculate figure height based on number of k-mers fig_height = min(20, max(8, len(display_kmers) * 0.4)) # Create figure with controlled size fig = plt.figure(figsize=(12, fig_height)) ax = plt.gca() # Add title and base value plt.text(0.01, 1.02, f"base value = {base_prob:.3f}", transform=ax.transAxes, fontsize=10) # Plot k-mers with controlled spacing y_spacing = 0.9 / max(len(display_kmers), 1) y_position = 0.95 max_seq_display = 100 # Maximum sequence length to show for pos, kmer, impact in display_kmers: # Truncate sequence display if too long pre_sequence = sequence[max(0, pos-20):pos] post_sequence = sequence[pos+k:min(pos+k+20, len(sequence))] # Add ellipsis if truncated pre_ellipsis = "..." if pos > 20 else "" post_ellipsis = "..." if pos+k+20 < len(sequence) else "" # Choose color based on impact color = '#ffcccb' if impact > 0 else '#cce0ff' arrow = '↑' if impact > 0 else '↓' # Draw text elements plt.text(0.01, y_position, f"{pre_ellipsis}{pre_sequence}", fontsize=9) plt.text(0.01 + len(f"{pre_ellipsis}{pre_sequence}")/50, y_position, kmer, fontsize=9, bbox=dict(facecolor=color, alpha=0.3, pad=1)) plt.text(0.01 + (len(f"{pre_ellipsis}{pre_sequence}") + len(kmer))/50, y_position, f"{post_sequence}{post_ellipsis}", fontsize=9) # Add impact value plt.text(0.8, y_position, f"{arrow} {impact:+.3f}", fontsize=9) y_position -= y_spacing plt.axis('off') # Adjust layout with specific margins plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05) return fig def predict(file_obj, top_kmers=10, fasta_text=""): """Main prediction function for Gradio interface.""" # Handle input if fasta_text.strip(): text = fasta_text.strip() elif file_obj is not None: try: with open(file_obj, 'r') as f: text = f.read() except Exception as e: return f"Error reading file: {str(e)}", None, None else: return "Please provide a FASTA sequence.", None, None # Parse FASTA sequences = parse_fasta(text) if not sequences: return "No valid FASTA sequences found.", None, None header, seq = sequences[0] # Load model and process sequence device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') try: model = VirusClassifier(256).to(device) model.load_state_dict(torch.load('model.pt', map_location=device, weights_only=True)) scaler = joblib.load('scaler.pkl') except Exception as e: return f"Error loading model: {str(e)}", None, None # Generate features freq_vector = sequence_to_kmer_vector(seq) scaled_vector = scaler.transform(freq_vector.reshape(1, -1)) x_tensor = torch.FloatTensor(scaled_vector).to(device) # Calculate SHAP values and get prediction shap_values, prob_human = calculate_shap_values(model, x_tensor) # Generate result text results = [ f"Sequence: {header}", f"Prediction: {'Human' if prob_human > 0.5 else 'Non-human'} Origin", f"Confidence: {max(prob_human, 1-prob_human):.3f}", f"Human Probability: {prob_human:.3f}", "\nTop Contributing k-mers:" ] # Get k-mers for visualization kmers = [''.join(p) for p in product("ACGT", repeat=4)] # Create visualizations importance_plot = create_importance_bar_plot(shap_values, kmers, top_kmers) sequence_plot = visualize_sequence_impacts(seq, kmers, shap_values, prob_human) # Convert plots to images def fig_to_image(fig): buf = io.BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', dpi=150) buf.seek(0) img = Image.open(buf) plt.close(fig) return img return "\n".join(results), fig_to_image(importance_plot), fig_to_image(sequence_plot) # Create Gradio interface css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } """ with gr.Blocks(css=css) as iface: gr.Markdown(""" # Virus Host Classifier Predicts whether a viral sequence is of human or non-human origin using k-mer analysis. """) with gr.Row(): with gr.Column(scale=1): file_input = gr.File( label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath" ) text_input = gr.Textbox( label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5 ) top_k = gr.Slider( minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display" ) submit_btn = gr.Button("Analyze Sequence", variant="primary") with gr.Column(scale=2): results = gr.Textbox(label="Analysis Results", lines=10) kmer_plot = gr.Image(label="K-mer Importance Plot") shap_plot = gr.Image(label="Sequence Impact Visualization (SHAP-style)") submit_btn.click( predict, inputs=[file_input, top_k, text_input], outputs=[results, kmer_plot, shap_plot] ) gr.Markdown(""" ### Visualization Guide - **K-mer Importance Plot**: Shows the most influential k-mers and their SHAP values - **Sequence Impact Visualization**: Shows the sequence with highlighted k-mers: - Red highlights = pushing toward human origin - Blue highlights = pushing toward non-human origin - Arrows (↑/↓) show impact direction - Values show impact magnitude """) if __name__ == "__main__": iface.launch()