HostClassifier / app.py
hiyata's picture
Update app.py
18efb8a verified
raw
history blame
47.1 kB
import gradio as gr
import torch
import joblib
import numpy as np
from itertools import product
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
import io
from io import BytesIO # Import io then BytesIO
from PIL import Image, ImageDraw, ImageFont
from Bio.Graphics import GenomeDiagram
from Bio.SeqFeature import SeqFeature, FeatureLocation
from reportlab.lib import colors
import pandas as pd
import tempfile
import os
from typing import List, Dict, Tuple, Optional, Any
import seaborn as sns
###############################################################################
# 1. MODEL DEFINITION
###############################################################################
class VirusClassifier(nn.Module):
def __init__(self, input_shape: int):
super(VirusClassifier, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_shape, 64),
nn.GELU(),
nn.BatchNorm1d(64),
nn.Dropout(0.3),
nn.Linear(64, 32),
nn.GELU(),
nn.BatchNorm1d(32),
nn.Dropout(0.3),
nn.Linear(32, 32),
nn.GELU(),
nn.Linear(32, 2)
)
def forward(self, x):
return self.network(x)
###############################################################################
# 2. FASTA PARSING & K-MER FEATURE ENGINEERING
###############################################################################
def parse_fasta(text):
sequences = []
current_header = None
current_sequence = []
for line in text.strip().split('\n'):
line = line.strip()
if not line:
continue
if line.startswith('>'):
if current_header:
sequences.append((current_header, ''.join(current_sequence)))
current_header = line[1:]
current_sequence = []
else:
current_sequence.append(line.upper())
if current_header:
sequences.append((current_header, ''.join(current_sequence)))
return sequences
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
"""
Convert a sequence into a frequency vector of all possible 4-mer combinations.
"""
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
kmer_dict = {km: i for i, km in enumerate(kmers)}
vec = np.zeros(len(kmers), dtype=np.float32)
for i in range(len(sequence) - k + 1):
kmer = sequence[i:i+k]
if kmer in kmer_dict:
vec[kmer_dict[kmer]] += 1
total_kmers = len(sequence) - k + 1
if total_kmers > 0:
vec /= total_kmers
return vec
###############################################################################
# 3. SHAP-VALUE (ABLATION) CALCULATION
###############################################################################
def calculate_shap_values(model, x_tensor):
"""
A simple ablation-based SHAP approximation. Zero out each position
and measure the impact on the 'human' probability.
"""
model.eval()
with torch.no_grad():
baseline_output = model(x_tensor)
baseline_probs = torch.softmax(baseline_output, dim=1)
baseline_prob = baseline_probs[0, 1].item() # Probability for 'human'
shap_values = []
x_zeroed = x_tensor.clone()
for i in range(x_tensor.shape[1]):
original_val = x_zeroed[0, i].item()
x_zeroed[0, i] = 0.0
output = model(x_zeroed)
probs = torch.softmax(output, dim=1)
prob = probs[0, 1].item()
shap_values.append(baseline_prob - prob)
x_zeroed[0, i] = original_val
return np.array(shap_values), baseline_prob
###############################################################################
# 4. PER-BASE SHAP AGGREGATION
###############################################################################
def compute_positionwise_scores(sequence, shap_values, k=4):
"""
Distribute each k-mer's SHAP contribution across its k underlying positions.
"""
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
kmer_dict = {km: i for i, km in enumerate(kmers)}
seq_len = len(sequence)
shap_sums = np.zeros(seq_len, dtype=np.float32)
coverage = np.zeros(seq_len, dtype=np.float32)
for i in range(seq_len - k + 1):
kmer = sequence[i:i+k]
if kmer in kmer_dict:
val = shap_values[kmer_dict[kmer]]
shap_sums[i:i+k] += val
coverage[i:i+k] += 1
with np.errstate(divide='ignore', invalid='ignore'):
shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0)
return shap_means
###############################################################################
# 5. FIND EXTREME SHAP REGIONS
###############################################################################
def find_extreme_subregion(shap_means, window_size=500, mode="max"):
"""
Use a sliding window to find the subregion with the highest (or lowest) average SHAP.
"""
n = len(shap_means)
if n == 0:
return (0, 0, 0.0)
if window_size >= n:
return (0, n, float(np.mean(shap_means)))
csum = np.zeros(n + 1, dtype=np.float32)
csum[1:] = np.cumsum(shap_means)
best_start = 0
best_sum = csum[window_size] - csum[0]
best_avg = best_sum / window_size
for start in range(1, n - window_size + 1):
wsum = csum[start + window_size] - csum[start]
wavg = wsum / window_size
if mode == "max" and wavg > best_avg:
best_avg = wavg
best_start = start
elif mode == "min" and wavg < best_avg:
best_avg = wavg
best_start = start
return (best_start, best_start + window_size, float(best_avg))
###############################################################################
# 6. PLOTTING / UTILITIES
###############################################################################
def fig_to_image(fig):
"""
Render a Matplotlib figure to a PIL Image.
"""
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
buf.seek(0)
img = Image.open(buf)
plt.close(fig)
return img
def get_zero_centered_cmap():
"""
Create a symmetrical (blue-white-red) colormap around zero.
"""
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
"""
Plot an inline heatmap for the chosen region (or entire genome if start/end not provided).
"""
if start is not None and end is not None:
local_shap = shap_means[start:end]
subtitle = f" (positions {start}-{end})"
else:
local_shap = shap_means
subtitle = ""
if len(local_shap) == 0:
local_shap = np.array([0.0])
heatmap_data = local_shap.reshape(1, -1)
min_val = np.min(local_shap)
max_val = np.max(local_shap)
extent = max(abs(min_val), abs(max_val))
cmap = get_zero_centered_cmap()
fig, ax = plt.subplots(figsize=(12, 1.8))
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
cbar.ax.tick_params(labelsize=8)
cbar.set_label('SHAP Contribution', fontsize=9, labelpad=5)
ax.set_yticks([])
ax.set_xlabel('Position in Sequence', fontsize=10)
ax.set_title(f"{title}{subtitle}", pad=10)
plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
return fig
def create_importance_bar_plot(shap_values, kmers, top_k=10):
"""
Show bar chart of top k-mers by absolute SHAP value.
"""
plt.rcParams.update({'font.size': 10})
fig = plt.figure(figsize=(10, 5))
indices = np.argsort(np.abs(shap_values))[-top_k:]
values = shap_values[indices]
features = [kmers[i] for i in indices]
colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
plt.barh(range(len(values)), values, color=colors)
plt.yticks(range(len(values)), features)
plt.xlabel('SHAP Value (impact on model output)')
plt.title(f'Top {top_k} Most Influential k-mers')
plt.gca().invert_yaxis()
plt.tight_layout()
return fig
def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bins=30):
"""
Plot a histogram of SHAP values in some region.
"""
fig, ax = plt.subplots(figsize=(6, 4))
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
ax.axvline(0, color='red', linestyle='--', label='0.0')
ax.set_xlabel("SHAP Value")
ax.set_ylabel("Count")
ax.set_title(title)
ax.legend()
plt.tight_layout()
return fig
def compute_gc_content(sequence):
"""
Compute GC content (%) for a given sequence.
"""
if not sequence:
return 0.0
gc_count = sequence.count('G') + sequence.count('C')
return (gc_count / len(sequence)) * 100.0
###############################################################################
# 7. MAIN ANALYSIS STEP (Gradio Step 1)
###############################################################################
def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
"""
Perform the main classification, SHAP analysis, and extreme subregion detection
for a single sequence.
"""
# 1) Read input
if fasta_text.strip():
text = fasta_text.strip()
elif file_obj is not None:
try:
with open(file_obj, 'r') as f:
text = f.read()
except Exception as e:
return (f"Error reading file: {str(e)}", None, None, None, None, None)
else:
return ("Please provide a FASTA sequence.", None, None, None, None, None)
# 2) Parse FASTA
sequences = parse_fasta(text)
if not sequences:
return ("No valid FASTA sequences found.", None, None, None, None, None)
header, seq = sequences[0]
# 3) Load model, scaler, and run inference
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
state_dict = torch.load('model.pt', map_location=device)
model = VirusClassifier(256).to(device)
model.load_state_dict(state_dict)
scaler = joblib.load('scaler.pkl')
except Exception as e:
return (f"Error loading model/scaler: {str(e)}", None, None, None, None, None)
freq_vector = sequence_to_kmer_vector(seq)
scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
x_tensor = torch.FloatTensor(scaled_vector).to(device)
shap_values, prob_human = calculate_shap_values(model, x_tensor)
prob_nonhuman = 1.0 - prob_human
classification = "Human" if prob_human > 0.5 else "Non-human"
confidence = max(prob_human, prob_nonhuman)
# 4) Per-base SHAP & subregion detection
shap_means = compute_positionwise_scores(seq, shap_values, k=4)
max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max")
min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min")
# 5) Prepare result text
results_text = (
f"Sequence: {header}\n"
f"Length: {len(seq):,} bases\n"
f"Classification: {classification}\n"
f"Confidence: {confidence:.3f}\n"
f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
f"---\n"
f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
f"Start: {max_start}, End: {max_end}, Avg SHAP: {max_avg:.4f}\n\n"
f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
)
# 6) Create bar & heatmap figures
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
bar_img = fig_to_image(bar_fig)
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
heatmap_img = fig_to_image(heatmap_fig)
# 7) Build the "state" dictionary so we can do subregion analysis
state_dict_out = {"seq": seq, "shap_means": shap_means}
# Return 6 items to match your Gradio output
return (results_text, bar_img, heatmap_img, state_dict_out, header, None)
###############################################################################
# 8. SUBREGION ANALYSIS (Gradio Step 2)
###############################################################################
def analyze_subregion(state, header, region_start, region_end):
"""
Examine a subregion’s SHAP distribution, GC content, etc.
"""
if not state or "seq" not in state or "shap_means" not in state:
return ("No sequence data found. Please run Step 1 first.", None, None, None)
seq = state["seq"]
shap_means = state["shap_means"]
region_start = int(region_start)
region_end = int(region_end)
region_start = max(0, min(region_start, len(seq)))
region_end = max(0, min(region_end, len(seq)))
if region_end <= region_start:
return ("Invalid region range. End must be > Start.", None, None, None)
region_seq = seq[region_start:region_end]
region_shap = shap_means[region_start:region_end]
gc_percent = compute_gc_content(region_seq)
avg_shap = float(np.mean(region_shap))
positive_fraction = np.mean(region_shap > 0)
negative_fraction = np.mean(region_shap < 0)
if avg_shap > 0.05:
region_classification = "Likely pushing toward human"
elif avg_shap < -0.05:
region_classification = "Likely pushing toward non-human"
else:
region_classification = "Near neutral (no strong push)"
region_info = (
f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
f"Region length: {len(region_seq)} bases\n"
f"GC content: {gc_percent:.2f}%\n"
f"Average SHAP in region: {avg_shap:.4f}\n"
f"Fraction with SHAP > 0 (toward human): {positive_fraction:.2f}\n"
f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
f"Subregion interpretation: {region_classification}\n"
)
heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end)
heatmap_img = fig_to_image(heatmap_fig)
hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
hist_img = fig_to_image(hist_fig)
# Return 4 items to match your Gradio output
return (region_info, heatmap_img, hist_img, None)
###############################################################################
# 9. COMPARISON ANALYSIS FUNCTIONS (Step 4)
###############################################################################
def compute_shap_difference(shap1_norm, shap2_norm):
"""
Compute the SHAP difference (Seq2 - Seq1).
"""
return shap2_norm - shap1_norm
def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
"""
Plot a 1D heatmap of differences using relative positions 0-100%.
"""
heatmap_data = shap_diff.reshape(1, -1)
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
fig, ax = plt.subplots(figsize=(12, 1.8))
cmap = get_zero_centered_cmap()
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
# Create percentage-based x-axis ticks
num_ticks = 5
tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
ax.set_xticks(tick_positions)
ax.set_xticklabels(tick_labels)
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
cbar.ax.tick_params(labelsize=8)
cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
ax.set_yticks([])
ax.set_xlabel('Relative Position in Sequence', fontsize=10)
ax.set_title(title, pad=10)
plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
return fig
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
"""
Plot a histogram of SHAP values with optional # of bins.
"""
fig, ax = plt.subplots(figsize=(6, 4))
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
ax.axvline(0, color='red', linestyle='--', label='0.0')
ax.set_xlabel("SHAP Value")
ax.set_ylabel("Count")
ax.set_title(title)
ax.legend()
plt.tight_layout()
return fig
def calculate_adaptive_parameters(len1, len2):
"""
Choose smoothing & interpolation parameters automatically based on length difference.
"""
length_diff = abs(len1 - len2)
max_length = max(len1, len2)
min_length = min(len1, len2)
length_ratio = min_length / max_length
# Base number of points
base_points = min(2000, max(500, max_length // 100))
if length_diff < 500:
resolution_factor = 2.0
num_points = min(3000, base_points * 2)
smooth_window = max(10, length_diff // 50)
elif length_diff < 5000:
resolution_factor = 1.5
num_points = min(2000, base_points * 1.5)
smooth_window = max(20, length_diff // 100)
elif length_diff < 50000:
resolution_factor = 1.0
num_points = base_points
smooth_window = max(50, length_diff // 200)
else:
resolution_factor = 0.75
num_points = max(500, base_points // 2)
smooth_window = max(100, length_diff // 500)
smooth_window = int(smooth_window * (1 + (1 - length_ratio)))
return int(num_points), int(smooth_window), resolution_factor
def sliding_window_smooth(values, window_size=50):
"""
A custom smoothing approach, including exponential decay at edges.
"""
if window_size < 3:
return values
window = np.ones(window_size)
decay = np.exp(-np.linspace(0, 3, window_size // 2))
window[:window_size // 2] = decay
window[-(window_size // 2):] = decay[::-1]
window = window / window.sum()
smoothed = np.convolve(values, window, mode='valid')
pad_size = len(values) - len(smoothed)
pad_left = pad_size // 2
pad_right = pad_size - pad_left
result = np.zeros_like(values)
result[pad_left:-pad_right] = smoothed
result[:pad_left] = values[:pad_left]
result[-pad_right:] = values[-pad_right:]
return result
def normalize_shap_lengths(shap1, shap2):
"""
Smooth, interpolate, and return arrays of the same length for direct comparison.
"""
num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
shap1_smooth = sliding_window_smooth(shap1, smooth_window)
shap2_smooth = sliding_window_smooth(shap2, smooth_window)
x1 = np.linspace(0, 1, len(shap1_smooth))
x2 = np.linspace(0, 1, len(shap2_smooth))
x_norm = np.linspace(0, 1, num_points)
shap1_interp = np.interp(x_norm, x1, shap1_smooth)
shap2_interp = np.interp(x_norm, x2, shap2_smooth)
return shap1_interp, shap2_interp, smooth_window
def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
"""
Compare two sequences using the previously defined analysis pipeline
and produce difference visualizations & stats.
"""
try:
# Analyze first sequence
res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500)
if isinstance(res1[0], str) and "Error" in res1[0]:
return (f"Error in sequence 1: {res1[0]}", None, None, None)
# Analyze second sequence
res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500)
if isinstance(res2[0], str) and "Error" in res2[0]:
return (f"Error in sequence 2: {res2[0]}", None, None, None)
shap1 = res1[3]["shap_means"]
shap2 = res2[3]["shap_means"]
len1, len2 = len(shap1), len(shap2)
length_diff = abs(len1 - len2)
length_ratio = min(len1, len2) / max(len1, len2)
# Normalize both to the same length
shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2)
shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
# Compute stats
base_threshold = 0.05
adaptive_threshold = base_threshold * (1 + (1 - length_ratio))
if length_diff > 50000:
adaptive_threshold *= 1.5
avg_diff = np.mean(shap_diff)
std_diff = np.std(shap_diff)
max_diff = np.max(shap_diff)
min_diff = np.min(shap_diff)
substantial_diffs = np.abs(shap_diff) > adaptive_threshold
frac_different = np.mean(substantial_diffs)
# Extract classification from text
try:
classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip()
classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip()
except:
classification1 = "Unknown"
classification2 = "Unknown"
comparison_text = (
"Sequence Comparison Results:\n"
f"Sequence 1: {res1[4]}\n"
f"Length: {len1:,} bases\n"
f"Classification: {classification1}\n\n"
f"Sequence 2: {res2[4]}\n"
f"Length: {len2:,} bases\n"
f"Classification: {classification2}\n\n"
"Comparison Parameters:\n"
f"Length Difference: {length_diff:,} bases\n"
f"Length Ratio: {length_ratio:.3f}\n"
f"Smoothing Window: {smooth_window} points\n"
f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n"
"Statistics:\n"
f"Average SHAP difference: {avg_diff:.4f}\n"
f"Standard deviation: {std_diff:.4f}\n"
f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
f"Fraction with substantial differences: {frac_different:.2%}\n\n"
"Note: All parameters automatically adjusted based on sequence properties\n\n"
"Interpretation:\n"
"- Red regions: Sequence 2 more human-like\n"
"- Blue regions: Sequence 1 more human-like\n"
"- White regions: Similar between sequences"
)
heatmap_fig = plot_comparative_heatmap(
shap_diff,
title=f"SHAP Difference Heatmap (window: {smooth_window})"
)
heatmap_img = fig_to_image(heatmap_fig)
num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
hist_fig = plot_shap_histogram(
shap_diff,
title="Distribution of SHAP Differences",
num_bins=num_bins
)
hist_img = fig_to_image(hist_fig)
return (comparison_text, heatmap_img, hist_img, None)
except Exception as e:
error_msg = f"Error during sequence comparison: {str(e)}"
return (error_msg, None, None, None)
###############################################################################
# 10. ADDITIONAL / ADVANCED VISUALIZATIONS & STATISTICS
###############################################################################
def n50_length(sequence):
"""
Calculate the N50 for a single continuous sequence (for demonstration).
For a single sequence, N50 is typically the length if it's just one piece,
but let's do a simplistic example.
"""
# If you had contigs, you'd do a sorted list, cumulative sums, etc.
# We'll do a trivial approach here:
return len(sequence) # Because we have only one contiguous region
def sequence_complexity(sequence):
"""
Compute a simple measure of 'sequence complexity'.
Here, we define complexity as the Shannon entropy over the nucleotides.
"""
from math import log2
length = len(sequence)
if length == 0:
return 0.0
freq = {}
for base in sequence:
freq[base] = freq.get(base, 0) + 1
complexity = 0.0
for base, count in freq.items():
p = count / length
complexity -= p * log2(p)
return complexity
def advanced_gene_statistics(gene_shap: np.ndarray, gene_seq: str) -> Dict[str, float]:
"""
Additional stats: N50, complexity, etc.
"""
stats = {}
stats['n50'] = len(gene_seq) # trivial for a single gene region
stats['entropy'] = sequence_complexity(gene_seq)
stats['avg_shap'] = float(np.mean(gene_shap))
stats['max_shap'] = float(np.max(gene_shap)) if len(gene_shap) else 0.0
stats['min_shap'] = float(np.min(gene_shap)) if len(gene_shap) else 0.0
return stats
###############################################################################
# 11. GENE FEATURE ANALYSIS
###############################################################################
def parse_gene_features(text: str) -> List[Dict[str, Any]]:
"""Parse gene features from text file in a FASTA-like format."""
genes = []
current_header = None
current_sequence = []
for line in text.strip().split('\n'):
line = line.strip()
if not line:
continue
if line.startswith('>'):
if current_header:
genes.append({
'header': current_header,
'sequence': ''.join(current_sequence),
'metadata': parse_gene_metadata(current_header)
})
current_header = line[1:]
current_sequence = []
else:
current_sequence.append(line.upper())
if current_header:
genes.append({
'header': current_header,
'sequence': ''.join(current_sequence),
'metadata': parse_gene_metadata(current_header)
})
return genes
def parse_gene_metadata(header: str) -> Dict[str, str]:
"""Extract metadata from gene header line."""
metadata = {}
parts = header.split()
for part in parts:
if '[' in part and ']' in part:
key_value = part[1:-1].split('=', 1)
if len(key_value) == 2:
metadata[key_value[0]] = key_value[1]
return metadata
def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
"""Parse gene location string, handling forward and complement strands."""
try:
clean_loc = location_str.replace('complement(', '').replace(')', '')
if '..' in clean_loc:
start, end = map(int, clean_loc.split('..'))
return start, end
else:
return None, None
except Exception as e:
print(f"Error parsing location {location_str}: {str(e)}")
return None, None
def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
"""Basic statistical measures for gene SHAP values."""
return {
'avg_shap': float(np.mean(gene_shap)) if len(gene_shap) else 0.0,
'median_shap': float(np.median(gene_shap)) if len(gene_shap) else 0.0,
'std_shap': float(np.std(gene_shap)) if len(gene_shap) else 0.0,
'max_shap': float(np.max(gene_shap)) if len(gene_shap) else 0.0,
'min_shap': float(np.min(gene_shap)) if len(gene_shap) else 0.0,
'pos_fraction': float(np.mean(gene_shap > 0)) if len(gene_shap) else 0.0
}
def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
"""
A quick PIL-based diagram to show genes along the genome.
Color intensity = magnitude of SHAP. Red/Blue = sign of SHAP.
"""
if not gene_results or genome_length <= 0:
img = Image.new('RGB', (800, 100), color='white')
draw = ImageDraw.Draw(img)
draw.text((10, 40), "Error: Invalid input data", fill='black')
return img
for gene in gene_results:
gene['start'] = max(0, int(gene['start']))
gene['end'] = min(genome_length, int(gene['end']))
if gene['start'] >= gene['end']:
print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}")
width = 1500
height = 600
margin = 50
track_height = 40
img = Image.new('RGB', (width, height), 'white')
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
except:
font = ImageFont.load_default()
title_font = ImageFont.load_default()
draw.text((margin, margin // 2), "Genome SHAP Analysis (Simple)", fill='black', font=title_font or font)
line_y = height // 2
draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2)
scale = float(width - 2 * margin) / float(genome_length)
# Scale markers
num_ticks = 10
step = max(1, genome_length // num_ticks)
for i in range(0, genome_length + 1, step):
x_coord = margin + i * scale
draw.line([
(int(x_coord), int(line_y - 5)),
(int(x_coord), int(line_y + 5))
], fill='black', width=1)
draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
for idx, gene in enumerate(sorted_genes):
start_x = margin + int(gene['start'] * scale)
end_x = margin + int(gene['end'] * scale)
avg_shap = gene['avg_shap']
intensity = int(abs(avg_shap) * 500)
intensity = max(50, min(255, intensity))
if avg_shap > 0:
color = (255, 255 - intensity, 255 - intensity) # Redish
else:
color = (255 - intensity, 255 - intensity, 255) # Blueish
draw.rectangle([
(int(start_x), int(line_y - track_height // 2)),
(int(end_x), int(line_y + track_height // 2))
], fill=color, outline='black')
label = str(gene.get('gene_name','?'))
label_mask = font.getmask(label)
label_width, label_height = label_mask.size
if idx % 2 == 0:
text_y = line_y - track_height - 15
else:
text_y = line_y + track_height + 5
gene_width = end_x - start_x
if gene_width > label_width:
text_x = start_x + (gene_width - label_width) // 2
draw.text((int(text_x), int(text_y)), label, fill='black', font=font)
elif gene_width > 20:
txt_img = Image.new('RGBA', (label_width, label_height), (255, 255, 255, 0))
txt_draw = ImageDraw.Draw(txt_img)
txt_draw.text((0, 0), label, font=font, fill='black')
rotated_img = txt_img.rotate(90, expand=True)
img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img)
return img
def create_advanced_genome_diagram(gene_results: List[Dict[str, Any]],
genome_length: int,
shap_means: np.ndarray,
diagram_title: str = "Advanced Genome Diagram") -> Image.Image:
"""
An advanced genome diagram using Biopython's GenomeDiagram.
We'll create tracks for genes and a 'SHAP line plot' track.
"""
if not gene_results or genome_length <= 0 or len(shap_means) == 0:
# Fallback if data is invalid
img = Image.new('RGB', (800, 100), color='white')
d = ImageDraw.Draw(img)
d.text((10, 40), "Error: Not enough data for advanced diagram", fill='black')
return img
diagram = GenomeDiagram.Diagram(diagram_title)
gene_track = diagram.new_track(1, name="Genes", greytrack=False, height=0.5)
gene_set = gene_track.new_set()
# Add each gene as a feature
for gene in gene_results:
start = max(0, int(gene['start']))
end = min(genome_length, int(gene['end']))
avg_shap = gene['avg_shap']
# Color scale: negative = blue, positive = red
intensity = abs(avg_shap) * 500
intensity = max(50, min(255, intensity))
if avg_shap >= 0:
color_hex = colors.Color(1.0, 1.0 - intensity/255.0, 1.0 - intensity/255.0)
else:
color_hex = colors.Color(1.0 - intensity/255.0, 1.0 - intensity/255.0, 1.0)
feature = SeqFeature(FeatureLocation(start, end), strand=1)
gene_set.add_feature(
feature,
color=color_hex,
label=True,
name=str(gene.get('gene_name','?')),
label_size=8,
label_color=colors.black
)
# Add a track for the SHAP line
shap_track = diagram.new_track(2, name="SHAP Score", greytrack=False, height=0.3)
shap_set = shap_track.new_set("graph")
# We'll plot the entire shap_means array.
# X coords = [0..genome_length], Y coords = shap_means
# We'll keep negative values below baseline, positive above.
# Normalizing for visualization
max_abs = max(abs(shap_means.min()), abs(shap_means.max()))
if max_abs == 0:
scaled_shap = [0]*len(shap_means)
else:
scaled_shap = (shap_means / max_abs * 50).tolist() # scale to +/- 50
shap_set.add_graph(
data=scaled_shap,
name="shap_line",
style="line",
color=colors.darkgreen,
altcolor=colors.red,
linewidth=1
)
# Draw to a temporary file
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmpf:
diagram.draw(format="linear", pagesize='A3', fragments=1, start=0, end=genome_length)
diagram.write(tmpf.name, "PDF")
# Convert PDF to a PIL image (requires poppler or similar).
# If you do not have poppler, you can skip PDF -> image or use Cairo.
try:
import pdf2image
pages = pdf2image.convert_from_path(tmpf.name, dpi=100)
img = pages[0] if pages else Image.new('RGB', (800, 100), color='white')
except ImportError:
img = Image.new('RGB', (800, 100), color='white')
d = ImageDraw.Draw(img)
d.text((10, 40), "pdf2image not installed, can't show advanced diagram as image.", fill='black')
# Cleanup
os.remove(tmpf.name)
return img
def analyze_gene_features(sequence_file: str,
features_file: str,
fasta_text: str = "",
features_text: str = "",
diagram_mode: str = "advanced"
) -> Tuple[str, Optional[str], Optional[Image.Image]]:
"""
Analyze each gene in the features file, compute gene-level SHAP stats,
produce tabular output, and create an optional genome diagram.
"""
# 1) Analyze the entire sequence with the top-level function
sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
return f"Error in sequence analysis: {sequence_results[0]}", None, None
seq = sequence_results[3]["seq"]
shap_means = sequence_results[3]["shap_means"]
genome_length = len(seq)
# 2) Read gene features
try:
if features_text.strip():
genes = parse_gene_features(features_text)
else:
with open(features_file, 'r') as f:
genes = parse_gene_features(f.read())
except Exception as e:
return f"Error reading features file: {str(e)}", None, None
gene_results = []
for gene in genes:
location = gene['metadata'].get('location', '')
if not location:
continue
start, end = parse_location(location)
if start is None or end is None or start >= end or end > genome_length:
continue
gene_shap = shap_means[start:end]
basic_stats = compute_gene_statistics(gene_shap)
# Additional stats
gene_seq = seq[start:end]
adv_stats = advanced_gene_statistics(gene_shap, gene_seq)
# Merge basic + advanced stats
all_stats = {**basic_stats, **adv_stats}
classification = 'Human' if basic_stats['avg_shap'] > 0 else 'Non-human'
locus_tag = gene['metadata'].get('locus_tag', '')
gene_name = gene['metadata'].get('gene', 'Unknown')
gene_dict = {
'gene_name': gene_name,
'location': location,
'start': start,
'end': end,
'locus_tag': locus_tag,
'avg_shap': all_stats['avg_shap'],
'median_shap': basic_stats['median_shap'],
'std_shap': basic_stats['std_shap'],
'max_shap': basic_stats['max_shap'],
'min_shap': basic_stats['min_shap'],
'pos_fraction': basic_stats['pos_fraction'],
'n50': all_stats['n50'],
'entropy': all_stats['entropy'],
'classification': classification,
'confidence': abs(all_stats['avg_shap'])
}
gene_results.append(gene_dict)
if not gene_results:
return "No valid genes could be processed", None, None
# 3) Summaries
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
results_text = "Gene Analysis Results:\n\n"
results_text += f"Total genes analyzed: {len(gene_results)}\n"
num_human = sum(1 for g in gene_results if g['classification'] == 'Human')
results_text += f"Human-like genes: {num_human}\n"
results_text += f"Non-human-like genes: {len(gene_results) - num_human}\n\n"
results_text += "Top 10 most distinctive genes (by avg SHAP magnitude):\n"
for gene in sorted_genes[:10]:
results_text += (
f"Gene: {gene['gene_name']}\n"
f"Location: {gene['location']}\n"
f"Classification: {gene['classification']} "
f"(confidence: {gene['confidence']:.4f})\n"
f"Average SHAP: {gene['avg_shap']:.4f}\n"
f"N50: {gene['n50']}, Entropy: {gene['entropy']:.3f}\n\n"
)
# 4) Make CSV
csv_content = "gene_name,location,start,end,locus_tag,avg_shap,median_shap,std_shap,"
csv_content += "max_shap,min_shap,pos_fraction,n50,entropy,classification,confidence\n"
for g in gene_results:
csv_content += (
f"{g['gene_name']},{g['location']},{g['start']},{g['end']},{g['locus_tag']},"
f"{g['avg_shap']:.4f},{g['median_shap']:.4f},{g['std_shap']:.4f},"
f"{g['max_shap']:.4f},{g['min_shap']:.4f},{g['pos_fraction']:.4f},"
f"{g['n50']},{g['entropy']:.4f},{g['classification']},{g['confidence']:.4f}\n"
)
try:
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv")
with open(temp_path, 'w') as f:
f.write(csv_content)
except Exception as e:
print(f"Error saving CSV: {str(e)}")
temp_path = None
# 5) Create diagram
try:
if diagram_mode == "advanced":
diagram_img = create_advanced_genome_diagram(gene_results, genome_length, shap_means)
else:
diagram_img = create_simple_genome_diagram(gene_results, genome_length)
except Exception as e:
print(f"Error creating visualization: {str(e)}")
diagram_img = Image.new('RGB', (800, 100), color='white')
draw = ImageDraw.Draw(diagram_img)
draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black')
return results_text, temp_path, diagram_img
###############################################################################
# 12. DOWNLOAD FUNCTIONS
###############################################################################
def prepare_csv_download(data, filename="analysis_results.csv"):
"""
Convert data to CSV for Gradio download button.
"""
if isinstance(data, str):
return data.encode(), filename
elif isinstance(data, (list, dict)):
import csv
from io import StringIO
output = StringIO()
writer = csv.DictWriter(output, fieldnames=data[0].keys())
writer.writeheader()
writer.writerows(data)
return output.getvalue().encode(), filename
else:
raise ValueError("Unsupported data type for CSV download")
###############################################################################
# 13. BUILD GRADIO INTERFACE
###############################################################################
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.download-button {
margin-top: 10px;
}
"""
with gr.Blocks(css=css) as iface:
gr.Markdown("""
# Virus Host Classifier + Extended Genome Visualization
**Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme subregions.
**Step 2**: Explore subregions (local SHAP, GC content, histogram).
**Step 3**: Analyze gene features (per-gene SHAP, advanced stats, improved diagrams).
**Step 4**: Compare sequences for SHAP differences.
**Color Scale**: Negative SHAP = Blue, 0 = White, Positive = Red.
""")
with gr.Tab("1) Full-Sequence Analysis"):
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
text_input = gr.Textbox(label="Or paste FASTA", placeholder=">name\nACGT...", lines=5)
top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Subregion Window Size")
analyze_btn = gr.Button("Analyze Sequence", variant="primary")
with gr.Column(scale=2):
results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
kmer_img = gr.Image(label="Top k-mer SHAP")
genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
download_results = gr.File(label="Download Results", visible=False, elem_classes="download-button")
seq_state = gr.State()
header_state = gr.State()
analyze_btn.click(
analyze_sequence,
inputs=[file_input, top_k, text_input, win_size],
outputs=[results_box, kmer_img, genome_img, seq_state, header_state, download_results]
)
with gr.Tab("2) Subregion Exploration"):
gr.Markdown("""
**Subregion Analysis**
View SHAP signals, GC content, etc. for a specific region.
""")
with gr.Row():
region_start = gr.Number(label="Region Start", value=0)
region_end = gr.Number(label="Region End", value=500)
region_btn = gr.Button("Analyze Subregion")
subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False)
with gr.Row():
subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
download_subregion = gr.File(label="Download Subregion", visible=False, elem_classes="download-button")
region_btn.click(
analyze_subregion,
inputs=[seq_state, header_state, region_start, region_end],
outputs=[subregion_info, subregion_img, subregion_hist_img, download_subregion]
)
with gr.Tab("3) Gene Features Analysis"):
gr.Markdown("""
**Analyze Gene Features**
- Upload a FASTA file and a gene features file.
- See per-gene SHAP, classification, N50, entropy, etc.
- Choose a diagram mode (simple or advanced).
""")
with gr.Row():
with gr.Column(scale=1):
gene_fasta_file = gr.File(label="FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence", lines=5)
with gr.Column(scale=1):
features_file = gr.File(label="Gene features file", file_types=[".txt"], type="filepath")
features_text = gr.Textbox(label="Or paste gene features", lines=5)
diagram_mode = gr.Radio(choices=["simple", "advanced"], value="advanced", label="Diagram Mode")
analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary")
gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False)
gene_diagram = gr.Image(label="Genome Diagram")
download_gene_results = gr.File(label="Download Gene Analysis (CSV)", visible=True)
analyze_genes_btn.click(
analyze_gene_features,
inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text, diagram_mode],
outputs=[gene_results, download_gene_results, gene_diagram]
)
with gr.Tab("4) Comparative Analysis"):
gr.Markdown("""
**Compare Two Sequences**
- Upload or paste two FASTA sequences.
- We'll compare SHAP patterns (normalized for different lengths).
""")
with gr.Row():
with gr.Column(scale=1):
file_input1 = gr.File(label="1st FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
text_input1 = gr.Textbox(label="Or paste 1st FASTA", lines=5)
with gr.Column(scale=1):
file_input2 = gr.File(label="2nd FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
text_input2 = gr.Textbox(label="Or paste 2nd FASTA", lines=5)
compare_btn = gr.Button("Compare Sequences", variant="primary")
comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
with gr.Row():
diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
diff_hist = gr.Image(label="Distribution of SHAP Differences")
download_comparison = gr.File(label="Download Comparison", visible=False, elem_classes="download-button")
compare_btn.click(
analyze_sequence_comparison,
inputs=[file_input1, file_input2, text_input1, text_input2],
outputs=[comparison_text, diff_heatmap, diff_hist, download_comparison]
)
gr.Markdown("""
### Notes & Features
- **Advanced Genome Diagram** uses Biopython’s `GenomeDiagram` (requires `pdf2image` if you want it as an image).
- **Additional Stats**: N50, Shannon entropy, etc.
- **Auto-scaling** for comparative analysis with adaptive smoothing.
- **Data Export**: Download CSV of analysis results.
""")
if __name__ == "__main__":
iface.launch()