hiyata commited on
Commit
17c9ecb
·
verified ·
1 Parent(s): 5bf9386

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -8
app.py CHANGED
@@ -118,12 +118,38 @@ def predict(file_obj):
118
  kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
119
  X_tensor = torch.FloatTensor(kmer_vector).to(device)
120
 
121
- # Get feature importance and human probability
122
- importance, human_prob = model.get_feature_importance(X_tensor)
123
- kmer_importance = importance[0].cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # Scale importance values relative to the prediction
126
- kmer_importance = kmer_importance * human_prob
 
 
 
 
 
127
 
128
  # Get top k-mers by absolute importance
129
  top_k = 10
@@ -157,10 +183,9 @@ def predict(file_obj):
157
  human_prob = float(probs[0][1])
158
 
159
  # Create SHAP explanation
160
- # We'll use the actual probabilities for alignment
161
  explanation = shap.Explanation(
162
  values=np.array(top_values),
163
- base_values=human_prob, # Use actual prediction as base
164
  data=np.array([
165
  raw_freq_vector[kmer_dict[feat]] if feat != "Others"
166
  else np.sum(raw_freq_vector[others_mask])
@@ -168,7 +193,7 @@ def predict(file_obj):
168
  ]),
169
  feature_names=top_features
170
  )
171
- explanation.expected_value = human_prob # Match the actual prediction
172
 
173
  # Create waterfall plot
174
  plt.figure(figsize=(10, 6))
 
118
  kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
119
  X_tensor = torch.FloatTensor(kmer_vector).to(device)
120
 
121
+ # Calculate final probabilities first
122
+ with torch.no_grad():
123
+ output = model(X_tensor)
124
+ probs = torch.softmax(output, dim=1)
125
+ human_prob = float(probs[0][1])
126
+
127
+ # Get feature importance using integrated gradients
128
+ baseline = torch.zeros_like(X_tensor) # baseline of zeros
129
+ steps = 50
130
+
131
+ all_importance = []
132
+ for i in range(steps + 1):
133
+ alpha = i / steps
134
+ interpolated = baseline + alpha * (X_tensor - baseline)
135
+ interpolated.requires_grad_(True)
136
+
137
+ output = model(interpolated)
138
+ probs = torch.softmax(output, dim=1)
139
+ human_class = probs[..., 1]
140
+
141
+ if interpolated.grad is not None:
142
+ interpolated.grad.zero_()
143
+ human_class.backward()
144
+ all_importance.append(interpolated.grad.cpu().numpy())
145
 
146
+ # Average the gradients
147
+ kmer_importance = np.mean(all_importance, axis=0)[0]
148
+ # Scale to match probability difference
149
+ target_diff = human_prob - 0.5 # difference from neutral prediction
150
+ current_sum = np.sum(kmer_importance)
151
+ if current_sum != 0: # avoid division by zero
152
+ kmer_importance = kmer_importance * (target_diff / current_sum)
153
 
154
  # Get top k-mers by absolute importance
155
  top_k = 10
 
183
  human_prob = float(probs[0][1])
184
 
185
  # Create SHAP explanation
 
186
  explanation = shap.Explanation(
187
  values=np.array(top_values),
188
+ base_values=0.5, # Start from neutral prediction
189
  data=np.array([
190
  raw_freq_vector[kmer_dict[feat]] if feat != "Others"
191
  else np.sum(raw_freq_vector[others_mask])
 
193
  ]),
194
  feature_names=top_features
195
  )
196
+ explanation.expected_value = 0.5 # Start from neutral prediction
197
 
198
  # Create waterfall plot
199
  plt.figure(figsize=(10, 6))