Spaces:
Running
Running
import gradio as gr | |
import torch | |
import joblib | |
import numpy as np | |
from itertools import product | |
import torch.nn as nn | |
import matplotlib.pyplot as plt | |
import matplotlib.colors as mcolors | |
import io | |
from PIL import Image | |
from scipy.interpolate import interp1d | |
############################################################################### | |
# 1. MODEL DEFINITION | |
############################################################################### | |
class VirusClassifier(nn.Module): | |
def __init__(self, input_shape: int): | |
super(VirusClassifier, self).__init__() | |
self.network = nn.Sequential( | |
nn.Linear(input_shape, 64), | |
nn.GELU(), | |
nn.BatchNorm1d(64), | |
nn.Dropout(0.3), | |
nn.Linear(64, 32), | |
nn.GELU(), | |
nn.BatchNorm1d(32), | |
nn.Dropout(0.3), | |
nn.Linear(32, 32), | |
nn.GELU(), | |
nn.Linear(32, 2) | |
) | |
def forward(self, x): | |
return self.network(x) | |
############################################################################### | |
# 2. FASTA PARSING & K-MER FEATURE ENGINEERING | |
############################################################################### | |
def parse_fasta(text): | |
sequences = [] | |
current_header = None | |
current_sequence = [] | |
for line in text.strip().split('\n'): | |
line = line.strip() | |
if not line: continue | |
if line.startswith('>'): | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
current_header = line[1:] | |
current_sequence = [] | |
else: | |
current_sequence.append(line.upper()) | |
if current_header: | |
sequences.append((current_header, ''.join(current_sequence))) | |
return sequences | |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: | |
kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
kmer_dict = {km: i for i, km in enumerate(kmers)} | |
vec = np.zeros(len(kmers), dtype=np.float32) | |
for i in range(len(sequence) - k + 1): | |
kmer = sequence[i:i+k] | |
if kmer in kmer_dict: | |
vec[kmer_dict[kmer]] += 1 | |
total_kmers = len(sequence) - k + 1 | |
if total_kmers > 0: | |
vec /= total_kmers | |
return vec | |
############################################################################### | |
# 3. SHAP-VALUE (ABLATION) CALCULATION | |
############################################################################### | |
def calculate_shap_values(model, x_tensor): | |
model.eval() | |
with torch.no_grad(): | |
baseline_output = model(x_tensor) | |
baseline_probs = torch.softmax(baseline_output, dim=1) | |
baseline_prob = baseline_probs[0, 1].item() # Prob of 'human' | |
shap_values = [] | |
x_zeroed = x_tensor.clone() | |
for i in range(x_tensor.shape[1]): | |
original_val = x_zeroed[0, i].item() | |
x_zeroed[0, i] = 0.0 | |
output = model(x_zeroed) | |
probs = torch.softmax(output, dim=1) | |
prob = probs[0, 1].item() | |
shap_values.append(baseline_prob - prob) | |
x_zeroed[0, i] = original_val | |
return np.array(shap_values), baseline_prob | |
############################################################################### | |
# 4. PER-BASE SHAP AGGREGATION | |
############################################################################### | |
def compute_positionwise_scores(sequence, shap_values, k=4): | |
kmers = [''.join(p) for p in product("ACGT", repeat=k)] | |
kmer_dict = {km: i for i, km in enumerate(kmers)} | |
seq_len = len(sequence) | |
shap_sums = np.zeros(seq_len, dtype=np.float32) | |
coverage = np.zeros(seq_len, dtype=np.float32) | |
for i in range(seq_len - k + 1): | |
kmer = sequence[i:i+k] | |
if kmer in kmer_dict: | |
val = shap_values[kmer_dict[kmer]] | |
shap_sums[i:i+k] += val | |
coverage[i:i+k] += 1 | |
with np.errstate(divide='ignore', invalid='ignore'): | |
shap_means = np.where(coverage > 0, shap_sums / coverage, 0.0) | |
return shap_means | |
############################################################################### | |
# 5. FIND EXTREME SHAP REGIONS | |
############################################################################### | |
def find_extreme_subregion(shap_means, window_size=500, mode="max"): | |
n = len(shap_means) | |
if n == 0: return (0, 0, 0.0) | |
if window_size >= n: | |
return (0, n, float(np.mean(shap_means))) | |
csum = np.zeros(n + 1, dtype=np.float32) | |
csum[1:] = np.cumsum(shap_means) | |
best_start = 0 | |
best_sum = csum[window_size] - csum[0] | |
best_avg = best_sum / window_size | |
for start in range(1, n - window_size + 1): | |
wsum = csum[start + window_size] - csum[start] | |
wavg = wsum / window_size | |
if mode == "max" and wavg > best_avg: | |
best_avg = wavg; best_start = start | |
elif mode == "min" and wavg < best_avg: | |
best_avg = wavg; best_start = start | |
return (best_start, best_start + window_size, float(best_avg)) | |
############################################################################### | |
# 6. PLOTTING / UTILITIES | |
############################################################################### | |
def fig_to_image(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150) | |
buf.seek(0) | |
img = Image.open(buf) | |
plt.close(fig) | |
return img | |
def get_zero_centered_cmap(): | |
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')] | |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors) | |
def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None): | |
if start is not None and end is not None: | |
local_shap = shap_means[start:end] | |
subtitle = f" (positions {start}-{end})" | |
else: | |
local_shap = shap_means | |
subtitle = "" | |
if len(local_shap) == 0: | |
local_shap = np.array([0.0]) | |
heatmap_data = local_shap.reshape(1, -1) | |
min_val = np.min(local_shap) | |
max_val = np.max(local_shap) | |
extent = max(abs(min_val), abs(max_val)) | |
cmap = get_zero_centered_cmap() | |
fig, ax = plt.subplots(figsize=(12, 1.8)) | |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent) | |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8) | |
cbar.ax.tick_params(labelsize=8) | |
cbar.set_label('SHAP Contribution', fontsize=9, labelpad=5) | |
ax.set_yticks([]) | |
ax.set_xlabel('Position in Sequence', fontsize=10) | |
ax.set_title(f"{title}{subtitle}", pad=10) | |
plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95) | |
return fig | |
def create_importance_bar_plot(shap_values, kmers, top_k=10): | |
plt.rcParams.update({'font.size': 10}) | |
fig = plt.figure(figsize=(10, 5)) | |
indices = np.argsort(np.abs(shap_values))[-top_k:] | |
values = shap_values[indices] | |
features = [kmers[i] for i in indices] | |
colors = ['#99ccff' if v < 0 else '#ff9999' for v in values] | |
plt.barh(range(len(values)), values, color=colors) | |
plt.yticks(range(len(values)), features) | |
plt.xlabel('SHAP Value (impact on model output)') | |
plt.title(f'Top {top_k} Most Influential k-mers') | |
plt.gca().invert_yaxis() | |
plt.tight_layout() | |
return fig | |
def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"): | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
ax.hist(shap_array, bins=30, color='gray', edgecolor='black') | |
ax.axvline(0, color='red', linestyle='--', label='0.0') | |
ax.set_xlabel("SHAP Value") | |
ax.set_ylabel("Count") | |
ax.set_title(title) | |
ax.legend() | |
plt.tight_layout() | |
return fig | |
def compute_gc_content(sequence): | |
if not sequence: return 0 | |
gc_count = sequence.count('G') + sequence.count('C') | |
return (gc_count / len(sequence)) * 100.0 | |
############################################################################### | |
# 7. MAIN ANALYSIS STEP (Gradio Step 1) | |
############################################################################### | |
def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500): | |
if fasta_text.strip(): | |
text = fasta_text.strip() | |
elif file_obj is not None: | |
try: | |
with open(file_obj, 'r') as f: | |
text = f.read() | |
except Exception as e: | |
return (f"Error reading file: {str(e)}", None, None, None, None) | |
else: | |
return ("Please provide a FASTA sequence.", None, None, None, None) | |
sequences = parse_fasta(text) | |
if not sequences: | |
return ("No valid FASTA sequences found.", None, None, None, None) | |
header, seq = sequences[0] | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
try: | |
state_dict = torch.load('model.pt', map_location=device, weights_only=True) | |
model = VirusClassifier(256).to(device) | |
model.load_state_dict(state_dict) | |
scaler = joblib.load('scaler.pkl') | |
except Exception as e: | |
return (f"Error loading model/scaler: {str(e)}", None, None, None, None) | |
freq_vector = sequence_to_kmer_vector(seq) | |
scaled_vector = scaler.transform(freq_vector.reshape(1, -1)) | |
x_tensor = torch.FloatTensor(scaled_vector).to(device) | |
shap_values, prob_human = calculate_shap_values(model, x_tensor) | |
prob_nonhuman = 1.0 - prob_human | |
classification = "Human" if prob_human > 0.5 else "Non-human" | |
confidence = max(prob_human, prob_nonhuman) | |
shap_means = compute_positionwise_scores(seq, shap_values, k=4) | |
max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max") | |
min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min") | |
results_text = ( | |
f"Sequence: {header}\n" | |
f"Length: {len(seq):,} bases\n" | |
f"Classification: {classification}\n" | |
f"Confidence: {confidence:.3f}\n" | |
f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n" | |
f"---\n" | |
f"**Most Human-Pushing {window_size}-bp Subregion**:\n" | |
f"Start: {max_start}, End: {max_end}, Avg SHAP: {max_avg:.4f}\n\n" | |
f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n" | |
f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}" | |
) | |
kmers = [''.join(p) for p in product("ACGT", repeat=4)] | |
bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers) | |
bar_img = fig_to_image(bar_fig) | |
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP") | |
heatmap_img = fig_to_image(heatmap_fig) | |
state_dict_out = {"seq": seq, "shap_means": shap_means} | |
return (results_text, bar_img, heatmap_img, state_dict_out, header) | |
############################################################################### | |
# 8. SUBREGION ANALYSIS (Gradio Step 2) | |
############################################################################### | |
def analyze_subregion(state, header, region_start, region_end): | |
if not state or "seq" not in state or "shap_means" not in state: | |
return ("No sequence data found. Please run Step 1 first.", None, None) | |
seq = state["seq"] | |
shap_means = state["shap_means"] | |
region_start = int(region_start) | |
region_end = int(region_end) | |
region_start = max(0, min(region_start, len(seq))) | |
region_end = max(0, min(region_end, len(seq))) | |
if region_end <= region_start: | |
return ("Invalid region range. End must be > Start.", None, None) | |
region_seq = seq[region_start:region_end] | |
region_shap = shap_means[region_start:region_end] | |
gc_percent = compute_gc_content(region_seq) | |
avg_shap = float(np.mean(region_shap)) | |
positive_fraction = np.mean(region_shap > 0) | |
negative_fraction = np.mean(region_shap < 0) | |
if avg_shap > 0.05: | |
region_classification = "Likely pushing toward human" | |
elif avg_shap < -0.05: | |
region_classification = "Likely pushing toward non-human" | |
else: | |
region_classification = "Near neutral (no strong push)" | |
region_info = ( | |
f"Analyzing subregion of {header} from {region_start} to {region_end}\n" | |
f"Region length: {len(region_seq)} bases\n" | |
f"GC content: {gc_percent:.2f}%\n" | |
f"Average SHAP in region: {avg_shap:.4f}\n" | |
f"Fraction with SHAP > 0 (toward human): {positive_fraction:.2f}\n" | |
f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n" | |
f"Subregion interpretation: {region_classification}\n" | |
) | |
heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end) | |
heatmap_img = fig_to_image(heatmap_fig) | |
hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion") | |
hist_img = fig_to_image(hist_fig) | |
return (region_info, heatmap_img, hist_img) | |
############################################################################### | |
# 9. COMPARISON ANALYSIS FUNCTIONS | |
############################################################################### | |
def get_zero_centered_cmap(): | |
"""Create a zero-centered blue-white-red colormap""" | |
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')] | |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors) | |
def compute_shap_difference(shap1_norm, shap2_norm): | |
"""Compute the SHAP difference between normalized sequences""" | |
return shap2_norm - shap1_norm | |
def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"): | |
""" | |
Plot heatmap using relative positions (0-100%) | |
""" | |
heatmap_data = shap_diff.reshape(1, -1) | |
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff))) | |
fig, ax = plt.subplots(figsize=(12, 1.8)) | |
cmap = get_zero_centered_cmap() | |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent) | |
# Create percentage-based x-axis ticks | |
num_ticks = 5 | |
tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks) | |
tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)] | |
ax.set_xticks(tick_positions) | |
ax.set_xticklabels(tick_labels) | |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8) | |
cbar.ax.tick_params(labelsize=8) | |
cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5) | |
ax.set_yticks([]) | |
ax.set_xlabel('Relative Position in Sequence', fontsize=10) | |
ax.set_title(title, pad=10) | |
plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95) | |
return fig | |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30): | |
""" | |
Plot histogram of SHAP values with configurable number of bins | |
""" | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7) | |
ax.axvline(0, color='red', linestyle='--', label='0.0') | |
ax.set_xlabel("SHAP Value") | |
ax.set_ylabel("Count") | |
ax.set_title(title) | |
ax.legend() | |
plt.tight_layout() | |
return fig | |
def calculate_adaptive_parameters(len1, len2): | |
""" | |
Calculate adaptive parameters based on sequence lengths and their difference. | |
Returns: (num_points, smooth_window, resolution_factor) | |
""" | |
length_diff = abs(len1 - len2) | |
max_length = max(len1, len2) | |
min_length = min(len1, len2) | |
length_ratio = min_length / max_length | |
# Base number of points scales with sequence length | |
base_points = min(2000, max(500, max_length // 100)) | |
# Adjust parameters based on sequence properties | |
if length_diff < 500: | |
resolution_factor = 2.0 | |
num_points = min(3000, base_points * 2) | |
smooth_window = max(10, length_diff // 50) | |
elif length_diff < 5000: | |
resolution_factor = 1.5 | |
num_points = min(2000, base_points * 1.5) | |
smooth_window = max(20, length_diff // 100) | |
elif length_diff < 50000: | |
resolution_factor = 1.0 | |
num_points = base_points | |
smooth_window = max(50, length_diff // 200) | |
else: | |
resolution_factor = 0.75 | |
num_points = max(500, base_points // 2) | |
smooth_window = max(100, length_diff // 500) | |
# Adjust window size based on length ratio | |
smooth_window = int(smooth_window * (1 + (1 - length_ratio))) | |
return int(num_points), int(smooth_window), resolution_factor | |
def sliding_window_smooth(values, window_size=50): | |
""" | |
Apply sliding window smoothing with edge handling | |
""" | |
if window_size < 3: | |
return values | |
# Create window with exponential decay at edges | |
window = np.ones(window_size) | |
decay = np.exp(-np.linspace(0, 3, window_size // 2)) | |
window[:window_size // 2] = decay | |
window[-(window_size // 2):] = decay[::-1] | |
window = window / window.sum() | |
# Apply convolution | |
smoothed = np.convolve(values, window, mode='valid') | |
# Handle edges | |
pad_size = len(values) - len(smoothed) | |
pad_left = pad_size // 2 | |
pad_right = pad_size - pad_left | |
result = np.zeros_like(values) | |
result[pad_left:-pad_right] = smoothed | |
result[:pad_left] = values[:pad_left] | |
result[-pad_right:] = values[-pad_right:] | |
return result | |
def normalize_shap_lengths(shap1, shap2): | |
""" | |
Normalize and smooth SHAP values with dynamic adaptation | |
""" | |
# Calculate adaptive parameters | |
num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2)) | |
# Apply initial smoothing | |
shap1_smooth = sliding_window_smooth(shap1, smooth_window) | |
shap2_smooth = sliding_window_smooth(shap2, smooth_window) | |
# Create relative positions and interpolate | |
x1 = np.linspace(0, 1, len(shap1_smooth)) | |
x2 = np.linspace(0, 1, len(shap2_smooth)) | |
x_norm = np.linspace(0, 1, num_points) | |
shap1_interp = np.interp(x_norm, x1, shap1_smooth) | |
shap2_interp = np.interp(x_norm, x2, shap2_smooth) | |
return shap1_interp, shap2_interp, smooth_window | |
def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""): | |
""" | |
Compare two sequences with adaptive parameters and visualization | |
""" | |
try: | |
# Analyze first sequence | |
res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500) | |
if isinstance(res1[0], str) and "Error" in res1[0]: | |
return (f"Error in sequence 1: {res1[0]}", None, None) | |
# Analyze second sequence | |
res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500) | |
if isinstance(res2[0], str) and "Error" in res2[0]: | |
return (f"Error in sequence 2: {res2[0]}", None, None) | |
# Extract SHAP values and sequence info | |
shap1 = res1[3]["shap_means"] | |
shap2 = res2[3]["shap_means"] | |
# Calculate sequence properties | |
len1, len2 = len(shap1), len(shap2) | |
length_diff = abs(len1 - len2) | |
length_ratio = min(len1, len2) / max(len1, len2) | |
# Normalize and compare sequences | |
shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2) | |
shap_diff = compute_shap_difference(shap1_norm, shap2_norm) | |
# Calculate adaptive threshold and statistics | |
base_threshold = 0.05 | |
adaptive_threshold = base_threshold * (1 + (1 - length_ratio)) | |
if length_diff > 50000: | |
adaptive_threshold *= 1.5 | |
# Calculate comparison statistics | |
avg_diff = np.mean(shap_diff) | |
std_diff = np.std(shap_diff) | |
max_diff = np.max(shap_diff) | |
min_diff = np.min(shap_diff) | |
substantial_diffs = np.abs(shap_diff) > adaptive_threshold | |
frac_different = np.mean(substantial_diffs) | |
# Extract classifications | |
try: | |
classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip() | |
classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip() | |
except: | |
classification1 = "Unknown" | |
classification2 = "Unknown" | |
# Format output text | |
comparison_text = ( | |
"Sequence Comparison Results:\n" | |
f"Sequence 1: {res1[4]}\n" | |
f"Length: {len1:,} bases\n" | |
f"Classification: {classification1}\n\n" | |
f"Sequence 2: {res2[4]}\n" | |
f"Length: {len2:,} bases\n" | |
f"Classification: {classification2}\n\n" | |
"Comparison Parameters:\n" | |
f"Length Difference: {length_diff:,} bases\n" | |
f"Length Ratio: {length_ratio:.3f}\n" | |
f"Smoothing Window: {smooth_window} points\n" | |
f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n" | |
"Statistics:\n" | |
f"Average SHAP difference: {avg_diff:.4f}\n" | |
f"Standard deviation: {std_diff:.4f}\n" | |
f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n" | |
f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n" | |
f"Fraction with substantial differences: {frac_different:.2%}\n\n" | |
"Note: All parameters automatically adjusted based on sequence properties\n\n" | |
"Interpretation:\n" | |
"- Red regions: Sequence 2 more human-like\n" | |
"- Blue regions: Sequence 1 more human-like\n" | |
"- White regions: Similar between sequences" | |
) | |
# Generate visualizations | |
heatmap_fig = plot_comparative_heatmap( | |
shap_diff, | |
title=f"SHAP Difference Heatmap (window: {smooth_window})" | |
) | |
heatmap_img = fig_to_image(heatmap_fig) | |
# Create histogram with adaptive bins | |
num_bins = max(20, min(50, int(np.sqrt(len(shap_diff))))) | |
hist_fig = plot_shap_histogram( | |
shap_diff, | |
title="Distribution of SHAP Differences", | |
num_bins=num_bins | |
) | |
hist_img = fig_to_image(hist_fig) | |
return comparison_text, heatmap_img, hist_img | |
except Exception as e: | |
error_msg = f"Error during sequence comparison: {str(e)}" | |
return error_msg, None, None | |
############################################################################### | |
# 11. GENE FEATURE ANALYSIS | |
############################################################################### | |
def parse_gene_features(text): | |
"""Parse gene features from text file in FASTA-like format""" | |
genes = [] | |
current_header = None | |
current_sequence = [] | |
for line in text.strip().split('\n'): | |
line = line.strip() | |
if not line: | |
continue | |
if line.startswith('>'): | |
if current_header: | |
genes.append({ | |
'header': current_header, | |
'sequence': ''.join(current_sequence), | |
'metadata': parse_gene_metadata(current_header) | |
}) | |
current_header = line[1:] | |
current_sequence = [] | |
else: | |
current_sequence.append(line.upper()) | |
if current_header: | |
genes.append({ | |
'header': current_header, | |
'sequence': ''.join(current_sequence), | |
'metadata': parse_gene_metadata(current_header) | |
}) | |
return genes | |
def parse_gene_metadata(header): | |
"""Extract metadata from gene header""" | |
metadata = {} | |
parts = header.split() | |
for part in parts: | |
if '[' in part and ']' in part: | |
key_value = part[1:-1].split('=', 1) | |
if len(key_value) == 2: | |
metadata[key_value[0]] = key_value[1] | |
return metadata | |
def analyze_gene_features(sequence_file, features_file, fasta_text="", features_text=""): | |
"""Analyze SHAP values for each gene feature""" | |
# First analyze whole sequence | |
sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text) | |
if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]: | |
return f"Error in sequence analysis: {sequence_results[0]}", None, None | |
# Get SHAP values | |
shap_means = sequence_results[3]["shap_means"] | |
# Parse gene features | |
if features_text.strip(): | |
genes = parse_gene_features(features_text) | |
else: | |
try: | |
with open(features_file, 'r') as f: | |
genes = parse_gene_features(f.read()) | |
except Exception as e: | |
return f"Error reading features file: {str(e)}", None, None | |
# Analyze each gene | |
gene_results = [] | |
for gene in genes: | |
try: | |
location = gene['metadata'].get('location', '') | |
if not location: | |
continue | |
# Parse location (assuming format like "21729..22861") | |
start, end = map(int, location.split('..')) | |
# Get SHAP values for this region | |
gene_shap = shap_means[start:end] | |
avg_shap = float(np.mean(gene_shap)) | |
gene_results.append({ | |
'gene_name': gene['metadata'].get('gene', 'Unknown'), | |
'location': location, | |
'avg_shap': avg_shap, | |
'start': start, | |
'end': end, | |
'locus_tag': gene['metadata'].get('locus_tag', ''), | |
'classification': 'Human' if avg_shap > 0 else 'Non-human', | |
'confidence': abs(avg_shap) | |
}) | |
except Exception as e: | |
print(f"Error processing gene {gene['metadata'].get('gene', 'Unknown')}: {str(e)}") | |
continue | |
# Create CSV output | |
csv_output = "gene_name,location,avg_shap,classification,confidence,locus_tag\n" | |
for result in gene_results: | |
csv_output += f"{result['gene_name']},{result['location']},{result['avg_shap']:.4f}," | |
csv_output += f"{result['classification']},{result['confidence']:.4f},{result['locus_tag']}\n" | |
# Create genome diagram | |
diagram_img = create_genome_diagram(gene_results, len(shap_means)) | |
return gene_results, csv_output, diagram_img | |
def create_genome_diagram(gene_results, genome_length): | |
"""Create genome diagram using BioPython""" | |
from Bio.Graphics import GenomeDiagram | |
from Bio.SeqFeature import SeqFeature, FeatureLocation | |
from reportlab.lib import colors | |
from io import BytesIO | |
from PIL import Image | |
# Create diagram | |
gd_diagram = GenomeDiagram.Diagram("Genome SHAP Analysis") | |
gd_track = gd_diagram.new_track(1, name="Genes") | |
gd_feature_set = gd_track.new_set() | |
# Add features | |
for gene in gene_results: | |
# Create feature | |
feature = SeqFeature( | |
FeatureLocation(gene['start'], gene['end']), | |
type="gene" | |
) | |
# Calculate color based on SHAP value | |
if gene['avg_shap'] > 0: | |
intensity = min(1.0, abs(gene['avg_shap']) * 2) | |
color = colors.Color(1-intensity, 1-intensity, 1) # Red | |
else: | |
intensity = min(1.0, abs(gene['avg_shap']) * 2) | |
color = colors.Color(1-intensity, 1-intensity, 1) # Blue | |
# Add to diagram | |
gd_feature_set.add_feature( | |
feature, | |
color=color, | |
label=True, | |
name=f"{gene['gene_name']}\n(SHAP: {gene['avg_shap']:.3f})" | |
) | |
# Draw diagram | |
gd_diagram.draw( | |
format="linear", | |
orientation="landscape", | |
pagesize=(15, 5), | |
start=0, | |
end=genome_length, | |
fragments=1 | |
) | |
# Save to BytesIO and convert to PIL Image | |
buffer = BytesIO() | |
gd_diagram.write(buffer, "PNG") | |
buffer.seek(0) | |
return Image.open(buffer) | |
############################################################################### | |
# 12. DOWNLOAD FUNCTIONS | |
############################################################################### | |
def prepare_csv_download(data, filename="analysis_results.csv"): | |
"""Prepare CSV data for download""" | |
if isinstance(data, str): | |
return data.encode(), filename | |
elif isinstance(data, (list, dict)): | |
import csv | |
from io import StringIO | |
output = StringIO() | |
writer = csv.DictWriter(output, fieldnames=data[0].keys()) | |
writer.writeheader() | |
writer.writerows(data) | |
return output.getvalue().encode(), filename | |
else: | |
raise ValueError("Unsupported data type for CSV download") | |
############################################################################### | |
# 13. BUILD GRADIO INTERFACE | |
############################################################################### | |
css = """ | |
.gradio-container { | |
font-family: 'IBM Plex Sans', sans-serif; | |
} | |
.download-button { | |
margin-top: 10px; | |
} | |
""" | |
with gr.Blocks(css=css) as iface: | |
gr.Markdown(""" | |
# Virus Host Classifier | |
**Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions. | |
**Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc. | |
**Step 3**: Analyze gene features and their contributions. | |
**Step 4**: Compare sequences and analyze differences. | |
**Color Scale**: Negative SHAP = Blue, Zero = White, Positive = Red. | |
""") | |
with gr.Tab("1) Full-Sequence Analysis"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
text_input = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5) | |
top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display") | |
win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Window size for 'most pushing' subregions") | |
analyze_btn = gr.Button("Analyze Sequence", variant="primary") | |
with gr.Column(scale=2): | |
results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False) | |
kmer_img = gr.Image(label="Top k-mer SHAP") | |
genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)") | |
download_results = gr.File(label="Download Results", visible=False, elem_classes="download-button") | |
seq_state = gr.State() | |
header_state = gr.State() | |
analyze_btn.click( | |
analyze_sequence, | |
inputs=[file_input, top_k, text_input, win_size], | |
outputs=[results_box, kmer_img, genome_img, seq_state, header_state, download_results] | |
) | |
with gr.Tab("2) Subregion Exploration"): | |
gr.Markdown(""" | |
**Subregion Analysis** | |
Select start/end positions to view local SHAP signals, distribution, GC content, etc. | |
The heatmap uses the same Blue-White-Red scale. | |
""") | |
with gr.Row(): | |
region_start = gr.Number(label="Region Start", value=0) | |
region_end = gr.Number(label="Region End", value=500) | |
region_btn = gr.Button("Analyze Subregion") | |
subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False) | |
with gr.Row(): | |
subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)") | |
subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)") | |
download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button") | |
region_btn.click( | |
analyze_subregion, | |
inputs=[seq_state, header_state, region_start, region_end], | |
outputs=[subregion_info, subregion_img, subregion_hist_img, download_subregion] | |
) | |
with gr.Tab("3) Gene Features Analysis"): | |
gr.Markdown(""" | |
**Analyze Gene Features** | |
Upload a FASTA file and corresponding gene features file to analyze SHAP values per gene. | |
Gene features should be in the format: | |
``` | |
>gene_name [gene=X] [locus_tag=Y] [location=start..end] | |
SEQUENCE | |
``` | |
The genome viewer will show genes color-coded by their contribution: | |
- Red: Genes pushing toward human origin | |
- Blue: Genes pushing toward non-human origin | |
- Color intensity indicates strength of signal | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gene_fasta_file = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5) | |
with gr.Column(scale=1): | |
features_file = gr.File(label="Upload gene features file", file_types=[".txt"], type="filepath") | |
features_text = gr.Textbox(label="Or paste gene features", placeholder=">gene_1 [gene=U12]...\nACGT...", lines=5) | |
analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary") | |
gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False) | |
gene_diagram = gr.Image(label="Genome Diagram with Gene Features") | |
download_gene_results = gr.File(label="Download Gene Analysis", visible=False, elem_classes="download-button") | |
analyze_genes_btn.click( | |
analyze_gene_features, | |
inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text], | |
outputs=[gene_results, download_gene_results, gene_diagram] | |
) | |
with gr.Tab("4) Comparative Analysis"): | |
gr.Markdown(""" | |
**Compare Two Sequences** | |
Upload or paste two FASTA sequences to compare their SHAP patterns. | |
The sequences will be normalized to the same length for comparison. | |
**Color Scale**: | |
- Red: Sequence 2 is more human-like in this region | |
- Blue: Sequence 1 is more human-like in this region | |
- White: No substantial difference | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_input1 = gr.File(label="Upload first FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
text_input1 = gr.Textbox(label="Or paste first FASTA sequence", placeholder=">sequence1\nACGTACGT...", lines=5) | |
with gr.Column(scale=1): | |
file_input2 = gr.File(label="Upload second FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
text_input2 = gr.Textbox(label="Or paste second FASTA sequence", placeholder=">sequence2\nACGTACGT...", lines=5) | |
compare_btn = gr.Button("Compare Sequences", variant="primary") | |
comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False) | |
with gr.Row(): | |
diff_heatmap = gr.Image(label="SHAP Difference Heatmap") | |
diff_hist = gr.Image(label="Distribution of SHAP Differences") | |
download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button") | |
compare_btn.click( | |
analyze_sequence_comparison, | |
inputs=[file_input1, file_input2, text_input1, text_input2], | |
outputs=[comparison_text, diff_heatmap, diff_hist, download_comparison] | |
) | |
gr.Markdown(""" | |
### Interface Features | |
- **Overall Classification** (human vs non-human) using k-mer frequencies | |
- **SHAP Analysis** shows which k-mers push classification toward or away from human | |
- **White-Centered SHAP Gradient**: | |
- Negative (blue), 0 (white), Positive (red) | |
- Symmetrical color range around 0 | |
- **Identify Subregions** with strongest push for human or non-human | |
- **Gene Feature Analysis**: | |
- Analyze individual genes' contributions | |
- Interactive genome viewer | |
- Gene-level statistics and classification | |
- **Sequence Comparison**: | |
- Compare two sequences to identify regions of difference | |
- Normalized comparison to handle different lengths | |
- Statistical summary of differences | |
- **Data Export**: | |
- Download results as CSV files | |
- Save analysis outputs for further processing | |
""") | |
if __name__ == "__main__": | |
iface.launch() |