Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,6 @@ import matplotlib.colors as mcolors
|
|
9 |
import io
|
10 |
from PIL import Image
|
11 |
from scipy.interpolate import interp1d
|
12 |
-
import numpy as np
|
13 |
|
14 |
###############################################################################
|
15 |
# 1. MODEL DEFINITION
|
@@ -317,91 +316,90 @@ def compute_gc_content(sequence):
|
|
317 |
return (gc_count / len(sequence)) * 100.0
|
318 |
|
319 |
###############################################################################
|
320 |
-
# 7.
|
321 |
###############################################################################
|
322 |
|
323 |
-
def
|
324 |
"""
|
325 |
-
|
326 |
-
Returns
|
327 |
"""
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
"
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
###############################################################################
|
405 |
|
406 |
def analyze_subregion(state, header, region_start, region_end):
|
407 |
"""
|
@@ -468,9 +466,8 @@ def analyze_subregion(state, header, region_start, region_end):
|
|
468 |
|
469 |
return (region_info, heatmap_img, hist_img)
|
470 |
|
471 |
-
|
472 |
###############################################################################
|
473 |
-
#
|
474 |
###############################################################################
|
475 |
|
476 |
def normalize_shap_lengths(shap1, shap2, num_points=1000):
|
@@ -583,7 +580,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
583 |
# Compute difference (positive = seq2 more human-like)
|
584 |
shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
|
585 |
|
586 |
-
# Calculate
|
587 |
avg_diff = np.mean(shap_diff)
|
588 |
std_diff = np.std(shap_diff)
|
589 |
max_diff = np.max(shap_diff)
|
@@ -594,7 +591,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
594 |
substantial_diffs = np.abs(shap_diff) > threshold
|
595 |
frac_different = np.mean(substantial_diffs)
|
596 |
|
597 |
-
# Extract classifications safely
|
598 |
classification1 = results1[0].split('Classification: ')[1].split('\n')[0].strip()
|
599 |
classification2 = results2[0].split('Classification: ')[1].split('\n')[0].strip()
|
600 |
|
@@ -603,7 +600,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
603 |
len2_formatted = "{:,}".format(len(shap2))
|
604 |
frac_formatted = "{:.2%}".format(frac_different)
|
605 |
|
606 |
-
# Build comparison text
|
607 |
comparison_text = (
|
608 |
"Sequence Comparison Results:\n"
|
609 |
f"Sequence 1: {results1[4]}\n"
|
@@ -635,7 +632,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
635 |
hist_img = fig_to_image(hist_fig)
|
636 |
|
637 |
return comparison_text, heatmap_img, hist_img
|
638 |
-
|
639 |
###############################################################################
|
640 |
# 9. BUILD GRADIO INTERFACE
|
641 |
###############################################################################
|
@@ -694,7 +691,6 @@ with gr.Blocks(css=css) as iface:
|
|
694 |
seq_state = gr.State()
|
695 |
header_state = gr.State()
|
696 |
|
697 |
-
# analyze_sequence(...) returns 5 items
|
698 |
analyze_btn.click(
|
699 |
analyze_sequence,
|
700 |
inputs=[file_input, top_k, text_input, win_size],
|
@@ -781,6 +777,7 @@ with gr.Blocks(css=css) as iface:
|
|
781 |
inputs=[file_input1, file_input2, text_input1, text_input2],
|
782 |
outputs=[comparison_text, diff_heatmap, diff_hist]
|
783 |
)
|
|
|
784 |
gr.Markdown("""
|
785 |
### Interface Features
|
786 |
- **Overall Classification** (human vs non-human) using k-mer frequencies.
|
@@ -793,7 +790,28 @@ with gr.Blocks(css=css) as iface:
|
|
793 |
- GC content
|
794 |
- Fraction of positions pushing human vs. non-human
|
795 |
- Simple logic-based classification
|
|
|
|
|
|
|
|
|
796 |
""")
|
797 |
|
|
|
|
|
|
|
|
|
798 |
if __name__ == "__main__":
|
799 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import io
|
10 |
from PIL import Image
|
11 |
from scipy.interpolate import interp1d
|
|
|
12 |
|
13 |
###############################################################################
|
14 |
# 1. MODEL DEFINITION
|
|
|
316 |
return (gc_count / len(sequence)) * 100.0
|
317 |
|
318 |
###############################################################################
|
319 |
+
# 7. SEQUENCE ANALYSIS FUNCTIONS
|
320 |
###############################################################################
|
321 |
|
322 |
+
def analyze_sequence(file_path, top_k=10, fasta_text="", window_size=500):
|
323 |
"""
|
324 |
+
Analyze a virus sequence from a FASTA file or text input.
|
325 |
+
Returns (results_text, kmer_plot, heatmap_plot, state_dict, header)
|
326 |
"""
|
327 |
+
try:
|
328 |
+
# Load model and k-mer info
|
329 |
+
model = VirusClassifier(256) # 4^4 = 256 k-mers for k=4
|
330 |
+
model.load_state_dict(torch.load("model.pt"))
|
331 |
+
model.eval()
|
332 |
+
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
|
333 |
+
|
334 |
+
# Process input (file takes precedence over text)
|
335 |
+
if file_path:
|
336 |
+
with open(file_path, 'r') as f:
|
337 |
+
fasta_text = f.read()
|
338 |
+
|
339 |
+
if not fasta_text.strip():
|
340 |
+
return ("Error: No sequence provided", None, None, {}, "")
|
341 |
+
|
342 |
+
# Parse FASTA
|
343 |
+
sequences = parse_fasta(fasta_text)
|
344 |
+
if not sequences:
|
345 |
+
return ("Error: No valid FASTA sequences found", None, None, {}, "")
|
346 |
+
|
347 |
+
header, sequence = sequences[0] # Take first sequence
|
348 |
+
|
349 |
+
# Convert to k-mer frequencies
|
350 |
+
x = sequence_to_kmer_vector(sequence)
|
351 |
+
x_tensor = torch.tensor(x).float().unsqueeze(0)
|
352 |
+
|
353 |
+
# Get model prediction
|
354 |
+
with torch.no_grad():
|
355 |
+
output = model(x_tensor)
|
356 |
+
probs = torch.softmax(output, dim=1)
|
357 |
+
pred_human = probs[0, 1].item()
|
358 |
+
|
359 |
+
# Calculate SHAP values
|
360 |
+
shap_values, prob = calculate_shap_values(model, x_tensor)
|
361 |
+
|
362 |
+
# Find most extreme regions
|
363 |
+
shap_means = compute_positionwise_scores(sequence, shap_values)
|
364 |
+
start_max, end_max, avg_max = find_extreme_subregion(shap_means, window_size, mode="max")
|
365 |
+
start_min, end_min, avg_min = find_extreme_subregion(shap_means, window_size, mode="min")
|
366 |
+
|
367 |
+
# Format results text
|
368 |
+
classification = "Human" if pred_human > 0.5 else "Non-human"
|
369 |
+
results = (
|
370 |
+
f"Classification: {classification} "
|
371 |
+
f"(probability of human = {pred_human:.3f})\n\n"
|
372 |
+
f"Sequence length: {len(sequence):,} bases\n"
|
373 |
+
f"Overall GC content: {compute_gc_content(sequence):.1f}%\n\n"
|
374 |
+
f"Most human-like {window_size}bp region:\n"
|
375 |
+
f"Position {start_max:,} to {end_max:,}\n"
|
376 |
+
f"Average SHAP: {avg_max:.4f}\n"
|
377 |
+
f"GC content: {compute_gc_content(sequence[start_max:end_max]):.1f}%\n\n"
|
378 |
+
f"Least human-like {window_size}bp region:\n"
|
379 |
+
f"Position {start_min:,} to {end_min:,}\n"
|
380 |
+
f"Average SHAP: {avg_min:.4f}\n"
|
381 |
+
f"GC content: {compute_gc_content(sequence[start_min:end_min]):.1f}%"
|
382 |
+
)
|
383 |
+
|
384 |
+
# Create k-mer importance plot
|
385 |
+
kmer_fig = create_importance_bar_plot(shap_values, kmers, top_k)
|
386 |
+
kmer_img = fig_to_image(kmer_fig)
|
387 |
+
|
388 |
+
# Create genome-wide heatmap
|
389 |
+
heatmap_fig = plot_linear_heatmap(shap_means)
|
390 |
+
heatmap_img = fig_to_image(heatmap_fig)
|
391 |
+
|
392 |
+
# Store data for subregion analysis
|
393 |
+
state = {
|
394 |
+
"seq": sequence,
|
395 |
+
"shap_means": shap_means
|
396 |
+
}
|
397 |
+
|
398 |
+
return results, kmer_img, heatmap_img, state, header
|
399 |
+
|
400 |
+
except Exception as e:
|
401 |
+
error_msg = f"Error analyzing sequence: {str(e)}"
|
402 |
+
return (error_msg, None, None, {}, "")
|
|
|
403 |
|
404 |
def analyze_subregion(state, header, region_start, region_end):
|
405 |
"""
|
|
|
466 |
|
467 |
return (region_info, heatmap_img, hist_img)
|
468 |
|
|
|
469 |
###############################################################################
|
470 |
+
# 8. COMPARISON ANALYSIS FUNCTIONS
|
471 |
###############################################################################
|
472 |
|
473 |
def normalize_shap_lengths(shap1, shap2, num_points=1000):
|
|
|
580 |
# Compute difference (positive = seq2 more human-like)
|
581 |
shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
|
582 |
|
583 |
+
# Calculate statistics
|
584 |
avg_diff = np.mean(shap_diff)
|
585 |
std_diff = np.std(shap_diff)
|
586 |
max_diff = np.max(shap_diff)
|
|
|
591 |
substantial_diffs = np.abs(shap_diff) > threshold
|
592 |
frac_different = np.mean(substantial_diffs)
|
593 |
|
594 |
+
# Extract classifications safely
|
595 |
classification1 = results1[0].split('Classification: ')[1].split('\n')[0].strip()
|
596 |
classification2 = results2[0].split('Classification: ')[1].split('\n')[0].strip()
|
597 |
|
|
|
600 |
len2_formatted = "{:,}".format(len(shap2))
|
601 |
frac_formatted = "{:.2%}".format(frac_different)
|
602 |
|
603 |
+
# Build comparison text
|
604 |
comparison_text = (
|
605 |
"Sequence Comparison Results:\n"
|
606 |
f"Sequence 1: {results1[4]}\n"
|
|
|
632 |
hist_img = fig_to_image(hist_fig)
|
633 |
|
634 |
return comparison_text, heatmap_img, hist_img
|
635 |
+
|
636 |
###############################################################################
|
637 |
# 9. BUILD GRADIO INTERFACE
|
638 |
###############################################################################
|
|
|
691 |
seq_state = gr.State()
|
692 |
header_state = gr.State()
|
693 |
|
|
|
694 |
analyze_btn.click(
|
695 |
analyze_sequence,
|
696 |
inputs=[file_input, top_k, text_input, win_size],
|
|
|
777 |
inputs=[file_input1, file_input2, text_input1, text_input2],
|
778 |
outputs=[comparison_text, diff_heatmap, diff_hist]
|
779 |
)
|
780 |
+
|
781 |
gr.Markdown("""
|
782 |
### Interface Features
|
783 |
- **Overall Classification** (human vs non-human) using k-mer frequencies.
|
|
|
790 |
- GC content
|
791 |
- Fraction of positions pushing human vs. non-human
|
792 |
- Simple logic-based classification
|
793 |
+
- **Sequence Comparison**:
|
794 |
+
- Compare two sequences to identify regions of difference
|
795 |
+
- Normalized comparison to handle different sequence lengths
|
796 |
+
- Statistical summary of differences
|
797 |
""")
|
798 |
|
799 |
+
###############################################################################
|
800 |
+
# 10. MAIN EXECUTION
|
801 |
+
###############################################################################
|
802 |
+
|
803 |
if __name__ == "__main__":
|
804 |
+
# Set up any global configurations if needed
|
805 |
+
plt.style.use('default')
|
806 |
+
plt.rcParams['figure.figsize'] = [10, 6]
|
807 |
+
plt.rcParams['figure.dpi'] = 100
|
808 |
+
plt.rcParams['font.size'] = 10
|
809 |
+
|
810 |
+
# Launch the interface
|
811 |
+
iface.launch(
|
812 |
+
share=False, # Set to True to create a public link
|
813 |
+
server_name="0.0.0.0", # Listen on all network interfaces
|
814 |
+
server_port=7860, # Default Gradio port
|
815 |
+
show_api=False, # Hide API docs
|
816 |
+
debug=False # Set to True for debugging
|
817 |
+
)
|