Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
|
|
7 |
import shap
|
8 |
import matplotlib.pyplot as plt
|
9 |
import io
|
|
|
10 |
from PIL import Image
|
11 |
|
12 |
class VirusClassifier(nn.Module):
|
@@ -183,6 +184,7 @@ def predict(file_obj):
|
|
183 |
human_prob = float(probs[0][1])
|
184 |
|
185 |
# Create SHAP explanation
|
|
|
186 |
explanation = shap.Explanation(
|
187 |
values=np.array(top_values),
|
188 |
base_values=0.5, # Start from neutral prediction
|
@@ -196,36 +198,67 @@ def predict(file_obj):
|
|
196 |
explanation.expected_value = 0.5 # Start from neutral prediction
|
197 |
|
198 |
# Calculate step-by-step probabilities
|
199 |
-
step_probs = []
|
200 |
current_prob = 0.5 # Start at neutral
|
201 |
-
|
202 |
|
203 |
# Process each k-mer contribution
|
204 |
-
for
|
205 |
change = kmer['importance']
|
206 |
current_prob += change
|
207 |
-
|
208 |
-
"step": str(i),
|
209 |
-
"probability": current_prob,
|
210 |
-
"kmer": kmer['kmer'],
|
211 |
-
"change": change
|
212 |
-
})
|
213 |
|
214 |
# Add final "Others" contribution
|
215 |
current_prob += others_sum
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
print(f"Steps data: {steps_json}") # For debugging
|
226 |
|
227 |
-
#
|
228 |
-
|
|
|
|
|
|
|
|
|
229 |
|
230 |
# Calculate final probabilities
|
231 |
with torch.no_grad():
|
|
|
7 |
import shap
|
8 |
import matplotlib.pyplot as plt
|
9 |
import io
|
10 |
+
import json
|
11 |
from PIL import Image
|
12 |
|
13 |
class VirusClassifier(nn.Module):
|
|
|
184 |
human_prob = float(probs[0][1])
|
185 |
|
186 |
# Create SHAP explanation
|
187 |
+
# We'll use the actual probabilities for alignment
|
188 |
explanation = shap.Explanation(
|
189 |
values=np.array(top_values),
|
190 |
base_values=0.5, # Start from neutral prediction
|
|
|
198 |
explanation.expected_value = 0.5 # Start from neutral prediction
|
199 |
|
200 |
# Calculate step-by-step probabilities
|
|
|
201 |
current_prob = 0.5 # Start at neutral
|
202 |
+
steps = [('Start', current_prob, 0)]
|
203 |
|
204 |
# Process each k-mer contribution
|
205 |
+
for kmer in important_kmers:
|
206 |
change = kmer['importance']
|
207 |
current_prob += change
|
208 |
+
steps.append((kmer['kmer'], current_prob, change))
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
# Add final "Others" contribution
|
211 |
current_prob += others_sum
|
212 |
+
steps.append(('Others', current_prob, others_sum))
|
213 |
+
|
214 |
+
# Create step plot
|
215 |
+
plt.figure(figsize=(12, 6))
|
216 |
+
x = range(len(steps))
|
217 |
+
y = [step[1] for step in steps]
|
218 |
+
|
219 |
+
# Plot steps
|
220 |
+
plt.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
|
221 |
+
plt.plot(x, y, 'b.', markersize=10)
|
222 |
+
|
223 |
+
# Add reference line
|
224 |
+
plt.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
|
225 |
+
|
226 |
+
# Customize plot
|
227 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
228 |
+
plt.ylim(0, 1)
|
229 |
+
plt.ylabel('Human Probability')
|
230 |
+
plt.title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
|
231 |
+
|
232 |
+
# Add labels for each point
|
233 |
+
for i, (kmer, prob, change) in enumerate(steps):
|
234 |
+
# Add k-mer label
|
235 |
+
plt.annotate(kmer,
|
236 |
+
(i, prob),
|
237 |
+
xytext=(0, 10 if i % 2 == 0 else -20), # Alternate up/down
|
238 |
+
textcoords='offset points',
|
239 |
+
ha='center',
|
240 |
+
rotation=45 if len(kmer) > 5 else 0)
|
241 |
+
|
242 |
+
# Add change value
|
243 |
+
if i > 0: # Skip first point (Start)
|
244 |
+
change_text = f'{change:+.3f}'
|
245 |
+
color = 'green' if change > 0 else 'red'
|
246 |
+
plt.annotate(change_text,
|
247 |
+
(i, prob),
|
248 |
+
xytext=(0, -20 if i % 2 == 0 else 10),
|
249 |
+
textcoords='offset points',
|
250 |
+
ha='center',
|
251 |
+
color=color)
|
252 |
|
253 |
+
plt.legend()
|
254 |
+
plt.tight_layout()
|
|
|
255 |
|
256 |
+
# Save plot
|
257 |
+
buf = io.BytesIO()
|
258 |
+
plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
|
259 |
+
buf.seek(0)
|
260 |
+
plot_image = Image.open(buf)
|
261 |
+
plt.close()
|
262 |
|
263 |
# Calculate final probabilities
|
264 |
with torch.no_grad():
|