Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -67,6 +67,9 @@ def parse_fasta(text):
|
|
67 |
return sequences
|
68 |
|
69 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
|
|
|
|
|
70 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
71 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
72 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
@@ -84,11 +87,15 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
84 |
###############################################################################
|
85 |
|
86 |
def calculate_shap_values(model, x_tensor):
|
|
|
|
|
|
|
|
|
87 |
model.eval()
|
88 |
with torch.no_grad():
|
89 |
baseline_output = model(x_tensor)
|
90 |
baseline_probs = torch.softmax(baseline_output, dim=1)
|
91 |
-
baseline_prob = baseline_probs[0, 1].item() #
|
92 |
shap_values = []
|
93 |
x_zeroed = x_tensor.clone()
|
94 |
for i in range(x_tensor.shape[1]):
|
@@ -106,6 +113,9 @@ def calculate_shap_values(model, x_tensor):
|
|
106 |
###############################################################################
|
107 |
|
108 |
def compute_positionwise_scores(sequence, shap_values, k=4):
|
|
|
|
|
|
|
109 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
110 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
111 |
seq_len = len(sequence)
|
@@ -126,6 +136,9 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
|
|
126 |
###############################################################################
|
127 |
|
128 |
def find_extreme_subregion(shap_means, window_size=500, mode="max"):
|
|
|
|
|
|
|
129 |
n = len(shap_means)
|
130 |
if n == 0:
|
131 |
return (0, 0, 0.0)
|
@@ -152,6 +165,9 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
|
|
152 |
###############################################################################
|
153 |
|
154 |
def fig_to_image(fig):
|
|
|
|
|
|
|
155 |
buf = io.BytesIO()
|
156 |
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
|
157 |
buf.seek(0)
|
@@ -160,10 +176,16 @@ def fig_to_image(fig):
|
|
160 |
return img
|
161 |
|
162 |
def get_zero_centered_cmap():
|
|
|
|
|
|
|
163 |
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
|
164 |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
165 |
|
166 |
def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
|
|
|
|
|
|
|
167 |
if start is not None and end is not None:
|
168 |
local_shap = shap_means[start:end]
|
169 |
subtitle = f" (positions {start}-{end})"
|
@@ -189,6 +211,9 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
|
|
189 |
return fig
|
190 |
|
191 |
def create_importance_bar_plot(shap_values, kmers, top_k=10):
|
|
|
|
|
|
|
192 |
plt.rcParams.update({'font.size': 10})
|
193 |
fig = plt.figure(figsize=(10, 5))
|
194 |
indices = np.argsort(np.abs(shap_values))[-top_k:]
|
@@ -204,6 +229,9 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
|
|
204 |
return fig
|
205 |
|
206 |
def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bins=30):
|
|
|
|
|
|
|
207 |
fig, ax = plt.subplots(figsize=(6, 4))
|
208 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
|
209 |
ax.axvline(0, color='red', linestyle='--', label='0.0')
|
@@ -215,8 +243,11 @@ def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bin
|
|
215 |
return fig
|
216 |
|
217 |
def compute_gc_content(sequence):
|
|
|
|
|
|
|
218 |
if not sequence:
|
219 |
-
return 0
|
220 |
gc_count = sequence.count('G') + sequence.count('C')
|
221 |
return (gc_count / len(sequence)) * 100.0
|
222 |
|
@@ -225,6 +256,11 @@ def compute_gc_content(sequence):
|
|
225 |
###############################################################################
|
226 |
|
227 |
def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
|
|
|
|
|
|
|
|
|
228 |
if fasta_text.strip():
|
229 |
text = fasta_text.strip()
|
230 |
elif file_obj is not None:
|
@@ -236,14 +272,15 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
236 |
else:
|
237 |
return ("Please provide a FASTA sequence.", None, None, None, None, None)
|
238 |
|
|
|
239 |
sequences = parse_fasta(text)
|
240 |
if not sequences:
|
241 |
return ("No valid FASTA sequences found.", None, None, None, None, None)
|
242 |
header, seq = sequences[0]
|
243 |
|
|
|
244 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
245 |
try:
|
246 |
-
# IMPORTANT: adjust how you load your model as needed
|
247 |
state_dict = torch.load('model.pt', map_location=device)
|
248 |
model = VirusClassifier(256).to(device)
|
249 |
model.load_state_dict(state_dict)
|
@@ -260,10 +297,12 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
260 |
classification = "Human" if prob_human > 0.5 else "Non-human"
|
261 |
confidence = max(prob_human, prob_nonhuman)
|
262 |
|
|
|
263 |
shap_means = compute_positionwise_scores(seq, shap_values, k=4)
|
264 |
max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max")
|
265 |
min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min")
|
266 |
|
|
|
267 |
results_text = (
|
268 |
f"Sequence: {header}\n"
|
269 |
f"Length: {len(seq):,} bases\n"
|
@@ -277,6 +316,7 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
277 |
f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
|
278 |
)
|
279 |
|
|
|
280 |
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
|
281 |
bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
|
282 |
bar_img = fig_to_image(bar_fig)
|
@@ -284,10 +324,10 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
284 |
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
|
285 |
heatmap_img = fig_to_image(heatmap_fig)
|
286 |
|
287 |
-
#
|
288 |
-
# Here, we'll simply return None for the file download:
|
289 |
state_dict_out = {"seq": seq, "shap_means": shap_means}
|
290 |
|
|
|
291 |
return (results_text, bar_img, heatmap_img, state_dict_out, header, None)
|
292 |
|
293 |
###############################################################################
|
@@ -295,6 +335,9 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
295 |
###############################################################################
|
296 |
|
297 |
def analyze_subregion(state, header, region_start, region_end):
|
|
|
|
|
|
|
298 |
if not state or "seq" not in state or "shap_means" not in state:
|
299 |
return ("No sequence data found. Please run Step 1 first.", None, None, None)
|
300 |
seq = state["seq"]
|
@@ -305,18 +348,22 @@ def analyze_subregion(state, header, region_start, region_end):
|
|
305 |
region_end = max(0, min(region_end, len(seq)))
|
306 |
if region_end <= region_start:
|
307 |
return ("Invalid region range. End must be > Start.", None, None, None)
|
|
|
308 |
region_seq = seq[region_start:region_end]
|
309 |
region_shap = shap_means[region_start:region_end]
|
|
|
310 |
gc_percent = compute_gc_content(region_seq)
|
311 |
avg_shap = float(np.mean(region_shap))
|
312 |
positive_fraction = np.mean(region_shap > 0)
|
313 |
negative_fraction = np.mean(region_shap < 0)
|
|
|
314 |
if avg_shap > 0.05:
|
315 |
region_classification = "Likely pushing toward human"
|
316 |
elif avg_shap < -0.05:
|
317 |
region_classification = "Likely pushing toward non-human"
|
318 |
else:
|
319 |
region_classification = "Near neutral (no strong push)"
|
|
|
320 |
region_info = (
|
321 |
f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
|
322 |
f"Region length: {len(region_seq)} bases\n"
|
@@ -326,30 +373,29 @@ def analyze_subregion(state, header, region_start, region_end):
|
|
326 |
f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
|
327 |
f"Subregion interpretation: {region_classification}\n"
|
328 |
)
|
|
|
329 |
heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end)
|
330 |
heatmap_img = fig_to_image(heatmap_fig)
|
|
|
331 |
hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
|
332 |
hist_img = fig_to_image(hist_fig)
|
333 |
-
|
334 |
-
#
|
335 |
return (region_info, heatmap_img, hist_img, None)
|
336 |
|
337 |
###############################################################################
|
338 |
-
# 9. COMPARISON ANALYSIS FUNCTIONS
|
339 |
###############################################################################
|
340 |
|
341 |
-
def get_zero_centered_cmap():
|
342 |
-
"""Create a zero-centered blue-white-red colormap"""
|
343 |
-
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
|
344 |
-
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
345 |
-
|
346 |
def compute_shap_difference(shap1_norm, shap2_norm):
|
347 |
-
"""
|
|
|
|
|
348 |
return shap2_norm - shap1_norm
|
349 |
|
350 |
def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
|
351 |
"""
|
352 |
-
Plot heatmap using relative positions
|
353 |
"""
|
354 |
heatmap_data = shap_diff.reshape(1, -1)
|
355 |
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
|
@@ -378,7 +424,7 @@ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
|
|
378 |
|
379 |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
|
380 |
"""
|
381 |
-
Plot histogram of SHAP values with
|
382 |
"""
|
383 |
fig, ax = plt.subplots(figsize=(6, 4))
|
384 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
|
@@ -392,18 +438,16 @@ def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
|
|
392 |
|
393 |
def calculate_adaptive_parameters(len1, len2):
|
394 |
"""
|
395 |
-
|
396 |
-
Returns: (num_points, smooth_window, resolution_factor)
|
397 |
"""
|
398 |
length_diff = abs(len1 - len2)
|
399 |
max_length = max(len1, len2)
|
400 |
min_length = min(len1, len2)
|
401 |
length_ratio = min_length / max_length
|
402 |
|
403 |
-
# Base number of points
|
404 |
base_points = min(2000, max(500, max_length // 100))
|
405 |
|
406 |
-
# Adjust parameters based on sequence properties
|
407 |
if length_diff < 500:
|
408 |
resolution_factor = 2.0
|
409 |
num_points = min(3000, base_points * 2)
|
@@ -421,29 +465,22 @@ def calculate_adaptive_parameters(len1, len2):
|
|
421 |
num_points = max(500, base_points // 2)
|
422 |
smooth_window = max(100, length_diff // 500)
|
423 |
|
424 |
-
# Adjust window size based on length ratio
|
425 |
smooth_window = int(smooth_window * (1 + (1 - length_ratio)))
|
426 |
-
|
427 |
return int(num_points), int(smooth_window), resolution_factor
|
428 |
|
429 |
def sliding_window_smooth(values, window_size=50):
|
430 |
"""
|
431 |
-
|
432 |
"""
|
433 |
if window_size < 3:
|
434 |
return values
|
435 |
-
|
436 |
-
# Create window with exponential decay at edges
|
437 |
window = np.ones(window_size)
|
438 |
decay = np.exp(-np.linspace(0, 3, window_size // 2))
|
439 |
window[:window_size // 2] = decay
|
440 |
window[-(window_size // 2):] = decay[::-1]
|
441 |
window = window / window.sum()
|
442 |
|
443 |
-
# Apply convolution
|
444 |
smoothed = np.convolve(values, window, mode='valid')
|
445 |
-
|
446 |
-
# Handle edges
|
447 |
pad_size = len(values) - len(smoothed)
|
448 |
pad_left = pad_size // 2
|
449 |
pad_right = pad_size - pad_left
|
@@ -457,16 +494,13 @@ def sliding_window_smooth(values, window_size=50):
|
|
457 |
|
458 |
def normalize_shap_lengths(shap1, shap2):
|
459 |
"""
|
460 |
-
|
461 |
"""
|
462 |
-
# Calculate adaptive parameters
|
463 |
num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
|
464 |
|
465 |
-
# Apply initial smoothing
|
466 |
shap1_smooth = sliding_window_smooth(shap1, smooth_window)
|
467 |
shap2_smooth = sliding_window_smooth(shap2, smooth_window)
|
468 |
|
469 |
-
# Create relative positions and interpolate
|
470 |
x1 = np.linspace(0, 1, len(shap1_smooth))
|
471 |
x2 = np.linspace(0, 1, len(shap2_smooth))
|
472 |
x_norm = np.linspace(0, 1, num_points)
|
@@ -478,7 +512,8 @@ def normalize_shap_lengths(shap1, shap2):
|
|
478 |
|
479 |
def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
480 |
"""
|
481 |
-
Compare two sequences
|
|
|
482 |
"""
|
483 |
try:
|
484 |
# Analyze first sequence
|
@@ -491,26 +526,23 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
491 |
if isinstance(res2[0], str) and "Error" in res2[0]:
|
492 |
return (f"Error in sequence 2: {res2[0]}", None, None, None)
|
493 |
|
494 |
-
# Extract SHAP values and sequence info
|
495 |
shap1 = res1[3]["shap_means"]
|
496 |
shap2 = res2[3]["shap_means"]
|
497 |
|
498 |
-
# Calculate sequence properties
|
499 |
len1, len2 = len(shap1), len(shap2)
|
500 |
length_diff = abs(len1 - len2)
|
501 |
length_ratio = min(len1, len2) / max(len1, len2)
|
502 |
-
|
503 |
-
# Normalize
|
504 |
shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2)
|
505 |
shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
|
506 |
|
507 |
-
#
|
508 |
base_threshold = 0.05
|
509 |
adaptive_threshold = base_threshold * (1 + (1 - length_ratio))
|
510 |
if length_diff > 50000:
|
511 |
adaptive_threshold *= 1.5
|
512 |
|
513 |
-
# Calculate comparison statistics
|
514 |
avg_diff = np.mean(shap_diff)
|
515 |
std_diff = np.std(shap_diff)
|
516 |
max_diff = np.max(shap_diff)
|
@@ -518,7 +550,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
518 |
substantial_diffs = np.abs(shap_diff) > adaptive_threshold
|
519 |
frac_different = np.mean(substantial_diffs)
|
520 |
|
521 |
-
# Extract
|
522 |
try:
|
523 |
classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip()
|
524 |
classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip()
|
@@ -526,7 +558,6 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
526 |
classification1 = "Unknown"
|
527 |
classification2 = "Unknown"
|
528 |
|
529 |
-
# Format output text
|
530 |
comparison_text = (
|
531 |
"Sequence Comparison Results:\n"
|
532 |
f"Sequence 1: {res1[4]}\n"
|
@@ -553,14 +584,12 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
553 |
"- White regions: Similar between sequences"
|
554 |
)
|
555 |
|
556 |
-
# Generate visualizations
|
557 |
heatmap_fig = plot_comparative_heatmap(
|
558 |
shap_diff,
|
559 |
title=f"SHAP Difference Heatmap (window: {smooth_window})"
|
560 |
)
|
561 |
heatmap_img = fig_to_image(heatmap_fig)
|
562 |
|
563 |
-
# Create histogram with adaptive bins
|
564 |
num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
|
565 |
hist_fig = plot_shap_histogram(
|
566 |
shap_diff,
|
@@ -569,7 +598,6 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
569 |
)
|
570 |
hist_img = fig_to_image(hist_fig)
|
571 |
|
572 |
-
# Return 4 outputs (text, image, image, and a file or None for the last)
|
573 |
return (comparison_text, heatmap_img, hist_img, None)
|
574 |
|
575 |
except Exception as e:
|
@@ -577,23 +605,55 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
577 |
return (error_msg, None, None, None)
|
578 |
|
579 |
###############################################################################
|
580 |
-
#
|
581 |
###############################################################################
|
582 |
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
594 |
|
595 |
def parse_gene_features(text: str) -> List[Dict[str, Any]]:
|
596 |
-
"""Parse gene features from text file in FASTA-like format"""
|
597 |
genes = []
|
598 |
current_header = None
|
599 |
current_sequence = []
|
@@ -602,7 +662,6 @@ def parse_gene_features(text: str) -> List[Dict[str, Any]]:
|
|
602 |
line = line.strip()
|
603 |
if not line:
|
604 |
continue
|
605 |
-
|
606 |
if line.startswith('>'):
|
607 |
if current_header:
|
608 |
genes.append({
|
@@ -614,36 +673,29 @@ def parse_gene_features(text: str) -> List[Dict[str, Any]]:
|
|
614 |
current_sequence = []
|
615 |
else:
|
616 |
current_sequence.append(line.upper())
|
617 |
-
|
618 |
if current_header:
|
619 |
genes.append({
|
620 |
'header': current_header,
|
621 |
'sequence': ''.join(current_sequence),
|
622 |
'metadata': parse_gene_metadata(current_header)
|
623 |
})
|
624 |
-
|
625 |
return genes
|
626 |
|
627 |
def parse_gene_metadata(header: str) -> Dict[str, str]:
|
628 |
-
"""Extract metadata from gene header"""
|
629 |
metadata = {}
|
630 |
parts = header.split()
|
631 |
-
|
632 |
for part in parts:
|
633 |
if '[' in part and ']' in part:
|
634 |
key_value = part[1:-1].split('=', 1)
|
635 |
if len(key_value) == 2:
|
636 |
metadata[key_value[0]] = key_value[1]
|
637 |
-
|
638 |
return metadata
|
639 |
|
640 |
def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
|
641 |
-
"""Parse gene location string, handling
|
642 |
try:
|
643 |
-
# Remove 'complement(' and ')' if present
|
644 |
clean_loc = location_str.replace('complement(', '').replace(')', '')
|
645 |
-
|
646 |
-
# Split on '..' and convert to integers
|
647 |
if '..' in clean_loc:
|
648 |
start, end = map(int, clean_loc.split('..'))
|
649 |
return start, end
|
@@ -654,48 +706,41 @@ def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
|
|
654 |
return None, None
|
655 |
|
656 |
def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
|
657 |
-
"""
|
658 |
return {
|
659 |
-
'avg_shap': float(np.mean(gene_shap)),
|
660 |
-
'median_shap': float(np.median(gene_shap)),
|
661 |
-
'std_shap': float(np.std(gene_shap)),
|
662 |
-
'max_shap': float(np.max(gene_shap)),
|
663 |
-
'min_shap': float(np.min(gene_shap)),
|
664 |
-
'pos_fraction': float(np.mean(gene_shap > 0))
|
665 |
}
|
666 |
|
667 |
def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
|
668 |
"""
|
669 |
-
|
670 |
-
|
671 |
"""
|
672 |
-
from PIL import Image, ImageDraw, ImageFont
|
673 |
-
|
674 |
-
# Validate inputs
|
675 |
if not gene_results or genome_length <= 0:
|
676 |
img = Image.new('RGB', (800, 100), color='white')
|
677 |
draw = ImageDraw.Draw(img)
|
678 |
draw.text((10, 40), "Error: Invalid input data", fill='black')
|
679 |
return img
|
680 |
-
|
681 |
-
# Ensure all gene coordinates are valid integers
|
682 |
for gene in gene_results:
|
683 |
gene['start'] = max(0, int(gene['start']))
|
684 |
gene['end'] = min(genome_length, int(gene['end']))
|
685 |
if gene['start'] >= gene['end']:
|
686 |
-
print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}
|
687 |
|
688 |
-
# Image dimensions
|
689 |
width = 1500
|
690 |
height = 600
|
691 |
margin = 50
|
692 |
track_height = 40
|
693 |
|
694 |
-
# Create image with white background
|
695 |
img = Image.new('RGB', (width, height), 'white')
|
696 |
draw = ImageDraw.Draw(img)
|
697 |
|
698 |
-
# Try to load font, fall back to default if unavailable
|
699 |
try:
|
700 |
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
|
701 |
title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
|
@@ -703,24 +748,16 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
|
|
703 |
font = ImageFont.load_default()
|
704 |
title_font = ImageFont.load_default()
|
705 |
|
706 |
-
|
707 |
-
draw.text((margin, margin // 2), "Genome SHAP Analysis", fill='black', font=title_font or font)
|
708 |
|
709 |
-
# Draw genome line
|
710 |
line_y = height // 2
|
711 |
draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2)
|
712 |
|
713 |
-
# Calculate scale factor
|
714 |
scale = float(width - 2 * margin) / float(genome_length)
|
715 |
|
716 |
-
#
|
717 |
num_ticks = 10
|
718 |
-
|
719 |
-
step = 1
|
720 |
-
else:
|
721 |
-
step = genome_length // num_ticks
|
722 |
-
|
723 |
-
# Draw scale markers
|
724 |
for i in range(0, genome_length + 1, step):
|
725 |
x_coord = margin + i * scale
|
726 |
draw.line([
|
@@ -729,50 +766,33 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
|
|
729 |
], fill='black', width=1)
|
730 |
draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
|
731 |
|
732 |
-
# Sort genes by absolute SHAP value for drawing
|
733 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
|
734 |
-
|
735 |
-
# Draw genes
|
736 |
for idx, gene in enumerate(sorted_genes):
|
737 |
-
# Calculate position and ensure integers
|
738 |
start_x = margin + int(gene['start'] * scale)
|
739 |
end_x = margin + int(gene['end'] * scale)
|
740 |
-
|
741 |
-
# Calculate color based on SHAP value
|
742 |
avg_shap = gene['avg_shap']
|
743 |
-
|
744 |
-
# Convert shap -> color intensity (0 to 255)
|
745 |
-
# Then clamp to a minimum intensity so it never ends up plain white
|
746 |
intensity = int(abs(avg_shap) * 500)
|
747 |
-
intensity = max(50, min(255, intensity))
|
748 |
|
749 |
if avg_shap > 0:
|
750 |
-
|
751 |
-
color = (255, 255 - intensity, 255 - intensity)
|
752 |
else:
|
753 |
-
|
754 |
-
color = (255 - intensity, 255 - intensity, 255)
|
755 |
|
756 |
-
# Draw gene rectangle
|
757 |
draw.rectangle([
|
758 |
(int(start_x), int(line_y - track_height // 2)),
|
759 |
(int(end_x), int(line_y + track_height // 2))
|
760 |
], fill=color, outline='black')
|
761 |
|
762 |
-
# Prepare gene name label
|
763 |
label = str(gene.get('gene_name','?'))
|
764 |
-
|
765 |
-
# Fallback for label size
|
766 |
label_mask = font.getmask(label)
|
767 |
label_width, label_height = label_mask.size
|
768 |
|
769 |
-
# Alternate label positions
|
770 |
if idx % 2 == 0:
|
771 |
text_y = line_y - track_height - 15
|
772 |
else:
|
773 |
text_y = line_y + track_height + 5
|
774 |
|
775 |
-
# Decide whether to rotate text based on space
|
776 |
gene_width = end_x - start_x
|
777 |
if gene_width > label_width:
|
778 |
text_x = start_x + (gene_width - label_width) // 2
|
@@ -784,64 +804,113 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
|
|
784 |
rotated_img = txt_img.rotate(90, expand=True)
|
785 |
img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img)
|
786 |
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
(
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
828 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
return img
|
830 |
|
831 |
def analyze_gene_features(sequence_file: str,
|
832 |
features_file: str,
|
833 |
fasta_text: str = "",
|
834 |
-
features_text: str = ""
|
835 |
-
|
836 |
-
|
|
|
|
|
|
|
|
|
|
|
837 |
sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
|
838 |
if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
|
839 |
return f"Error in sequence analysis: {sequence_results[0]}", None, None
|
840 |
|
841 |
-
|
842 |
shap_means = sequence_results[3]["shap_means"]
|
843 |
-
|
844 |
-
|
|
|
845 |
try:
|
846 |
if features_text.strip():
|
847 |
genes = parse_gene_features(features_text)
|
@@ -850,98 +919,100 @@ def analyze_gene_features(sequence_file: str,
|
|
850 |
genes = parse_gene_features(f.read())
|
851 |
except Exception as e:
|
852 |
return f"Error reading features file: {str(e)}", None, None
|
853 |
-
|
854 |
-
# Analyze each gene
|
855 |
gene_results = []
|
856 |
for gene in genes:
|
857 |
-
|
858 |
-
|
859 |
-
if not location:
|
860 |
-
continue
|
861 |
-
|
862 |
-
start, end = parse_location(location)
|
863 |
-
if start is None or end is None:
|
864 |
-
continue
|
865 |
-
|
866 |
-
# Get SHAP values for this region
|
867 |
-
gene_shap = shap_means[start:end]
|
868 |
-
stats = compute_gene_statistics(gene_shap)
|
869 |
-
|
870 |
-
gene_results.append({
|
871 |
-
'gene_name': gene['metadata'].get('gene', 'Unknown'),
|
872 |
-
'location': location,
|
873 |
-
'start': start,
|
874 |
-
'end': end,
|
875 |
-
'locus_tag': gene['metadata'].get('locus_tag', ''),
|
876 |
-
'avg_shap': stats['avg_shap'],
|
877 |
-
'median_shap': stats['median_shap'],
|
878 |
-
'std_shap': stats['std_shap'],
|
879 |
-
'max_shap': stats['max_shap'],
|
880 |
-
'min_shap': stats['min_shap'],
|
881 |
-
'pos_fraction': stats['pos_fraction'],
|
882 |
-
'classification': 'Human' if stats['avg_shap'] > 0 else 'Non-human',
|
883 |
-
'confidence': abs(stats['avg_shap'])
|
884 |
-
})
|
885 |
-
|
886 |
-
except Exception as e:
|
887 |
-
print(f"Error processing gene {gene['metadata'].get('gene', 'Unknown')}: {str(e)}")
|
888 |
continue
|
889 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
890 |
if not gene_results:
|
891 |
return "No valid genes could be processed", None, None
|
892 |
-
|
893 |
-
#
|
894 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
|
895 |
-
|
896 |
-
# Create results text
|
897 |
results_text = "Gene Analysis Results:\n\n"
|
898 |
results_text += f"Total genes analyzed: {len(gene_results)}\n"
|
899 |
-
|
900 |
-
results_text += f"
|
|
|
901 |
|
902 |
-
results_text += "Top 10 most distinctive genes:\n"
|
903 |
for gene in sorted_genes[:10]:
|
904 |
results_text += (
|
905 |
f"Gene: {gene['gene_name']}\n"
|
906 |
f"Location: {gene['location']}\n"
|
907 |
f"Classification: {gene['classification']} "
|
908 |
f"(confidence: {gene['confidence']:.4f})\n"
|
909 |
-
f"Average SHAP: {gene['avg_shap']:.4f}\n
|
|
|
910 |
)
|
911 |
-
|
912 |
-
#
|
913 |
-
csv_content = "gene_name,location,avg_shap,median_shap,std_shap,
|
914 |
-
csv_content += "pos_fraction,classification,confidence
|
915 |
-
|
916 |
-
for gene in gene_results:
|
917 |
csv_content += (
|
918 |
-
f"{
|
919 |
-
f"{
|
920 |
-
f"{
|
921 |
-
f"{
|
922 |
)
|
923 |
-
|
924 |
-
# Save CSV to temp file
|
925 |
try:
|
926 |
temp_dir = tempfile.gettempdir()
|
927 |
temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv")
|
928 |
-
|
929 |
with open(temp_path, 'w') as f:
|
930 |
f.write(csv_content)
|
931 |
except Exception as e:
|
932 |
print(f"Error saving CSV: {str(e)}")
|
933 |
temp_path = None
|
934 |
-
|
935 |
-
# Create
|
936 |
try:
|
937 |
-
|
|
|
|
|
|
|
938 |
except Exception as e:
|
939 |
print(f"Error creating visualization: {str(e)}")
|
940 |
-
# Create error image
|
941 |
diagram_img = Image.new('RGB', (800, 100), color='white')
|
942 |
draw = ImageDraw.Draw(diagram_img)
|
943 |
draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black')
|
944 |
-
|
945 |
return results_text, temp_path, diagram_img
|
946 |
|
947 |
###############################################################################
|
@@ -949,13 +1020,14 @@ def analyze_gene_features(sequence_file: str,
|
|
949 |
###############################################################################
|
950 |
|
951 |
def prepare_csv_download(data, filename="analysis_results.csv"):
|
952 |
-
"""
|
|
|
|
|
953 |
if isinstance(data, str):
|
954 |
return data.encode(), filename
|
955 |
elif isinstance(data, (list, dict)):
|
956 |
import csv
|
957 |
from io import StringIO
|
958 |
-
|
959 |
output = StringIO()
|
960 |
writer = csv.DictWriter(output, fieldnames=data[0].keys())
|
961 |
writer.writeheader()
|
@@ -979,22 +1051,22 @@ css = """
|
|
979 |
|
980 |
with gr.Blocks(css=css) as iface:
|
981 |
gr.Markdown("""
|
982 |
-
# Virus Host Classifier
|
983 |
-
**Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme
|
984 |
-
**Step 2**: Explore subregions
|
985 |
-
**Step 3**: Analyze gene features
|
986 |
-
**Step 4**: Compare sequences
|
987 |
-
|
988 |
-
**Color Scale**: Negative SHAP = Blue,
|
989 |
""")
|
990 |
|
991 |
with gr.Tab("1) Full-Sequence Analysis"):
|
992 |
with gr.Row():
|
993 |
with gr.Column(scale=1):
|
994 |
file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
|
995 |
-
text_input = gr.Textbox(label="Or paste FASTA
|
996 |
top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
|
997 |
-
win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Window
|
998 |
analyze_btn = gr.Button("Analyze Sequence", variant="primary")
|
999 |
with gr.Column(scale=2):
|
1000 |
results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
|
@@ -1013,8 +1085,7 @@ with gr.Blocks(css=css) as iface:
|
|
1013 |
with gr.Tab("2) Subregion Exploration"):
|
1014 |
gr.Markdown("""
|
1015 |
**Subregion Analysis**
|
1016 |
-
|
1017 |
-
The heatmap uses the same Blue-White-Red scale.
|
1018 |
""")
|
1019 |
with gr.Row():
|
1020 |
region_start = gr.Number(label="Region Start", value=0)
|
@@ -1024,7 +1095,7 @@ with gr.Blocks(css=css) as iface:
|
|
1024 |
with gr.Row():
|
1025 |
subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
|
1026 |
subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
|
1027 |
-
download_subregion = gr.File(label="Download Subregion
|
1028 |
|
1029 |
region_btn.click(
|
1030 |
analyze_subregion,
|
@@ -1035,60 +1106,48 @@ with gr.Blocks(css=css) as iface:
|
|
1035 |
with gr.Tab("3) Gene Features Analysis"):
|
1036 |
gr.Markdown("""
|
1037 |
**Analyze Gene Features**
|
1038 |
-
Upload a FASTA file and
|
1039 |
-
|
1040 |
-
|
1041 |
-
>gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)]
|
1042 |
-
SEQUENCE
|
1043 |
-
```
|
1044 |
-
The genome viewer will show genes color-coded by their contribution:
|
1045 |
-
- Red: Genes pushing toward human origin
|
1046 |
-
- Blue: Genes pushing toward non-human origin
|
1047 |
-
- Color intensity indicates strength of signal
|
1048 |
""")
|
1049 |
with gr.Row():
|
1050 |
with gr.Column(scale=1):
|
1051 |
-
gene_fasta_file = gr.File(label="
|
1052 |
-
gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence",
|
1053 |
with gr.Column(scale=1):
|
1054 |
-
features_file = gr.File(label="
|
1055 |
-
features_text = gr.Textbox(label="Or paste gene features",
|
1056 |
-
|
1057 |
analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary")
|
1058 |
gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False)
|
1059 |
-
gene_diagram = gr.Image(label="Genome Diagram
|
1060 |
download_gene_results = gr.File(label="Download Gene Analysis (CSV)", visible=True)
|
1061 |
|
1062 |
analyze_genes_btn.click(
|
1063 |
analyze_gene_features,
|
1064 |
-
inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text],
|
1065 |
outputs=[gene_results, download_gene_results, gene_diagram]
|
1066 |
)
|
1067 |
|
1068 |
with gr.Tab("4) Comparative Analysis"):
|
1069 |
gr.Markdown("""
|
1070 |
**Compare Two Sequences**
|
1071 |
-
Upload or paste two FASTA sequences
|
1072 |
-
|
1073 |
-
|
1074 |
-
**Color Scale**:
|
1075 |
-
- Red: Sequence 2 more human-like
|
1076 |
-
- Blue: Sequence 1 more human-like
|
1077 |
-
- White: No substantial difference
|
1078 |
""")
|
1079 |
with gr.Row():
|
1080 |
with gr.Column(scale=1):
|
1081 |
-
file_input1 = gr.File(label="
|
1082 |
-
text_input1 = gr.Textbox(label="Or paste
|
1083 |
with gr.Column(scale=1):
|
1084 |
-
file_input2 = gr.File(label="
|
1085 |
-
text_input2 = gr.Textbox(label="Or paste
|
1086 |
compare_btn = gr.Button("Compare Sequences", variant="primary")
|
1087 |
comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
|
1088 |
with gr.Row():
|
1089 |
diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
|
1090 |
diff_hist = gr.Image(label="Distribution of SHAP Differences")
|
1091 |
-
download_comparison = gr.File(label="Download Comparison
|
1092 |
|
1093 |
compare_btn.click(
|
1094 |
analyze_sequence_comparison,
|
@@ -1097,25 +1156,12 @@ with gr.Blocks(css=css) as iface:
|
|
1097 |
)
|
1098 |
|
1099 |
gr.Markdown("""
|
1100 |
-
###
|
1101 |
-
- **
|
1102 |
-
- **
|
1103 |
-
- **
|
1104 |
-
|
1105 |
-
- Symmetrical color range around 0
|
1106 |
-
- **Identify Subregions** with strongest push for human or non-human
|
1107 |
-
- **Gene Feature Analysis**:
|
1108 |
-
- Analyze individual genes' contributions
|
1109 |
-
- Interactive genome viewer
|
1110 |
-
- Gene-level statistics and classification
|
1111 |
-
- **Sequence Comparison**:
|
1112 |
-
- Compare two sequences to identify regions of difference
|
1113 |
-
- Normalized comparison to handle different lengths
|
1114 |
-
- Statistical summary of differences
|
1115 |
-
- **Data Export**:
|
1116 |
-
- Download results as CSV files
|
1117 |
-
- Save analysis outputs for further processing
|
1118 |
""")
|
1119 |
-
|
1120 |
if __name__ == "__main__":
|
1121 |
iface.launch()
|
|
|
67 |
return sequences
|
68 |
|
69 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
70 |
+
"""
|
71 |
+
Convert a sequence into a frequency vector of all possible 4-mer combinations.
|
72 |
+
"""
|
73 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
74 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
75 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
|
|
87 |
###############################################################################
|
88 |
|
89 |
def calculate_shap_values(model, x_tensor):
|
90 |
+
"""
|
91 |
+
A simple ablation-based SHAP approximation. Zero out each position
|
92 |
+
and measure the impact on the 'human' probability.
|
93 |
+
"""
|
94 |
model.eval()
|
95 |
with torch.no_grad():
|
96 |
baseline_output = model(x_tensor)
|
97 |
baseline_probs = torch.softmax(baseline_output, dim=1)
|
98 |
+
baseline_prob = baseline_probs[0, 1].item() # Probability for 'human'
|
99 |
shap_values = []
|
100 |
x_zeroed = x_tensor.clone()
|
101 |
for i in range(x_tensor.shape[1]):
|
|
|
113 |
###############################################################################
|
114 |
|
115 |
def compute_positionwise_scores(sequence, shap_values, k=4):
|
116 |
+
"""
|
117 |
+
Distribute each k-mer's SHAP contribution across its k underlying positions.
|
118 |
+
"""
|
119 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
120 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
121 |
seq_len = len(sequence)
|
|
|
136 |
###############################################################################
|
137 |
|
138 |
def find_extreme_subregion(shap_means, window_size=500, mode="max"):
|
139 |
+
"""
|
140 |
+
Use a sliding window to find the subregion with the highest (or lowest) average SHAP.
|
141 |
+
"""
|
142 |
n = len(shap_means)
|
143 |
if n == 0:
|
144 |
return (0, 0, 0.0)
|
|
|
165 |
###############################################################################
|
166 |
|
167 |
def fig_to_image(fig):
|
168 |
+
"""
|
169 |
+
Render a Matplotlib figure to a PIL Image.
|
170 |
+
"""
|
171 |
buf = io.BytesIO()
|
172 |
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
|
173 |
buf.seek(0)
|
|
|
176 |
return img
|
177 |
|
178 |
def get_zero_centered_cmap():
|
179 |
+
"""
|
180 |
+
Create a symmetrical (blue-white-red) colormap around zero.
|
181 |
+
"""
|
182 |
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
|
183 |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
184 |
|
185 |
def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
|
186 |
+
"""
|
187 |
+
Plot an inline heatmap for the chosen region (or entire genome if start/end not provided).
|
188 |
+
"""
|
189 |
if start is not None and end is not None:
|
190 |
local_shap = shap_means[start:end]
|
191 |
subtitle = f" (positions {start}-{end})"
|
|
|
211 |
return fig
|
212 |
|
213 |
def create_importance_bar_plot(shap_values, kmers, top_k=10):
|
214 |
+
"""
|
215 |
+
Show bar chart of top k-mers by absolute SHAP value.
|
216 |
+
"""
|
217 |
plt.rcParams.update({'font.size': 10})
|
218 |
fig = plt.figure(figsize=(10, 5))
|
219 |
indices = np.argsort(np.abs(shap_values))[-top_k:]
|
|
|
229 |
return fig
|
230 |
|
231 |
def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bins=30):
|
232 |
+
"""
|
233 |
+
Plot a histogram of SHAP values in some region.
|
234 |
+
"""
|
235 |
fig, ax = plt.subplots(figsize=(6, 4))
|
236 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
|
237 |
ax.axvline(0, color='red', linestyle='--', label='0.0')
|
|
|
243 |
return fig
|
244 |
|
245 |
def compute_gc_content(sequence):
|
246 |
+
"""
|
247 |
+
Compute GC content (%) for a given sequence.
|
248 |
+
"""
|
249 |
if not sequence:
|
250 |
+
return 0.0
|
251 |
gc_count = sequence.count('G') + sequence.count('C')
|
252 |
return (gc_count / len(sequence)) * 100.0
|
253 |
|
|
|
256 |
###############################################################################
|
257 |
|
258 |
def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
259 |
+
"""
|
260 |
+
Perform the main classification, SHAP analysis, and extreme subregion detection
|
261 |
+
for a single sequence.
|
262 |
+
"""
|
263 |
+
# 1) Read input
|
264 |
if fasta_text.strip():
|
265 |
text = fasta_text.strip()
|
266 |
elif file_obj is not None:
|
|
|
272 |
else:
|
273 |
return ("Please provide a FASTA sequence.", None, None, None, None, None)
|
274 |
|
275 |
+
# 2) Parse FASTA
|
276 |
sequences = parse_fasta(text)
|
277 |
if not sequences:
|
278 |
return ("No valid FASTA sequences found.", None, None, None, None, None)
|
279 |
header, seq = sequences[0]
|
280 |
|
281 |
+
# 3) Load model, scaler, and run inference
|
282 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
283 |
try:
|
|
|
284 |
state_dict = torch.load('model.pt', map_location=device)
|
285 |
model = VirusClassifier(256).to(device)
|
286 |
model.load_state_dict(state_dict)
|
|
|
297 |
classification = "Human" if prob_human > 0.5 else "Non-human"
|
298 |
confidence = max(prob_human, prob_nonhuman)
|
299 |
|
300 |
+
# 4) Per-base SHAP & subregion detection
|
301 |
shap_means = compute_positionwise_scores(seq, shap_values, k=4)
|
302 |
max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max")
|
303 |
min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min")
|
304 |
|
305 |
+
# 5) Prepare result text
|
306 |
results_text = (
|
307 |
f"Sequence: {header}\n"
|
308 |
f"Length: {len(seq):,} bases\n"
|
|
|
316 |
f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
|
317 |
)
|
318 |
|
319 |
+
# 6) Create bar & heatmap figures
|
320 |
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
|
321 |
bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
|
322 |
bar_img = fig_to_image(bar_fig)
|
|
|
324 |
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
|
325 |
heatmap_img = fig_to_image(heatmap_fig)
|
326 |
|
327 |
+
# 7) Build the "state" dictionary so we can do subregion analysis
|
|
|
328 |
state_dict_out = {"seq": seq, "shap_means": shap_means}
|
329 |
|
330 |
+
# Return 6 items to match your Gradio output
|
331 |
return (results_text, bar_img, heatmap_img, state_dict_out, header, None)
|
332 |
|
333 |
###############################################################################
|
|
|
335 |
###############################################################################
|
336 |
|
337 |
def analyze_subregion(state, header, region_start, region_end):
|
338 |
+
"""
|
339 |
+
Examine a subregion’s SHAP distribution, GC content, etc.
|
340 |
+
"""
|
341 |
if not state or "seq" not in state or "shap_means" not in state:
|
342 |
return ("No sequence data found. Please run Step 1 first.", None, None, None)
|
343 |
seq = state["seq"]
|
|
|
348 |
region_end = max(0, min(region_end, len(seq)))
|
349 |
if region_end <= region_start:
|
350 |
return ("Invalid region range. End must be > Start.", None, None, None)
|
351 |
+
|
352 |
region_seq = seq[region_start:region_end]
|
353 |
region_shap = shap_means[region_start:region_end]
|
354 |
+
|
355 |
gc_percent = compute_gc_content(region_seq)
|
356 |
avg_shap = float(np.mean(region_shap))
|
357 |
positive_fraction = np.mean(region_shap > 0)
|
358 |
negative_fraction = np.mean(region_shap < 0)
|
359 |
+
|
360 |
if avg_shap > 0.05:
|
361 |
region_classification = "Likely pushing toward human"
|
362 |
elif avg_shap < -0.05:
|
363 |
region_classification = "Likely pushing toward non-human"
|
364 |
else:
|
365 |
region_classification = "Near neutral (no strong push)"
|
366 |
+
|
367 |
region_info = (
|
368 |
f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
|
369 |
f"Region length: {len(region_seq)} bases\n"
|
|
|
373 |
f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
|
374 |
f"Subregion interpretation: {region_classification}\n"
|
375 |
)
|
376 |
+
|
377 |
heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end)
|
378 |
heatmap_img = fig_to_image(heatmap_fig)
|
379 |
+
|
380 |
hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
|
381 |
hist_img = fig_to_image(hist_fig)
|
382 |
+
|
383 |
+
# Return 4 items to match your Gradio output
|
384 |
return (region_info, heatmap_img, hist_img, None)
|
385 |
|
386 |
###############################################################################
|
387 |
+
# 9. COMPARISON ANALYSIS FUNCTIONS (Step 4)
|
388 |
###############################################################################
|
389 |
|
|
|
|
|
|
|
|
|
|
|
390 |
def compute_shap_difference(shap1_norm, shap2_norm):
|
391 |
+
"""
|
392 |
+
Compute the SHAP difference (Seq2 - Seq1).
|
393 |
+
"""
|
394 |
return shap2_norm - shap1_norm
|
395 |
|
396 |
def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
|
397 |
"""
|
398 |
+
Plot a 1D heatmap of differences using relative positions 0-100%.
|
399 |
"""
|
400 |
heatmap_data = shap_diff.reshape(1, -1)
|
401 |
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
|
|
|
424 |
|
425 |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
|
426 |
"""
|
427 |
+
Plot a histogram of SHAP values with optional # of bins.
|
428 |
"""
|
429 |
fig, ax = plt.subplots(figsize=(6, 4))
|
430 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
|
|
|
438 |
|
439 |
def calculate_adaptive_parameters(len1, len2):
|
440 |
"""
|
441 |
+
Choose smoothing & interpolation parameters automatically based on length difference.
|
|
|
442 |
"""
|
443 |
length_diff = abs(len1 - len2)
|
444 |
max_length = max(len1, len2)
|
445 |
min_length = min(len1, len2)
|
446 |
length_ratio = min_length / max_length
|
447 |
|
448 |
+
# Base number of points
|
449 |
base_points = min(2000, max(500, max_length // 100))
|
450 |
|
|
|
451 |
if length_diff < 500:
|
452 |
resolution_factor = 2.0
|
453 |
num_points = min(3000, base_points * 2)
|
|
|
465 |
num_points = max(500, base_points // 2)
|
466 |
smooth_window = max(100, length_diff // 500)
|
467 |
|
|
|
468 |
smooth_window = int(smooth_window * (1 + (1 - length_ratio)))
|
|
|
469 |
return int(num_points), int(smooth_window), resolution_factor
|
470 |
|
471 |
def sliding_window_smooth(values, window_size=50):
|
472 |
"""
|
473 |
+
A custom smoothing approach, including exponential decay at edges.
|
474 |
"""
|
475 |
if window_size < 3:
|
476 |
return values
|
|
|
|
|
477 |
window = np.ones(window_size)
|
478 |
decay = np.exp(-np.linspace(0, 3, window_size // 2))
|
479 |
window[:window_size // 2] = decay
|
480 |
window[-(window_size // 2):] = decay[::-1]
|
481 |
window = window / window.sum()
|
482 |
|
|
|
483 |
smoothed = np.convolve(values, window, mode='valid')
|
|
|
|
|
484 |
pad_size = len(values) - len(smoothed)
|
485 |
pad_left = pad_size // 2
|
486 |
pad_right = pad_size - pad_left
|
|
|
494 |
|
495 |
def normalize_shap_lengths(shap1, shap2):
|
496 |
"""
|
497 |
+
Smooth, interpolate, and return arrays of the same length for direct comparison.
|
498 |
"""
|
|
|
499 |
num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
|
500 |
|
|
|
501 |
shap1_smooth = sliding_window_smooth(shap1, smooth_window)
|
502 |
shap2_smooth = sliding_window_smooth(shap2, smooth_window)
|
503 |
|
|
|
504 |
x1 = np.linspace(0, 1, len(shap1_smooth))
|
505 |
x2 = np.linspace(0, 1, len(shap2_smooth))
|
506 |
x_norm = np.linspace(0, 1, num_points)
|
|
|
512 |
|
513 |
def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
514 |
"""
|
515 |
+
Compare two sequences using the previously defined analysis pipeline
|
516 |
+
and produce difference visualizations & stats.
|
517 |
"""
|
518 |
try:
|
519 |
# Analyze first sequence
|
|
|
526 |
if isinstance(res2[0], str) and "Error" in res2[0]:
|
527 |
return (f"Error in sequence 2: {res2[0]}", None, None, None)
|
528 |
|
|
|
529 |
shap1 = res1[3]["shap_means"]
|
530 |
shap2 = res2[3]["shap_means"]
|
531 |
|
|
|
532 |
len1, len2 = len(shap1), len(shap2)
|
533 |
length_diff = abs(len1 - len2)
|
534 |
length_ratio = min(len1, len2) / max(len1, len2)
|
535 |
+
|
536 |
+
# Normalize both to the same length
|
537 |
shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2)
|
538 |
shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
|
539 |
|
540 |
+
# Compute stats
|
541 |
base_threshold = 0.05
|
542 |
adaptive_threshold = base_threshold * (1 + (1 - length_ratio))
|
543 |
if length_diff > 50000:
|
544 |
adaptive_threshold *= 1.5
|
545 |
|
|
|
546 |
avg_diff = np.mean(shap_diff)
|
547 |
std_diff = np.std(shap_diff)
|
548 |
max_diff = np.max(shap_diff)
|
|
|
550 |
substantial_diffs = np.abs(shap_diff) > adaptive_threshold
|
551 |
frac_different = np.mean(substantial_diffs)
|
552 |
|
553 |
+
# Extract classification from text
|
554 |
try:
|
555 |
classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip()
|
556 |
classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip()
|
|
|
558 |
classification1 = "Unknown"
|
559 |
classification2 = "Unknown"
|
560 |
|
|
|
561 |
comparison_text = (
|
562 |
"Sequence Comparison Results:\n"
|
563 |
f"Sequence 1: {res1[4]}\n"
|
|
|
584 |
"- White regions: Similar between sequences"
|
585 |
)
|
586 |
|
|
|
587 |
heatmap_fig = plot_comparative_heatmap(
|
588 |
shap_diff,
|
589 |
title=f"SHAP Difference Heatmap (window: {smooth_window})"
|
590 |
)
|
591 |
heatmap_img = fig_to_image(heatmap_fig)
|
592 |
|
|
|
593 |
num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
|
594 |
hist_fig = plot_shap_histogram(
|
595 |
shap_diff,
|
|
|
598 |
)
|
599 |
hist_img = fig_to_image(hist_fig)
|
600 |
|
|
|
601 |
return (comparison_text, heatmap_img, hist_img, None)
|
602 |
|
603 |
except Exception as e:
|
|
|
605 |
return (error_msg, None, None, None)
|
606 |
|
607 |
###############################################################################
|
608 |
+
# 10. ADDITIONAL / ADVANCED VISUALIZATIONS & STATISTICS
|
609 |
###############################################################################
|
610 |
|
611 |
+
def n50_length(sequence):
|
612 |
+
"""
|
613 |
+
Calculate the N50 for a single continuous sequence (for demonstration).
|
614 |
+
For a single sequence, N50 is typically the length if it's just one piece,
|
615 |
+
but let's do a simplistic example.
|
616 |
+
"""
|
617 |
+
# If you had contigs, you'd do a sorted list, cumulative sums, etc.
|
618 |
+
# We'll do a trivial approach here:
|
619 |
+
return len(sequence) # Because we have only one contiguous region
|
620 |
+
|
621 |
+
def sequence_complexity(sequence):
|
622 |
+
"""
|
623 |
+
Compute a simple measure of 'sequence complexity'.
|
624 |
+
Here, we define complexity as the Shannon entropy over the nucleotides.
|
625 |
+
"""
|
626 |
+
from math import log2
|
627 |
+
length = len(sequence)
|
628 |
+
if length == 0:
|
629 |
+
return 0.0
|
630 |
+
freq = {}
|
631 |
+
for base in sequence:
|
632 |
+
freq[base] = freq.get(base, 0) + 1
|
633 |
+
complexity = 0.0
|
634 |
+
for base, count in freq.items():
|
635 |
+
p = count / length
|
636 |
+
complexity -= p * log2(p)
|
637 |
+
return complexity
|
638 |
+
|
639 |
+
def advanced_gene_statistics(gene_shap: np.ndarray, gene_seq: str) -> Dict[str, float]:
|
640 |
+
"""
|
641 |
+
Additional stats: N50, complexity, etc.
|
642 |
+
"""
|
643 |
+
stats = {}
|
644 |
+
stats['n50'] = len(gene_seq) # trivial for a single gene region
|
645 |
+
stats['entropy'] = sequence_complexity(gene_seq)
|
646 |
+
stats['avg_shap'] = float(np.mean(gene_shap))
|
647 |
+
stats['max_shap'] = float(np.max(gene_shap)) if len(gene_shap) else 0.0
|
648 |
+
stats['min_shap'] = float(np.min(gene_shap)) if len(gene_shap) else 0.0
|
649 |
+
return stats
|
650 |
+
|
651 |
+
###############################################################################
|
652 |
+
# 11. GENE FEATURE ANALYSIS
|
653 |
+
###############################################################################
|
654 |
|
655 |
def parse_gene_features(text: str) -> List[Dict[str, Any]]:
|
656 |
+
"""Parse gene features from text file in a FASTA-like format."""
|
657 |
genes = []
|
658 |
current_header = None
|
659 |
current_sequence = []
|
|
|
662 |
line = line.strip()
|
663 |
if not line:
|
664 |
continue
|
|
|
665 |
if line.startswith('>'):
|
666 |
if current_header:
|
667 |
genes.append({
|
|
|
673 |
current_sequence = []
|
674 |
else:
|
675 |
current_sequence.append(line.upper())
|
|
|
676 |
if current_header:
|
677 |
genes.append({
|
678 |
'header': current_header,
|
679 |
'sequence': ''.join(current_sequence),
|
680 |
'metadata': parse_gene_metadata(current_header)
|
681 |
})
|
|
|
682 |
return genes
|
683 |
|
684 |
def parse_gene_metadata(header: str) -> Dict[str, str]:
|
685 |
+
"""Extract metadata from gene header line."""
|
686 |
metadata = {}
|
687 |
parts = header.split()
|
|
|
688 |
for part in parts:
|
689 |
if '[' in part and ']' in part:
|
690 |
key_value = part[1:-1].split('=', 1)
|
691 |
if len(key_value) == 2:
|
692 |
metadata[key_value[0]] = key_value[1]
|
|
|
693 |
return metadata
|
694 |
|
695 |
def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
|
696 |
+
"""Parse gene location string, handling forward and complement strands."""
|
697 |
try:
|
|
|
698 |
clean_loc = location_str.replace('complement(', '').replace(')', '')
|
|
|
|
|
699 |
if '..' in clean_loc:
|
700 |
start, end = map(int, clean_loc.split('..'))
|
701 |
return start, end
|
|
|
706 |
return None, None
|
707 |
|
708 |
def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
|
709 |
+
"""Basic statistical measures for gene SHAP values."""
|
710 |
return {
|
711 |
+
'avg_shap': float(np.mean(gene_shap)) if len(gene_shap) else 0.0,
|
712 |
+
'median_shap': float(np.median(gene_shap)) if len(gene_shap) else 0.0,
|
713 |
+
'std_shap': float(np.std(gene_shap)) if len(gene_shap) else 0.0,
|
714 |
+
'max_shap': float(np.max(gene_shap)) if len(gene_shap) else 0.0,
|
715 |
+
'min_shap': float(np.min(gene_shap)) if len(gene_shap) else 0.0,
|
716 |
+
'pos_fraction': float(np.mean(gene_shap > 0)) if len(gene_shap) else 0.0
|
717 |
}
|
718 |
|
719 |
def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
|
720 |
"""
|
721 |
+
A quick PIL-based diagram to show genes along the genome.
|
722 |
+
Color intensity = magnitude of SHAP. Red/Blue = sign of SHAP.
|
723 |
"""
|
|
|
|
|
|
|
724 |
if not gene_results or genome_length <= 0:
|
725 |
img = Image.new('RGB', (800, 100), color='white')
|
726 |
draw = ImageDraw.Draw(img)
|
727 |
draw.text((10, 40), "Error: Invalid input data", fill='black')
|
728 |
return img
|
729 |
+
|
|
|
730 |
for gene in gene_results:
|
731 |
gene['start'] = max(0, int(gene['start']))
|
732 |
gene['end'] = min(genome_length, int(gene['end']))
|
733 |
if gene['start'] >= gene['end']:
|
734 |
+
print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}")
|
735 |
|
|
|
736 |
width = 1500
|
737 |
height = 600
|
738 |
margin = 50
|
739 |
track_height = 40
|
740 |
|
|
|
741 |
img = Image.new('RGB', (width, height), 'white')
|
742 |
draw = ImageDraw.Draw(img)
|
743 |
|
|
|
744 |
try:
|
745 |
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
|
746 |
title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
|
|
|
748 |
font = ImageFont.load_default()
|
749 |
title_font = ImageFont.load_default()
|
750 |
|
751 |
+
draw.text((margin, margin // 2), "Genome SHAP Analysis (Simple)", fill='black', font=title_font or font)
|
|
|
752 |
|
|
|
753 |
line_y = height // 2
|
754 |
draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2)
|
755 |
|
|
|
756 |
scale = float(width - 2 * margin) / float(genome_length)
|
757 |
|
758 |
+
# Scale markers
|
759 |
num_ticks = 10
|
760 |
+
step = max(1, genome_length // num_ticks)
|
|
|
|
|
|
|
|
|
|
|
761 |
for i in range(0, genome_length + 1, step):
|
762 |
x_coord = margin + i * scale
|
763 |
draw.line([
|
|
|
766 |
], fill='black', width=1)
|
767 |
draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
|
768 |
|
|
|
769 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
|
|
|
|
|
770 |
for idx, gene in enumerate(sorted_genes):
|
|
|
771 |
start_x = margin + int(gene['start'] * scale)
|
772 |
end_x = margin + int(gene['end'] * scale)
|
|
|
|
|
773 |
avg_shap = gene['avg_shap']
|
|
|
|
|
|
|
774 |
intensity = int(abs(avg_shap) * 500)
|
775 |
+
intensity = max(50, min(255, intensity))
|
776 |
|
777 |
if avg_shap > 0:
|
778 |
+
color = (255, 255 - intensity, 255 - intensity) # Redish
|
|
|
779 |
else:
|
780 |
+
color = (255 - intensity, 255 - intensity, 255) # Blueish
|
|
|
781 |
|
|
|
782 |
draw.rectangle([
|
783 |
(int(start_x), int(line_y - track_height // 2)),
|
784 |
(int(end_x), int(line_y + track_height // 2))
|
785 |
], fill=color, outline='black')
|
786 |
|
|
|
787 |
label = str(gene.get('gene_name','?'))
|
|
|
|
|
788 |
label_mask = font.getmask(label)
|
789 |
label_width, label_height = label_mask.size
|
790 |
|
|
|
791 |
if idx % 2 == 0:
|
792 |
text_y = line_y - track_height - 15
|
793 |
else:
|
794 |
text_y = line_y + track_height + 5
|
795 |
|
|
|
796 |
gene_width = end_x - start_x
|
797 |
if gene_width > label_width:
|
798 |
text_x = start_x + (gene_width - label_width) // 2
|
|
|
804 |
rotated_img = txt_img.rotate(90, expand=True)
|
805 |
img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img)
|
806 |
|
807 |
+
return img
|
808 |
+
|
809 |
+
def create_advanced_genome_diagram(gene_results: List[Dict[str, Any]],
|
810 |
+
genome_length: int,
|
811 |
+
shap_means: np.ndarray,
|
812 |
+
diagram_title: str = "Advanced Genome Diagram") -> Image.Image:
|
813 |
+
"""
|
814 |
+
An advanced genome diagram using Biopython's GenomeDiagram.
|
815 |
+
We'll create tracks for genes and a 'SHAP line plot' track.
|
816 |
+
"""
|
817 |
+
if not gene_results or genome_length <= 0 or len(shap_means) == 0:
|
818 |
+
# Fallback if data is invalid
|
819 |
+
img = Image.new('RGB', (800, 100), color='white')
|
820 |
+
d = ImageDraw.Draw(img)
|
821 |
+
d.text((10, 40), "Error: Not enough data for advanced diagram", fill='black')
|
822 |
+
return img
|
823 |
+
|
824 |
+
diagram = GenomeDiagram.Diagram(diagram_title)
|
825 |
+
gene_track = diagram.new_track(1, name="Genes", greytrack=False, height=0.5)
|
826 |
+
gene_set = gene_track.new_set()
|
827 |
+
|
828 |
+
# Add each gene as a feature
|
829 |
+
for gene in gene_results:
|
830 |
+
start = max(0, int(gene['start']))
|
831 |
+
end = min(genome_length, int(gene['end']))
|
832 |
+
avg_shap = gene['avg_shap']
|
833 |
+
# Color scale: negative = blue, positive = red
|
834 |
+
intensity = abs(avg_shap) * 500
|
835 |
+
intensity = max(50, min(255, intensity))
|
836 |
+
if avg_shap >= 0:
|
837 |
+
color_hex = colors.Color(1.0, 1.0 - intensity/255.0, 1.0 - intensity/255.0)
|
838 |
+
else:
|
839 |
+
color_hex = colors.Color(1.0 - intensity/255.0, 1.0 - intensity/255.0, 1.0)
|
840 |
+
|
841 |
+
feature = SeqFeature(FeatureLocation(start, end), strand=1)
|
842 |
+
gene_set.add_feature(
|
843 |
+
feature,
|
844 |
+
color=color_hex,
|
845 |
+
label=True,
|
846 |
+
name=str(gene.get('gene_name','?')),
|
847 |
+
label_size=8,
|
848 |
+
label_color=colors.black
|
849 |
+
)
|
850 |
+
|
851 |
+
# Add a track for the SHAP line
|
852 |
+
shap_track = diagram.new_track(2, name="SHAP Score", greytrack=False, height=0.3)
|
853 |
+
shap_set = shap_track.new_set("graph")
|
854 |
+
# We'll plot the entire shap_means array.
|
855 |
+
# X coords = [0..genome_length], Y coords = shap_means
|
856 |
+
# We'll keep negative values below baseline, positive above.
|
857 |
+
|
858 |
+
# Normalizing for visualization
|
859 |
+
max_abs = max(abs(shap_means.min()), abs(shap_means.max()))
|
860 |
+
if max_abs == 0:
|
861 |
+
scaled_shap = [0]*len(shap_means)
|
862 |
+
else:
|
863 |
+
scaled_shap = (shap_means / max_abs * 50).tolist() # scale to +/- 50
|
864 |
|
865 |
+
shap_set.add_graph(
|
866 |
+
data=scaled_shap,
|
867 |
+
name="shap_line",
|
868 |
+
style="line",
|
869 |
+
color=colors.darkgreen,
|
870 |
+
altcolor=colors.red,
|
871 |
+
linewidth=1
|
872 |
+
)
|
873 |
+
|
874 |
+
# Draw to a temporary file
|
875 |
+
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmpf:
|
876 |
+
diagram.draw(format="linear", pagesize='A3', fragments=1, start=0, end=genome_length)
|
877 |
+
diagram.write(tmpf.name, "PDF")
|
878 |
+
|
879 |
+
# Convert PDF to a PIL image (requires poppler or similar).
|
880 |
+
# If you do not have poppler, you can skip PDF -> image or use Cairo.
|
881 |
+
try:
|
882 |
+
import pdf2image
|
883 |
+
pages = pdf2image.convert_from_path(tmpf.name, dpi=100)
|
884 |
+
img = pages[0] if pages else Image.new('RGB', (800, 100), color='white')
|
885 |
+
except ImportError:
|
886 |
+
img = Image.new('RGB', (800, 100), color='white')
|
887 |
+
d = ImageDraw.Draw(img)
|
888 |
+
d.text((10, 40), "pdf2image not installed, can't show advanced diagram as image.", fill='black')
|
889 |
+
|
890 |
+
# Cleanup
|
891 |
+
os.remove(tmpf.name)
|
892 |
return img
|
893 |
|
894 |
def analyze_gene_features(sequence_file: str,
|
895 |
features_file: str,
|
896 |
fasta_text: str = "",
|
897 |
+
features_text: str = "",
|
898 |
+
diagram_mode: str = "advanced"
|
899 |
+
) -> Tuple[str, Optional[str], Optional[Image.Image]]:
|
900 |
+
"""
|
901 |
+
Analyze each gene in the features file, compute gene-level SHAP stats,
|
902 |
+
produce tabular output, and create an optional genome diagram.
|
903 |
+
"""
|
904 |
+
# 1) Analyze the entire sequence with the top-level function
|
905 |
sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
|
906 |
if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
|
907 |
return f"Error in sequence analysis: {sequence_results[0]}", None, None
|
908 |
|
909 |
+
seq = sequence_results[3]["seq"]
|
910 |
shap_means = sequence_results[3]["shap_means"]
|
911 |
+
genome_length = len(seq)
|
912 |
+
|
913 |
+
# 2) Read gene features
|
914 |
try:
|
915 |
if features_text.strip():
|
916 |
genes = parse_gene_features(features_text)
|
|
|
919 |
genes = parse_gene_features(f.read())
|
920 |
except Exception as e:
|
921 |
return f"Error reading features file: {str(e)}", None, None
|
922 |
+
|
|
|
923 |
gene_results = []
|
924 |
for gene in genes:
|
925 |
+
location = gene['metadata'].get('location', '')
|
926 |
+
if not location:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
927 |
continue
|
928 |
+
start, end = parse_location(location)
|
929 |
+
if start is None or end is None or start >= end or end > genome_length:
|
930 |
+
continue
|
931 |
+
gene_shap = shap_means[start:end]
|
932 |
+
basic_stats = compute_gene_statistics(gene_shap)
|
933 |
+
# Additional stats
|
934 |
+
gene_seq = seq[start:end]
|
935 |
+
adv_stats = advanced_gene_statistics(gene_shap, gene_seq)
|
936 |
+
|
937 |
+
# Merge basic + advanced stats
|
938 |
+
all_stats = {**basic_stats, **adv_stats}
|
939 |
+
|
940 |
+
classification = 'Human' if basic_stats['avg_shap'] > 0 else 'Non-human'
|
941 |
+
locus_tag = gene['metadata'].get('locus_tag', '')
|
942 |
+
gene_name = gene['metadata'].get('gene', 'Unknown')
|
943 |
+
|
944 |
+
gene_dict = {
|
945 |
+
'gene_name': gene_name,
|
946 |
+
'location': location,
|
947 |
+
'start': start,
|
948 |
+
'end': end,
|
949 |
+
'locus_tag': locus_tag,
|
950 |
+
'avg_shap': all_stats['avg_shap'],
|
951 |
+
'median_shap': basic_stats['median_shap'],
|
952 |
+
'std_shap': basic_stats['std_shap'],
|
953 |
+
'max_shap': basic_stats['max_shap'],
|
954 |
+
'min_shap': basic_stats['min_shap'],
|
955 |
+
'pos_fraction': basic_stats['pos_fraction'],
|
956 |
+
'n50': all_stats['n50'],
|
957 |
+
'entropy': all_stats['entropy'],
|
958 |
+
'classification': classification,
|
959 |
+
'confidence': abs(all_stats['avg_shap'])
|
960 |
+
}
|
961 |
+
gene_results.append(gene_dict)
|
962 |
+
|
963 |
if not gene_results:
|
964 |
return "No valid genes could be processed", None, None
|
965 |
+
|
966 |
+
# 3) Summaries
|
967 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
|
|
|
|
|
968 |
results_text = "Gene Analysis Results:\n\n"
|
969 |
results_text += f"Total genes analyzed: {len(gene_results)}\n"
|
970 |
+
num_human = sum(1 for g in gene_results if g['classification'] == 'Human')
|
971 |
+
results_text += f"Human-like genes: {num_human}\n"
|
972 |
+
results_text += f"Non-human-like genes: {len(gene_results) - num_human}\n\n"
|
973 |
|
974 |
+
results_text += "Top 10 most distinctive genes (by avg SHAP magnitude):\n"
|
975 |
for gene in sorted_genes[:10]:
|
976 |
results_text += (
|
977 |
f"Gene: {gene['gene_name']}\n"
|
978 |
f"Location: {gene['location']}\n"
|
979 |
f"Classification: {gene['classification']} "
|
980 |
f"(confidence: {gene['confidence']:.4f})\n"
|
981 |
+
f"Average SHAP: {gene['avg_shap']:.4f}\n"
|
982 |
+
f"N50: {gene['n50']}, Entropy: {gene['entropy']:.3f}\n\n"
|
983 |
)
|
984 |
+
|
985 |
+
# 4) Make CSV
|
986 |
+
csv_content = "gene_name,location,start,end,locus_tag,avg_shap,median_shap,std_shap,"
|
987 |
+
csv_content += "max_shap,min_shap,pos_fraction,n50,entropy,classification,confidence\n"
|
988 |
+
for g in gene_results:
|
|
|
989 |
csv_content += (
|
990 |
+
f"{g['gene_name']},{g['location']},{g['start']},{g['end']},{g['locus_tag']},"
|
991 |
+
f"{g['avg_shap']:.4f},{g['median_shap']:.4f},{g['std_shap']:.4f},"
|
992 |
+
f"{g['max_shap']:.4f},{g['min_shap']:.4f},{g['pos_fraction']:.4f},"
|
993 |
+
f"{g['n50']},{g['entropy']:.4f},{g['classification']},{g['confidence']:.4f}\n"
|
994 |
)
|
|
|
|
|
995 |
try:
|
996 |
temp_dir = tempfile.gettempdir()
|
997 |
temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv")
|
|
|
998 |
with open(temp_path, 'w') as f:
|
999 |
f.write(csv_content)
|
1000 |
except Exception as e:
|
1001 |
print(f"Error saving CSV: {str(e)}")
|
1002 |
temp_path = None
|
1003 |
+
|
1004 |
+
# 5) Create diagram
|
1005 |
try:
|
1006 |
+
if diagram_mode == "advanced":
|
1007 |
+
diagram_img = create_advanced_genome_diagram(gene_results, genome_length, shap_means)
|
1008 |
+
else:
|
1009 |
+
diagram_img = create_simple_genome_diagram(gene_results, genome_length)
|
1010 |
except Exception as e:
|
1011 |
print(f"Error creating visualization: {str(e)}")
|
|
|
1012 |
diagram_img = Image.new('RGB', (800, 100), color='white')
|
1013 |
draw = ImageDraw.Draw(diagram_img)
|
1014 |
draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black')
|
1015 |
+
|
1016 |
return results_text, temp_path, diagram_img
|
1017 |
|
1018 |
###############################################################################
|
|
|
1020 |
###############################################################################
|
1021 |
|
1022 |
def prepare_csv_download(data, filename="analysis_results.csv"):
|
1023 |
+
"""
|
1024 |
+
Convert data to CSV for Gradio download button.
|
1025 |
+
"""
|
1026 |
if isinstance(data, str):
|
1027 |
return data.encode(), filename
|
1028 |
elif isinstance(data, (list, dict)):
|
1029 |
import csv
|
1030 |
from io import StringIO
|
|
|
1031 |
output = StringIO()
|
1032 |
writer = csv.DictWriter(output, fieldnames=data[0].keys())
|
1033 |
writer.writeheader()
|
|
|
1051 |
|
1052 |
with gr.Blocks(css=css) as iface:
|
1053 |
gr.Markdown("""
|
1054 |
+
# Virus Host Classifier + Extended Genome Visualization
|
1055 |
+
**Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme subregions.
|
1056 |
+
**Step 2**: Explore subregions (local SHAP, GC content, histogram).
|
1057 |
+
**Step 3**: Analyze gene features (per-gene SHAP, advanced stats, improved diagrams).
|
1058 |
+
**Step 4**: Compare sequences for SHAP differences.
|
1059 |
+
|
1060 |
+
**Color Scale**: Negative SHAP = Blue, 0 = White, Positive = Red.
|
1061 |
""")
|
1062 |
|
1063 |
with gr.Tab("1) Full-Sequence Analysis"):
|
1064 |
with gr.Row():
|
1065 |
with gr.Column(scale=1):
|
1066 |
file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
|
1067 |
+
text_input = gr.Textbox(label="Or paste FASTA", placeholder=">name\nACGT...", lines=5)
|
1068 |
top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
|
1069 |
+
win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Subregion Window Size")
|
1070 |
analyze_btn = gr.Button("Analyze Sequence", variant="primary")
|
1071 |
with gr.Column(scale=2):
|
1072 |
results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
|
|
|
1085 |
with gr.Tab("2) Subregion Exploration"):
|
1086 |
gr.Markdown("""
|
1087 |
**Subregion Analysis**
|
1088 |
+
View SHAP signals, GC content, etc. for a specific region.
|
|
|
1089 |
""")
|
1090 |
with gr.Row():
|
1091 |
region_start = gr.Number(label="Region Start", value=0)
|
|
|
1095 |
with gr.Row():
|
1096 |
subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
|
1097 |
subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
|
1098 |
+
download_subregion = gr.File(label="Download Subregion", visible=False, elem_classes="download-button")
|
1099 |
|
1100 |
region_btn.click(
|
1101 |
analyze_subregion,
|
|
|
1106 |
with gr.Tab("3) Gene Features Analysis"):
|
1107 |
gr.Markdown("""
|
1108 |
**Analyze Gene Features**
|
1109 |
+
- Upload a FASTA file and a gene features file.
|
1110 |
+
- See per-gene SHAP, classification, N50, entropy, etc.
|
1111 |
+
- Choose a diagram mode (simple or advanced).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1112 |
""")
|
1113 |
with gr.Row():
|
1114 |
with gr.Column(scale=1):
|
1115 |
+
gene_fasta_file = gr.File(label="FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
|
1116 |
+
gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence", lines=5)
|
1117 |
with gr.Column(scale=1):
|
1118 |
+
features_file = gr.File(label="Gene features file", file_types=[".txt"], type="filepath")
|
1119 |
+
features_text = gr.Textbox(label="Or paste gene features", lines=5)
|
1120 |
+
diagram_mode = gr.Radio(choices=["simple", "advanced"], value="advanced", label="Diagram Mode")
|
1121 |
analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary")
|
1122 |
gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False)
|
1123 |
+
gene_diagram = gr.Image(label="Genome Diagram")
|
1124 |
download_gene_results = gr.File(label="Download Gene Analysis (CSV)", visible=True)
|
1125 |
|
1126 |
analyze_genes_btn.click(
|
1127 |
analyze_gene_features,
|
1128 |
+
inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text, diagram_mode],
|
1129 |
outputs=[gene_results, download_gene_results, gene_diagram]
|
1130 |
)
|
1131 |
|
1132 |
with gr.Tab("4) Comparative Analysis"):
|
1133 |
gr.Markdown("""
|
1134 |
**Compare Two Sequences**
|
1135 |
+
- Upload or paste two FASTA sequences.
|
1136 |
+
- We'll compare SHAP patterns (normalized for different lengths).
|
|
|
|
|
|
|
|
|
|
|
1137 |
""")
|
1138 |
with gr.Row():
|
1139 |
with gr.Column(scale=1):
|
1140 |
+
file_input1 = gr.File(label="1st FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
|
1141 |
+
text_input1 = gr.Textbox(label="Or paste 1st FASTA", lines=5)
|
1142 |
with gr.Column(scale=1):
|
1143 |
+
file_input2 = gr.File(label="2nd FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
|
1144 |
+
text_input2 = gr.Textbox(label="Or paste 2nd FASTA", lines=5)
|
1145 |
compare_btn = gr.Button("Compare Sequences", variant="primary")
|
1146 |
comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
|
1147 |
with gr.Row():
|
1148 |
diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
|
1149 |
diff_hist = gr.Image(label="Distribution of SHAP Differences")
|
1150 |
+
download_comparison = gr.File(label="Download Comparison", visible=False, elem_classes="download-button")
|
1151 |
|
1152 |
compare_btn.click(
|
1153 |
analyze_sequence_comparison,
|
|
|
1156 |
)
|
1157 |
|
1158 |
gr.Markdown("""
|
1159 |
+
### Notes & Features
|
1160 |
+
- **Advanced Genome Diagram** uses Biopython’s `GenomeDiagram` (requires `pdf2image` if you want it as an image).
|
1161 |
+
- **Additional Stats**: N50, Shannon entropy, etc.
|
1162 |
+
- **Auto-scaling** for comparative analysis with adaptive smoothing.
|
1163 |
+
- **Data Export**: Download CSV of analysis results.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1164 |
""")
|
1165 |
+
|
1166 |
if __name__ == "__main__":
|
1167 |
iface.launch()
|