Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,10 +4,8 @@ import joblib
|
|
4 |
import numpy as np
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
7 |
-
import shap
|
8 |
import matplotlib.pyplot as plt
|
9 |
import io
|
10 |
-
import json
|
11 |
from PIL import Image
|
12 |
|
13 |
class VirusClassifier(nn.Module):
|
@@ -31,16 +29,16 @@ class VirusClassifier(nn.Module):
|
|
31 |
return self.network(x)
|
32 |
|
33 |
def get_feature_importance(self, x):
|
34 |
-
"""Calculate feature importance using gradient-based method
|
35 |
x.requires_grad_(True)
|
36 |
output = self.network(x)
|
37 |
probs = torch.softmax(output, dim=1)
|
38 |
|
39 |
-
#
|
40 |
human_prob = probs[..., 1]
|
|
|
|
|
41 |
human_prob.backward()
|
42 |
-
|
43 |
-
# The gradient shows how each feature affects the human probability
|
44 |
importance = x.grad
|
45 |
|
46 |
return importance, float(human_prob)
|
@@ -82,6 +80,94 @@ def parse_fasta(text):
|
|
82 |
sequences.append((current_header, ''.join(current_sequence)))
|
83 |
return sequences
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
def predict(file_obj):
|
86 |
if file_obj is None:
|
87 |
return "Please upload a FASTA file", None
|
@@ -119,172 +205,64 @@ def predict(file_obj):
|
|
119 |
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
120 |
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
121 |
|
122 |
-
#
|
123 |
with torch.no_grad():
|
124 |
output = model(X_tensor)
|
125 |
probs = torch.softmax(output, dim=1)
|
126 |
-
human_prob = float(probs[0][1])
|
127 |
-
|
128 |
-
# Get feature importance using integrated gradients
|
129 |
-
baseline = torch.zeros_like(X_tensor) # baseline of zeros
|
130 |
-
steps = 50
|
131 |
-
|
132 |
-
all_importance = []
|
133 |
-
for i in range(steps + 1):
|
134 |
-
alpha = i / steps
|
135 |
-
interpolated = baseline + alpha * (X_tensor - baseline)
|
136 |
-
interpolated.requires_grad_(True)
|
137 |
-
|
138 |
-
output = model(interpolated)
|
139 |
-
probs = torch.softmax(output, dim=1)
|
140 |
-
human_class = probs[..., 1]
|
141 |
-
|
142 |
-
if interpolated.grad is not None:
|
143 |
-
interpolated.grad.zero_()
|
144 |
-
human_class.backward()
|
145 |
-
all_importance.append(interpolated.grad.cpu().numpy())
|
146 |
|
147 |
-
#
|
148 |
-
|
149 |
-
|
150 |
-
target_diff = human_prob - 0.5 # difference from neutral prediction
|
151 |
-
current_sum = np.sum(kmer_importance)
|
152 |
-
if current_sum != 0: # avoid division by zero
|
153 |
-
kmer_importance = kmer_importance * (target_diff / current_sum)
|
154 |
|
155 |
-
# Get top k-mers
|
156 |
top_k = 10
|
157 |
top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
|
158 |
-
important_kmers = [
|
159 |
-
{
|
160 |
-
'kmer': list(kmer_dict.keys())[list(kmer_dict.values()).index(i)],
|
161 |
-
'importance': float(kmer_importance[i]),
|
162 |
-
'frequency': float(raw_freq_vector[i]),
|
163 |
-
'scaled': float(kmer_vector[0][i])
|
164 |
-
}
|
165 |
-
for i in top_indices
|
166 |
-
]
|
167 |
-
|
168 |
-
# Prepare data for SHAP waterfall plot
|
169 |
-
top_features = [item['kmer'] for item in important_kmers]
|
170 |
-
top_values = [item['importance'] for item in important_kmers]
|
171 |
-
|
172 |
-
# Calculate the impact of remaining features
|
173 |
-
others_mask = np.ones_like(kmer_importance, dtype=bool)
|
174 |
-
others_mask[top_indices] = False
|
175 |
-
others_sum = np.sum(kmer_importance[others_mask])
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
human_prob = float(probs[0][1])
|
185 |
-
|
186 |
-
# Create SHAP explanation
|
187 |
-
# We'll use the actual probabilities for alignment
|
188 |
-
explanation = shap.Explanation(
|
189 |
-
values=np.array(top_values),
|
190 |
-
base_values=0.5, # Start from neutral prediction
|
191 |
-
data=np.array([
|
192 |
-
raw_freq_vector[kmer_dict[feat]] if feat != "Others"
|
193 |
-
else np.sum(raw_freq_vector[others_mask])
|
194 |
-
for feat in top_features
|
195 |
-
]),
|
196 |
-
feature_names=top_features
|
197 |
-
)
|
198 |
-
explanation.expected_value = 0.5 # Start from neutral prediction
|
199 |
-
|
200 |
-
# Calculate step-by-step probabilities
|
201 |
-
current_prob = 0.5 # Start at neutral
|
202 |
-
steps = [('Start', current_prob, 0)]
|
203 |
-
|
204 |
-
# Process each k-mer contribution
|
205 |
-
for kmer in important_kmers:
|
206 |
-
change = kmer['importance']
|
207 |
-
current_prob += change
|
208 |
-
steps.append((kmer['kmer'], current_prob, change))
|
209 |
-
|
210 |
-
# Add final "Others" contribution
|
211 |
-
current_prob += others_sum
|
212 |
-
steps.append(('Others', current_prob, others_sum))
|
213 |
-
|
214 |
-
# Create step plot
|
215 |
-
plt.figure(figsize=(12, 6))
|
216 |
-
x = range(len(steps))
|
217 |
-
y = [step[1] for step in steps]
|
218 |
-
|
219 |
-
# Plot steps
|
220 |
-
plt.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
|
221 |
-
plt.plot(x, y, 'b.', markersize=10)
|
222 |
-
|
223 |
-
# Add reference line
|
224 |
-
plt.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
|
225 |
-
|
226 |
-
# Customize plot
|
227 |
-
plt.grid(True, linestyle='--', alpha=0.7)
|
228 |
-
plt.ylim(0, 1)
|
229 |
-
plt.ylabel('Human Probability')
|
230 |
-
plt.title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
|
231 |
-
|
232 |
-
# Add labels for each point
|
233 |
-
for i, (kmer, prob, change) in enumerate(steps):
|
234 |
-
# Add k-mer label
|
235 |
-
plt.annotate(kmer,
|
236 |
-
(i, prob),
|
237 |
-
xytext=(0, 10 if i % 2 == 0 else -20), # Alternate up/down
|
238 |
-
textcoords='offset points',
|
239 |
-
ha='center',
|
240 |
-
rotation=45 if len(kmer) > 5 else 0)
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
textcoords='offset points',
|
250 |
-
ha='center',
|
251 |
-
color=color)
|
252 |
-
|
253 |
-
plt.legend()
|
254 |
-
plt.tight_layout()
|
255 |
|
256 |
-
#
|
257 |
-
buf = io.BytesIO()
|
258 |
-
plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
|
259 |
-
buf.seek(0)
|
260 |
-
plot_image = Image.open(buf)
|
261 |
-
plt.close()
|
262 |
-
|
263 |
-
# Calculate final probabilities
|
264 |
-
with torch.no_grad():
|
265 |
-
output = model(X_tensor)
|
266 |
-
probs = torch.softmax(output, dim=1)
|
267 |
-
|
268 |
pred_class = 1 if probs[0][1] > probs[0][0] else 0
|
269 |
pred_label = 'human' if pred_class == 1 else 'non-human'
|
|
|
270 |
|
271 |
-
|
272 |
-
results_text += f"""Sequence: {header}
|
273 |
Prediction: {pred_label}
|
274 |
Confidence: {float(max(probs[0])):0.4f}
|
275 |
-
Human probability: {
|
276 |
Non-human probability: {float(probs[0][0]):0.4f}
|
277 |
Most influential k-mers (ranked by importance):"""
|
278 |
|
279 |
for kmer in important_kmers:
|
280 |
-
direction = "human" if kmer['importance'] > 0 else "non-human"
|
281 |
results_text += f"\n {kmer['kmer']}: "
|
282 |
-
results_text += f"pushes toward {direction} (impact={
|
283 |
-
results_text += f"occurrence={kmer['
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
except Exception as e:
|
290 |
return f"Error processing sequences: {str(e)}", None
|
@@ -294,7 +272,10 @@ Most influential k-mers (ranked by importance):"""
|
|
294 |
iface = gr.Interface(
|
295 |
fn=predict,
|
296 |
inputs=gr.File(label="Upload FASTA file", type="binary"),
|
297 |
-
outputs=[
|
|
|
|
|
|
|
298 |
title="Virus Host Classifier"
|
299 |
)
|
300 |
|
|
|
4 |
import numpy as np
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import io
|
|
|
9 |
from PIL import Image
|
10 |
|
11 |
class VirusClassifier(nn.Module):
|
|
|
29 |
return self.network(x)
|
30 |
|
31 |
def get_feature_importance(self, x):
|
32 |
+
"""Calculate feature importance using gradient-based method"""
|
33 |
x.requires_grad_(True)
|
34 |
output = self.network(x)
|
35 |
probs = torch.softmax(output, dim=1)
|
36 |
|
37 |
+
# Get importance for human class (index 1)
|
38 |
human_prob = probs[..., 1]
|
39 |
+
if x.grad is not None:
|
40 |
+
x.grad.zero_()
|
41 |
human_prob.backward()
|
|
|
|
|
42 |
importance = x.grad
|
43 |
|
44 |
return importance, float(human_prob)
|
|
|
80 |
sequences.append((current_header, ''.join(current_sequence)))
|
81 |
return sequences
|
82 |
|
83 |
+
def create_visualization(important_kmers, human_prob, title):
|
84 |
+
"""Create a comprehensive visualization of k-mer impacts"""
|
85 |
+
fig = plt.figure(figsize=(15, 10))
|
86 |
+
|
87 |
+
# Create grid for subplots
|
88 |
+
gs = plt.GridSpec(2, 1, height_ratios=[1.5, 1], hspace=0.3)
|
89 |
+
|
90 |
+
# 1. Probability Step Plot
|
91 |
+
ax1 = plt.subplot(gs[0])
|
92 |
+
current_prob = 0.5
|
93 |
+
steps = [('Start', current_prob, 0)]
|
94 |
+
|
95 |
+
for kmer in important_kmers:
|
96 |
+
change = kmer['impact'] * (-1 if kmer['direction'] == 'non-human' else 1)
|
97 |
+
current_prob += change
|
98 |
+
steps.append((kmer['kmer'], current_prob, change))
|
99 |
+
|
100 |
+
x = range(len(steps))
|
101 |
+
y = [step[1] for step in steps]
|
102 |
+
|
103 |
+
# Plot steps
|
104 |
+
ax1.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
|
105 |
+
ax1.plot(x, y, 'b.', markersize=10)
|
106 |
+
|
107 |
+
# Add reference line
|
108 |
+
ax1.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
|
109 |
+
|
110 |
+
# Customize plot
|
111 |
+
ax1.grid(True, linestyle='--', alpha=0.7)
|
112 |
+
ax1.set_ylim(0, 1)
|
113 |
+
ax1.set_ylabel('Human Probability')
|
114 |
+
ax1.set_title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
|
115 |
+
|
116 |
+
# Add labels for each point
|
117 |
+
for i, (kmer, prob, change) in enumerate(steps):
|
118 |
+
# Add k-mer label
|
119 |
+
ax1.annotate(kmer,
|
120 |
+
(i, prob),
|
121 |
+
xytext=(0, 10 if i % 2 == 0 else -20),
|
122 |
+
textcoords='offset points',
|
123 |
+
ha='center',
|
124 |
+
rotation=45)
|
125 |
+
|
126 |
+
# Add change value
|
127 |
+
if i > 0:
|
128 |
+
change_text = f'{change:+.3f}'
|
129 |
+
color = 'green' if change > 0 else 'red'
|
130 |
+
ax1.annotate(change_text,
|
131 |
+
(i, prob),
|
132 |
+
xytext=(0, -20 if i % 2 == 0 else 10),
|
133 |
+
textcoords='offset points',
|
134 |
+
ha='center',
|
135 |
+
color=color)
|
136 |
+
|
137 |
+
ax1.legend()
|
138 |
+
|
139 |
+
# 2. K-mer Frequency and Sigma Plot
|
140 |
+
ax2 = plt.subplot(gs[1])
|
141 |
+
|
142 |
+
# Prepare data
|
143 |
+
kmers = [k['kmer'] for k in important_kmers]
|
144 |
+
frequencies = [k['occurrence'] for k in important_kmers]
|
145 |
+
sigmas = [k['sigma'] for k in important_kmers]
|
146 |
+
colors = ['g' if k['direction'] == 'human' else 'r' for k in important_kmers]
|
147 |
+
|
148 |
+
# Create bar plot for frequencies
|
149 |
+
x = np.arange(len(kmers))
|
150 |
+
width = 0.35
|
151 |
+
|
152 |
+
ax2.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
|
153 |
+
ax2_twin = ax2.twinx()
|
154 |
+
ax2_twin.bar(x + width/2, sigmas, width, label='σ from mean', color=[c if s > 0 else 'gray' for c, s in zip(colors, sigmas)], alpha=0.3)
|
155 |
+
|
156 |
+
# Customize plot
|
157 |
+
ax2.set_xticks(x)
|
158 |
+
ax2.set_xticklabels(kmers, rotation=45)
|
159 |
+
ax2.set_ylabel('Frequency (%)')
|
160 |
+
ax2_twin.set_ylabel('Standard Deviations (σ) from Mean')
|
161 |
+
ax2.set_title('K-mer Frequencies and Statistical Significance')
|
162 |
+
|
163 |
+
# Add legends
|
164 |
+
lines1, labels1 = ax2.get_legend_handles_labels()
|
165 |
+
lines2, labels2 = ax2_twin.get_legend_handles_labels()
|
166 |
+
ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
|
167 |
+
|
168 |
+
plt.tight_layout()
|
169 |
+
return fig
|
170 |
+
|
171 |
def predict(file_obj):
|
172 |
if file_obj is None:
|
173 |
return "Please upload a FASTA file", None
|
|
|
205 |
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
206 |
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
207 |
|
208 |
+
# Get model predictions
|
209 |
with torch.no_grad():
|
210 |
output = model(X_tensor)
|
211 |
probs = torch.softmax(output, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
+
# Get feature importance
|
214 |
+
importance, _ = model.get_feature_importance(X_tensor)
|
215 |
+
kmer_importance = importance[0].cpu().numpy()
|
|
|
|
|
|
|
|
|
216 |
|
217 |
+
# Get top k-mers
|
218 |
top_k = 10
|
219 |
top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
+
important_kmers = []
|
222 |
+
for idx in top_indices:
|
223 |
+
kmer = list(kmer_dict.keys())[list(kmer_dict.values()).index(idx)]
|
224 |
+
imp = float(abs(kmer_importance[idx]))
|
225 |
+
direction = 'human' if kmer_importance[idx] > 0 else 'non-human'
|
226 |
+
freq = float(raw_freq_vector[idx] * 100) # Convert to percentage
|
227 |
+
sigma = float(kmer_vector[0][idx])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
+
important_kmers.append({
|
230 |
+
'kmer': kmer,
|
231 |
+
'impact': imp,
|
232 |
+
'direction': direction,
|
233 |
+
'occurrence': freq,
|
234 |
+
'sigma': sigma
|
235 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
+
# Generate text results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
pred_class = 1 if probs[0][1] > probs[0][0] else 0
|
239 |
pred_label = 'human' if pred_class == 1 else 'non-human'
|
240 |
+
human_prob = float(probs[0][1])
|
241 |
|
242 |
+
results_text = f"""Sequence: {header}
|
|
|
243 |
Prediction: {pred_label}
|
244 |
Confidence: {float(max(probs[0])):0.4f}
|
245 |
+
Human probability: {human_prob:0.4f}
|
246 |
Non-human probability: {float(probs[0][0]):0.4f}
|
247 |
Most influential k-mers (ranked by importance):"""
|
248 |
|
249 |
for kmer in important_kmers:
|
|
|
250 |
results_text += f"\n {kmer['kmer']}: "
|
251 |
+
results_text += f"pushes toward {kmer['direction']} (impact={kmer['impact']:.4f}), "
|
252 |
+
results_text += f"occurrence={kmer['occurrence']:.2f}% of sequence "
|
253 |
+
results_text += f"(appears {abs(kmer['sigma']):.2f}σ "
|
254 |
+
results_text += "more" if kmer['sigma'] > 0 else "less"
|
255 |
+
results_text += " than average)"
|
256 |
+
|
257 |
+
# Create visualization
|
258 |
+
fig = create_visualization(important_kmers, human_prob, header)
|
259 |
+
|
260 |
+
# Save plot
|
261 |
+
buf = io.BytesIO()
|
262 |
+
fig.savefig(buf, format='png', bbox_inches='tight', dpi=300)
|
263 |
+
buf.seek(0)
|
264 |
+
plot_image = Image.open(buf)
|
265 |
+
plt.close(fig)
|
266 |
|
267 |
except Exception as e:
|
268 |
return f"Error processing sequences: {str(e)}", None
|
|
|
272 |
iface = gr.Interface(
|
273 |
fn=predict,
|
274 |
inputs=gr.File(label="Upload FASTA file", type="binary"),
|
275 |
+
outputs=[
|
276 |
+
gr.Textbox(label="Results"),
|
277 |
+
gr.Image(label="K-mer Analysis Visualization")
|
278 |
+
],
|
279 |
title="Virus Host Classifier"
|
280 |
)
|
281 |
|