Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -32,7 +32,6 @@ class VirusClassifier(nn.Module):
|
|
32 |
def forward(self, x):
|
33 |
return self.network(x)
|
34 |
|
35 |
-
|
36 |
###############################################################################
|
37 |
# 2. FASTA PARSING & K-MER FEATURE ENGINEERING
|
38 |
###############################################################################
|
@@ -59,7 +58,7 @@ def parse_fasta(text):
|
|
59 |
return sequences
|
60 |
|
61 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
62 |
-
"""Convert a sequence to a k-mer frequency vector."""
|
63 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
64 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
65 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
@@ -75,7 +74,6 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
75 |
|
76 |
return vec
|
77 |
|
78 |
-
|
79 |
###############################################################################
|
80 |
# 3. SHAP-VALUE (ABLATION) CALCULATION
|
81 |
###############################################################################
|
@@ -83,30 +81,29 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
83 |
def calculate_shap_values(model, x_tensor):
|
84 |
"""
|
85 |
Calculate SHAP values using a simple ablation approach.
|
86 |
-
Returns
|
87 |
"""
|
88 |
model.eval()
|
89 |
with torch.no_grad():
|
90 |
-
#
|
91 |
baseline_output = model(x_tensor)
|
92 |
baseline_probs = torch.softmax(baseline_output, dim=1)
|
93 |
baseline_prob = baseline_probs[0, 1].item() # Probability of 'human' class
|
94 |
|
95 |
-
#
|
96 |
shap_values = []
|
97 |
x_zeroed = x_tensor.clone()
|
98 |
for i in range(x_tensor.shape[1]):
|
99 |
-
|
100 |
x_zeroed[0, i] = 0.0
|
101 |
output = model(x_zeroed)
|
102 |
probs = torch.softmax(output, dim=1)
|
103 |
prob = probs[0, 1].item()
|
104 |
impact = baseline_prob - prob
|
105 |
shap_values.append(impact)
|
106 |
-
x_zeroed[0, i] =
|
107 |
return np.array(shap_values), baseline_prob
|
108 |
|
109 |
-
|
110 |
###############################################################################
|
111 |
# 4. PER-BASE SHAP AGGREGATION
|
112 |
###############################################################################
|
@@ -116,7 +113,6 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
|
|
116 |
Returns an array of per-base SHAP contributions by averaging
|
117 |
the k-mer SHAP values of all k-mers covering that base.
|
118 |
"""
|
119 |
-
# Create the list of k-mers (in lexicographic order)
|
120 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
121 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
122 |
|
@@ -136,79 +132,44 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
|
|
136 |
|
137 |
return shap_means
|
138 |
|
139 |
-
|
140 |
###############################################################################
|
141 |
-
# 5.
|
142 |
###############################################################################
|
143 |
|
144 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
"""
|
146 |
Plots a 1D heatmap of per-base SHAP contributions.
|
147 |
Negative = push toward Non-Human, Positive = push toward Human.
|
|
|
148 |
"""
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
151 |
|
|
|
|
|
|
|
152 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
|
153 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.2)
|
154 |
cbar.set_label('SHAP Contribution')
|
155 |
|
156 |
ax.set_yticks([])
|
157 |
ax.set_xlabel('Position in Sequence')
|
158 |
-
ax.set_title(title)
|
159 |
-
plt.tight_layout()
|
160 |
-
return fig
|
161 |
-
|
162 |
-
def get_top_signal_region(shap_means, window_size=500):
|
163 |
-
"""
|
164 |
-
Find the window of length `window_size` that has the highest
|
165 |
-
sum of absolute SHAP values. Returns (start_index, end_index).
|
166 |
-
"""
|
167 |
-
seq_len = len(shap_means)
|
168 |
-
if window_size >= seq_len:
|
169 |
-
return 0, seq_len # entire sequence if window too large
|
170 |
-
|
171 |
-
abs_values = np.abs(shap_means)
|
172 |
-
max_sum = -1
|
173 |
-
max_start = 0
|
174 |
-
|
175 |
-
# Slide a window over shap_means
|
176 |
-
current_sum = np.sum(abs_values[:window_size])
|
177 |
-
max_sum = current_sum
|
178 |
-
for start in range(1, seq_len - window_size + 1):
|
179 |
-
# Remove the leftmost base, add the new rightmost base
|
180 |
-
current_sum = current_sum - abs_values[start-1] + abs_values[start + window_size - 1]
|
181 |
-
if current_sum > max_sum:
|
182 |
-
max_sum = current_sum
|
183 |
-
max_start = start
|
184 |
-
|
185 |
-
return max_start, max_start + window_size
|
186 |
-
|
187 |
-
def plot_zoomed_heatmap(shap_means, window_size=500, title="Zoomed SHAP Region"):
|
188 |
-
"""
|
189 |
-
Finds the region with the largest absolute SHAP sum in a fixed window,
|
190 |
-
then plots a 1D heatmap of just that sub-region.
|
191 |
-
"""
|
192 |
-
start, end = get_top_signal_region(shap_means, window_size)
|
193 |
-
sub_means = shap_means[start:end].reshape(1, -1)
|
194 |
-
|
195 |
-
fig, ax = plt.subplots(figsize=(12, 2))
|
196 |
-
cax = ax.imshow(sub_means, aspect='auto', cmap='RdBu_r')
|
197 |
-
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.2)
|
198 |
-
cbar.set_label('SHAP Contribution')
|
199 |
-
|
200 |
-
ax.set_yticks([])
|
201 |
-
ax.set_xlabel(f'Position in Sequence (zoomed in {start} - {end})')
|
202 |
-
ax.set_title(title)
|
203 |
-
|
204 |
plt.tight_layout()
|
205 |
return fig
|
206 |
|
207 |
-
|
208 |
-
###############################################################################
|
209 |
-
# 6. OTHER PLOT: TOP-K K-MER BAR PLOT
|
210 |
-
###############################################################################
|
211 |
-
|
212 |
def create_importance_bar_plot(shap_values, kmers, top_k=10):
|
213 |
"""Create a bar plot of the most important k-mers."""
|
214 |
plt.rcParams.update({'font.size': 10})
|
@@ -223,31 +184,24 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
|
|
223 |
|
224 |
plt.barh(range(len(values)), values, color=colors)
|
225 |
plt.yticks(range(len(values)), features)
|
226 |
-
plt.xlabel('SHAP
|
227 |
plt.title(f'Top {top_k} Most Influential k-mers')
|
228 |
plt.gca().invert_yaxis()
|
229 |
return fig
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
import io
|
238 |
-
buf = io.BytesIO()
|
239 |
-
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
|
240 |
-
buf.seek(0)
|
241 |
-
img = Image.open(buf)
|
242 |
-
plt.close(fig)
|
243 |
-
return img
|
244 |
|
245 |
###############################################################################
|
246 |
-
#
|
247 |
###############################################################################
|
248 |
|
249 |
-
def
|
250 |
-
"""
|
251 |
# Handle input
|
252 |
if fasta_text.strip():
|
253 |
text = fasta_text.strip()
|
@@ -256,14 +210,14 @@ def predict(file_obj, top_kmers=10, fasta_text="", zoom_window=500):
|
|
256 |
with open(file_obj, 'r') as f:
|
257 |
text = f.read()
|
258 |
except Exception as e:
|
259 |
-
return f"Error reading file: {str(e)}", None, None, None
|
260 |
else:
|
261 |
-
return "Please provide a FASTA sequence.", None, None, None
|
262 |
|
263 |
# Parse FASTA
|
264 |
sequences = parse_fasta(text)
|
265 |
if not sequences:
|
266 |
-
return "No valid FASTA sequences found.", None, None, None
|
267 |
|
268 |
header, seq = sequences[0]
|
269 |
|
@@ -274,49 +228,101 @@ def predict(file_obj, top_kmers=10, fasta_text="", zoom_window=500):
|
|
274 |
model.load_state_dict(torch.load('model.pt', map_location=device))
|
275 |
scaler = joblib.load('scaler.pkl')
|
276 |
except Exception as e:
|
277 |
-
return f"Error loading model: {str(e)}", None, None, None
|
278 |
|
279 |
-
#
|
280 |
freq_vector = sequence_to_kmer_vector(seq)
|
281 |
scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
|
282 |
x_tensor = torch.FloatTensor(scaled_vector).to(device)
|
283 |
|
284 |
-
#
|
285 |
shap_values, prob_human = calculate_shap_values(model, x_tensor)
|
|
|
286 |
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
f"
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
|
|
296 |
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
# 2) Full-genome per-base SHAP heatmap
|
303 |
shap_means = compute_positionwise_scores(seq, shap_values, k=4)
|
304 |
-
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide
|
305 |
heatmap_img = fig_to_image(heatmap_fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
|
318 |
###############################################################################
|
319 |
-
#
|
320 |
###############################################################################
|
321 |
|
322 |
css = """
|
@@ -327,57 +333,87 @@ css = """
|
|
327 |
|
328 |
with gr.Blocks(css=css) as iface:
|
329 |
gr.Markdown("""
|
330 |
-
# Virus Host Classifier
|
331 |
-
|
|
|
332 |
""")
|
333 |
|
334 |
-
with gr.
|
335 |
-
with gr.
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
gr.Markdown("""
|
375 |
-
###
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
|
|
|
|
|
|
|
|
|
|
381 |
""")
|
382 |
|
383 |
if __name__ == "__main__":
|
|
|
32 |
def forward(self, x):
|
33 |
return self.network(x)
|
34 |
|
|
|
35 |
###############################################################################
|
36 |
# 2. FASTA PARSING & K-MER FEATURE ENGINEERING
|
37 |
###############################################################################
|
|
|
58 |
return sequences
|
59 |
|
60 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
61 |
+
"""Convert a sequence to a k-mer frequency vector for classification."""
|
62 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
63 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
64 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
|
|
74 |
|
75 |
return vec
|
76 |
|
|
|
77 |
###############################################################################
|
78 |
# 3. SHAP-VALUE (ABLATION) CALCULATION
|
79 |
###############################################################################
|
|
|
81 |
def calculate_shap_values(model, x_tensor):
|
82 |
"""
|
83 |
Calculate SHAP values using a simple ablation approach.
|
84 |
+
Returns shap_values, prob_human
|
85 |
"""
|
86 |
model.eval()
|
87 |
with torch.no_grad():
|
88 |
+
# Baseline
|
89 |
baseline_output = model(x_tensor)
|
90 |
baseline_probs = torch.softmax(baseline_output, dim=1)
|
91 |
baseline_prob = baseline_probs[0, 1].item() # Probability of 'human' class
|
92 |
|
93 |
+
# Zeroing each feature to measure impact
|
94 |
shap_values = []
|
95 |
x_zeroed = x_tensor.clone()
|
96 |
for i in range(x_tensor.shape[1]):
|
97 |
+
original_val = x_zeroed[0, i].item()
|
98 |
x_zeroed[0, i] = 0.0
|
99 |
output = model(x_zeroed)
|
100 |
probs = torch.softmax(output, dim=1)
|
101 |
prob = probs[0, 1].item()
|
102 |
impact = baseline_prob - prob
|
103 |
shap_values.append(impact)
|
104 |
+
x_zeroed[0, i] = original_val # restore
|
105 |
return np.array(shap_values), baseline_prob
|
106 |
|
|
|
107 |
###############################################################################
|
108 |
# 4. PER-BASE SHAP AGGREGATION
|
109 |
###############################################################################
|
|
|
113 |
Returns an array of per-base SHAP contributions by averaging
|
114 |
the k-mer SHAP values of all k-mers covering that base.
|
115 |
"""
|
|
|
116 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
117 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
118 |
|
|
|
132 |
|
133 |
return shap_means
|
134 |
|
|
|
135 |
###############################################################################
|
136 |
+
# 5. PLOTTING / UTILITIES
|
137 |
###############################################################################
|
138 |
|
139 |
+
def fig_to_image(fig):
|
140 |
+
"""Convert a Matplotlib figure to a PIL Image for Gradio."""
|
141 |
+
buf = io.BytesIO()
|
142 |
+
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
|
143 |
+
buf.seek(0)
|
144 |
+
img = Image.open(buf)
|
145 |
+
plt.close(fig)
|
146 |
+
return img
|
147 |
+
|
148 |
+
def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
|
149 |
"""
|
150 |
Plots a 1D heatmap of per-base SHAP contributions.
|
151 |
Negative = push toward Non-Human, Positive = push toward Human.
|
152 |
+
Optionally can show only a subrange (start:end).
|
153 |
"""
|
154 |
+
if start is not None and end is not None:
|
155 |
+
shap_means = shap_means[start:end]
|
156 |
+
subtitle = f" (positions {start}-{end})"
|
157 |
+
else:
|
158 |
+
subtitle = ""
|
159 |
|
160 |
+
heatmap_data = shap_means.reshape(1, -1) # shape (1, region_length)
|
161 |
+
|
162 |
+
fig, ax = plt.subplots(figsize=(12, 2))
|
163 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
|
164 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.2)
|
165 |
cbar.set_label('SHAP Contribution')
|
166 |
|
167 |
ax.set_yticks([])
|
168 |
ax.set_xlabel('Position in Sequence')
|
169 |
+
ax.set_title(f"{title}{subtitle}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
plt.tight_layout()
|
171 |
return fig
|
172 |
|
|
|
|
|
|
|
|
|
|
|
173 |
def create_importance_bar_plot(shap_values, kmers, top_k=10):
|
174 |
"""Create a bar plot of the most important k-mers."""
|
175 |
plt.rcParams.update({'font.size': 10})
|
|
|
184 |
|
185 |
plt.barh(range(len(values)), values, color=colors)
|
186 |
plt.yticks(range(len(values)), features)
|
187 |
+
plt.xlabel('SHAP Value (impact on model output)')
|
188 |
plt.title(f'Top {top_k} Most Influential k-mers')
|
189 |
plt.gca().invert_yaxis()
|
190 |
return fig
|
191 |
|
192 |
+
def compute_gc_content(sequence):
|
193 |
+
"""Compute %GC in the sequence (A, C, G, T)."""
|
194 |
+
if not sequence:
|
195 |
+
return 0
|
196 |
+
gc_count = sequence.count('G') + sequence.count('C')
|
197 |
+
return (gc_count / len(sequence)) * 100.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
###############################################################################
|
200 |
+
# 6. MAIN ANALYSIS STEP (Gradio Step 1)
|
201 |
###############################################################################
|
202 |
|
203 |
+
def analyze_sequence(file_obj, top_kmers=10, fasta_text=""):
|
204 |
+
"""Analyzes the entire genome, returning classification and a heatmap."""
|
205 |
# Handle input
|
206 |
if fasta_text.strip():
|
207 |
text = fasta_text.strip()
|
|
|
210 |
with open(file_obj, 'r') as f:
|
211 |
text = f.read()
|
212 |
except Exception as e:
|
213 |
+
return (f"Error reading file: {str(e)}", None, None, None, None)
|
214 |
else:
|
215 |
+
return ("Please provide a FASTA sequence.", None, None, None, None)
|
216 |
|
217 |
# Parse FASTA
|
218 |
sequences = parse_fasta(text)
|
219 |
if not sequences:
|
220 |
+
return ("No valid FASTA sequences found.", None, None, None, None)
|
221 |
|
222 |
header, seq = sequences[0]
|
223 |
|
|
|
228 |
model.load_state_dict(torch.load('model.pt', map_location=device))
|
229 |
scaler = joblib.load('scaler.pkl')
|
230 |
except Exception as e:
|
231 |
+
return (f"Error loading model: {str(e)}", None, None, None, None)
|
232 |
|
233 |
+
# Vectorize + scale
|
234 |
freq_vector = sequence_to_kmer_vector(seq)
|
235 |
scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
|
236 |
x_tensor = torch.FloatTensor(scaled_vector).to(device)
|
237 |
|
238 |
+
# SHAP + classification
|
239 |
shap_values, prob_human = calculate_shap_values(model, x_tensor)
|
240 |
+
prob_nonhuman = 1.0 - prob_human
|
241 |
|
242 |
+
classification = "Human" if prob_human > 0.5 else "Non-human"
|
243 |
+
confidence = max(prob_human, prob_nonhuman)
|
244 |
+
|
245 |
+
# Build results text
|
246 |
+
results_text = (
|
247 |
+
f"Sequence: {header}\n"
|
248 |
+
f"Length: {len(seq):,} bases\n"
|
249 |
+
f"Classification: {classification}\n"
|
250 |
+
f"Confidence: {confidence:.3f}\n"
|
251 |
+
f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})"
|
252 |
+
)
|
253 |
+
|
254 |
+
# K-mer importance plot
|
255 |
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
|
256 |
+
bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
|
257 |
+
bar_img = fig_to_image(bar_fig)
|
258 |
+
|
259 |
+
# Per-base SHAP for entire genome
|
|
|
|
|
260 |
shap_means = compute_positionwise_scores(seq, shap_values, k=4)
|
261 |
+
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
|
262 |
heatmap_img = fig_to_image(heatmap_fig)
|
263 |
+
|
264 |
+
# Return:
|
265 |
+
# 1) results text
|
266 |
+
# 2) k-mer bar image
|
267 |
+
# 3) full-genome heatmap
|
268 |
+
# 4) the "state" we need for step 2: (sequence, shap_means)
|
269 |
+
# We'll store these in a dictionary so we can pass it around in Gradio.
|
270 |
+
state_dict = {
|
271 |
+
"seq": seq,
|
272 |
+
"shap_means": shap_means
|
273 |
+
}
|
274 |
+
|
275 |
+
return (results_text, bar_img, heatmap_img, state_dict, header)
|
276 |
+
|
277 |
+
###############################################################################
|
278 |
+
# 7. SUBREGION ANALYSIS (Gradio Step 2)
|
279 |
+
###############################################################################
|
280 |
+
|
281 |
+
def analyze_subregion(state, header, region_start, region_end):
|
282 |
+
"""
|
283 |
+
Takes stored data from step 1 and a user-chosen region.
|
284 |
+
Returns a subregion heatmap and some stats (like GC content, average SHAP).
|
285 |
+
"""
|
286 |
+
if not state or "seq" not in state or "shap_means" not in state:
|
287 |
+
return ("No sequence data found. Please run Step 1 first.", None)
|
288 |
|
289 |
+
seq = state["seq"]
|
290 |
+
shap_means = state["shap_means"]
|
291 |
+
|
292 |
+
# Validate bounds
|
293 |
+
region_start = max(0, min(region_start, len(seq)))
|
294 |
+
region_end = max(0, min(region_end, len(seq)))
|
295 |
+
if region_end <= region_start:
|
296 |
+
return ("Invalid region range. End must be > Start.", None)
|
297 |
+
|
298 |
+
# Subsequence
|
299 |
+
region_seq = seq[region_start:region_end]
|
300 |
+
region_shap = shap_means[region_start:region_end]
|
301 |
+
|
302 |
+
# Some stats
|
303 |
+
gc_percent = compute_gc_content(region_seq)
|
304 |
+
avg_shap = float(np.mean(region_shap))
|
305 |
|
306 |
+
region_info = (
|
307 |
+
f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
|
308 |
+
f"Region length: {len(region_seq)} bases\n"
|
309 |
+
f"GC content: {gc_percent:.2f}%\n"
|
310 |
+
f"Average SHAP in region: {avg_shap:.4f} "
|
311 |
+
f"({'toward human' if avg_shap > 0 else 'toward non-human' if avg_shap < 0 else 'neutral'})"
|
312 |
+
)
|
313 |
+
|
314 |
+
# Plot region as small heatmap
|
315 |
+
fig = plot_linear_heatmap(shap_means,
|
316 |
+
title="Subregion SHAP",
|
317 |
+
start=region_start,
|
318 |
+
end=region_end)
|
319 |
+
heatmap_img = fig_to_image(fig)
|
320 |
+
|
321 |
+
return (region_info, heatmap_img)
|
322 |
|
323 |
|
324 |
###############################################################################
|
325 |
+
# 8. BUILD GRADIO INTERFACE
|
326 |
###############################################################################
|
327 |
|
328 |
css = """
|
|
|
333 |
|
334 |
with gr.Blocks(css=css) as iface:
|
335 |
gr.Markdown("""
|
336 |
+
# Virus Host Classifier (with Interactive Region Viewer)
|
337 |
+
**Step 1**: Predict overall viral sequence origin (human vs non-human)
|
338 |
+
**Step 2**: Explore subregions to see local SHAP signals and GC content
|
339 |
""")
|
340 |
|
341 |
+
with gr.Tab("1) Full-Sequence Analysis"):
|
342 |
+
with gr.Row():
|
343 |
+
with gr.Column(scale=1):
|
344 |
+
file_input = gr.File(
|
345 |
+
label="Upload FASTA file",
|
346 |
+
file_types=[".fasta", ".fa", ".txt"],
|
347 |
+
type="filepath"
|
348 |
+
)
|
349 |
+
text_input = gr.Textbox(
|
350 |
+
label="Or paste FASTA sequence",
|
351 |
+
placeholder=">sequence_name\nACGTACGT...",
|
352 |
+
lines=5
|
353 |
+
)
|
354 |
+
top_k = gr.Slider(
|
355 |
+
minimum=5,
|
356 |
+
maximum=30,
|
357 |
+
value=10,
|
358 |
+
step=1,
|
359 |
+
label="Number of top k-mers to display"
|
360 |
+
)
|
361 |
+
analyze_btn = gr.Button("Analyze Sequence", variant="primary")
|
362 |
+
|
363 |
+
with gr.Column(scale=2):
|
364 |
+
results_box = gr.Textbox(
|
365 |
+
label="Classification Results", lines=7, interactive=False
|
366 |
+
)
|
367 |
+
kmer_img = gr.Image(label="Top k-mer SHAP")
|
368 |
+
genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
|
369 |
+
|
370 |
+
# Hidden states that store data for step 2
|
371 |
+
# "state" will hold (sequence, shap_means).
|
372 |
+
# "header" is optional meta info
|
373 |
+
seq_state = gr.State()
|
374 |
+
header_state = gr.State()
|
375 |
+
|
376 |
+
# The "analyze_sequence" function returns 5 values, which we map here:
|
377 |
+
analyze_btn.click(
|
378 |
+
analyze_sequence,
|
379 |
+
inputs=[file_input, top_k, text_input],
|
380 |
+
outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
|
381 |
+
)
|
382 |
|
383 |
+
with gr.Tab("2) Subregion Exploration"):
|
384 |
+
gr.Markdown("""
|
385 |
+
Select start/end positions to view local SHAP signals.
|
386 |
+
""")
|
387 |
+
with gr.Row():
|
388 |
+
region_start = gr.Number(label="Region Start", value=0)
|
389 |
+
region_end = gr.Number(label="Region End", value=500)
|
390 |
+
region_btn = gr.Button("Analyze Subregion")
|
391 |
+
|
392 |
+
subregion_info = gr.Textbox(
|
393 |
+
label="Subregion Analysis",
|
394 |
+
lines=4,
|
395 |
+
interactive=False
|
396 |
+
)
|
397 |
+
subregion_img = gr.Image(label="Subregion SHAP Heatmap")
|
398 |
+
|
399 |
+
region_btn.click(
|
400 |
+
analyze_subregion,
|
401 |
+
inputs=[seq_state, header_state, region_start, region_end],
|
402 |
+
outputs=[subregion_info, subregion_img]
|
403 |
+
)
|
404 |
|
405 |
gr.Markdown("""
|
406 |
+
### What does this interface provide?
|
407 |
+
1. **Overall Classification** (human vs non-human), using a learned model on k-mer frequencies.
|
408 |
+
2. **SHAP Analysis** (ablation-based) to see which k-mer features push classification toward or away from "human".
|
409 |
+
3. **Genome-Wide SHAP Heatmap**: Each base's average SHAP across overlapping k-mers.
|
410 |
+
4. **Subregion Exploration**:
|
411 |
+
- View SHAP signals in a user-chosen region.
|
412 |
+
- Calculate local GC content, average SHAP, etc.
|
413 |
+
|
414 |
+
### Tips
|
415 |
+
- For very large sequences (e.g., >100k bases), the full heatmap might be large; consider downsampling if needed.
|
416 |
+
- Adjust *Region Start* and *End* to explore different parts of the genome.
|
417 |
""")
|
418 |
|
419 |
if __name__ == "__main__":
|