Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ import numpy as np
|
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
7 |
import matplotlib.pyplot as plt
|
|
|
8 |
import io
|
9 |
from PIL import Image
|
10 |
|
@@ -144,20 +145,17 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
|
|
144 |
"""
|
145 |
n = len(shap_means)
|
146 |
if n == 0:
|
147 |
-
# Edge case: empty array
|
148 |
return (0, 0, 0.0)
|
149 |
if window_size >= n:
|
150 |
-
#
|
151 |
avg_val = float(np.mean(shap_means))
|
152 |
return (0, n, avg_val)
|
153 |
|
154 |
-
# We'll build csum
|
155 |
-
# That means sum in [start, start+window_size) = csum[start+window_size] - csum[start].
|
156 |
csum = np.zeros(n + 1, dtype=np.float32)
|
157 |
csum[1:] = np.cumsum(shap_means)
|
158 |
|
159 |
best_start = 0
|
160 |
-
# Initialize with the first window: [0, window_size)
|
161 |
best_sum = csum[window_size] - csum[0]
|
162 |
best_avg = best_sum / window_size
|
163 |
|
@@ -188,29 +186,65 @@ def fig_to_image(fig):
|
|
188 |
plt.close(fig)
|
189 |
return img
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
|
192 |
"""
|
193 |
-
Plots a 1D heatmap of per-base SHAP contributions
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
199 |
"""
|
200 |
if start is not None and end is not None:
|
201 |
-
|
202 |
subtitle = f" (positions {start}-{end})"
|
203 |
else:
|
|
|
204 |
subtitle = ""
|
205 |
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
fig, ax = plt.subplots(figsize=(12, 2))
|
209 |
-
cax = ax.imshow(
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
-
# Place colorbar below
|
212 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.35)
|
213 |
-
cbar.set_label('SHAP Contribution')
|
214 |
|
215 |
ax.set_yticks([])
|
216 |
ax.set_xlabel('Position in Sequence')
|
@@ -231,7 +265,8 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
|
|
231 |
values = shap_values[indices]
|
232 |
features = [kmers[i] for i in indices]
|
233 |
|
234 |
-
|
|
|
235 |
|
236 |
plt.barh(range(len(values)), values, color=colors)
|
237 |
plt.yticks(range(len(values)), features)
|
@@ -244,7 +279,6 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
|
|
244 |
def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
|
245 |
"""
|
246 |
Simple histogram of SHAP values in the subregion.
|
247 |
-
Helps see how many positions push human vs non-human.
|
248 |
"""
|
249 |
fig, ax = plt.subplots(figsize=(6, 4))
|
250 |
ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
|
@@ -294,12 +328,11 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
294 |
# Load model and scaler
|
295 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
296 |
try:
|
297 |
-
# Use weights_only=True
|
298 |
state_dict = torch.load('model.pt', map_location=device, weights_only=True)
|
299 |
model = VirusClassifier(256).to(device)
|
300 |
model.load_state_dict(state_dict)
|
301 |
-
|
302 |
-
# Load scaler (warning if version mismatch)
|
303 |
scaler = joblib.load('scaler.pkl')
|
304 |
except Exception as e:
|
305 |
return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
|
@@ -353,7 +386,6 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
353 |
"shap_means": shap_means
|
354 |
}
|
355 |
|
356 |
-
# Return exactly 5 items
|
357 |
return (results_text, bar_img, heatmap_img, state_dict_out, header)
|
358 |
|
359 |
###############################################################################
|
@@ -438,9 +470,11 @@ css = """
|
|
438 |
|
439 |
with gr.Blocks(css=css) as iface:
|
440 |
gr.Markdown("""
|
441 |
-
# Virus Host Classifier
|
442 |
**Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
|
443 |
**Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
|
|
|
|
|
444 |
""")
|
445 |
|
446 |
with gr.Tab("1) Full-Sequence Analysis"):
|
@@ -477,12 +511,12 @@ with gr.Blocks(css=css) as iface:
|
|
477 |
label="Classification Results", lines=12, interactive=False
|
478 |
)
|
479 |
kmer_img = gr.Image(label="Top k-mer SHAP")
|
480 |
-
genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
|
481 |
|
482 |
seq_state = gr.State()
|
483 |
header_state = gr.State()
|
484 |
|
485 |
-
# analyze_sequence(...) returns 5 items
|
486 |
analyze_btn.click(
|
487 |
analyze_sequence,
|
488 |
inputs=[file_input, top_k, text_input, win_size],
|
@@ -492,7 +526,8 @@ with gr.Blocks(css=css) as iface:
|
|
492 |
with gr.Tab("2) Subregion Exploration"):
|
493 |
gr.Markdown("""
|
494 |
**Subregion Analysis**
|
495 |
-
Select start/end positions to view local SHAP signals, distribution, and GC content.
|
|
|
496 |
""")
|
497 |
with gr.Row():
|
498 |
region_start = gr.Number(label="Region Start", value=0)
|
@@ -505,7 +540,7 @@ with gr.Blocks(css=css) as iface:
|
|
505 |
interactive=False
|
506 |
)
|
507 |
with gr.Row():
|
508 |
-
subregion_img = gr.Image(label="Subregion SHAP Heatmap")
|
509 |
subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
|
510 |
|
511 |
region_btn.click(
|
@@ -517,10 +552,15 @@ with gr.Blocks(css=css) as iface:
|
|
517 |
gr.Markdown("""
|
518 |
### Interface Features
|
519 |
- **Overall Classification** (human vs non-human) using k-mer frequencies.
|
520 |
-
- **
|
521 |
-
- **
|
522 |
-
|
523 |
-
- **
|
|
|
|
|
|
|
|
|
|
|
524 |
""")
|
525 |
|
526 |
if __name__ == "__main__":
|
|
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
7 |
import matplotlib.pyplot as plt
|
8 |
+
import matplotlib.colors as mcolors
|
9 |
import io
|
10 |
from PIL import Image
|
11 |
|
|
|
145 |
"""
|
146 |
n = len(shap_means)
|
147 |
if n == 0:
|
|
|
148 |
return (0, 0, 0.0)
|
149 |
if window_size >= n:
|
150 |
+
# entire sequence
|
151 |
avg_val = float(np.mean(shap_means))
|
152 |
return (0, n, avg_val)
|
153 |
|
154 |
+
# We'll build csum of length n+1
|
|
|
155 |
csum = np.zeros(n + 1, dtype=np.float32)
|
156 |
csum[1:] = np.cumsum(shap_means)
|
157 |
|
158 |
best_start = 0
|
|
|
159 |
best_sum = csum[window_size] - csum[0]
|
160 |
best_avg = best_sum / window_size
|
161 |
|
|
|
186 |
plt.close(fig)
|
187 |
return img
|
188 |
|
189 |
+
def get_zero_centered_cmap():
|
190 |
+
"""
|
191 |
+
Creates a custom diverging colormap that is:
|
192 |
+
- Blue for negative
|
193 |
+
- White for zero
|
194 |
+
- Red for positive
|
195 |
+
"""
|
196 |
+
colors = [
|
197 |
+
(0.0, 'blue'), # negative
|
198 |
+
(0.5, 'white'), # zero
|
199 |
+
(1.0, 'red') # positive
|
200 |
+
]
|
201 |
+
cmap = mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
202 |
+
return cmap
|
203 |
+
|
204 |
def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
|
205 |
"""
|
206 |
+
Plots a 1D heatmap of per-base SHAP contributions with a custom colormap:
|
207 |
+
- Negative = blue
|
208 |
+
- 0 = white
|
209 |
+
- Positive = red
|
210 |
+
We'll force the range to be symmetrical around 0 by using:
|
211 |
+
vmin=-extent, vmax=+extent
|
212 |
+
so 0 is in the middle.
|
213 |
"""
|
214 |
if start is not None and end is not None:
|
215 |
+
local_shap = shap_means[start:end]
|
216 |
subtitle = f" (positions {start}-{end})"
|
217 |
else:
|
218 |
+
local_shap = shap_means
|
219 |
subtitle = ""
|
220 |
|
221 |
+
if len(local_shap) == 0:
|
222 |
+
# Edge case: no data to plot
|
223 |
+
local_shap = np.array([0.0])
|
224 |
+
|
225 |
+
# Build 2D array for imshow
|
226 |
+
heatmap_data = local_shap.reshape(1, -1)
|
227 |
+
|
228 |
+
# Force symmetrical range
|
229 |
+
min_val = np.min(local_shap)
|
230 |
+
max_val = np.max(local_shap)
|
231 |
+
extent = max(abs(min_val), abs(max_val))
|
232 |
+
|
233 |
+
# Create custom colormap
|
234 |
+
custom_cmap = get_zero_centered_cmap()
|
235 |
|
236 |
fig, ax = plt.subplots(figsize=(12, 2))
|
237 |
+
cax = ax.imshow(
|
238 |
+
heatmap_data,
|
239 |
+
aspect='auto',
|
240 |
+
cmap=custom_cmap,
|
241 |
+
vmin=-extent,
|
242 |
+
vmax=+extent
|
243 |
+
)
|
244 |
|
245 |
+
# Place colorbar below with plenty of margin
|
246 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.35)
|
247 |
+
cbar.set_label('SHAP Contribution (negative=blue, zero=white, positive=red)')
|
248 |
|
249 |
ax.set_yticks([])
|
250 |
ax.set_xlabel('Position in Sequence')
|
|
|
265 |
values = shap_values[indices]
|
266 |
features = [kmers[i] for i in indices]
|
267 |
|
268 |
+
# negative -> blue, positive -> red
|
269 |
+
colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
|
270 |
|
271 |
plt.barh(range(len(values)), values, color=colors)
|
272 |
plt.yticks(range(len(values)), features)
|
|
|
279 |
def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
|
280 |
"""
|
281 |
Simple histogram of SHAP values in the subregion.
|
|
|
282 |
"""
|
283 |
fig, ax = plt.subplots(figsize=(6, 4))
|
284 |
ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
|
|
|
328 |
# Load model and scaler
|
329 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
330 |
try:
|
331 |
+
# Use weights_only=True for safer loading
|
332 |
state_dict = torch.load('model.pt', map_location=device, weights_only=True)
|
333 |
model = VirusClassifier(256).to(device)
|
334 |
model.load_state_dict(state_dict)
|
335 |
+
|
|
|
336 |
scaler = joblib.load('scaler.pkl')
|
337 |
except Exception as e:
|
338 |
return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
|
|
|
386 |
"shap_means": shap_means
|
387 |
}
|
388 |
|
|
|
389 |
return (results_text, bar_img, heatmap_img, state_dict_out, header)
|
390 |
|
391 |
###############################################################################
|
|
|
470 |
|
471 |
with gr.Blocks(css=css) as iface:
|
472 |
gr.Markdown("""
|
473 |
+
# Virus Host Classifier with White-Centered Gradient
|
474 |
**Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
|
475 |
**Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
|
476 |
+
|
477 |
+
**Color Scale**: Negative SHAP = Blue, Zero = White, Positive = Red.
|
478 |
""")
|
479 |
|
480 |
with gr.Tab("1) Full-Sequence Analysis"):
|
|
|
511 |
label="Classification Results", lines=12, interactive=False
|
512 |
)
|
513 |
kmer_img = gr.Image(label="Top k-mer SHAP")
|
514 |
+
genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
|
515 |
|
516 |
seq_state = gr.State()
|
517 |
header_state = gr.State()
|
518 |
|
519 |
+
# analyze_sequence(...) returns 5 items
|
520 |
analyze_btn.click(
|
521 |
analyze_sequence,
|
522 |
inputs=[file_input, top_k, text_input, win_size],
|
|
|
526 |
with gr.Tab("2) Subregion Exploration"):
|
527 |
gr.Markdown("""
|
528 |
**Subregion Analysis**
|
529 |
+
Select start/end positions to view local SHAP signals, distribution, and GC content.
|
530 |
+
The heatmap also uses the same Blue-White-Red scale.
|
531 |
""")
|
532 |
with gr.Row():
|
533 |
region_start = gr.Number(label="Region Start", value=0)
|
|
|
540 |
interactive=False
|
541 |
)
|
542 |
with gr.Row():
|
543 |
+
subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
|
544 |
subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
|
545 |
|
546 |
region_btn.click(
|
|
|
552 |
gr.Markdown("""
|
553 |
### Interface Features
|
554 |
- **Overall Classification** (human vs non-human) using k-mer frequencies.
|
555 |
+
- **SHAP Analysis** to see which k-mers push classification toward or away from human.
|
556 |
+
- **White-Centered SHAP Gradient**:
|
557 |
+
- Negative (blue), 0 (white), Positive (red), with symmetrical color range around 0.
|
558 |
+
- **Identify Subregions** with the strongest push for human or non-human.
|
559 |
+
- **Subregion Exploration**:
|
560 |
+
- Local SHAP heatmap & histogram
|
561 |
+
- GC content
|
562 |
+
- Fraction of positions pushing human vs. non-human
|
563 |
+
- Simple logic-based classification
|
564 |
""")
|
565 |
|
566 |
if __name__ == "__main__":
|