Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -33,14 +33,14 @@ class VirusClassifier(nn.Module):
|
|
33 |
|
34 |
def get_feature_importance(self, x):
|
35 |
"""
|
36 |
-
Calculate gradient-based feature importance
|
37 |
-
|
38 |
"""
|
39 |
x.requires_grad_(True)
|
40 |
output = self.network(x)
|
41 |
probs = torch.softmax(output, dim=1)
|
42 |
|
43 |
-
#
|
44 |
human_prob = probs[..., 1]
|
45 |
if x.grad is not None:
|
46 |
x.grad.zero_()
|
@@ -94,127 +94,160 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
94 |
###############################################################################
|
95 |
# Visualization
|
96 |
###############################################################################
|
97 |
-
def
|
98 |
"""
|
99 |
-
Create a
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
103 |
"""
|
104 |
|
105 |
-
#
|
106 |
-
|
107 |
-
gs = plt.GridSpec(2, 2, width_ratios=[1.2, 1], height_ratios=[1.2, 1], hspace=0.35, wspace=0.3)
|
108 |
-
|
109 |
-
# -------------------------------------------------------------------------
|
110 |
-
# 1. Waterfall-like Plot (top-left subplot)
|
111 |
-
# -------------------------------------------------------------------------
|
112 |
-
ax_waterfall = plt.subplot(gs[0, 0])
|
113 |
|
114 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
baseline = 0.5
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
#
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
if i == 0: # baseline
|
140 |
-
ax_waterfall.annotate(kmer, (i, prob), textcoords="offset points", xytext=(0, -15), ha='center', color='black')
|
141 |
-
continue
|
142 |
-
|
143 |
-
color = "green" if change > 0 else "red"
|
144 |
-
ax_waterfall.annotate(
|
145 |
-
f"{kmer}\n({change:+.3f})",
|
146 |
-
(i, prob),
|
147 |
-
textcoords="offset points",
|
148 |
-
xytext=(0, -15),
|
149 |
-
ha='center',
|
150 |
-
color=color,
|
151 |
-
fontsize=9
|
152 |
-
)
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
# 2. Frequency & σ from Mean (top-right subplot)
|
163 |
-
# -------------------------------------------------------------------------
|
164 |
-
ax_bar = plt.subplot(gs[0, 1])
|
165 |
-
|
166 |
-
kmers = [k["kmer"] for k in important_kmers]
|
167 |
-
frequencies = [k["occurrence"] for k in important_kmers] # in %
|
168 |
-
sigmas = [k["sigma"] for k in important_kmers]
|
169 |
-
directions = [k["direction"] for k in important_kmers]
|
170 |
|
171 |
-
# X-locations
|
172 |
x = np.arange(len(kmers))
|
173 |
width = 0.4
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
178 |
ax_bar.set_ylabel("Frequency (%)")
|
179 |
ax_bar.set_ylim(0, max(frequencies) * 1.2 if frequencies else 1)
|
180 |
-
ax_bar.set_title("Frequency vs. σ from Mean")
|
181 |
|
182 |
# Twin axis for σ
|
183 |
ax_bar_twin = ax_bar.twinx()
|
184 |
-
|
185 |
-
|
|
|
186 |
ax_bar_twin.set_ylabel("Standard Deviations (σ)")
|
187 |
|
|
|
188 |
ax_bar.set_xticks(x)
|
189 |
-
ax_bar.set_xticklabels(kmers, rotation=45, ha='right'
|
190 |
-
|
191 |
-
#
|
192 |
lines1, labels1 = ax_bar.get_legend_handles_labels()
|
193 |
lines2, labels2 = ax_bar_twin.get_legend_handles_labels()
|
194 |
-
ax_bar.legend(lines1 + lines2, labels1 + labels2, loc=
|
195 |
-
|
196 |
-
# -------------------------------------------------------------------------
|
197 |
-
# 3. Top Feature Importances (Bottom, spanning both columns)
|
198 |
-
# -------------------------------------------------------------------------
|
199 |
-
ax_imp = plt.subplot(gs[1, :])
|
200 |
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
|
207 |
-
|
208 |
-
|
|
|
|
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
|
217 |
-
plt.suptitle(title, fontsize=14, y=1.02)
|
218 |
plt.tight_layout()
|
219 |
return fig
|
220 |
|
@@ -224,149 +257,213 @@ def create_visualization(important_kmers, human_prob, title):
|
|
224 |
###############################################################################
|
225 |
def predict(file_obj):
|
226 |
"""
|
227 |
-
Main function
|
228 |
-
1. Reads the uploaded FASTA file
|
229 |
2. Loads the model and scaler.
|
230 |
3. Generates predictions, probabilities, and top k-mers.
|
231 |
-
4.
|
|
|
|
|
|
|
|
|
232 |
"""
|
|
|
233 |
if file_obj is None:
|
234 |
-
return
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
-
# Read text from file
|
237 |
try:
|
|
|
238 |
if isinstance(file_obj, str):
|
239 |
text = file_obj
|
240 |
else:
|
|
|
241 |
text = file_obj.decode('utf-8')
|
242 |
except Exception as e:
|
243 |
-
return
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
k = 4
|
247 |
-
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
248 |
-
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
249 |
-
|
250 |
-
# Load model & scaler
|
251 |
try:
|
252 |
-
device =
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
254 |
state_dict = torch.load('model.pt', map_location=device)
|
255 |
model.load_state_dict(state_dict)
|
256 |
scaler = joblib.load('scaler.pkl')
|
257 |
model.eval()
|
258 |
-
except Exception as e:
|
259 |
-
return f"Error loading model or scaler: {str(e)}", None
|
260 |
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
try:
|
265 |
-
# Parse FASTA
|
266 |
-
sequences = parse_fasta(text)
|
267 |
-
if len(sequences) == 0:
|
268 |
-
return "No valid FASTA sequences found. Please check your input.", None
|
269 |
-
|
270 |
-
header, seq = sequences[0] # For simplicity, we'll only classify the first sequence
|
271 |
|
272 |
-
#
|
273 |
-
raw_freq_vector = sequence_to_kmer_vector(seq)
|
274 |
-
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
275 |
-
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
276 |
-
|
277 |
-
# Inference
|
278 |
with torch.no_grad():
|
279 |
-
|
280 |
-
probs = torch.softmax(
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
283 |
importance, hum_prob_grad = model.get_feature_importance(X_tensor)
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
-
#
|
|
|
287 |
top_k = 10
|
288 |
-
|
289 |
important_kmers = []
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
direction = 'human' if kmer_importance[idx] > 0 else 'non-human'
|
300 |
-
freq = float(raw_freq_vector[idx] * 100) # frequency in %
|
301 |
-
sigma = float(kmer_vector[0][idx]) # scaled value (Z-score if standard scaler)
|
302 |
-
|
303 |
important_kmers.append({
|
304 |
-
'kmer':
|
305 |
-
'
|
|
|
306 |
'direction': direction,
|
307 |
-
'occurrence':
|
308 |
-
'sigma':
|
309 |
})
|
310 |
|
311 |
-
|
312 |
-
|
313 |
-
human_prob = float(probs[0][1])
|
314 |
-
non_human_prob = float(probs[0][0])
|
315 |
-
conf = float(max(probs[0])) # confidence in the predicted class
|
316 |
-
|
317 |
-
# Generate text results
|
318 |
-
results_text = (
|
319 |
f"**Sequence Header**: {header}\n\n"
|
320 |
f"**Predicted Label**: {pred_label}\n"
|
321 |
-
f"**Confidence**: {
|
322 |
f"**Human Probability**: {human_prob:.4f}\n"
|
323 |
f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
|
324 |
"### Most Influential k-mers:\n"
|
325 |
)
|
326 |
-
for
|
327 |
-
direction_text = f"pushes toward {
|
328 |
-
|
329 |
-
sigma_text = f"{abs(
|
330 |
-
|
331 |
-
f"- **{
|
332 |
-
f"
|
333 |
-
f"occurrence = {occurrence_text}, "
|
334 |
-
f"({sigma_text})\n"
|
335 |
)
|
336 |
|
337 |
-
#
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
except Exception as e:
|
348 |
-
return
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
-
return results_text, plot_image
|
351 |
|
352 |
###############################################################################
|
353 |
# Gradio Interface
|
354 |
###############################################################################
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
if __name__ == "__main__":
|
372 |
-
|
|
|
33 |
|
34 |
def get_feature_importance(self, x):
|
35 |
"""
|
36 |
+
Calculate gradient-based feature importance, specifically for the
|
37 |
+
'human' class (index=1) by computing gradient of that probability wrt x.
|
38 |
"""
|
39 |
x.requires_grad_(True)
|
40 |
output = self.network(x)
|
41 |
probs = torch.softmax(output, dim=1)
|
42 |
|
43 |
+
# Probability of 'human' class (index=1)
|
44 |
human_prob = probs[..., 1]
|
45 |
if x.grad is not None:
|
46 |
x.grad.zero_()
|
|
|
94 |
###############################################################################
|
95 |
# Visualization
|
96 |
###############################################################################
|
97 |
+
def create_shap_waterfall_plot(important_kmers, all_kmer_importance, human_prob, title):
|
98 |
"""
|
99 |
+
Create a SHAP-like waterfall plot:
|
100 |
+
- Start at baseline = 0.5
|
101 |
+
- Add a bar for "Other" which is the combined effect of all less-important k-mers
|
102 |
+
- Then apply each of the top k-mers in descending order of absolute importance
|
103 |
+
- Show final predicted human probability as the endpoint
|
104 |
"""
|
105 |
|
106 |
+
# 1) Sort 'important_kmers' by absolute impact descending
|
107 |
+
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
# 2) Compute the total effect of "other" k-mers
|
110 |
+
# We have 256 total features. We selected top 10. Sum the rest.
|
111 |
+
top_ids = set([km['idx'] for km in sorted_kmers])
|
112 |
+
other_contributions = []
|
113 |
+
for i, val in enumerate(all_kmer_importance):
|
114 |
+
if i not in top_ids:
|
115 |
+
other_contributions.append(val)
|
116 |
+
# sum up those "other" contributions
|
117 |
+
other_sum = np.sum(other_contributions)
|
118 |
+
# The "impact" for "other" will be the absolute value, direction depends on sign
|
119 |
+
other_impact = float(abs(other_sum))
|
120 |
+
other_direction = "human" if other_sum > 0 else "non-human"
|
121 |
+
|
122 |
+
# 3) Build a list of all bars: first "other", then each top k-mer
|
123 |
+
# Each bar needs: name, raw_contribution_value
|
124 |
+
# We'll store (label, contribution). The sign indicates direction.
|
125 |
+
bars = []
|
126 |
+
bars.append(("Other", other_sum)) # lumps the leftover k-mers
|
127 |
+
|
128 |
+
for km in sorted_kmers:
|
129 |
+
# We re-inject the sign on the raw gradient
|
130 |
+
# (We stored only the absolute in "impact," so let's create a signed value)
|
131 |
+
signed_val = km['impact'] if km['direction'] == 'human' else -km['impact']
|
132 |
+
bars.append((km['kmer'], signed_val))
|
133 |
+
|
134 |
+
# 4) Waterfall plot data:
|
135 |
+
# We'll accumulate partial sums from baseline=0.5
|
136 |
baseline = 0.5
|
137 |
+
running_val = baseline
|
138 |
+
x_labels = []
|
139 |
+
y_vals = []
|
140 |
+
bar_colors = []
|
141 |
+
|
142 |
+
# We'll use green for positive contributions (pushing toward 'human'),
|
143 |
+
# red for negative contributions (pushing away from 'human')
|
144 |
+
for (label, contrib) in bars:
|
145 |
+
x_labels.append(label)
|
146 |
+
# new value after adding this contribution
|
147 |
+
new_val = running_val + (0.05 * contrib)
|
148 |
+
# ^ scaled by 0.05 for better display. Adjust as desired.
|
149 |
+
|
150 |
+
y_vals.append((running_val, new_val))
|
151 |
+
running_val = new_val
|
152 |
+
if contrib >= 0:
|
153 |
+
bar_colors.append("green")
|
154 |
+
else:
|
155 |
+
bar_colors.append("red")
|
156 |
|
157 |
+
final_prob = running_val
|
158 |
+
# Final point is the model's predicted probability (not always exact, but this is a shap-like idea).
|
159 |
+
# If we want to forcibly ensure final_prob = human_prob, we could do:
|
160 |
+
# correction = human_prob - running_val
|
161 |
+
# running_val += correction
|
162 |
+
# but for now let's keep the "waterfall" purely additive from the gradient.
|
163 |
|
164 |
+
# Let's plot:
|
165 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
166 |
+
|
167 |
+
# We'll create the bars manually
|
168 |
+
x_positions = np.arange(len(x_labels))
|
169 |
+
last_end = baseline
|
170 |
+
|
171 |
+
for i, ((start_val, end_val), color) in enumerate(zip(y_vals, bar_colors)):
|
172 |
+
# The bar's height is the difference
|
173 |
+
height = end_val - start_val
|
174 |
+
ax.bar(i, height, bottom=start_val, color=color, edgecolor='black', alpha=0.7)
|
175 |
+
ax.text(i, (start_val + end_val) / 2, f"{height:+.3f}", ha='center', va='center', color='white', fontsize=8)
|
176 |
+
|
177 |
+
ax.axhline(y=baseline, color='black', linestyle='--', linewidth=1)
|
178 |
+
ax.set_xticks(x_positions)
|
179 |
+
ax.set_xticklabels(x_labels, rotation=45, ha='right')
|
180 |
+
ax.set_ylim(0, 1)
|
181 |
+
ax.set_ylabel("Running Probability (Human)")
|
182 |
+
ax.set_title(f"SHAP-like Waterfall — Final Probability: {final_prob:.3f} (Model Probability: {human_prob:.3f})")
|
183 |
|
184 |
+
plt.tight_layout()
|
185 |
+
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
+
def create_frequency_sigma_plot(important_kmers, title):
|
188 |
+
"""Creates a bar plot of the top k-mers (by importance) showing frequency (%) and σ from mean."""
|
189 |
+
# Sort by absolute impact
|
190 |
+
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
|
191 |
+
kmers = [k["kmer"] for k in sorted_kmers]
|
192 |
+
frequencies = [k["occurrence"] for k in sorted_kmers] # in %
|
193 |
+
sigmas = [k["sigma"] for k in sorted_kmers]
|
194 |
+
directions = [k["direction"] for k in sorted_kmers]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
|
|
196 |
x = np.arange(len(kmers))
|
197 |
width = 0.4
|
198 |
|
199 |
+
fig, ax_bar = plt.subplots(figsize=(10, 6))
|
200 |
+
|
201 |
+
# Bar for frequency
|
202 |
+
bars_freq = ax_bar.bar(
|
203 |
+
x - width/2, frequencies, width, alpha=0.7,
|
204 |
+
color=["green" if d=="human" else "red" for d in directions],
|
205 |
+
label="Frequency (%)"
|
206 |
+
)
|
207 |
ax_bar.set_ylabel("Frequency (%)")
|
208 |
ax_bar.set_ylim(0, max(frequencies) * 1.2 if frequencies else 1)
|
|
|
209 |
|
210 |
# Twin axis for σ
|
211 |
ax_bar_twin = ax_bar.twinx()
|
212 |
+
bars_sigma = ax_bar_twin.bar(
|
213 |
+
x + width/2, sigmas, width, alpha=0.5, color="gray", label="σ from Mean"
|
214 |
+
)
|
215 |
ax_bar_twin.set_ylabel("Standard Deviations (σ)")
|
216 |
|
217 |
+
ax_bar.set_title(f"Frequency & σ from Mean for Top k-mers — {title}")
|
218 |
ax_bar.set_xticks(x)
|
219 |
+
ax_bar.set_xticklabels(kmers, rotation=45, ha='right')
|
220 |
+
|
221 |
+
# Combined legend
|
222 |
lines1, labels1 = ax_bar.get_legend_handles_labels()
|
223 |
lines2, labels2 = ax_bar_twin.get_legend_handles_labels()
|
224 |
+
ax_bar.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
+
plt.tight_layout()
|
227 |
+
return fig
|
228 |
+
|
229 |
+
def create_importance_bar_plot(important_kmers, title):
|
230 |
+
"""
|
231 |
+
Create a simple bar chart showing the absolute gradient magnitude
|
232 |
+
for the top k-mers, sorted descending.
|
233 |
+
"""
|
234 |
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
|
235 |
+
kmers = [k['kmer'] for k in sorted_kmers]
|
236 |
+
impacts = [k['impact'] for k in sorted_kmers]
|
237 |
+
directions = [k["direction"] for k in sorted_kmers]
|
238 |
|
239 |
+
x = np.arange(len(kmers))
|
240 |
+
|
241 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
242 |
+
bar_colors = ["green" if d=="human" else "red" for d in directions]
|
243 |
|
244 |
+
ax.bar(x, impacts, color=bar_colors, alpha=0.7)
|
245 |
+
ax.set_xticks(x)
|
246 |
+
ax.set_xticklabels(kmers, rotation=45, ha='right')
|
247 |
+
ax.set_title(f"Absolute Feature Importance (Top k-mers) — {title}")
|
248 |
+
ax.set_ylabel("Gradient Magnitude")
|
249 |
+
ax.grid(axis="y", alpha=0.3)
|
250 |
|
|
|
251 |
plt.tight_layout()
|
252 |
return fig
|
253 |
|
|
|
257 |
###############################################################################
|
258 |
def predict(file_obj):
|
259 |
"""
|
260 |
+
Main function for Gradio:
|
261 |
+
1. Reads the uploaded FASTA file or text.
|
262 |
2. Loads the model and scaler.
|
263 |
3. Generates predictions, probabilities, and top k-mers.
|
264 |
+
4. Returns multiple outputs:
|
265 |
+
- A textual summary (Markdown).
|
266 |
+
- Waterfall plot.
|
267 |
+
- Frequency & sigma plot.
|
268 |
+
- Absolute importance bar plot.
|
269 |
"""
|
270 |
+
# 0. Basic file read
|
271 |
if file_obj is None:
|
272 |
+
return (
|
273 |
+
"Please upload a FASTA file.",
|
274 |
+
None,
|
275 |
+
None,
|
276 |
+
None
|
277 |
+
)
|
278 |
|
|
|
279 |
try:
|
280 |
+
# If user provided raw text, use that
|
281 |
if isinstance(file_obj, str):
|
282 |
text = file_obj
|
283 |
else:
|
284 |
+
# If user uploaded a file, decode it
|
285 |
text = file_obj.decode('utf-8')
|
286 |
except Exception as e:
|
287 |
+
return (
|
288 |
+
f"Error reading file: {str(e)}",
|
289 |
+
None,
|
290 |
+
None,
|
291 |
+
None
|
292 |
+
)
|
293 |
|
294 |
+
# 1. Parse FASTA
|
295 |
+
sequences = parse_fasta(text)
|
296 |
+
if len(sequences) == 0:
|
297 |
+
return (
|
298 |
+
"No valid FASTA sequences found. Please check your input.",
|
299 |
+
None,
|
300 |
+
None,
|
301 |
+
None
|
302 |
+
)
|
303 |
+
# We’ll just classify the first sequence for demonstration
|
304 |
+
header, seq = sequences[0]
|
305 |
+
|
306 |
+
# 2. Create k-mer vector & load model
|
307 |
k = 4
|
|
|
|
|
|
|
|
|
308 |
try:
|
309 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
310 |
+
|
311 |
+
# Prepare raw freq vector & scale
|
312 |
+
raw_freq_vector = sequence_to_kmer_vector(seq, k=k)
|
313 |
+
|
314 |
+
# Load model & scaler
|
315 |
+
model = VirusClassifier(input_shape=4**k).to(device)
|
316 |
state_dict = torch.load('model.pt', map_location=device)
|
317 |
model.load_state_dict(state_dict)
|
318 |
scaler = joblib.load('scaler.pkl')
|
319 |
model.eval()
|
|
|
|
|
320 |
|
321 |
+
scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
322 |
+
X_tensor = torch.FloatTensor(scaled_vector).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
|
324 |
+
# 3. Inference
|
|
|
|
|
|
|
|
|
|
|
325 |
with torch.no_grad():
|
326 |
+
logits = model(X_tensor)
|
327 |
+
probs = torch.softmax(logits, dim=1)
|
328 |
+
human_prob = float(probs[0][1])
|
329 |
+
non_human_prob = float(probs[0][0])
|
330 |
+
pred_class = 1 if human_prob >= non_human_prob else 0
|
331 |
+
pred_label = "human" if pred_class == 1 else "non-human"
|
332 |
+
confidence = float(max(probs[0]))
|
333 |
+
|
334 |
+
# 4. Feature importance
|
335 |
importance, hum_prob_grad = model.get_feature_importance(X_tensor)
|
336 |
+
# shape: [1, 256]
|
337 |
+
kmer_importances = importance[0].cpu().numpy()
|
338 |
+
|
339 |
+
# We’ll store them as a dictionary: index -> (k-mer, importance)
|
340 |
+
# Build up a dict for k-mer strings
|
341 |
+
kmers_list = [''.join(p) for p in product("ACGT", repeat=k)]
|
342 |
+
kmer_dict = {km: i for i, km in enumerate(kmers_list)}
|
343 |
|
344 |
+
# 5. Get the top 10 k-mers by absolute importance
|
345 |
+
abs_importance = np.abs(kmer_importances)
|
346 |
top_k = 10
|
347 |
+
top_idxs = np.argsort(abs_importance)[-top_k:][::-1] # descending
|
348 |
important_kmers = []
|
349 |
+
for idx in top_idxs:
|
350 |
+
# Find the k-mer by index
|
351 |
+
kmer_str = kmers_list[idx]
|
352 |
+
# direction
|
353 |
+
direction = "human" if kmer_importances[idx] > 0 else "non-human"
|
354 |
+
# frequency in % from raw_freq_vector
|
355 |
+
freq_percent = float(raw_freq_vector[idx] * 100)
|
356 |
+
# sigma from scaled vector
|
357 |
+
sigma_val = float(scaled_vector[0][idx])
|
|
|
|
|
|
|
|
|
358 |
important_kmers.append({
|
359 |
+
'kmer': kmer_str,
|
360 |
+
'idx': idx,
|
361 |
+
'impact': float(abs_importance[idx]),
|
362 |
'direction': direction,
|
363 |
+
'occurrence': freq_percent,
|
364 |
+
'sigma': sigma_val
|
365 |
})
|
366 |
|
367 |
+
# 6. Text Summary
|
368 |
+
summary_text = (
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
f"**Sequence Header**: {header}\n\n"
|
370 |
f"**Predicted Label**: {pred_label}\n"
|
371 |
+
f"**Confidence**: {confidence:.4f}\n\n"
|
372 |
f"**Human Probability**: {human_prob:.4f}\n"
|
373 |
f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
|
374 |
"### Most Influential k-mers:\n"
|
375 |
)
|
376 |
+
for km in important_kmers:
|
377 |
+
direction_text = f"(pushes toward {km['direction']})"
|
378 |
+
freq_text = f"{km['occurrence']:.2f}%"
|
379 |
+
sigma_text = f"{abs(km['sigma']):.2f}σ " + ("above" if km['sigma']>0 else "below") + " mean"
|
380 |
+
summary_text += (
|
381 |
+
f"- **{km['kmer']}**: impact={km['impact']:.4f}, {direction_text}, "
|
382 |
+
f"occurrence={freq_text}, ({sigma_text})\n"
|
|
|
|
|
383 |
)
|
384 |
|
385 |
+
# 7. Plots
|
386 |
+
# a) SHAP-like Waterfall Plot
|
387 |
+
fig_waterfall = create_shap_waterfall_plot(
|
388 |
+
important_kmers,
|
389 |
+
kmer_importances,
|
390 |
+
human_prob,
|
391 |
+
f"{header}"
|
392 |
+
)
|
393 |
+
buf1 = io.BytesIO()
|
394 |
+
fig_waterfall.savefig(buf1, format='png', bbox_inches='tight', dpi=120)
|
395 |
+
buf1.seek(0)
|
396 |
+
waterfall_img = Image.open(buf1)
|
397 |
+
plt.close(fig_waterfall)
|
398 |
+
|
399 |
+
# b) Frequency & σ Plot (top 10 k-mers)
|
400 |
+
fig_freq_sigma = create_frequency_sigma_plot(
|
401 |
+
important_kmers,
|
402 |
+
f"{header}"
|
403 |
+
)
|
404 |
+
buf2 = io.BytesIO()
|
405 |
+
fig_freq_sigma.savefig(buf2, format='png', bbox_inches='tight', dpi=120)
|
406 |
+
buf2.seek(0)
|
407 |
+
freq_sigma_img = Image.open(buf2)
|
408 |
+
plt.close(fig_freq_sigma)
|
409 |
+
|
410 |
+
# c) Absolute Importance Bar Plot
|
411 |
+
fig_imp = create_importance_bar_plot(
|
412 |
+
important_kmers,
|
413 |
+
f"{header}"
|
414 |
+
)
|
415 |
+
buf3 = io.BytesIO()
|
416 |
+
fig_imp.savefig(buf3, format='png', bbox_inches='tight', dpi=120)
|
417 |
+
buf3.seek(0)
|
418 |
+
importance_img = Image.open(buf3)
|
419 |
+
plt.close(fig_imp)
|
420 |
+
|
421 |
+
return summary_text, waterfall_img, freq_sigma_img, importance_img
|
422 |
+
|
423 |
except Exception as e:
|
424 |
+
return (
|
425 |
+
f"Error during prediction or visualization: {str(e)}",
|
426 |
+
None,
|
427 |
+
None,
|
428 |
+
None
|
429 |
+
)
|
430 |
|
|
|
431 |
|
432 |
###############################################################################
|
433 |
# Gradio Interface
|
434 |
###############################################################################
|
435 |
+
with gr.Blocks(title="Advanced Virus Host Classifier") as demo:
|
436 |
+
gr.Markdown(
|
437 |
+
"""
|
438 |
+
# Advanced Virus Host Classifier
|
439 |
+
**Upload a FASTA file** containing a single nucleotide sequence.
|
440 |
+
The model will predict whether this sequence is **human** or **non-human**,
|
441 |
+
provide a confidence score, and highlight the most influential k-mers
|
442 |
+
(using a SHAP-like waterfall plot) along with two additional plots.
|
443 |
+
"""
|
444 |
+
)
|
445 |
+
|
446 |
+
with gr.Row():
|
447 |
+
file_in = gr.File(label="Upload FASTA", type="binary")
|
448 |
+
btn = gr.Button("Run Prediction")
|
449 |
+
|
450 |
+
# We will create multiple tabs for our outputs
|
451 |
+
with gr.Tabs():
|
452 |
+
with gr.Tab("Prediction Results"):
|
453 |
+
md_out = gr.Markdown()
|
454 |
+
with gr.Tab("SHAP-like Waterfall Plot"):
|
455 |
+
water_out = gr.Image()
|
456 |
+
with gr.Tab("Frequency & σ Plot"):
|
457 |
+
freq_out = gr.Image()
|
458 |
+
with gr.Tab("Importance Bar Plot"):
|
459 |
+
imp_out = gr.Image()
|
460 |
+
|
461 |
+
# Link the button
|
462 |
+
btn.click(
|
463 |
+
fn=predict,
|
464 |
+
inputs=[file_in],
|
465 |
+
outputs=[md_out, water_out, freq_out, imp_out]
|
466 |
+
)
|
467 |
|
468 |
if __name__ == "__main__":
|
469 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|