Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -118,12 +118,38 @@ def predict(file_obj):
|
|
118 |
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
119 |
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
120 |
|
121 |
-
#
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
-
#
|
126 |
-
kmer_importance =
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
# Get top k-mers by absolute importance
|
129 |
top_k = 10
|
@@ -157,10 +183,9 @@ def predict(file_obj):
|
|
157 |
human_prob = float(probs[0][1])
|
158 |
|
159 |
# Create SHAP explanation
|
160 |
-
# We'll use the actual probabilities for alignment
|
161 |
explanation = shap.Explanation(
|
162 |
values=np.array(top_values),
|
163 |
-
base_values=
|
164 |
data=np.array([
|
165 |
raw_freq_vector[kmer_dict[feat]] if feat != "Others"
|
166 |
else np.sum(raw_freq_vector[others_mask])
|
@@ -168,7 +193,7 @@ def predict(file_obj):
|
|
168 |
]),
|
169 |
feature_names=top_features
|
170 |
)
|
171 |
-
explanation.expected_value =
|
172 |
|
173 |
# Create waterfall plot
|
174 |
plt.figure(figsize=(10, 6))
|
|
|
118 |
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
119 |
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
120 |
|
121 |
+
# Calculate final probabilities first
|
122 |
+
with torch.no_grad():
|
123 |
+
output = model(X_tensor)
|
124 |
+
probs = torch.softmax(output, dim=1)
|
125 |
+
human_prob = float(probs[0][1])
|
126 |
+
|
127 |
+
# Get feature importance using integrated gradients
|
128 |
+
baseline = torch.zeros_like(X_tensor) # baseline of zeros
|
129 |
+
steps = 50
|
130 |
+
|
131 |
+
all_importance = []
|
132 |
+
for i in range(steps + 1):
|
133 |
+
alpha = i / steps
|
134 |
+
interpolated = baseline + alpha * (X_tensor - baseline)
|
135 |
+
interpolated.requires_grad_(True)
|
136 |
+
|
137 |
+
output = model(interpolated)
|
138 |
+
probs = torch.softmax(output, dim=1)
|
139 |
+
human_class = probs[..., 1]
|
140 |
+
|
141 |
+
if interpolated.grad is not None:
|
142 |
+
interpolated.grad.zero_()
|
143 |
+
human_class.backward()
|
144 |
+
all_importance.append(interpolated.grad.cpu().numpy())
|
145 |
|
146 |
+
# Average the gradients
|
147 |
+
kmer_importance = np.mean(all_importance, axis=0)[0]
|
148 |
+
# Scale to match probability difference
|
149 |
+
target_diff = human_prob - 0.5 # difference from neutral prediction
|
150 |
+
current_sum = np.sum(kmer_importance)
|
151 |
+
if current_sum != 0: # avoid division by zero
|
152 |
+
kmer_importance = kmer_importance * (target_diff / current_sum)
|
153 |
|
154 |
# Get top k-mers by absolute importance
|
155 |
top_k = 10
|
|
|
183 |
human_prob = float(probs[0][1])
|
184 |
|
185 |
# Create SHAP explanation
|
|
|
186 |
explanation = shap.Explanation(
|
187 |
values=np.array(top_values),
|
188 |
+
base_values=0.5, # Start from neutral prediction
|
189 |
data=np.array([
|
190 |
raw_freq_vector[kmer_dict[feat]] if feat != "Others"
|
191 |
else np.sum(raw_freq_vector[others_mask])
|
|
|
193 |
]),
|
194 |
feature_names=top_features
|
195 |
)
|
196 |
+
explanation.expected_value = 0.5 # Start from neutral prediction
|
197 |
|
198 |
# Create waterfall plot
|
199 |
plt.figure(figsize=(10, 6))
|