Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -80,26 +80,56 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
80 |
return vec
|
81 |
|
82 |
###############################################################################
|
83 |
-
# 3. SHAP-VALUE
|
84 |
###############################################################################
|
85 |
|
|
|
|
|
86 |
def calculate_shap_values(model, x_tensor):
|
87 |
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
with torch.no_grad():
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
for i in range(x_tensor.shape[1]):
|
95 |
-
original_val = x_zeroed[0, i].item()
|
96 |
-
x_zeroed[0, i] = 0.0
|
97 |
-
output = model(x_zeroed)
|
98 |
-
probs = torch.softmax(output, dim=1)
|
99 |
-
prob = probs[0, 1].item()
|
100 |
-
shap_values.append(baseline_prob - prob)
|
101 |
-
x_zeroed[0, i] = original_val
|
102 |
-
return np.array(shap_values), baseline_prob
|
103 |
|
104 |
###############################################################################
|
105 |
# 4. PER-BASE SHAP AGGREGATION
|
|
|
80 |
return vec
|
81 |
|
82 |
###############################################################################
|
83 |
+
# 3. SHAP-VALUE CALCULATION
|
84 |
###############################################################################
|
85 |
|
86 |
+
import shap
|
87 |
+
|
88 |
def calculate_shap_values(model, x_tensor):
|
89 |
model.eval()
|
90 |
+
device = next(model.parameters()).device
|
91 |
+
|
92 |
+
# Create background dataset (baseline)
|
93 |
+
background = torch.zeros((10, x_tensor.shape[1]), device=device)
|
94 |
+
|
95 |
+
try:
|
96 |
+
# Try using DeepExplainer (efficient for neural networks)
|
97 |
+
explainer = shap.DeepExplainer(model, background)
|
98 |
+
|
99 |
+
# Calculate SHAP values
|
100 |
+
shap_values_all = explainer.shap_values(x_tensor)
|
101 |
+
|
102 |
+
# Get SHAP values for human class (index 1)
|
103 |
+
shap_values = shap_values_all[1][0]
|
104 |
+
|
105 |
+
except Exception as e:
|
106 |
+
print(f"DeepExplainer failed, falling back to KernelExplainer: {str(e)}")
|
107 |
+
|
108 |
+
# Create model wrapper function
|
109 |
+
def model_predict(x):
|
110 |
+
with torch.no_grad():
|
111 |
+
tensor_x = torch.FloatTensor(x).to(device)
|
112 |
+
output = model(tensor_x)
|
113 |
+
probs = torch.softmax(output, dim=1)[:, 1] # Human probability
|
114 |
+
return probs.cpu().numpy()
|
115 |
+
|
116 |
+
# Create baseline distribution
|
117 |
+
background = np.zeros((1, x_tensor.shape[1]))
|
118 |
+
|
119 |
+
# Use KernelExplainer as fallback
|
120 |
+
explainer = shap.KernelExplainer(model_predict, background)
|
121 |
+
|
122 |
+
# Calculate SHAP values
|
123 |
+
x_numpy = x_tensor.cpu().numpy()
|
124 |
+
shap_values = explainer.shap_values(x_numpy, nsamples=100)
|
125 |
+
|
126 |
+
# Get human probability
|
127 |
with torch.no_grad():
|
128 |
+
output = model(x_tensor)
|
129 |
+
probs = torch.softmax(output, dim=1)
|
130 |
+
prob_human = probs[0, 1].item()
|
131 |
+
|
132 |
+
return np.array(shap_values), prob_human
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
###############################################################################
|
135 |
# 4. PER-BASE SHAP AGGREGATION
|