Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -474,7 +474,230 @@ def analyze_subregion(state, header, region_start, region_end):
|
|
474 |
|
475 |
return (region_info, heatmap_img, hist_img)
|
476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
###############################################################################
|
479 |
# 9. BUILD GRADIO INTERFACE
|
480 |
###############################################################################
|
|
|
474 |
|
475 |
return (region_info, heatmap_img, hist_img)
|
476 |
|
477 |
+
# Add these imports at the top of the file, after existing imports
|
478 |
+
from scipy.interpolate import interp1d
|
479 |
+
import numpy as np
|
480 |
+
|
481 |
+
###############################################################################
|
482 |
+
# NEW SECTION: COMPARATIVE ANALYSIS FUNCTIONS
|
483 |
+
###############################################################################
|
484 |
|
485 |
+
def normalize_shap_lengths(shap1, shap2, num_points=1000):
|
486 |
+
"""
|
487 |
+
Normalize two SHAP arrays to the same length using interpolation.
|
488 |
+
Returns (normalized_shap1, normalized_shap2)
|
489 |
+
"""
|
490 |
+
# Create x coordinates for both sequences
|
491 |
+
x1 = np.linspace(0, 1, len(shap1))
|
492 |
+
x2 = np.linspace(0, 1, len(shap2))
|
493 |
+
|
494 |
+
# Create interpolation functions
|
495 |
+
f1 = interp1d(x1, shap1, kind='linear')
|
496 |
+
f2 = interp1d(x2, shap2, kind='linear')
|
497 |
+
|
498 |
+
# Create new x coordinates for interpolation
|
499 |
+
x_new = np.linspace(0, 1, num_points)
|
500 |
+
|
501 |
+
# Interpolate both sequences to new length
|
502 |
+
shap1_norm = f1(x_new)
|
503 |
+
shap2_norm = f2(x_new)
|
504 |
+
|
505 |
+
return shap1_norm, shap2_norm
|
506 |
+
|
507 |
+
def compute_shap_difference(shap1_norm, shap2_norm):
|
508 |
+
"""
|
509 |
+
Compute the difference between two normalized SHAP arrays.
|
510 |
+
Positive values indicate seq2 is more "human-like" than seq1.
|
511 |
+
"""
|
512 |
+
return shap2_norm - shap1_norm
|
513 |
+
|
514 |
+
def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
|
515 |
+
"""
|
516 |
+
Plot the difference between two sequences' SHAP values.
|
517 |
+
Red indicates seq2 is more human-like, blue indicates seq1 is more human-like.
|
518 |
+
"""
|
519 |
+
# Build 2D array for imshow
|
520 |
+
heatmap_data = shap_diff.reshape(1, -1)
|
521 |
+
|
522 |
+
# Force symmetrical range
|
523 |
+
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
|
524 |
+
|
525 |
+
# Create figure with adjusted height ratio
|
526 |
+
fig, ax = plt.subplots(figsize=(12, 1.8))
|
527 |
+
|
528 |
+
# Create custom colormap
|
529 |
+
custom_cmap = get_zero_centered_cmap()
|
530 |
+
|
531 |
+
# Plot heatmap
|
532 |
+
cax = ax.imshow(
|
533 |
+
heatmap_data,
|
534 |
+
aspect='auto',
|
535 |
+
cmap=custom_cmap,
|
536 |
+
vmin=-extent,
|
537 |
+
vmax=+extent
|
538 |
+
)
|
539 |
+
|
540 |
+
# Configure colorbar
|
541 |
+
cbar = plt.colorbar(
|
542 |
+
cax,
|
543 |
+
orientation='horizontal',
|
544 |
+
pad=0.25,
|
545 |
+
aspect=40,
|
546 |
+
shrink=0.8
|
547 |
+
)
|
548 |
+
|
549 |
+
# Style the colorbar
|
550 |
+
cbar.ax.tick_params(labelsize=8)
|
551 |
+
cbar.set_label(
|
552 |
+
'SHAP Difference (Seq2 - Seq1)',
|
553 |
+
fontsize=9,
|
554 |
+
labelpad=5
|
555 |
+
)
|
556 |
+
|
557 |
+
# Configure main plot
|
558 |
+
ax.set_yticks([])
|
559 |
+
ax.set_xlabel('Normalized Position (0-100%)', fontsize=10)
|
560 |
+
ax.set_title(title, pad=10)
|
561 |
+
|
562 |
+
plt.subplots_adjust(
|
563 |
+
bottom=0.25,
|
564 |
+
left=0.05,
|
565 |
+
right=0.95
|
566 |
+
)
|
567 |
+
|
568 |
+
return fig
|
569 |
+
|
570 |
+
def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
571 |
+
"""
|
572 |
+
Compare two sequences by analyzing their SHAP differences.
|
573 |
+
Returns comparison text and visualizations.
|
574 |
+
"""
|
575 |
+
# Process first sequence
|
576 |
+
results1 = analyze_sequence(file1, fasta_text=fasta1)
|
577 |
+
if isinstance(results1[0], str) and "Error" in results1[0]:
|
578 |
+
return (f"Error in sequence 1: {results1[0]}", None, None)
|
579 |
+
|
580 |
+
# Process second sequence
|
581 |
+
results2 = analyze_sequence(file2, fasta_text=fasta2)
|
582 |
+
if isinstance(results2[0], str) and "Error" in results2[0]:
|
583 |
+
return (f"Error in sequence 2: {results2[0]}", None, None)
|
584 |
+
|
585 |
+
# Get SHAP means from state dictionaries
|
586 |
+
shap1 = results1[3]["shap_means"]
|
587 |
+
shap2 = results2[3]["shap_means"]
|
588 |
+
|
589 |
+
# Normalize lengths
|
590 |
+
shap1_norm, shap2_norm = normalize_shap_lengths(shap1, shap2)
|
591 |
+
|
592 |
+
# Compute difference (positive = seq2 more human-like)
|
593 |
+
shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
|
594 |
+
|
595 |
+
# Calculate some statistics
|
596 |
+
avg_diff = np.mean(shap_diff)
|
597 |
+
std_diff = np.std(shap_diff)
|
598 |
+
max_diff = np.max(shap_diff)
|
599 |
+
min_diff = np.min(shap_diff)
|
600 |
+
|
601 |
+
# Calculate what fraction of positions show substantial differences
|
602 |
+
threshold = 0.05 # Arbitrary threshold for "substantial" difference
|
603 |
+
substantial_diffs = np.abs(shap_diff) > threshold
|
604 |
+
frac_different = np.mean(substantial_diffs)
|
605 |
+
|
606 |
+
# Generate comparison text
|
607 |
+
comparison_text = f"""Sequence Comparison Results:
|
608 |
+
Sequence 1: {results1[4]}
|
609 |
+
Length: {len(shap1):,} bases
|
610 |
+
Classification: {results1[0].split('Classification: ')[1].split('\n')[0]}
|
611 |
+
|
612 |
+
Sequence 2: {results2[4]}
|
613 |
+
Length: {len(shap2):,} bases
|
614 |
+
Classification: {results2[0].split('Classification: ')[1].split('\n')[0]}
|
615 |
+
|
616 |
+
Comparison Statistics:
|
617 |
+
Average SHAP difference: {avg_diff:.4f}
|
618 |
+
Standard deviation: {std_diff:.4f}
|
619 |
+
Max difference: {max_diff:.4f} (Seq2 more human-like)
|
620 |
+
Min difference: {min_diff:.4f} (Seq1 more human-like)
|
621 |
+
Fraction of positions with substantial differences: {frac_different:.2%}
|
622 |
+
|
623 |
+
Interpretation:
|
624 |
+
Positive values (red) indicate regions where Sequence 2 is more "human-like"
|
625 |
+
Negative values (blue) indicate regions where Sequence 1 is more "human-like"
|
626 |
+
"""
|
627 |
+
|
628 |
+
# Create comparison heatmap
|
629 |
+
heatmap_fig = plot_comparative_heatmap(shap_diff)
|
630 |
+
heatmap_img = fig_to_image(heatmap_fig)
|
631 |
+
|
632 |
+
# Create histogram of differences
|
633 |
+
hist_fig = plot_shap_histogram(
|
634 |
+
shap_diff,
|
635 |
+
title="Distribution of SHAP Differences"
|
636 |
+
)
|
637 |
+
hist_img = fig_to_image(hist_fig)
|
638 |
+
|
639 |
+
return comparison_text, heatmap_img, hist_img
|
640 |
+
|
641 |
+
###############################################################################
|
642 |
+
# NEW TAB TO GRADIO
|
643 |
+
###############################################################################
|
644 |
+
|
645 |
+
# Inside the Gradio interface definition, add this new tab:
|
646 |
+
with gr.Tab("3) Comparative Analysis"):
|
647 |
+
gr.Markdown("""
|
648 |
+
**Compare Two Sequences**
|
649 |
+
Upload or paste two FASTA sequences to compare their SHAP patterns.
|
650 |
+
The sequences will be normalized to the same length for comparison.
|
651 |
+
|
652 |
+
**Color Scale**:
|
653 |
+
- Red: Sequence 2 is more human-like in this region
|
654 |
+
- Blue: Sequence 1 is more human-like in this region
|
655 |
+
- White: No substantial difference
|
656 |
+
""")
|
657 |
+
|
658 |
+
with gr.Row():
|
659 |
+
with gr.Column(scale=1):
|
660 |
+
file_input1 = gr.File(
|
661 |
+
label="Upload first FASTA file",
|
662 |
+
file_types=[".fasta", ".fa", ".txt"],
|
663 |
+
type="filepath"
|
664 |
+
)
|
665 |
+
text_input1 = gr.Textbox(
|
666 |
+
label="Or paste first FASTA sequence",
|
667 |
+
placeholder=">sequence1\nACGTACGT...",
|
668 |
+
lines=5
|
669 |
+
)
|
670 |
+
|
671 |
+
with gr.Column(scale=1):
|
672 |
+
file_input2 = gr.File(
|
673 |
+
label="Upload second FASTA file",
|
674 |
+
file_types=[".fasta", ".fa", ".txt"],
|
675 |
+
type="filepath"
|
676 |
+
)
|
677 |
+
text_input2 = gr.Textbox(
|
678 |
+
label="Or paste second FASTA sequence",
|
679 |
+
placeholder=">sequence2\nACGTACGT...",
|
680 |
+
lines=5
|
681 |
+
)
|
682 |
+
|
683 |
+
compare_btn = gr.Button("Compare Sequences", variant="primary")
|
684 |
+
|
685 |
+
comparison_text = gr.Textbox(
|
686 |
+
label="Comparison Results",
|
687 |
+
lines=12,
|
688 |
+
interactive=False
|
689 |
+
)
|
690 |
+
|
691 |
+
with gr.Row():
|
692 |
+
diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
|
693 |
+
diff_hist = gr.Image(label="Distribution of SHAP Differences")
|
694 |
+
|
695 |
+
compare_btn.click(
|
696 |
+
analyze_sequence_comparison,
|
697 |
+
inputs=[file_input1, file_input2, text_input1, text_input2],
|
698 |
+
outputs=[comparison_text, diff_heatmap, diff_hist]
|
699 |
+
)
|
700 |
+
|
701 |
###############################################################################
|
702 |
# 9. BUILD GRADIO INTERFACE
|
703 |
###############################################################################
|