import gradio as gr import torch import joblib import numpy as np from itertools import product import torch.nn as nn import matplotlib.pyplot as plt import matplotlib.colors as mcolors from matplotlib.colors import LinearSegmentedColormap import io from io import BytesIO # Import io then BytesIO from PIL import Image, ImageDraw, ImageFont from Bio.Graphics import GenomeDiagram from Bio.SeqFeature import SeqFeature, FeatureLocation from reportlab.lib import colors import pandas as pd import tempfile import os from typing import List, Dict, Tuple, Optional, Any import seaborn as sns ############################################################################### # 1. MODEL DEFINITION ############################################################################### class VirusClassifier(nn.Module): def __init__(self, input_shape: int): super(VirusClassifier, self).__init__() self.network = nn.Sequential( nn.Linear(input_shape, 64), nn.GELU(), nn.BatchNorm1d(64), nn.Dropout(0.3), nn.Linear(64, 32), nn.GELU(), nn.BatchNorm1d(32), nn.Dropout(0.3), nn.Linear(32, 32), nn.GELU(), nn.Linear(32, 2) ) def forward(self, x): return self.network(x) ############################################################################### # 2. FASTA PARSING & K-MER FEATURE ENGINEERING ############################################################################### def parse_fasta(text): sequences = [] current_header = None current_sequence = [] for line in text.strip().split('\n'): line = line.strip() if not line: continue if line.startswith('>'): if current_header: sequences.append((current_header, ''.join(current_sequence))) current_header = line[1:] current_sequence = [] else: current_sequence.append(line.upper()) if current_header: sequences.append((current_header, ''.join(current_sequence))) return sequences def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: """ Convert a sequence into a frequency vector of all possible 4-mer combinations. """ kmers = [''.join(p) for p in product("ACGT", repeat=k)] kmer_dict = {km: i for i, km in enumerate(kmers)} vec = np.zeros(len(kmers), dtype=np.float32) for i in range(len(sequence) - k + 1): kmer = sequence[i:i+k] if kmer in kmer_dict: vec[kmer_dict[kmer]] += 1 total_kmers = len(sequence) - k + 1 if total_kmers > 0: vec /= total_kmers return vec ############################################################################### # 3. SHAP-VALUE (ABLATION) CALCULATION ############################################################################### def calculate_shap_values(model, x_tensor): """ A simple ablation-based SHAP approximation. Zero out each position and measure the impact on the 'human' probability. """ model.eval() with torch.no_grad(): baseline_output = model(x_tensor) baseline_probs = torch.softmax(baseline_output, dim=1) baseline_prob = baseline_probs[0, 1].item() # Probability for 'human' shap_values = [] x_zeroed = x_tensor.clone() for i in range(x_tensor.shape[1]): original_val = x_zeroed[0, i].item() x_zeroed[0, i] = 0.0 output = model(x_zeroed) probs = torch.softmax(output, dim=1) prob = probs[0, 1].item() shap_values.append(baseline_prob - prob) x_zeroed[0, i] = original_val return np.array(shap_values), baseline_prob ############################################################################### # 4. PER-BASE SHAP AGGREGATION ############################################################################### def compute_positionwise_scores(sequence, shap_values, k=4): """ Distribute each k-mer's SHAP contribution across its k underlying positions. """ kmers = [''.join(p) for p in product("ACGT", repeat=k)] kmer_dict = {km: i for i, km in enumerate(kmers)} seq_len = len(sequence) shap_sums = np.zeros(seq_len, dtype=np.float32) coverage = np.zeros(seq_len, dtype=np.float32) for i in range(seq_len - k + 1): kmer = sequence[i:i+k] if kmer in kmer_dict: val = shap_values[kmer_dict[kmer]] shap_sums[i:i+k] += val coverage[i:i+k] += 1 with np.errstate(divide='ignore', invalid='ignore'): shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0) return shap_means ############################################################################### # 5. FIND EXTREME SHAP REGIONS ############################################################################### def find_extreme_subregion(shap_means, window_size=500, mode="max"): """ Use a sliding window to find the subregion with the highest (or lowest) average SHAP. """ n = len(shap_means) if n == 0: return (0, 0, 0.0) if window_size >= n: return (0, n, float(np.mean(shap_means))) csum = np.zeros(n + 1, dtype=np.float32) csum[1:] = np.cumsum(shap_means) best_start = 0 best_sum = csum[window_size] - csum[0] best_avg = best_sum / window_size for start in range(1, n - window_size + 1): wsum = csum[start + window_size] - csum[start] wavg = wsum / window_size if mode == "max" and wavg > best_avg: best_avg = wavg best_start = start elif mode == "min" and wavg < best_avg: best_avg = wavg best_start = start return (best_start, best_start + window_size, float(best_avg)) ############################################################################### # 6. PLOTTING / UTILITIES ############################################################################### def fig_to_image(fig): """ Render a Matplotlib figure to a PIL Image. """ buf = io.BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', dpi=150) buf.seek(0) img = Image.open(buf) plt.close(fig) return img def get_zero_centered_cmap(): """ Create a symmetrical (blue-white-red) colormap around zero. """ colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')] return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors) def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None): """ Plot an inline heatmap for the chosen region (or entire genome if start/end not provided). """ if start is not None and end is not None: local_shap = shap_means[start:end] subtitle = f" (positions {start}-{end})" else: local_shap = shap_means subtitle = "" if len(local_shap) == 0: local_shap = np.array([0.0]) heatmap_data = local_shap.reshape(1, -1) min_val = np.min(local_shap) max_val = np.max(local_shap) extent = max(abs(min_val), abs(max_val)) cmap = get_zero_centered_cmap() fig, ax = plt.subplots(figsize=(12, 1.8)) cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent) cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8) cbar.ax.tick_params(labelsize=8) cbar.set_label('SHAP Contribution', fontsize=9, labelpad=5) ax.set_yticks([]) ax.set_xlabel('Position in Sequence', fontsize=10) ax.set_title(f"{title}{subtitle}", pad=10) plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95) return fig def create_importance_bar_plot(shap_values, kmers, top_k=10): """ Show bar chart of top k-mers by absolute SHAP value. """ plt.rcParams.update({'font.size': 10}) fig = plt.figure(figsize=(10, 5)) indices = np.argsort(np.abs(shap_values))[-top_k:] values = shap_values[indices] features = [kmers[i] for i in indices] colors = ['#99ccff' if v < 0 else '#ff9999' for v in values] plt.barh(range(len(values)), values, color=colors) plt.yticks(range(len(values)), features) plt.xlabel('SHAP Value (impact on model output)') plt.title(f'Top {top_k} Most Influential k-mers') plt.gca().invert_yaxis() plt.tight_layout() return fig def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bins=30): """ Plot a histogram of SHAP values in some region. """ fig, ax = plt.subplots(figsize=(6, 4)) ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black') ax.axvline(0, color='red', linestyle='--', label='0.0') ax.set_xlabel("SHAP Value") ax.set_ylabel("Count") ax.set_title(title) ax.legend() plt.tight_layout() return fig def compute_gc_content(sequence): """ Compute GC content (%) for a given sequence. """ if not sequence: return 0.0 gc_count = sequence.count('G') + sequence.count('C') return (gc_count / len(sequence)) * 100.0 ############################################################################### # 7. MAIN ANALYSIS STEP (Gradio Step 1) ############################################################################### def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500): """ Perform the main classification, SHAP analysis, and extreme subregion detection for a single sequence. """ # 1) Read input if fasta_text.strip(): text = fasta_text.strip() elif file_obj is not None: try: with open(file_obj, 'r') as f: text = f.read() except Exception as e: return (f"Error reading file: {str(e)}", None, None, None, None, None) else: return ("Please provide a FASTA sequence.", None, None, None, None, None) # 2) Parse FASTA sequences = parse_fasta(text) if not sequences: return ("No valid FASTA sequences found.", None, None, None, None, None) header, seq = sequences[0] # 3) Load model, scaler, and run inference device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') try: state_dict = torch.load('model.pt', map_location=device) model = VirusClassifier(256).to(device) model.load_state_dict(state_dict) scaler = joblib.load('scaler.pkl') except Exception as e: return (f"Error loading model/scaler: {str(e)}", None, None, None, None, None) freq_vector = sequence_to_kmer_vector(seq) scaled_vector = scaler.transform(freq_vector.reshape(1, -1)) x_tensor = torch.FloatTensor(scaled_vector).to(device) shap_values, prob_human = calculate_shap_values(model, x_tensor) prob_nonhuman = 1.0 - prob_human classification = "Human" if prob_human > 0.5 else "Non-human" confidence = max(prob_human, prob_nonhuman) # 4) Per-base SHAP & subregion detection shap_means = compute_positionwise_scores(seq, shap_values, k=4) max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max") min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min") # 5) Prepare result text results_text = ( f"Sequence: {header}\n" f"Length: {len(seq):,} bases\n" f"Classification: {classification}\n" f"Confidence: {confidence:.3f}\n" f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n" f"---\n" f"**Most Human-Pushing {window_size}-bp Subregion**:\n" f"Start: {max_start}, End: {max_end}, Avg SHAP: {max_avg:.4f}\n\n" f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n" f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}" ) # 6) Create bar & heatmap figures kmers = [''.join(p) for p in product("ACGT", repeat=4)] bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers) bar_img = fig_to_image(bar_fig) heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP") heatmap_img = fig_to_image(heatmap_fig) # 7) Build the "state" dictionary so we can do subregion analysis state_dict_out = {"seq": seq, "shap_means": shap_means} # Return 6 items to match your Gradio output return (results_text, bar_img, heatmap_img, state_dict_out, header, None) ############################################################################### # 8. SUBREGION ANALYSIS (Gradio Step 2) ############################################################################### def analyze_subregion(state, header, region_start, region_end): """ Examine a subregion’s SHAP distribution, GC content, etc. """ if not state or "seq" not in state or "shap_means" not in state: return ("No sequence data found. Please run Step 1 first.", None, None, None) seq = state["seq"] shap_means = state["shap_means"] region_start = int(region_start) region_end = int(region_end) region_start = max(0, min(region_start, len(seq))) region_end = max(0, min(region_end, len(seq))) if region_end <= region_start: return ("Invalid region range. End must be > Start.", None, None, None) region_seq = seq[region_start:region_end] region_shap = shap_means[region_start:region_end] gc_percent = compute_gc_content(region_seq) avg_shap = float(np.mean(region_shap)) positive_fraction = np.mean(region_shap > 0) negative_fraction = np.mean(region_shap < 0) if avg_shap > 0.05: region_classification = "Likely pushing toward human" elif avg_shap < -0.05: region_classification = "Likely pushing toward non-human" else: region_classification = "Near neutral (no strong push)" region_info = ( f"Analyzing subregion of {header} from {region_start} to {region_end}\n" f"Region length: {len(region_seq)} bases\n" f"GC content: {gc_percent:.2f}%\n" f"Average SHAP in region: {avg_shap:.4f}\n" f"Fraction with SHAP > 0 (toward human): {positive_fraction:.2f}\n" f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n" f"Subregion interpretation: {region_classification}\n" ) heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end) heatmap_img = fig_to_image(heatmap_fig) hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion") hist_img = fig_to_image(hist_fig) # Return 4 items to match your Gradio output return (region_info, heatmap_img, hist_img, None) ############################################################################### # 9. COMPARISON ANALYSIS FUNCTIONS (Step 4) ############################################################################### def compute_shap_difference(shap1_norm, shap2_norm): """ Compute the SHAP difference (Seq2 - Seq1). """ return shap2_norm - shap1_norm def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"): """ Plot a 1D heatmap of differences using relative positions 0-100%. """ heatmap_data = shap_diff.reshape(1, -1) extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff))) fig, ax = plt.subplots(figsize=(12, 1.8)) cmap = get_zero_centered_cmap() cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent) # Create percentage-based x-axis ticks num_ticks = 5 tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks) tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)] ax.set_xticks(tick_positions) ax.set_xticklabels(tick_labels) cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8) cbar.ax.tick_params(labelsize=8) cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5) ax.set_yticks([]) ax.set_xlabel('Relative Position in Sequence', fontsize=10) ax.set_title(title, pad=10) plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95) return fig def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30): """ Plot a histogram of SHAP values with optional # of bins. """ fig, ax = plt.subplots(figsize=(6, 4)) ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7) ax.axvline(0, color='red', linestyle='--', label='0.0') ax.set_xlabel("SHAP Value") ax.set_ylabel("Count") ax.set_title(title) ax.legend() plt.tight_layout() return fig def calculate_adaptive_parameters(len1, len2): """ Choose smoothing & interpolation parameters automatically based on length difference. """ length_diff = abs(len1 - len2) max_length = max(len1, len2) min_length = min(len1, len2) length_ratio = min_length / max_length # Base number of points base_points = min(2000, max(500, max_length // 100)) if length_diff < 500: resolution_factor = 2.0 num_points = min(3000, base_points * 2) smooth_window = max(10, length_diff // 50) elif length_diff < 5000: resolution_factor = 1.5 num_points = min(2000, base_points * 1.5) smooth_window = max(20, length_diff // 100) elif length_diff < 50000: resolution_factor = 1.0 num_points = base_points smooth_window = max(50, length_diff // 200) else: resolution_factor = 0.75 num_points = max(500, base_points // 2) smooth_window = max(100, length_diff // 500) smooth_window = int(smooth_window * (1 + (1 - length_ratio))) return int(num_points), int(smooth_window), resolution_factor def sliding_window_smooth(values, window_size=50): """ A custom smoothing approach, including exponential decay at edges. """ if window_size < 3: return values window = np.ones(window_size) decay = np.exp(-np.linspace(0, 3, window_size // 2)) window[:window_size // 2] = decay window[-(window_size // 2):] = decay[::-1] window = window / window.sum() smoothed = np.convolve(values, window, mode='valid') pad_size = len(values) - len(smoothed) pad_left = pad_size // 2 pad_right = pad_size - pad_left result = np.zeros_like(values) result[pad_left:-pad_right] = smoothed result[:pad_left] = values[:pad_left] result[-pad_right:] = values[-pad_right:] return result def normalize_shap_lengths(shap1, shap2): """ Smooth, interpolate, and return arrays of the same length for direct comparison. """ num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2)) shap1_smooth = sliding_window_smooth(shap1, smooth_window) shap2_smooth = sliding_window_smooth(shap2, smooth_window) x1 = np.linspace(0, 1, len(shap1_smooth)) x2 = np.linspace(0, 1, len(shap2_smooth)) x_norm = np.linspace(0, 1, num_points) shap1_interp = np.interp(x_norm, x1, shap1_smooth) shap2_interp = np.interp(x_norm, x2, shap2_smooth) return shap1_interp, shap2_interp, smooth_window def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""): """ Compare two sequences using the previously defined analysis pipeline and produce difference visualizations & stats. """ try: # Analyze first sequence res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500) if isinstance(res1[0], str) and "Error" in res1[0]: return (f"Error in sequence 1: {res1[0]}", None, None, None) # Analyze second sequence res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500) if isinstance(res2[0], str) and "Error" in res2[0]: return (f"Error in sequence 2: {res2[0]}", None, None, None) shap1 = res1[3]["shap_means"] shap2 = res2[3]["shap_means"] len1, len2 = len(shap1), len(shap2) length_diff = abs(len1 - len2) length_ratio = min(len1, len2) / max(len1, len2) # Normalize both to the same length shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2) shap_diff = compute_shap_difference(shap1_norm, shap2_norm) # Compute stats base_threshold = 0.05 adaptive_threshold = base_threshold * (1 + (1 - length_ratio)) if length_diff > 50000: adaptive_threshold *= 1.5 avg_diff = np.mean(shap_diff) std_diff = np.std(shap_diff) max_diff = np.max(shap_diff) min_diff = np.min(shap_diff) substantial_diffs = np.abs(shap_diff) > adaptive_threshold frac_different = np.mean(substantial_diffs) # Extract classification from text try: classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip() classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip() except: classification1 = "Unknown" classification2 = "Unknown" comparison_text = ( "Sequence Comparison Results:\n" f"Sequence 1: {res1[4]}\n" f"Length: {len1:,} bases\n" f"Classification: {classification1}\n\n" f"Sequence 2: {res2[4]}\n" f"Length: {len2:,} bases\n" f"Classification: {classification2}\n\n" "Comparison Parameters:\n" f"Length Difference: {length_diff:,} bases\n" f"Length Ratio: {length_ratio:.3f}\n" f"Smoothing Window: {smooth_window} points\n" f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n" "Statistics:\n" f"Average SHAP difference: {avg_diff:.4f}\n" f"Standard deviation: {std_diff:.4f}\n" f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n" f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n" f"Fraction with substantial differences: {frac_different:.2%}\n\n" "Note: All parameters automatically adjusted based on sequence properties\n\n" "Interpretation:\n" "- Red regions: Sequence 2 more human-like\n" "- Blue regions: Sequence 1 more human-like\n" "- White regions: Similar between sequences" ) heatmap_fig = plot_comparative_heatmap( shap_diff, title=f"SHAP Difference Heatmap (window: {smooth_window})" ) heatmap_img = fig_to_image(heatmap_fig) num_bins = max(20, min(50, int(np.sqrt(len(shap_diff))))) hist_fig = plot_shap_histogram( shap_diff, title="Distribution of SHAP Differences", num_bins=num_bins ) hist_img = fig_to_image(hist_fig) return (comparison_text, heatmap_img, hist_img, None) except Exception as e: error_msg = f"Error during sequence comparison: {str(e)}" return (error_msg, None, None, None) ############################################################################### # 10. ADDITIONAL / ADVANCED VISUALIZATIONS & STATISTICS ############################################################################### def n50_length(sequence): """ Calculate the N50 for a single continuous sequence (for demonstration). For a single sequence, N50 is typically the length if it's just one piece, but let's do a simplistic example. """ # If you had contigs, you'd do a sorted list, cumulative sums, etc. # We'll do a trivial approach here: return len(sequence) # Because we have only one contiguous region def sequence_complexity(sequence): """ Compute a simple measure of 'sequence complexity'. Here, we define complexity as the Shannon entropy over the nucleotides. """ from math import log2 length = len(sequence) if length == 0: return 0.0 freq = {} for base in sequence: freq[base] = freq.get(base, 0) + 1 complexity = 0.0 for base, count in freq.items(): p = count / length complexity -= p * log2(p) return complexity def advanced_gene_statistics(gene_shap: np.ndarray, gene_seq: str) -> Dict[str, float]: """ Additional stats: N50, complexity, etc. """ stats = {} stats['n50'] = len(gene_seq) # trivial for a single gene region stats['entropy'] = sequence_complexity(gene_seq) stats['avg_shap'] = float(np.mean(gene_shap)) stats['max_shap'] = float(np.max(gene_shap)) if len(gene_shap) else 0.0 stats['min_shap'] = float(np.min(gene_shap)) if len(gene_shap) else 0.0 return stats ############################################################################### # 11. GENE FEATURE ANALYSIS ############################################################################### def parse_gene_features(text: str) -> List[Dict[str, Any]]: """Parse gene features from text file in a FASTA-like format.""" genes = [] current_header = None current_sequence = [] for line in text.strip().split('\n'): line = line.strip() if not line: continue if line.startswith('>'): if current_header: genes.append({ 'header': current_header, 'sequence': ''.join(current_sequence), 'metadata': parse_gene_metadata(current_header) }) current_header = line[1:] current_sequence = [] else: current_sequence.append(line.upper()) if current_header: genes.append({ 'header': current_header, 'sequence': ''.join(current_sequence), 'metadata': parse_gene_metadata(current_header) }) return genes def parse_gene_metadata(header: str) -> Dict[str, str]: """Extract metadata from gene header line.""" metadata = {} parts = header.split() for part in parts: if '[' in part and ']' in part: key_value = part[1:-1].split('=', 1) if len(key_value) == 2: metadata[key_value[0]] = key_value[1] return metadata def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]: """Parse gene location string, handling forward and complement strands.""" try: clean_loc = location_str.replace('complement(', '').replace(')', '') if '..' in clean_loc: start, end = map(int, clean_loc.split('..')) return start, end else: return None, None except Exception as e: print(f"Error parsing location {location_str}: {str(e)}") return None, None def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]: """Basic statistical measures for gene SHAP values.""" return { 'avg_shap': float(np.mean(gene_shap)) if len(gene_shap) else 0.0, 'median_shap': float(np.median(gene_shap)) if len(gene_shap) else 0.0, 'std_shap': float(np.std(gene_shap)) if len(gene_shap) else 0.0, 'max_shap': float(np.max(gene_shap)) if len(gene_shap) else 0.0, 'min_shap': float(np.min(gene_shap)) if len(gene_shap) else 0.0, 'pos_fraction': float(np.mean(gene_shap > 0)) if len(gene_shap) else 0.0 } def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image: """ A quick PIL-based diagram to show genes along the genome. Color intensity = magnitude of SHAP. Red/Blue = sign of SHAP. """ if not gene_results or genome_length <= 0: img = Image.new('RGB', (800, 100), color='white') draw = ImageDraw.Draw(img) draw.text((10, 40), "Error: Invalid input data", fill='black') return img for gene in gene_results: gene['start'] = max(0, int(gene['start'])) gene['end'] = min(genome_length, int(gene['end'])) if gene['start'] >= gene['end']: print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}") width = 1500 height = 600 margin = 50 track_height = 40 img = Image.new('RGB', (width, height), 'white') draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12) title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) except: font = ImageFont.load_default() title_font = ImageFont.load_default() draw.text((margin, margin // 2), "Genome SHAP Analysis (Simple)", fill='black', font=title_font or font) line_y = height // 2 draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2) scale = float(width - 2 * margin) / float(genome_length) # Scale markers num_ticks = 10 step = max(1, genome_length // num_ticks) for i in range(0, genome_length + 1, step): x_coord = margin + i * scale draw.line([ (int(x_coord), int(line_y - 5)), (int(x_coord), int(line_y + 5)) ], fill='black', width=1) draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font) sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap'])) for idx, gene in enumerate(sorted_genes): start_x = margin + int(gene['start'] * scale) end_x = margin + int(gene['end'] * scale) avg_shap = gene['avg_shap'] intensity = int(abs(avg_shap) * 500) intensity = max(50, min(255, intensity)) if avg_shap > 0: color = (255, 255 - intensity, 255 - intensity) # Redish else: color = (255 - intensity, 255 - intensity, 255) # Blueish draw.rectangle([ (int(start_x), int(line_y - track_height // 2)), (int(end_x), int(line_y + track_height // 2)) ], fill=color, outline='black') label = str(gene.get('gene_name','?')) label_mask = font.getmask(label) label_width, label_height = label_mask.size if idx % 2 == 0: text_y = line_y - track_height - 15 else: text_y = line_y + track_height + 5 gene_width = end_x - start_x if gene_width > label_width: text_x = start_x + (gene_width - label_width) // 2 draw.text((int(text_x), int(text_y)), label, fill='black', font=font) elif gene_width > 20: txt_img = Image.new('RGBA', (label_width, label_height), (255, 255, 255, 0)) txt_draw = ImageDraw.Draw(txt_img) txt_draw.text((0, 0), label, font=font, fill='black') rotated_img = txt_img.rotate(90, expand=True) img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img) return img def create_advanced_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int, shap_means: np.ndarray, diagram_title: str = "Advanced Genome Diagram") -> Image.Image: """ An advanced genome diagram using Biopython's GenomeDiagram. We'll create tracks for genes and a 'SHAP line plot' track. """ if not gene_results or genome_length <= 0 or len(shap_means) == 0: # Fallback if data is invalid img = Image.new('RGB', (800, 100), color='white') d = ImageDraw.Draw(img) d.text((10, 40), "Error: Not enough data for advanced diagram", fill='black') return img diagram = GenomeDiagram.Diagram(diagram_title) gene_track = diagram.new_track(1, name="Genes", greytrack=False, height=0.5) gene_set = gene_track.new_set() # Add each gene as a feature for gene in gene_results: start = max(0, int(gene['start'])) end = min(genome_length, int(gene['end'])) avg_shap = gene['avg_shap'] # Color scale: negative = blue, positive = red intensity = abs(avg_shap) * 500 intensity = max(50, min(255, intensity)) if avg_shap >= 0: color_hex = colors.Color(1.0, 1.0 - intensity/255.0, 1.0 - intensity/255.0) else: color_hex = colors.Color(1.0 - intensity/255.0, 1.0 - intensity/255.0, 1.0) feature = SeqFeature(FeatureLocation(start, end), strand=1) gene_set.add_feature( feature, color=color_hex, label=True, name=str(gene.get('gene_name','?')), label_size=8, label_color=colors.black ) # Add a track for the SHAP line shap_track = diagram.new_track(2, name="SHAP Score", greytrack=False, height=0.3) shap_set = shap_track.new_set("graph") # We'll plot the entire shap_means array. # X coords = [0..genome_length], Y coords = shap_means # We'll keep negative values below baseline, positive above. # Normalizing for visualization max_abs = max(abs(shap_means.min()), abs(shap_means.max())) if max_abs == 0: scaled_shap = [0]*len(shap_means) else: scaled_shap = (shap_means / max_abs * 50).tolist() # scale to +/- 50 shap_set.add_graph( data=scaled_shap, name="shap_line", style="line", color=colors.darkgreen, altcolor=colors.red, linewidth=1 ) # Draw to a temporary file with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmpf: diagram.draw(format="linear", pagesize='A3', fragments=1, start=0, end=genome_length) diagram.write(tmpf.name, "PDF") # Convert PDF to a PIL image (requires poppler or similar). # If you do not have poppler, you can skip PDF -> image or use Cairo. try: import pdf2image pages = pdf2image.convert_from_path(tmpf.name, dpi=100) img = pages[0] if pages else Image.new('RGB', (800, 100), color='white') except ImportError: img = Image.new('RGB', (800, 100), color='white') d = ImageDraw.Draw(img) d.text((10, 40), "pdf2image not installed, can't show advanced diagram as image.", fill='black') # Cleanup os.remove(tmpf.name) return img def analyze_gene_features(sequence_file: str, features_file: str, fasta_text: str = "", features_text: str = "", diagram_mode: str = "advanced" ) -> Tuple[str, Optional[str], Optional[Image.Image]]: """ Analyze each gene in the features file, compute gene-level SHAP stats, produce tabular output, and create an optional genome diagram. """ # 1) Analyze the entire sequence with the top-level function sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text) if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]: return f"Error in sequence analysis: {sequence_results[0]}", None, None seq = sequence_results[3]["seq"] shap_means = sequence_results[3]["shap_means"] genome_length = len(seq) # 2) Read gene features try: if features_text.strip(): genes = parse_gene_features(features_text) else: with open(features_file, 'r') as f: genes = parse_gene_features(f.read()) except Exception as e: return f"Error reading features file: {str(e)}", None, None gene_results = [] for gene in genes: location = gene['metadata'].get('location', '') if not location: continue start, end = parse_location(location) if start is None or end is None or start >= end or end > genome_length: continue gene_shap = shap_means[start:end] basic_stats = compute_gene_statistics(gene_shap) # Additional stats gene_seq = seq[start:end] adv_stats = advanced_gene_statistics(gene_shap, gene_seq) # Merge basic + advanced stats all_stats = {**basic_stats, **adv_stats} classification = 'Human' if basic_stats['avg_shap'] > 0 else 'Non-human' locus_tag = gene['metadata'].get('locus_tag', '') gene_name = gene['metadata'].get('gene', 'Unknown') gene_dict = { 'gene_name': gene_name, 'location': location, 'start': start, 'end': end, 'locus_tag': locus_tag, 'avg_shap': all_stats['avg_shap'], 'median_shap': basic_stats['median_shap'], 'std_shap': basic_stats['std_shap'], 'max_shap': basic_stats['max_shap'], 'min_shap': basic_stats['min_shap'], 'pos_fraction': basic_stats['pos_fraction'], 'n50': all_stats['n50'], 'entropy': all_stats['entropy'], 'classification': classification, 'confidence': abs(all_stats['avg_shap']) } gene_results.append(gene_dict) if not gene_results: return "No valid genes could be processed", None, None # 3) Summaries sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True) results_text = "Gene Analysis Results:\n\n" results_text += f"Total genes analyzed: {len(gene_results)}\n" num_human = sum(1 for g in gene_results if g['classification'] == 'Human') results_text += f"Human-like genes: {num_human}\n" results_text += f"Non-human-like genes: {len(gene_results) - num_human}\n\n" results_text += "Top 10 most distinctive genes (by avg SHAP magnitude):\n" for gene in sorted_genes[:10]: results_text += ( f"Gene: {gene['gene_name']}\n" f"Location: {gene['location']}\n" f"Classification: {gene['classification']} " f"(confidence: {gene['confidence']:.4f})\n" f"Average SHAP: {gene['avg_shap']:.4f}\n" f"N50: {gene['n50']}, Entropy: {gene['entropy']:.3f}\n\n" ) # 4) Make CSV csv_content = "gene_name,location,start,end,locus_tag,avg_shap,median_shap,std_shap," csv_content += "max_shap,min_shap,pos_fraction,n50,entropy,classification,confidence\n" for g in gene_results: csv_content += ( f"{g['gene_name']},{g['location']},{g['start']},{g['end']},{g['locus_tag']}," f"{g['avg_shap']:.4f},{g['median_shap']:.4f},{g['std_shap']:.4f}," f"{g['max_shap']:.4f},{g['min_shap']:.4f},{g['pos_fraction']:.4f}," f"{g['n50']},{g['entropy']:.4f},{g['classification']},{g['confidence']:.4f}\n" ) try: temp_dir = tempfile.gettempdir() temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv") with open(temp_path, 'w') as f: f.write(csv_content) except Exception as e: print(f"Error saving CSV: {str(e)}") temp_path = None # 5) Create diagram try: if diagram_mode == "advanced": diagram_img = create_advanced_genome_diagram(gene_results, genome_length, shap_means) else: diagram_img = create_simple_genome_diagram(gene_results, genome_length) except Exception as e: print(f"Error creating visualization: {str(e)}") diagram_img = Image.new('RGB', (800, 100), color='white') draw = ImageDraw.Draw(diagram_img) draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black') return results_text, temp_path, diagram_img ############################################################################### # 12. DOWNLOAD FUNCTIONS ############################################################################### def prepare_csv_download(data, filename="analysis_results.csv"): """ Convert data to CSV for Gradio download button. """ if isinstance(data, str): return data.encode(), filename elif isinstance(data, (list, dict)): import csv from io import StringIO output = StringIO() writer = csv.DictWriter(output, fieldnames=data[0].keys()) writer.writeheader() writer.writerows(data) return output.getvalue().encode(), filename else: raise ValueError("Unsupported data type for CSV download") ############################################################################### # 13. BUILD GRADIO INTERFACE ############################################################################### css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } .download-button { margin-top: 10px; } """ with gr.Blocks(css=css) as iface: gr.Markdown(""" # Virus Host Classifier + Extended Genome Visualization **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme subregions. **Step 2**: Explore subregions (local SHAP, GC content, histogram). **Step 3**: Analyze gene features (per-gene SHAP, advanced stats, improved diagrams). **Step 4**: Compare sequences for SHAP differences. **Color Scale**: Negative SHAP = Blue, 0 = White, Positive = Red. """) with gr.Tab("1) Full-Sequence Analysis"): with gr.Row(): with gr.Column(scale=1): file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") text_input = gr.Textbox(label="Or paste FASTA", placeholder=">name\nACGT...", lines=5) top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display") win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Subregion Window Size") analyze_btn = gr.Button("Analyze Sequence", variant="primary") with gr.Column(scale=2): results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False) kmer_img = gr.Image(label="Top k-mer SHAP") genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)") download_results = gr.File(label="Download Results", visible=False, elem_classes="download-button") seq_state = gr.State() header_state = gr.State() analyze_btn.click( analyze_sequence, inputs=[file_input, top_k, text_input, win_size], outputs=[results_box, kmer_img, genome_img, seq_state, header_state, download_results] ) with gr.Tab("2) Subregion Exploration"): gr.Markdown(""" **Subregion Analysis** View SHAP signals, GC content, etc. for a specific region. """) with gr.Row(): region_start = gr.Number(label="Region Start", value=0) region_end = gr.Number(label="Region End", value=500) region_btn = gr.Button("Analyze Subregion") subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False) with gr.Row(): subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)") subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)") download_subregion = gr.File(label="Download Subregion", visible=False, elem_classes="download-button") region_btn.click( analyze_subregion, inputs=[seq_state, header_state, region_start, region_end], outputs=[subregion_info, subregion_img, subregion_hist_img, download_subregion] ) with gr.Tab("3) Gene Features Analysis"): gr.Markdown(""" **Analyze Gene Features** - Upload a FASTA file and a gene features file. - See per-gene SHAP, classification, N50, entropy, etc. - Choose a diagram mode (simple or advanced). """) with gr.Row(): with gr.Column(scale=1): gene_fasta_file = gr.File(label="FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence", lines=5) with gr.Column(scale=1): features_file = gr.File(label="Gene features file", file_types=[".txt"], type="filepath") features_text = gr.Textbox(label="Or paste gene features", lines=5) diagram_mode = gr.Radio(choices=["simple", "advanced"], value="advanced", label="Diagram Mode") analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary") gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False) gene_diagram = gr.Image(label="Genome Diagram") download_gene_results = gr.File(label="Download Gene Analysis (CSV)", visible=True) analyze_genes_btn.click( analyze_gene_features, inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text, diagram_mode], outputs=[gene_results, download_gene_results, gene_diagram] ) with gr.Tab("4) Comparative Analysis"): gr.Markdown(""" **Compare Two Sequences** - Upload or paste two FASTA sequences. - We'll compare SHAP patterns (normalized for different lengths). """) with gr.Row(): with gr.Column(scale=1): file_input1 = gr.File(label="1st FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") text_input1 = gr.Textbox(label="Or paste 1st FASTA", lines=5) with gr.Column(scale=1): file_input2 = gr.File(label="2nd FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") text_input2 = gr.Textbox(label="Or paste 2nd FASTA", lines=5) compare_btn = gr.Button("Compare Sequences", variant="primary") comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False) with gr.Row(): diff_heatmap = gr.Image(label="SHAP Difference Heatmap") diff_hist = gr.Image(label="Distribution of SHAP Differences") download_comparison = gr.File(label="Download Comparison", visible=False, elem_classes="download-button") compare_btn.click( analyze_sequence_comparison, inputs=[file_input1, file_input2, text_input1, text_input2], outputs=[comparison_text, diff_heatmap, diff_hist, download_comparison] ) gr.Markdown(""" ### Notes & Features - **Advanced Genome Diagram** uses Biopython’s `GenomeDiagram` (requires `pdf2image` if you want it as an image). - **Additional Stats**: N50, Shannon entropy, etc. - **Auto-scaling** for comparative analysis with adaptive smoothing. - **Data Export**: Download CSV of analysis results. """) if __name__ == "__main__": iface.launch()