hiyata commited on
Commit
b0fba50
·
verified ·
1 Parent(s): 8773845

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -20
app.py CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
7
  import shap
8
  import matplotlib.pyplot as plt
9
  import io
 
10
  from PIL import Image
11
 
12
  class VirusClassifier(nn.Module):
@@ -183,6 +184,7 @@ def predict(file_obj):
183
  human_prob = float(probs[0][1])
184
 
185
  # Create SHAP explanation
 
186
  explanation = shap.Explanation(
187
  values=np.array(top_values),
188
  base_values=0.5, # Start from neutral prediction
@@ -196,36 +198,67 @@ def predict(file_obj):
196
  explanation.expected_value = 0.5 # Start from neutral prediction
197
 
198
  # Calculate step-by-step probabilities
199
- step_probs = []
200
  current_prob = 0.5 # Start at neutral
201
- step_probs.append({"step": "Start", "probability": current_prob, "kmer": "Initial", "change": 0})
202
 
203
  # Process each k-mer contribution
204
- for i, kmer in enumerate(important_kmers, 1):
205
  change = kmer['importance']
206
  current_prob += change
207
- step_probs.append({
208
- "step": str(i),
209
- "probability": current_prob,
210
- "kmer": kmer['kmer'],
211
- "change": change
212
- })
213
 
214
  # Add final "Others" contribution
215
  current_prob += others_sum
216
- step_probs.append({
217
- "step": "Others",
218
- "probability": current_prob,
219
- "kmer": "Others",
220
- "change": others_sum
221
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
- # Convert to JSON for React
224
- steps_json = json.dumps(step_probs, indent=2)
225
- print(f"Steps data: {steps_json}") # For debugging
226
 
227
- # Create visualization component
228
- plot_image = None # We'll use the React component instead
 
 
 
 
229
 
230
  # Calculate final probabilities
231
  with torch.no_grad():
 
7
  import shap
8
  import matplotlib.pyplot as plt
9
  import io
10
+ import json
11
  from PIL import Image
12
 
13
  class VirusClassifier(nn.Module):
 
184
  human_prob = float(probs[0][1])
185
 
186
  # Create SHAP explanation
187
+ # We'll use the actual probabilities for alignment
188
  explanation = shap.Explanation(
189
  values=np.array(top_values),
190
  base_values=0.5, # Start from neutral prediction
 
198
  explanation.expected_value = 0.5 # Start from neutral prediction
199
 
200
  # Calculate step-by-step probabilities
 
201
  current_prob = 0.5 # Start at neutral
202
+ steps = [('Start', current_prob, 0)]
203
 
204
  # Process each k-mer contribution
205
+ for kmer in important_kmers:
206
  change = kmer['importance']
207
  current_prob += change
208
+ steps.append((kmer['kmer'], current_prob, change))
 
 
 
 
 
209
 
210
  # Add final "Others" contribution
211
  current_prob += others_sum
212
+ steps.append(('Others', current_prob, others_sum))
213
+
214
+ # Create step plot
215
+ plt.figure(figsize=(12, 6))
216
+ x = range(len(steps))
217
+ y = [step[1] for step in steps]
218
+
219
+ # Plot steps
220
+ plt.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
221
+ plt.plot(x, y, 'b.', markersize=10)
222
+
223
+ # Add reference line
224
+ plt.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
225
+
226
+ # Customize plot
227
+ plt.grid(True, linestyle='--', alpha=0.7)
228
+ plt.ylim(0, 1)
229
+ plt.ylabel('Human Probability')
230
+ plt.title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
231
+
232
+ # Add labels for each point
233
+ for i, (kmer, prob, change) in enumerate(steps):
234
+ # Add k-mer label
235
+ plt.annotate(kmer,
236
+ (i, prob),
237
+ xytext=(0, 10 if i % 2 == 0 else -20), # Alternate up/down
238
+ textcoords='offset points',
239
+ ha='center',
240
+ rotation=45 if len(kmer) > 5 else 0)
241
+
242
+ # Add change value
243
+ if i > 0: # Skip first point (Start)
244
+ change_text = f'{change:+.3f}'
245
+ color = 'green' if change > 0 else 'red'
246
+ plt.annotate(change_text,
247
+ (i, prob),
248
+ xytext=(0, -20 if i % 2 == 0 else 10),
249
+ textcoords='offset points',
250
+ ha='center',
251
+ color=color)
252
 
253
+ plt.legend()
254
+ plt.tight_layout()
 
255
 
256
+ # Save plot
257
+ buf = io.BytesIO()
258
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
259
+ buf.seek(0)
260
+ plot_image = Image.open(buf)
261
+ plt.close()
262
 
263
  # Calculate final probabilities
264
  with torch.no_grad():