Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -140,28 +140,29 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
|
|
140 |
"""
|
141 |
Finds the subregion of length `window_size` that has the maximum
|
142 |
(mode="max") or minimum (mode="min") average SHAP.
|
143 |
-
Returns (best_start, best_end,
|
144 |
"""
|
145 |
n = len(shap_means)
|
|
|
|
|
|
|
146 |
if window_size >= n:
|
147 |
-
# If the window is bigger than the entire sequence, return
|
148 |
-
avg_val = np.mean(shap_means)
|
149 |
return (0, n, avg_val)
|
150 |
|
151 |
-
#
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
return csum[end] - csum[start]
|
157 |
-
|
158 |
best_start = 0
|
159 |
-
# Initialize
|
160 |
-
best_sum =
|
161 |
best_avg = best_sum / window_size
|
162 |
|
163 |
for start in range(1, n - window_size + 1):
|
164 |
-
wsum =
|
165 |
wavg = wsum / window_size
|
166 |
if mode == "max":
|
167 |
if wavg > best_avg:
|
@@ -172,7 +173,7 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
|
|
172 |
best_avg = wavg
|
173 |
best_start = start
|
174 |
|
175 |
-
return (best_start, best_start + window_size, best_avg)
|
176 |
|
177 |
###############################################################################
|
178 |
# 6. PLOTTING / UTILITIES
|
@@ -192,10 +193,9 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
|
|
192 |
Plots a 1D heatmap of per-base SHAP contributions.
|
193 |
Negative = push toward Non-Human, Positive = push toward Human.
|
194 |
Optionally can show only a subrange (start:end).
|
195 |
-
|
196 |
-
We adjust layout so the colorbar is well below the x-axis:
|
197 |
- orientation='horizontal', pad=0.35
|
198 |
-
- plt.subplots_adjust(bottom=0.4)
|
199 |
"""
|
200 |
if start is not None and end is not None:
|
201 |
shap_means = shap_means[start:end]
|
@@ -294,11 +294,15 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
294 |
# Load model and scaler
|
295 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
296 |
try:
|
|
|
|
|
297 |
model = VirusClassifier(256).to(device)
|
298 |
-
model.load_state_dict(
|
|
|
|
|
299 |
scaler = joblib.load('scaler.pkl')
|
300 |
except Exception as e:
|
301 |
-
return (f"Error loading model: {str(e)}", None, None, None, None)
|
302 |
|
303 |
# Vectorize + scale
|
304 |
freq_vector = sequence_to_kmer_vector(seq)
|
@@ -344,13 +348,13 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
344 |
heatmap_img = fig_to_image(heatmap_fig)
|
345 |
|
346 |
# Store data for subregion analysis
|
347 |
-
|
348 |
"seq": seq,
|
349 |
"shap_means": shap_means
|
350 |
}
|
351 |
|
352 |
-
#
|
353 |
-
return (results_text, bar_img, heatmap_img,
|
354 |
|
355 |
###############################################################################
|
356 |
# 8. SUBREGION ANALYSIS (Gradio Step 2)
|
@@ -475,16 +479,10 @@ with gr.Blocks(css=css) as iface:
|
|
475 |
kmer_img = gr.Image(label="Top k-mer SHAP")
|
476 |
genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
|
477 |
|
478 |
-
# State for step 2
|
479 |
seq_state = gr.State()
|
480 |
header_state = gr.State()
|
481 |
|
482 |
-
# analyze_sequence(...)
|
483 |
-
# 1) results_text
|
484 |
-
# 2) bar_img
|
485 |
-
# 3) heatmap_img
|
486 |
-
# 4) state_dict
|
487 |
-
# 5) header
|
488 |
analyze_btn.click(
|
489 |
analyze_sequence,
|
490 |
inputs=[file_input, top_k, text_input, win_size],
|
@@ -517,17 +515,12 @@ with gr.Blocks(css=css) as iface:
|
|
517 |
)
|
518 |
|
519 |
gr.Markdown("""
|
520 |
-
###
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
- GC content, fraction of bases pushing "human" vs "non-human"
|
527 |
-
- Simple logic-based interpretation based on average SHAP
|
528 |
-
5. **Identification of the most 'human-pushing' subregion** (max average SHAP)
|
529 |
-
and the most 'non-human–pushing' subregion (min average SHAP),
|
530 |
-
each of a chosen window size.
|
531 |
""")
|
532 |
|
533 |
if __name__ == "__main__":
|
|
|
140 |
"""
|
141 |
Finds the subregion of length `window_size` that has the maximum
|
142 |
(mode="max") or minimum (mode="min") average SHAP.
|
143 |
+
Returns (best_start, best_end, best_avg).
|
144 |
"""
|
145 |
n = len(shap_means)
|
146 |
+
if n == 0:
|
147 |
+
# Edge case: empty array
|
148 |
+
return (0, 0, 0.0)
|
149 |
if window_size >= n:
|
150 |
+
# If the window is bigger than the entire sequence, return entire seq
|
151 |
+
avg_val = float(np.mean(shap_means))
|
152 |
return (0, n, avg_val)
|
153 |
|
154 |
+
# We'll build csum as length n+1 so csum[i] = sum of shap_means[:i]
|
155 |
+
# That means sum in [start, start+window_size) = csum[start+window_size] - csum[start].
|
156 |
+
csum = np.zeros(n + 1, dtype=np.float32)
|
157 |
+
csum[1:] = np.cumsum(shap_means)
|
158 |
+
|
|
|
|
|
159 |
best_start = 0
|
160 |
+
# Initialize with the first window: [0, window_size)
|
161 |
+
best_sum = csum[window_size] - csum[0]
|
162 |
best_avg = best_sum / window_size
|
163 |
|
164 |
for start in range(1, n - window_size + 1):
|
165 |
+
wsum = csum[start + window_size] - csum[start]
|
166 |
wavg = wsum / window_size
|
167 |
if mode == "max":
|
168 |
if wavg > best_avg:
|
|
|
173 |
best_avg = wavg
|
174 |
best_start = start
|
175 |
|
176 |
+
return (best_start, best_start + window_size, float(best_avg))
|
177 |
|
178 |
###############################################################################
|
179 |
# 6. PLOTTING / UTILITIES
|
|
|
193 |
Plots a 1D heatmap of per-base SHAP contributions.
|
194 |
Negative = push toward Non-Human, Positive = push toward Human.
|
195 |
Optionally can show only a subrange (start:end).
|
196 |
+
Adjust layout so the colorbar is well below the x-axis:
|
|
|
197 |
- orientation='horizontal', pad=0.35
|
198 |
+
- plt.subplots_adjust(bottom=0.4)
|
199 |
"""
|
200 |
if start is not None and end is not None:
|
201 |
shap_means = shap_means[start:end]
|
|
|
294 |
# Load model and scaler
|
295 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
296 |
try:
|
297 |
+
# Use weights_only=True to address the FutureWarning about untrusted pickle data
|
298 |
+
state_dict = torch.load('model.pt', map_location=device, weights_only=True)
|
299 |
model = VirusClassifier(256).to(device)
|
300 |
+
model.load_state_dict(state_dict)
|
301 |
+
|
302 |
+
# Load scaler (warning if version mismatch)
|
303 |
scaler = joblib.load('scaler.pkl')
|
304 |
except Exception as e:
|
305 |
+
return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
|
306 |
|
307 |
# Vectorize + scale
|
308 |
freq_vector = sequence_to_kmer_vector(seq)
|
|
|
348 |
heatmap_img = fig_to_image(heatmap_fig)
|
349 |
|
350 |
# Store data for subregion analysis
|
351 |
+
state_dict_out = {
|
352 |
"seq": seq,
|
353 |
"shap_means": shap_means
|
354 |
}
|
355 |
|
356 |
+
# Return exactly 5 items
|
357 |
+
return (results_text, bar_img, heatmap_img, state_dict_out, header)
|
358 |
|
359 |
###############################################################################
|
360 |
# 8. SUBREGION ANALYSIS (Gradio Step 2)
|
|
|
479 |
kmer_img = gr.Image(label="Top k-mer SHAP")
|
480 |
genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
|
481 |
|
|
|
482 |
seq_state = gr.State()
|
483 |
header_state = gr.State()
|
484 |
|
485 |
+
# analyze_sequence(...) returns 5 items.
|
|
|
|
|
|
|
|
|
|
|
486 |
analyze_btn.click(
|
487 |
analyze_sequence,
|
488 |
inputs=[file_input, top_k, text_input, win_size],
|
|
|
515 |
)
|
516 |
|
517 |
gr.Markdown("""
|
518 |
+
### Interface Features
|
519 |
+
- **Overall Classification** (human vs non-human) using k-mer frequencies.
|
520 |
+
- **Top k-mer SHAP**: which k-mers push the classifier output.
|
521 |
+
- **Genome-Wide SHAP Heatmap**: each base's average SHAP across overlapping k-mers.
|
522 |
+
- **Identify Subregions** (sliding window) with the strongest push for human or non-human.
|
523 |
+
- **Subregion Exploration**: local SHAP heatmap & histogram, GC content, fraction of positions pushing human vs. non-human.
|
|
|
|
|
|
|
|
|
|
|
524 |
""")
|
525 |
|
526 |
if __name__ == "__main__":
|