Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -82,7 +82,7 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
82 |
return vec
|
83 |
|
84 |
###############################################################################
|
85 |
-
# 3.
|
86 |
###############################################################################
|
87 |
|
88 |
def calculate_shap_values(model, x_tensor):
|
@@ -105,7 +105,7 @@ def calculate_shap_values(model, x_tensor):
|
|
105 |
|
106 |
|
107 |
###############################################################################
|
108 |
-
# 4. PER-BASE
|
109 |
###############################################################################
|
110 |
|
111 |
def compute_positionwise_scores(sequence, shap_values, k=4):
|
@@ -125,7 +125,7 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
|
|
125 |
return shap_means
|
126 |
|
127 |
###############################################################################
|
128 |
-
# 5. FIND EXTREME
|
129 |
###############################################################################
|
130 |
|
131 |
def find_extreme_subregion(shap_means, window_size=500, mode="max"):
|
@@ -166,7 +166,7 @@ def get_zero_centered_cmap():
|
|
166 |
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
|
167 |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
168 |
|
169 |
-
def plot_linear_heatmap(shap_means, title="Per-base
|
170 |
if start is not None and end is not None:
|
171 |
local_shap = shap_means[start:end]
|
172 |
subtitle = f" (positions {start}-{end})"
|
@@ -184,7 +184,7 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
|
|
184 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
|
185 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
186 |
cbar.ax.tick_params(labelsize=8)
|
187 |
-
cbar.set_label('
|
188 |
ax.set_yticks([])
|
189 |
ax.set_xlabel('Position in Sequence', fontsize=10)
|
190 |
ax.set_title(f"{title}{subtitle}", pad=10)
|
@@ -200,17 +200,17 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
|
|
200 |
colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
|
201 |
plt.barh(range(len(values)), values, color=colors)
|
202 |
plt.yticks(range(len(values)), features)
|
203 |
-
plt.xlabel('
|
204 |
plt.title(f'Top {top_k} Most Influential k-mers')
|
205 |
plt.gca().invert_yaxis()
|
206 |
plt.tight_layout()
|
207 |
return fig
|
208 |
|
209 |
-
def plot_shap_histogram(shap_array, title="
|
210 |
fig, ax = plt.subplots(figsize=(6, 4))
|
211 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
|
212 |
ax.axvline(0, color='red', linestyle='--', label='0.0')
|
213 |
-
ax.set_xlabel("
|
214 |
ax.set_ylabel("Count")
|
215 |
ax.set_title(title)
|
216 |
ax.legend()
|
@@ -227,23 +227,23 @@ def compute_gc_content(sequence):
|
|
227 |
# 7. MAIN ANALYSIS STEP (Gradio Step 1)
|
228 |
###############################################################################
|
229 |
def create_kmer_shap_csv(kmers, shap_values):
|
230 |
-
"""Create a CSV file with k-mer
|
231 |
-
# Create DataFrame with k-mers and
|
232 |
kmer_df = pd.DataFrame({
|
233 |
'kmer': kmers,
|
234 |
-
'
|
235 |
-
'
|
236 |
})
|
237 |
|
238 |
-
# Sort by absolute
|
239 |
-
kmer_df = kmer_df.sort_values('
|
240 |
|
241 |
-
# Drop the
|
242 |
-
kmer_df = kmer_df[['kmer', '
|
243 |
|
244 |
# Save to temporary file
|
245 |
temp_dir = tempfile.gettempdir()
|
246 |
-
temp_path = os.path.join(temp_dir, f"
|
247 |
kmer_df.to_csv(temp_path, index=False)
|
248 |
|
249 |
return temp_path # Return only the file path, not a tuple
|
@@ -296,19 +296,19 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
296 |
f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
|
297 |
f"---\n"
|
298 |
f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
|
299 |
-
f"Start: {max_start}, End: {max_end}, Avg
|
300 |
f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
|
301 |
-
f"Start: {min_start}, End: {min_end}, Avg
|
302 |
)
|
303 |
|
304 |
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
|
305 |
bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
|
306 |
bar_img = fig_to_image(bar_fig)
|
307 |
|
308 |
-
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide
|
309 |
heatmap_img = fig_to_image(heatmap_fig)
|
310 |
|
311 |
-
# Create CSV with k-mer
|
312 |
kmer_shap_csv = create_kmer_shap_csv(kmers, shap_values)
|
313 |
|
314 |
# State dictionary for subregion analysis
|
@@ -347,14 +347,14 @@ def analyze_subregion(state, header, region_start, region_end):
|
|
347 |
f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
|
348 |
f"Region length: {len(region_seq)} bases\n"
|
349 |
f"GC content: {gc_percent:.2f}%\n"
|
350 |
-
f"Average
|
351 |
-
f"Fraction with
|
352 |
-
f"Fraction with
|
353 |
f"Subregion interpretation: {region_classification}\n"
|
354 |
)
|
355 |
-
heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion
|
356 |
heatmap_img = fig_to_image(heatmap_fig)
|
357 |
-
hist_fig = plot_shap_histogram(region_shap, title="
|
358 |
hist_img = fig_to_image(hist_fig)
|
359 |
|
360 |
# For demonstration, returning None for the file download as well
|
@@ -370,10 +370,10 @@ def get_zero_centered_cmap():
|
|
370 |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
371 |
|
372 |
def compute_shap_difference(shap1_norm, shap2_norm):
|
373 |
-
"""Compute the
|
374 |
return shap2_norm - shap1_norm
|
375 |
|
376 |
-
def plot_comparative_heatmap(shap_diff, title="
|
377 |
"""
|
378 |
Plot heatmap using relative positions (0-100%)
|
379 |
"""
|
@@ -393,7 +393,7 @@ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
|
|
393 |
|
394 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
395 |
cbar.ax.tick_params(labelsize=8)
|
396 |
-
cbar.set_label('
|
397 |
|
398 |
ax.set_yticks([])
|
399 |
ax.set_xlabel('Relative Position in Sequence', fontsize=10)
|
@@ -402,14 +402,14 @@ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
|
|
402 |
|
403 |
return fig
|
404 |
|
405 |
-
def plot_shap_histogram(shap_array, title="
|
406 |
"""
|
407 |
-
Plot histogram of
|
408 |
"""
|
409 |
fig, ax = plt.subplots(figsize=(6, 4))
|
410 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
|
411 |
ax.axvline(0, color='red', linestyle='--', label='0.0')
|
412 |
-
ax.set_xlabel("
|
413 |
ax.set_ylabel("Count")
|
414 |
ax.set_title(title)
|
415 |
ax.legend()
|
@@ -483,7 +483,7 @@ def sliding_window_smooth(values, window_size=50):
|
|
483 |
|
484 |
def normalize_shap_lengths(shap1, shap2):
|
485 |
"""
|
486 |
-
Normalize and smooth
|
487 |
"""
|
488 |
# Calculate adaptive parameters
|
489 |
num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
|
@@ -517,7 +517,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
517 |
if isinstance(res2[0], str) and "Error" in res2[0]:
|
518 |
return (f"Error in sequence 2: {res2[0]}", None, None, None)
|
519 |
|
520 |
-
# Extract
|
521 |
shap1 = res1[3]["shap_means"]
|
522 |
shap2 = res2[3]["shap_means"]
|
523 |
|
@@ -567,7 +567,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
567 |
f"Smoothing Window: {smooth_window} points\n"
|
568 |
f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n"
|
569 |
"Statistics:\n"
|
570 |
-
f"Average
|
571 |
f"Standard deviation: {std_diff:.4f}\n"
|
572 |
f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
|
573 |
f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
|
@@ -582,7 +582,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
582 |
# Generate visualizations
|
583 |
heatmap_fig = plot_comparative_heatmap(
|
584 |
shap_diff,
|
585 |
-
title=f"
|
586 |
)
|
587 |
heatmap_img = fig_to_image(heatmap_fig)
|
588 |
|
@@ -590,7 +590,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
590 |
num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
|
591 |
hist_fig = plot_shap_histogram(
|
592 |
shap_diff,
|
593 |
-
title="Distribution of
|
594 |
num_bins=num_bins
|
595 |
)
|
596 |
hist_img = fig_to_image(hist_fig)
|
@@ -680,7 +680,7 @@ def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
|
|
680 |
return None, None
|
681 |
|
682 |
def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
|
683 |
-
"""Compute statistical measures for gene
|
684 |
return {
|
685 |
'avg_shap': float(np.mean(gene_shap)),
|
686 |
'median_shap': float(np.median(gene_shap)),
|
@@ -693,7 +693,7 @@ def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
|
|
693 |
def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
|
694 |
"""
|
695 |
Create a simple genome diagram using PIL, forcing a minimum color intensity
|
696 |
-
so that small
|
697 |
"""
|
698 |
from PIL import Image, ImageDraw, ImageFont
|
699 |
|
@@ -730,7 +730,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
|
|
730 |
title_font = ImageFont.load_default()
|
731 |
|
732 |
# Draw title
|
733 |
-
draw.text((margin, margin // 2), "Genome
|
734 |
|
735 |
# Draw genome line
|
736 |
line_y = height // 2
|
@@ -755,7 +755,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
|
|
755 |
], fill='black', width=1)
|
756 |
draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
|
757 |
|
758 |
-
# Sort genes by absolute
|
759 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
|
760 |
|
761 |
# Draw genes
|
@@ -764,10 +764,10 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
|
|
764 |
start_x = margin + int(gene['start'] * scale)
|
765 |
end_x = margin + int(gene['end'] * scale)
|
766 |
|
767 |
-
# Calculate color based on
|
768 |
avg_shap = gene['avg_shap']
|
769 |
|
770 |
-
# Convert
|
771 |
# Then clamp to a minimum intensity so it never ends up plain white
|
772 |
intensity = int(abs(avg_shap) * 500)
|
773 |
intensity = max(50, min(255, intensity)) # clamp between 50 and 255
|
@@ -813,7 +813,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
|
|
813 |
# Draw legend
|
814 |
legend_x = margin
|
815 |
legend_y = height - margin
|
816 |
-
draw.text((int(legend_x), int(legend_y - 60)), "
|
817 |
|
818 |
# Draw legend boxes
|
819 |
box_width = 20
|
@@ -858,13 +858,13 @@ def analyze_gene_features(sequence_file: str,
|
|
858 |
features_file: str,
|
859 |
fasta_text: str = "",
|
860 |
features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
|
861 |
-
"""Analyze
|
862 |
# First analyze whole sequence
|
863 |
sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
|
864 |
if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
|
865 |
return f"Error in sequence analysis: {sequence_results[0]}", None, None
|
866 |
|
867 |
-
# Get
|
868 |
shap_means = sequence_results[3]["shap_means"]
|
869 |
|
870 |
# Parse gene features
|
@@ -889,7 +889,7 @@ def analyze_gene_features(sequence_file: str,
|
|
889 |
if start is None or end is None:
|
890 |
continue
|
891 |
|
892 |
-
# Get
|
893 |
gene_shap = shap_means[start:end]
|
894 |
stats = compute_gene_statistics(gene_shap)
|
895 |
|
@@ -916,7 +916,7 @@ def analyze_gene_features(sequence_file: str,
|
|
916 |
if not gene_results:
|
917 |
return "No valid genes could be processed", None, None
|
918 |
|
919 |
-
# Sort genes by absolute
|
920 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
|
921 |
|
922 |
# Create results text
|
@@ -932,11 +932,11 @@ def analyze_gene_features(sequence_file: str,
|
|
932 |
f"Location: {gene['location']}\n"
|
933 |
f"Classification: {gene['classification']} "
|
934 |
f"(confidence: {gene['confidence']:.4f})\n"
|
935 |
-
f"Average
|
936 |
)
|
937 |
|
938 |
# Create CSV content
|
939 |
-
csv_content = "gene_name,location,
|
940 |
csv_content += "pos_fraction,classification,confidence,locus_tag\n"
|
941 |
|
942 |
for gene in gene_results:
|
@@ -1020,11 +1020,11 @@ with gr.Blocks(css=css) as iface:
|
|
1020 |
gr.Markdown("""
|
1021 |
# Virus Host Classifier
|
1022 |
**Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
|
1023 |
-
**Step 2**: Explore subregions to see local
|
1024 |
**Step 3**: Analyze gene features and their contributions.
|
1025 |
**Step 4**: Compare sequences and analyze differences.
|
1026 |
|
1027 |
-
**Color Scale**: Negative
|
1028 |
""")
|
1029 |
|
1030 |
with gr.Tab("1) Full-Sequence Analysis"):
|
@@ -1043,11 +1043,11 @@ with gr.Blocks(css=css) as iface:
|
|
1043 |
|
1044 |
with gr.Column(scale=2):
|
1045 |
results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
|
1046 |
-
kmer_img = gr.Image(label="Top k-mer
|
1047 |
-
genome_img = gr.Image(label="Genome-wide
|
1048 |
|
1049 |
# File components with the correct type parameter
|
1050 |
-
download_kmer_shap = gr.File(label="Download k-mer
|
1051 |
download_results = gr.File(label="Download Results", visible=True, elem_classes="download-button")
|
1052 |
|
1053 |
seq_state = gr.State()
|
@@ -1071,7 +1071,7 @@ with gr.Blocks(css=css) as iface:
|
|
1071 |
with gr.Tab("2) Subregion Exploration"):
|
1072 |
gr.Markdown("""
|
1073 |
**Subregion Analysis**
|
1074 |
-
Select start/end positions to view local
|
1075 |
The heatmap uses the same Blue-White-Red scale.
|
1076 |
""")
|
1077 |
with gr.Row():
|
@@ -1080,8 +1080,8 @@ with gr.Blocks(css=css) as iface:
|
|
1080 |
region_btn = gr.Button("Analyze Subregion")
|
1081 |
subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False)
|
1082 |
with gr.Row():
|
1083 |
-
subregion_img = gr.Image(label="Subregion
|
1084 |
-
subregion_hist_img = gr.Image(label="
|
1085 |
download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button")
|
1086 |
|
1087 |
region_btn.click(
|
@@ -1093,12 +1093,11 @@ with gr.Blocks(css=css) as iface:
|
|
1093 |
with gr.Tab("3) Gene Features Analysis"):
|
1094 |
gr.Markdown("""
|
1095 |
**Analyze Gene Features**
|
1096 |
-
Upload a FASTA file and corresponding gene features file to analyze
|
1097 |
Gene features should be in the format:
|
1098 |
|
1099 |
>gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)]
|
1100 |
SEQUENCE
|
1101 |
-
|
1102 |
The genome viewer will show genes color-coded by their contribution:
|
1103 |
- Red: Genes pushing toward human origin
|
1104 |
- Blue: Genes pushing toward non-human origin
|
@@ -1126,7 +1125,7 @@ with gr.Blocks(css=css) as iface:
|
|
1126 |
with gr.Tab("4) Comparative Analysis"):
|
1127 |
gr.Markdown("""
|
1128 |
**Compare Two Sequences**
|
1129 |
-
Upload or paste two FASTA sequences to compare their
|
1130 |
The sequences will be normalized to the same length for comparison.
|
1131 |
|
1132 |
**Color Scale**:
|
@@ -1144,8 +1143,8 @@ with gr.Blocks(css=css) as iface:
|
|
1144 |
compare_btn = gr.Button("Compare Sequences", variant="primary")
|
1145 |
comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
|
1146 |
with gr.Row():
|
1147 |
-
diff_heatmap = gr.Image(label="
|
1148 |
-
diff_hist = gr.Image(label="Distribution of
|
1149 |
download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button")
|
1150 |
|
1151 |
compare_btn.click(
|
@@ -1157,8 +1156,8 @@ with gr.Blocks(css=css) as iface:
|
|
1157 |
gr.Markdown("""
|
1158 |
### Interface Features
|
1159 |
- **Overall Classification** (human vs non-human) using k-mer frequencies
|
1160 |
-
- **
|
1161 |
-
- **White-Centered
|
1162 |
- Negative (blue), 0 (white), Positive (red)
|
1163 |
- Symmetrical color range around 0
|
1164 |
- **Identify Subregions** with strongest push for human or non-human
|
@@ -1172,7 +1171,7 @@ with gr.Blocks(css=css) as iface:
|
|
1172 |
- Statistical summary of differences
|
1173 |
- **Data Export**:
|
1174 |
- Download results as CSV files
|
1175 |
-
- Download k-mer
|
1176 |
- Save analysis outputs for further processing
|
1177 |
""")
|
1178 |
|
|
|
82 |
return vec
|
83 |
|
84 |
###############################################################################
|
85 |
+
# 3. FEATURE IMPORTANCE (ABLATION) CALCULATION
|
86 |
###############################################################################
|
87 |
|
88 |
def calculate_shap_values(model, x_tensor):
|
|
|
105 |
|
106 |
|
107 |
###############################################################################
|
108 |
+
# 4. PER-BASE FEATURE IMPORTANCE AGGREGATION
|
109 |
###############################################################################
|
110 |
|
111 |
def compute_positionwise_scores(sequence, shap_values, k=4):
|
|
|
125 |
return shap_means
|
126 |
|
127 |
###############################################################################
|
128 |
+
# 5. FIND EXTREME IMPORTANCE REGIONS
|
129 |
###############################################################################
|
130 |
|
131 |
def find_extreme_subregion(shap_means, window_size=500, mode="max"):
|
|
|
166 |
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
|
167 |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
168 |
|
169 |
+
def plot_linear_heatmap(shap_means, title="Per-base Feature Importance Heatmap", start=None, end=None):
|
170 |
if start is not None and end is not None:
|
171 |
local_shap = shap_means[start:end]
|
172 |
subtitle = f" (positions {start}-{end})"
|
|
|
184 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
|
185 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
186 |
cbar.ax.tick_params(labelsize=8)
|
187 |
+
cbar.set_label('Feature Importance', fontsize=9, labelpad=5)
|
188 |
ax.set_yticks([])
|
189 |
ax.set_xlabel('Position in Sequence', fontsize=10)
|
190 |
ax.set_title(f"{title}{subtitle}", pad=10)
|
|
|
200 |
colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
|
201 |
plt.barh(range(len(values)), values, color=colors)
|
202 |
plt.yticks(range(len(values)), features)
|
203 |
+
plt.xlabel('Feature Importance (impact on model output)')
|
204 |
plt.title(f'Top {top_k} Most Influential k-mers')
|
205 |
plt.gca().invert_yaxis()
|
206 |
plt.tight_layout()
|
207 |
return fig
|
208 |
|
209 |
+
def plot_shap_histogram(shap_array, title="Feature Importance Distribution in Region", num_bins=30):
|
210 |
fig, ax = plt.subplots(figsize=(6, 4))
|
211 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
|
212 |
ax.axvline(0, color='red', linestyle='--', label='0.0')
|
213 |
+
ax.set_xlabel("Feature Importance Value")
|
214 |
ax.set_ylabel("Count")
|
215 |
ax.set_title(title)
|
216 |
ax.legend()
|
|
|
227 |
# 7. MAIN ANALYSIS STEP (Gradio Step 1)
|
228 |
###############################################################################
|
229 |
def create_kmer_shap_csv(kmers, shap_values):
|
230 |
+
"""Create a CSV file with k-mer importance values and return the filepath"""
|
231 |
+
# Create DataFrame with k-mers and importance values
|
232 |
kmer_df = pd.DataFrame({
|
233 |
'kmer': kmers,
|
234 |
+
'importance_value': shap_values,
|
235 |
+
'abs_importance': np.abs(shap_values)
|
236 |
})
|
237 |
|
238 |
+
# Sort by absolute importance value (most influential first)
|
239 |
+
kmer_df = kmer_df.sort_values('abs_importance', ascending=False)
|
240 |
|
241 |
+
# Drop the abs_importance column used for sorting
|
242 |
+
kmer_df = kmer_df[['kmer', 'importance_value']]
|
243 |
|
244 |
# Save to temporary file
|
245 |
temp_dir = tempfile.gettempdir()
|
246 |
+
temp_path = os.path.join(temp_dir, f"kmer_importance_values_{os.urandom(4).hex()}.csv")
|
247 |
kmer_df.to_csv(temp_path, index=False)
|
248 |
|
249 |
return temp_path # Return only the file path, not a tuple
|
|
|
296 |
f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
|
297 |
f"---\n"
|
298 |
f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
|
299 |
+
f"Start: {max_start}, End: {max_end}, Avg Importance: {max_avg:.4f}\n\n"
|
300 |
f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
|
301 |
+
f"Start: {min_start}, End: {min_end}, Avg Importance: {min_avg:.4f}"
|
302 |
)
|
303 |
|
304 |
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
|
305 |
bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
|
306 |
bar_img = fig_to_image(bar_fig)
|
307 |
|
308 |
+
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide Feature Importance")
|
309 |
heatmap_img = fig_to_image(heatmap_fig)
|
310 |
|
311 |
+
# Create CSV with k-mer importance values and return the file path
|
312 |
kmer_shap_csv = create_kmer_shap_csv(kmers, shap_values)
|
313 |
|
314 |
# State dictionary for subregion analysis
|
|
|
347 |
f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
|
348 |
f"Region length: {len(region_seq)} bases\n"
|
349 |
f"GC content: {gc_percent:.2f}%\n"
|
350 |
+
f"Average importance in region: {avg_shap:.4f}\n"
|
351 |
+
f"Fraction with importance > 0 (toward human): {positive_fraction:.2f}\n"
|
352 |
+
f"Fraction with importance < 0 (toward non-human): {negative_fraction:.2f}\n"
|
353 |
f"Subregion interpretation: {region_classification}\n"
|
354 |
)
|
355 |
+
heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion Feature Importance", start=region_start, end=region_end)
|
356 |
heatmap_img = fig_to_image(heatmap_fig)
|
357 |
+
hist_fig = plot_shap_histogram(region_shap, title="Feature Importance Distribution in Subregion")
|
358 |
hist_img = fig_to_image(hist_fig)
|
359 |
|
360 |
# For demonstration, returning None for the file download as well
|
|
|
370 |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
371 |
|
372 |
def compute_shap_difference(shap1_norm, shap2_norm):
|
373 |
+
"""Compute the feature importance difference between normalized sequences"""
|
374 |
return shap2_norm - shap1_norm
|
375 |
|
376 |
+
def plot_comparative_heatmap(shap_diff, title="Feature Importance Difference Heatmap"):
|
377 |
"""
|
378 |
Plot heatmap using relative positions (0-100%)
|
379 |
"""
|
|
|
393 |
|
394 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
395 |
cbar.ax.tick_params(labelsize=8)
|
396 |
+
cbar.set_label('Feature Importance Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
|
397 |
|
398 |
ax.set_yticks([])
|
399 |
ax.set_xlabel('Relative Position in Sequence', fontsize=10)
|
|
|
402 |
|
403 |
return fig
|
404 |
|
405 |
+
def plot_shap_histogram(shap_array, title="Feature Importance Distribution", num_bins=30):
|
406 |
"""
|
407 |
+
Plot histogram of feature importance values with configurable number of bins
|
408 |
"""
|
409 |
fig, ax = plt.subplots(figsize=(6, 4))
|
410 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
|
411 |
ax.axvline(0, color='red', linestyle='--', label='0.0')
|
412 |
+
ax.set_xlabel("Feature Importance Value")
|
413 |
ax.set_ylabel("Count")
|
414 |
ax.set_title(title)
|
415 |
ax.legend()
|
|
|
483 |
|
484 |
def normalize_shap_lengths(shap1, shap2):
|
485 |
"""
|
486 |
+
Normalize and smooth feature importance values with dynamic adaptation
|
487 |
"""
|
488 |
# Calculate adaptive parameters
|
489 |
num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
|
|
|
517 |
if isinstance(res2[0], str) and "Error" in res2[0]:
|
518 |
return (f"Error in sequence 2: {res2[0]}", None, None, None)
|
519 |
|
520 |
+
# Extract feature importance values and sequence info
|
521 |
shap1 = res1[3]["shap_means"]
|
522 |
shap2 = res2[3]["shap_means"]
|
523 |
|
|
|
567 |
f"Smoothing Window: {smooth_window} points\n"
|
568 |
f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n"
|
569 |
"Statistics:\n"
|
570 |
+
f"Average feature importance difference: {avg_diff:.4f}\n"
|
571 |
f"Standard deviation: {std_diff:.4f}\n"
|
572 |
f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
|
573 |
f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
|
|
|
582 |
# Generate visualizations
|
583 |
heatmap_fig = plot_comparative_heatmap(
|
584 |
shap_diff,
|
585 |
+
title=f"Feature Importance Difference Heatmap (window: {smooth_window})"
|
586 |
)
|
587 |
heatmap_img = fig_to_image(heatmap_fig)
|
588 |
|
|
|
590 |
num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
|
591 |
hist_fig = plot_shap_histogram(
|
592 |
shap_diff,
|
593 |
+
title="Distribution of Feature Importance Differences",
|
594 |
num_bins=num_bins
|
595 |
)
|
596 |
hist_img = fig_to_image(hist_fig)
|
|
|
680 |
return None, None
|
681 |
|
682 |
def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
|
683 |
+
"""Compute statistical measures for gene feature importance values"""
|
684 |
return {
|
685 |
'avg_shap': float(np.mean(gene_shap)),
|
686 |
'median_shap': float(np.median(gene_shap)),
|
|
|
693 |
def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
|
694 |
"""
|
695 |
Create a simple genome diagram using PIL, forcing a minimum color intensity
|
696 |
+
so that small feature importance values don't appear white.
|
697 |
"""
|
698 |
from PIL import Image, ImageDraw, ImageFont
|
699 |
|
|
|
730 |
title_font = ImageFont.load_default()
|
731 |
|
732 |
# Draw title
|
733 |
+
draw.text((margin, margin // 2), "Genome Feature Importance Analysis", fill='black', font=title_font or font)
|
734 |
|
735 |
# Draw genome line
|
736 |
line_y = height // 2
|
|
|
755 |
], fill='black', width=1)
|
756 |
draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
|
757 |
|
758 |
+
# Sort genes by absolute feature importance value for drawing
|
759 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
|
760 |
|
761 |
# Draw genes
|
|
|
764 |
start_x = margin + int(gene['start'] * scale)
|
765 |
end_x = margin + int(gene['end'] * scale)
|
766 |
|
767 |
+
# Calculate color based on feature importance value
|
768 |
avg_shap = gene['avg_shap']
|
769 |
|
770 |
+
# Convert importance -> color intensity (0 to 255)
|
771 |
# Then clamp to a minimum intensity so it never ends up plain white
|
772 |
intensity = int(abs(avg_shap) * 500)
|
773 |
intensity = max(50, min(255, intensity)) # clamp between 50 and 255
|
|
|
813 |
# Draw legend
|
814 |
legend_x = margin
|
815 |
legend_y = height - margin
|
816 |
+
draw.text((int(legend_x), int(legend_y - 60)), "Feature Importance Values:", fill='black', font=font)
|
817 |
|
818 |
# Draw legend boxes
|
819 |
box_width = 20
|
|
|
858 |
features_file: str,
|
859 |
fasta_text: str = "",
|
860 |
features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
|
861 |
+
"""Analyze feature importance values for each gene feature"""
|
862 |
# First analyze whole sequence
|
863 |
sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
|
864 |
if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
|
865 |
return f"Error in sequence analysis: {sequence_results[0]}", None, None
|
866 |
|
867 |
+
# Get feature importance values
|
868 |
shap_means = sequence_results[3]["shap_means"]
|
869 |
|
870 |
# Parse gene features
|
|
|
889 |
if start is None or end is None:
|
890 |
continue
|
891 |
|
892 |
+
# Get feature importance values for this region
|
893 |
gene_shap = shap_means[start:end]
|
894 |
stats = compute_gene_statistics(gene_shap)
|
895 |
|
|
|
916 |
if not gene_results:
|
917 |
return "No valid genes could be processed", None, None
|
918 |
|
919 |
+
# Sort genes by absolute feature importance value
|
920 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
|
921 |
|
922 |
# Create results text
|
|
|
932 |
f"Location: {gene['location']}\n"
|
933 |
f"Classification: {gene['classification']} "
|
934 |
f"(confidence: {gene['confidence']:.4f})\n"
|
935 |
+
f"Average Feature Importance: {gene['avg_shap']:.4f}\n\n"
|
936 |
)
|
937 |
|
938 |
# Create CSV content
|
939 |
+
csv_content = "gene_name,location,avg_importance,median_importance,std_importance,max_importance,min_importance,"
|
940 |
csv_content += "pos_fraction,classification,confidence,locus_tag\n"
|
941 |
|
942 |
for gene in gene_results:
|
|
|
1020 |
gr.Markdown("""
|
1021 |
# Virus Host Classifier
|
1022 |
**Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
|
1023 |
+
**Step 2**: Explore subregions to see local feature influence, distribution, GC content, etc.
|
1024 |
**Step 3**: Analyze gene features and their contributions.
|
1025 |
**Step 4**: Compare sequences and analyze differences.
|
1026 |
|
1027 |
+
**Color Scale**: Negative values = Blue, Zero = White, Positive values = Red.
|
1028 |
""")
|
1029 |
|
1030 |
with gr.Tab("1) Full-Sequence Analysis"):
|
|
|
1043 |
|
1044 |
with gr.Column(scale=2):
|
1045 |
results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
|
1046 |
+
kmer_img = gr.Image(label="Top k-mer Importance")
|
1047 |
+
genome_img = gr.Image(label="Genome-wide Feature Importance Heatmap (Blue=neg, White=0, Red=pos)")
|
1048 |
|
1049 |
# File components with the correct type parameter
|
1050 |
+
download_kmer_shap = gr.File(label="Download k-mer Importance Values (CSV)", visible=True, type="filepath")
|
1051 |
download_results = gr.File(label="Download Results", visible=True, elem_classes="download-button")
|
1052 |
|
1053 |
seq_state = gr.State()
|
|
|
1071 |
with gr.Tab("2) Subregion Exploration"):
|
1072 |
gr.Markdown("""
|
1073 |
**Subregion Analysis**
|
1074 |
+
Select start/end positions to view local feature importance, distribution, GC content, etc.
|
1075 |
The heatmap uses the same Blue-White-Red scale.
|
1076 |
""")
|
1077 |
with gr.Row():
|
|
|
1080 |
region_btn = gr.Button("Analyze Subregion")
|
1081 |
subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False)
|
1082 |
with gr.Row():
|
1083 |
+
subregion_img = gr.Image(label="Subregion Feature Importance Heatmap (B-W-R)")
|
1084 |
+
subregion_hist_img = gr.Image(label="Feature Importance Distribution (Histogram)")
|
1085 |
download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button")
|
1086 |
|
1087 |
region_btn.click(
|
|
|
1093 |
with gr.Tab("3) Gene Features Analysis"):
|
1094 |
gr.Markdown("""
|
1095 |
**Analyze Gene Features**
|
1096 |
+
Upload a FASTA file and corresponding gene features file to analyze feature importance values per gene.
|
1097 |
Gene features should be in the format:
|
1098 |
|
1099 |
>gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)]
|
1100 |
SEQUENCE
|
|
|
1101 |
The genome viewer will show genes color-coded by their contribution:
|
1102 |
- Red: Genes pushing toward human origin
|
1103 |
- Blue: Genes pushing toward non-human origin
|
|
|
1125 |
with gr.Tab("4) Comparative Analysis"):
|
1126 |
gr.Markdown("""
|
1127 |
**Compare Two Sequences**
|
1128 |
+
Upload or paste two FASTA sequences to compare their feature importance patterns.
|
1129 |
The sequences will be normalized to the same length for comparison.
|
1130 |
|
1131 |
**Color Scale**:
|
|
|
1143 |
compare_btn = gr.Button("Compare Sequences", variant="primary")
|
1144 |
comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
|
1145 |
with gr.Row():
|
1146 |
+
diff_heatmap = gr.Image(label="Feature Importance Difference Heatmap")
|
1147 |
+
diff_hist = gr.Image(label="Distribution of Feature Importance Differences")
|
1148 |
download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button")
|
1149 |
|
1150 |
compare_btn.click(
|
|
|
1156 |
gr.Markdown("""
|
1157 |
### Interface Features
|
1158 |
- **Overall Classification** (human vs non-human) using k-mer frequencies
|
1159 |
+
- **Feature Importance Analysis** shows which k-mers push classification toward or away from human
|
1160 |
+
- **White-Centered Gradient**:
|
1161 |
- Negative (blue), 0 (white), Positive (red)
|
1162 |
- Symmetrical color range around 0
|
1163 |
- **Identify Subregions** with strongest push for human or non-human
|
|
|
1171 |
- Statistical summary of differences
|
1172 |
- **Data Export**:
|
1173 |
- Download results as CSV files
|
1174 |
+
- Download k-mer importance values
|
1175 |
- Save analysis outputs for further processing
|
1176 |
""")
|
1177 |
|