Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,10 +8,6 @@ import matplotlib.pyplot as plt
|
|
8 |
import io
|
9 |
from PIL import Image
|
10 |
|
11 |
-
##############################################################################
|
12 |
-
# MODEL DEFINITION
|
13 |
-
##############################################################################
|
14 |
-
|
15 |
class VirusClassifier(nn.Module):
|
16 |
def __init__(self, input_shape: int):
|
17 |
super(VirusClassifier, self).__init__()
|
@@ -32,10 +28,6 @@ class VirusClassifier(nn.Module):
|
|
32 |
def forward(self, x):
|
33 |
return self.network(x)
|
34 |
|
35 |
-
##############################################################################
|
36 |
-
# UTILITIES
|
37 |
-
##############################################################################
|
38 |
-
|
39 |
def parse_fasta(text):
|
40 |
"""
|
41 |
Parses FASTA formatted text into a list of (header, sequence).
|
@@ -61,7 +53,7 @@ def parse_fasta(text):
|
|
61 |
|
62 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
63 |
"""
|
64 |
-
Convert a sequence to a k-mer frequency vector
|
65 |
"""
|
66 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
67 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
@@ -78,355 +70,218 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
78 |
|
79 |
return vec
|
80 |
|
81 |
-
def
|
82 |
"""
|
83 |
-
|
84 |
-
1. Compute baseline human probability p_base.
|
85 |
-
2. For each feature i, set x[i] = 0, re-run inference, compute new p, and
|
86 |
-
measure delta = p_base - p.
|
87 |
-
3. Return array of deltas (positive means that removing that feature
|
88 |
-
*decreases* the probability => that feature was pushing it higher).
|
89 |
"""
|
90 |
model.eval()
|
91 |
with torch.no_grad():
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
with torch.no_grad():
|
105 |
-
output_ablation = model(x_copy)
|
106 |
-
probs_ablation = torch.softmax(output_ablation, dim=1)
|
107 |
-
p_ablation = probs_ablation[0, 1].item()
|
108 |
-
# Delta
|
109 |
-
importances[i] = p_base - p_ablation
|
110 |
-
|
111 |
-
return importances, p_base
|
112 |
-
|
113 |
-
##############################################################################
|
114 |
-
# PLOTTING
|
115 |
-
##############################################################################
|
116 |
|
117 |
-
def
|
118 |
"""
|
119 |
-
|
120 |
-
and a frequency vs. sigma bar chart.
|
121 |
"""
|
122 |
-
|
|
|
123 |
|
124 |
-
#
|
125 |
-
|
|
|
|
|
126 |
|
127 |
-
#
|
128 |
-
ax1 = plt.subplot(gs[0])
|
129 |
-
current_prob = 0.5
|
130 |
-
steps = [('Start', current_prob, 0)]
|
131 |
-
|
132 |
-
for kmer_info in important_kmers:
|
133 |
-
change = kmer_info['impact'] # positive => pushes up, negative => pushes down
|
134 |
-
current_prob += change
|
135 |
-
steps.append((kmer_info['kmer'], current_prob, change))
|
136 |
-
|
137 |
-
x = range(len(steps))
|
138 |
-
y = [step[1] for step in steps]
|
139 |
-
|
140 |
-
# Plot steps
|
141 |
-
ax1.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
|
142 |
-
ax1.plot(x, y, 'b.', markersize=10)
|
143 |
-
|
144 |
-
# Add reference line
|
145 |
-
ax1.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
|
146 |
-
|
147 |
-
# Customize plot
|
148 |
-
ax1.grid(True, linestyle='--', alpha=0.7)
|
149 |
-
ax1.set_ylim(0, 1)
|
150 |
-
ax1.set_ylabel('Human Probability')
|
151 |
-
ax1.set_title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
|
152 |
-
|
153 |
-
# Add labels for each point
|
154 |
-
for i, (kmer, prob, change) in enumerate(steps):
|
155 |
-
# Add k-mer label
|
156 |
-
ax1.annotate(kmer,
|
157 |
-
(i, prob),
|
158 |
-
xytext=(0, 10 if i % 2 == 0 else -20),
|
159 |
-
textcoords='offset points',
|
160 |
-
ha='center',
|
161 |
-
rotation=45)
|
162 |
-
|
163 |
-
# Add change value
|
164 |
-
if i > 0:
|
165 |
-
change_text = f'{change:+.3f}'
|
166 |
-
color = 'green' if change > 0 else 'red'
|
167 |
-
ax1.annotate(change_text,
|
168 |
-
(i, prob),
|
169 |
-
xytext=(0, -20 if i % 2 == 0 else 10),
|
170 |
-
textcoords='offset points',
|
171 |
-
ha='center',
|
172 |
-
color=color)
|
173 |
|
174 |
-
|
|
|
|
|
|
|
|
|
175 |
|
176 |
-
# 2. K-mer Frequency and Sigma Plot
|
177 |
-
ax2 = plt.subplot(gs[1])
|
178 |
-
|
179 |
-
# Prepare data
|
180 |
-
kmers = [k['kmer'] for k in important_kmers]
|
181 |
-
frequencies = [k['occurrence'] for k in important_kmers]
|
182 |
-
sigmas = [k['sigma'] for k in important_kmers]
|
183 |
-
|
184 |
-
# Color the bars: if impact>0 => green, else red
|
185 |
-
colors = ['g' if k['impact'] > 0 else 'r' for k in important_kmers]
|
186 |
-
|
187 |
-
# Create bar plot for frequencies
|
188 |
-
x = np.arange(len(kmers))
|
189 |
-
width = 0.35
|
190 |
-
|
191 |
-
ax2.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
|
192 |
-
|
193 |
-
# Twin axis for sigma
|
194 |
-
ax2_twin = ax2.twinx()
|
195 |
-
# To highlight positive or negative sigma, pick color accordingly
|
196 |
-
sigma_colors = []
|
197 |
-
for s, c in zip(sigmas, colors):
|
198 |
-
if s >= 0:
|
199 |
-
sigma_colors.append('blue') # above average
|
200 |
-
else:
|
201 |
-
sigma_colors.append('gray') # below average
|
202 |
-
|
203 |
-
ax2_twin.bar(x + width/2, sigmas, width, label='σ from Mean', color=sigma_colors, alpha=0.3)
|
204 |
-
|
205 |
-
# Customize plot
|
206 |
-
ax2.set_xticks(x)
|
207 |
-
ax2.set_xticklabels(kmers, rotation=45)
|
208 |
-
ax2.set_ylabel('Frequency (%)')
|
209 |
-
ax2_twin.set_ylabel('Standard Deviations (σ) from Mean')
|
210 |
-
ax2.set_title('K-mer Frequencies and Statistical Significance')
|
211 |
-
|
212 |
-
# Add legends
|
213 |
-
lines1, labels1 = ax2.get_legend_handles_labels()
|
214 |
-
lines2, labels2 = ax2_twin.get_legend_handles_labels()
|
215 |
-
ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
|
216 |
-
|
217 |
-
plt.tight_layout()
|
218 |
return fig
|
219 |
|
220 |
-
def
|
221 |
"""
|
222 |
-
|
223 |
-
impact_values: array of float (length=256).
|
224 |
-
kmer_list: list of all k=4 kmers in order.
|
225 |
-
top_k: integer, how many top features to display.
|
226 |
"""
|
227 |
-
|
228 |
-
|
229 |
-
top_indices = indices_sorted[:top_k]
|
230 |
|
231 |
-
|
232 |
-
|
|
|
233 |
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
plt.
|
239 |
-
plt.
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
"""
|
247 |
-
fig = plt.figure(figsize=(12, 6))
|
248 |
-
indices_sorted = np.argsort(np.abs(impact_values))[::-1]
|
249 |
-
sorted_impacts = impact_values[indices_sorted]
|
250 |
-
sorted_kmers = [kmer_list[i] for i in indices_sorted]
|
251 |
|
252 |
-
plt.bar(range(len(sorted_impacts)), sorted_impacts,
|
253 |
-
color=['green' if i > 0 else 'red' for i in sorted_impacts])
|
254 |
-
plt.title("Global Impact of All 256 K-mers (Ablation Method)")
|
255 |
-
plt.xlabel("K-mer (sorted by |impact|)")
|
256 |
-
plt.ylabel("Impact on Human Probability")
|
257 |
-
# Optionally, we can skip labeling all 256 on x-axis.
|
258 |
-
# But we can show only the top/bottom or none for clarity.
|
259 |
-
plt.tight_layout()
|
260 |
return fig
|
261 |
|
262 |
-
|
263 |
-
# MAIN PREDICTION FUNCTION
|
264 |
-
##############################################################################
|
265 |
-
|
266 |
-
def predict(file_obj, top_kmers=10, advanced_plots=False, fasta_text=""):
|
267 |
"""
|
268 |
-
Main prediction function
|
269 |
-
- file_obj: optional uploaded FASTA file
|
270 |
-
- top_kmers: number of top k-mers to display in the main SHAP-like plot
|
271 |
-
- advanced_plots: bool, whether to return global bar plots
|
272 |
-
- fasta_text: optional direct-pasted FASTA text
|
273 |
"""
|
274 |
-
#
|
275 |
if fasta_text.strip():
|
276 |
text = fasta_text.strip()
|
277 |
-
|
278 |
-
if file_obj is None:
|
279 |
-
return "No FASTA input provided", None, None, None
|
280 |
try:
|
281 |
-
|
282 |
-
text = file_obj
|
283 |
-
else:
|
284 |
-
text = file_obj.decode('utf-8')
|
285 |
except Exception as e:
|
286 |
-
return f"Error reading file: {str(e)}", None, None
|
|
|
|
|
287 |
|
288 |
# Parse FASTA
|
289 |
sequences = parse_fasta(text)
|
290 |
-
if
|
291 |
-
return "No valid FASTA sequences found", None, None
|
|
|
292 |
header, seq = sequences[0]
|
293 |
|
294 |
-
#
|
295 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
296 |
-
model = VirusClassifier(256).to(device)
|
297 |
try:
|
298 |
-
|
299 |
-
model.load_state_dict(
|
300 |
scaler = joblib.load('scaler.pkl')
|
301 |
except Exception as e:
|
302 |
-
return f"Error loading model
|
303 |
|
304 |
-
#
|
305 |
-
|
306 |
-
scaled_vector = scaler.transform(
|
307 |
-
|
308 |
|
309 |
-
#
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
# We'll gather info in a list of dicts for each k-mer.
|
317 |
-
kmers_4 = [''.join(p) for p in product("ACGT", repeat=4)]
|
318 |
-
kmer_dict = {km: i for i, km in enumerate(kmers_4)}
|
319 |
-
|
320 |
-
# We'll sort by absolute impact to get the top 10 by default.
|
321 |
-
abs_sorted_idx = np.argsort(np.abs(importances))[::-1]
|
322 |
-
# But for the final step/frequency plot we only show top_kmers
|
323 |
-
top_indices = abs_sorted_idx[:top_kmers]
|
324 |
-
|
325 |
-
# Build a list of the top k-mers
|
326 |
important_kmers = []
|
327 |
-
for idx in
|
328 |
-
# "impact" is how much that feature changed the probability
|
329 |
-
impact = importances[idx]
|
330 |
-
# raw frequency => raw_freq_vector[idx] * 100 for %
|
331 |
-
freq_pct = float(raw_freq_vector[idx] * 100.0)
|
332 |
-
# sigma => scaled_vector[0][idx]
|
333 |
-
sigma_val = float(scaled_vector[0][idx])
|
334 |
-
|
335 |
important_kmers.append({
|
336 |
-
'kmer':
|
337 |
-
'impact':
|
338 |
-
'
|
339 |
-
'
|
340 |
})
|
341 |
-
|
342 |
-
# For text output
|
343 |
-
# We decide final class based on model's direct output
|
344 |
-
with torch.no_grad():
|
345 |
-
output = model(X_tensor)
|
346 |
-
probs = torch.softmax(output, dim=1)
|
347 |
-
pred_class = 1 if probs[0,1] > probs[0,0] else 0
|
348 |
-
pred_label = 'human' if pred_class == 1 else 'non-human'
|
349 |
-
human_prob = probs[0,1].item()
|
350 |
-
nonhuman_prob = probs[0,0].item()
|
351 |
-
confidence = max(human_prob, nonhuman_prob)
|
352 |
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
|
|
|
|
359 |
|
360 |
-
for
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
f"
|
365 |
-
f"
|
366 |
-
f"Occ={kmer_info['occurrence']:.2f}% of seq, "
|
367 |
-
f"{abs(kmer_info['sigma']):.2f}σ "
|
368 |
-
+ ("above" if kmer_info['sigma']>0 else "below")
|
369 |
-
+ " mean\n"
|
370 |
)
|
371 |
|
372 |
-
#
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
# PLOT 3 (optional advanced): global bar plot of all 256 features
|
379 |
-
global_fig = None
|
380 |
-
if advanced_plots:
|
381 |
-
global_fig = create_global_bar_plot(importances, kmers_4)
|
382 |
-
|
383 |
-
# Convert figures to PIL Images
|
384 |
def fig_to_image(fig):
|
385 |
buf = io.BytesIO()
|
386 |
-
fig.savefig(buf, format='png', bbox_inches='tight', dpi=
|
387 |
buf.seek(0)
|
388 |
-
|
389 |
plt.close(fig)
|
390 |
-
return
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
title_text = "Virus Host Classifier"
|
406 |
-
description_text = """
|
407 |
-
Upload or paste a FASTA sequence to predict if it's likely **human** or **non-human** origin.
|
408 |
-
- **k=4** k-mers are used as features.
|
409 |
-
- We display ablation-based feature importance for interpretability.
|
410 |
-
- Advanced plots can be toggled to see the global distribution of all 256 k-mer impacts.
|
411 |
"""
|
412 |
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
gr.
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
|
431 |
if __name__ == "__main__":
|
432 |
-
iface.launch(
|
|
|
8 |
import io
|
9 |
from PIL import Image
|
10 |
|
|
|
|
|
|
|
|
|
11 |
class VirusClassifier(nn.Module):
|
12 |
def __init__(self, input_shape: int):
|
13 |
super(VirusClassifier, self).__init__()
|
|
|
28 |
def forward(self, x):
|
29 |
return self.network(x)
|
30 |
|
|
|
|
|
|
|
|
|
31 |
def parse_fasta(text):
|
32 |
"""
|
33 |
Parses FASTA formatted text into a list of (header, sequence).
|
|
|
53 |
|
54 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
55 |
"""
|
56 |
+
Convert a sequence to a k-mer frequency vector.
|
57 |
"""
|
58 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
59 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
|
|
70 |
|
71 |
return vec
|
72 |
|
73 |
+
def calculate_shap_values(model, x_tensor):
|
74 |
"""
|
75 |
+
Calculate SHAP-like values using a simple ablation approach.
|
|
|
|
|
|
|
|
|
|
|
76 |
"""
|
77 |
model.eval()
|
78 |
with torch.no_grad():
|
79 |
+
baseline_output = model(x_tensor)
|
80 |
+
baseline_prob = torch.softmax(baseline_output, dim=1)[0, 1].item()
|
81 |
+
|
82 |
+
shap_values = []
|
83 |
+
for i in range(x_tensor.shape[1]):
|
84 |
+
perturbed_input = x_tensor.clone()
|
85 |
+
perturbed_input[0, i] = 0 # Ablate feature
|
86 |
+
output = model(perturbed_input)
|
87 |
+
prob = torch.softmax(output, dim=1)[0, 1].item()
|
88 |
+
shap_values.append(baseline_prob - prob)
|
89 |
+
|
90 |
+
return np.array(shap_values), baseline_prob
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
def create_importance_plot(shap_values, kmers, top_k=10):
|
93 |
"""
|
94 |
+
Create horizontal bar plot of feature importance.
|
|
|
95 |
"""
|
96 |
+
plt.style.use('seaborn')
|
97 |
+
fig = plt.figure(figsize=(10, 8))
|
98 |
|
99 |
+
# Sort by absolute importance
|
100 |
+
indices = np.argsort(np.abs(shap_values))[-top_k:]
|
101 |
+
values = shap_values[indices]
|
102 |
+
features = [kmers[i] for i in indices]
|
103 |
|
104 |
+
colors = ['#2ecc71' if v > 0 else '#e74c3c' for v in values]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
+
plt.barh(range(len(values)), values, color=colors)
|
107 |
+
plt.yticks(range(len(values)), features)
|
108 |
+
plt.xlabel('Impact on Prediction (SHAP value)')
|
109 |
+
plt.title(f'Top {top_k} Most Influential k-mers')
|
110 |
+
plt.gca().invert_yaxis()
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
return fig
|
113 |
|
114 |
+
def create_contribution_plot(important_kmers, final_prob):
|
115 |
"""
|
116 |
+
Create waterfall plot showing cumulative feature contributions.
|
|
|
|
|
|
|
117 |
"""
|
118 |
+
plt.style.use('seaborn')
|
119 |
+
fig = plt.figure(figsize=(12, 6))
|
|
|
120 |
|
121 |
+
base_prob = 0.5
|
122 |
+
cumulative = [base_prob]
|
123 |
+
labels = ['Base']
|
124 |
|
125 |
+
for kmer_info in important_kmers:
|
126 |
+
cumulative.append(cumulative[-1] + kmer_info['impact'])
|
127 |
+
labels.append(kmer_info['kmer'])
|
128 |
+
|
129 |
+
plt.plot(range(len(cumulative)), cumulative, 'b-o', linewidth=2)
|
130 |
+
plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
|
131 |
+
|
132 |
+
plt.xticks(range(len(labels)), labels, rotation=45)
|
133 |
+
plt.ylim(0, 1)
|
134 |
+
plt.grid(True, alpha=0.3)
|
135 |
+
plt.title('Cumulative Feature Contributions')
|
136 |
+
plt.ylabel('Probability of Human Origin')
|
|
|
|
|
|
|
|
|
|
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
return fig
|
139 |
|
140 |
+
def predict(file_obj, top_kmers=10, fasta_text=""):
|
|
|
|
|
|
|
|
|
141 |
"""
|
142 |
+
Main prediction function for the Gradio interface.
|
|
|
|
|
|
|
|
|
143 |
"""
|
144 |
+
# Handle input
|
145 |
if fasta_text.strip():
|
146 |
text = fasta_text.strip()
|
147 |
+
elif file_obj is not None:
|
|
|
|
|
148 |
try:
|
149 |
+
text = file_obj.decode('utf-8')
|
|
|
|
|
|
|
150 |
except Exception as e:
|
151 |
+
return f"Error reading file: {str(e)}", None, None
|
152 |
+
else:
|
153 |
+
return "Please provide a FASTA sequence either by file upload or text input.", None, None
|
154 |
|
155 |
# Parse FASTA
|
156 |
sequences = parse_fasta(text)
|
157 |
+
if not sequences:
|
158 |
+
return "No valid FASTA sequences found in input.", None, None
|
159 |
+
|
160 |
header, seq = sequences[0]
|
161 |
|
162 |
+
# Process sequence
|
163 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
164 |
try:
|
165 |
+
model = VirusClassifier(256).to(device)
|
166 |
+
model.load_state_dict(torch.load('model.pt', map_location=device))
|
167 |
scaler = joblib.load('scaler.pkl')
|
168 |
except Exception as e:
|
169 |
+
return f"Error loading model: {str(e)}", None, None
|
170 |
|
171 |
+
# Generate features
|
172 |
+
freq_vector = sequence_to_kmer_vector(seq)
|
173 |
+
scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
|
174 |
+
x_tensor = torch.FloatTensor(scaled_vector).to(device)
|
175 |
|
176 |
+
# Calculate SHAP values and predictions
|
177 |
+
shap_values, human_prob = calculate_shap_values(model, x_tensor)
|
178 |
+
|
179 |
+
# Generate k-mer information
|
180 |
+
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
|
181 |
+
important_indices = np.argsort(np.abs(shap_values))[-top_kmers:]
|
182 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
important_kmers = []
|
184 |
+
for idx in important_indices:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
important_kmers.append({
|
186 |
+
'kmer': kmers[idx],
|
187 |
+
'impact': shap_values[idx],
|
188 |
+
'frequency': freq_vector[idx] * 100,
|
189 |
+
'significance': scaled_vector[0][idx]
|
190 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
+
# Format results text
|
193 |
+
results = [
|
194 |
+
f"Sequence: {header}",
|
195 |
+
f"Prediction: {'Human' if human_prob > 0.5 else 'Non-human'} Origin",
|
196 |
+
f"Confidence: {max(human_prob, 1-human_prob):.3f}",
|
197 |
+
f"Human Probability: {human_prob:.3f}",
|
198 |
+
"\nTop Contributing k-mers:",
|
199 |
+
]
|
200 |
|
201 |
+
for kmer in important_kmers:
|
202 |
+
direction = "→ Human" if kmer['impact'] > 0 else "→ Non-human"
|
203 |
+
results.append(
|
204 |
+
f"• {kmer['kmer']}: {direction} "
|
205 |
+
f"(impact: {kmer['impact']:.3f}, "
|
206 |
+
f"freq: {kmer['frequency']:.2f}%)"
|
|
|
|
|
|
|
|
|
207 |
)
|
208 |
|
209 |
+
# Generate plots
|
210 |
+
shap_plot = create_importance_plot(shap_values, kmers, top_kmers)
|
211 |
+
contribution_plot = create_contribution_plot(important_kmers, human_prob)
|
212 |
+
|
213 |
+
# Convert plots to images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
def fig_to_image(fig):
|
215 |
buf = io.BytesIO()
|
216 |
+
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
|
217 |
buf.seek(0)
|
218 |
+
img = Image.open(buf)
|
219 |
plt.close(fig)
|
220 |
+
return img
|
221 |
+
|
222 |
+
return "\n".join(results), fig_to_image(shap_plot), fig_to_image(contribution_plot)
|
223 |
+
|
224 |
+
# Create Gradio interface
|
225 |
+
css = """
|
226 |
+
.gradio-container {
|
227 |
+
font-family: 'IBM Plex Sans', sans-serif;
|
228 |
+
}
|
229 |
+
.interpretation-container {
|
230 |
+
margin-top: 20px;
|
231 |
+
padding: 15px;
|
232 |
+
border-radius: 8px;
|
233 |
+
background-color: #f8f9fa;
|
234 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
"""
|
236 |
|
237 |
+
with gr.Blocks(css=css) as iface:
|
238 |
+
gr.Markdown("""
|
239 |
+
# Virus Host Classifier
|
240 |
+
This tool predicts whether a viral sequence is likely of human or non-human origin using k-mer frequency analysis.
|
241 |
+
|
242 |
+
### Instructions
|
243 |
+
1. Upload a FASTA file or paste your sequence in FASTA format
|
244 |
+
2. Adjust the number of top k-mers to display (default: 10)
|
245 |
+
3. View the prediction results and feature importance visualizations
|
246 |
+
""")
|
247 |
+
|
248 |
+
with gr.Row():
|
249 |
+
with gr.Column(scale=1):
|
250 |
+
file_input = gr.File(
|
251 |
+
label="Upload FASTA file",
|
252 |
+
file_types=[".fasta", ".fa", ".txt"]
|
253 |
+
)
|
254 |
+
text_input = gr.Textbox(
|
255 |
+
label="Or paste FASTA sequence",
|
256 |
+
placeholder=">sequence_name\nACGTACGT...",
|
257 |
+
lines=5
|
258 |
+
)
|
259 |
+
top_k = gr.Slider(
|
260 |
+
minimum=5,
|
261 |
+
maximum=20,
|
262 |
+
value=10,
|
263 |
+
step=1,
|
264 |
+
label="Number of top k-mers to display"
|
265 |
+
)
|
266 |
+
submit_btn = gr.Button("Analyze Sequence", variant="primary")
|
267 |
+
|
268 |
+
with gr.Column(scale=2):
|
269 |
+
results = gr.Textbox(label="Analysis Results", lines=10)
|
270 |
+
shap_plot = gr.Image(label="Feature Importance Plot")
|
271 |
+
contribution_plot = gr.Image(label="Cumulative Contribution Plot")
|
272 |
+
|
273 |
+
submit_btn.click(
|
274 |
+
predict,
|
275 |
+
inputs=[file_input, top_k, text_input],
|
276 |
+
outputs=[results, shap_plot, contribution_plot]
|
277 |
+
)
|
278 |
+
|
279 |
+
gr.Markdown("""
|
280 |
+
### About
|
281 |
+
- Uses 4-mer frequencies as sequence features
|
282 |
+
- Employs SHAP-like values for feature importance interpretation
|
283 |
+
- Visualizes cumulative feature contributions to the final prediction
|
284 |
+
""")
|
285 |
|
286 |
if __name__ == "__main__":
|
287 |
+
iface.launch()
|