hiyata commited on
Commit
5bf9386
·
verified ·
1 Parent(s): 40fe6da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -150,12 +150,17 @@ def predict(file_obj):
150
  top_features.append("Others")
151
  top_values.append(others_sum)
152
 
 
 
 
 
 
 
153
  # Create SHAP explanation
154
- # Set base_value to 0.5 (neutral prediction)
155
- # Values represent the push towards human (>0.5) or non-human (<0.5)
156
  explanation = shap.Explanation(
157
  values=np.array(top_values),
158
- base_values=0.5, # Start from neutral prediction
159
  data=np.array([
160
  raw_freq_vector[kmer_dict[feat]] if feat != "Others"
161
  else np.sum(raw_freq_vector[others_mask])
@@ -163,7 +168,7 @@ def predict(file_obj):
163
  ]),
164
  feature_names=top_features
165
  )
166
- explanation.expected_value = 0.5
167
 
168
  # Create waterfall plot
169
  plt.figure(figsize=(10, 6))
@@ -172,7 +177,7 @@ def predict(file_obj):
172
  show=False,
173
  max_display=11 # Show all features including "Others"
174
  )
175
- plt.title(f"Impact on prediction (>0.5 pushes toward human, <0.5 toward non-human)")
176
 
177
  # Save plot
178
  buf = io.BytesIO()
@@ -220,4 +225,4 @@ iface = gr.Interface(
220
  )
221
 
222
  if __name__ == "__main__":
223
- iface.launch(share=True)
 
150
  top_features.append("Others")
151
  top_values.append(others_sum)
152
 
153
+ # Calculate final probabilities first
154
+ with torch.no_grad():
155
+ output = model(X_tensor)
156
+ probs = torch.softmax(output, dim=1)
157
+ human_prob = float(probs[0][1])
158
+
159
  # Create SHAP explanation
160
+ # We'll use the actual probabilities for alignment
 
161
  explanation = shap.Explanation(
162
  values=np.array(top_values),
163
+ base_values=human_prob, # Use actual prediction as base
164
  data=np.array([
165
  raw_freq_vector[kmer_dict[feat]] if feat != "Others"
166
  else np.sum(raw_freq_vector[others_mask])
 
168
  ]),
169
  feature_names=top_features
170
  )
171
+ explanation.expected_value = human_prob # Match the actual prediction
172
 
173
  # Create waterfall plot
174
  plt.figure(figsize=(10, 6))
 
177
  show=False,
178
  max_display=11 # Show all features including "Others"
179
  )
180
+ plt.title(f"Feature contributions to human probability (final prob: {human_prob:.3f})")
181
 
182
  # Save plot
183
  buf = io.BytesIO()
 
225
  )
226
 
227
  if __name__ == "__main__":
228
+ iface.launch(share=True)