hiyata commited on
Commit
f6763a9
·
verified ·
1 Parent(s): 26706f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -15
app.py CHANGED
@@ -80,26 +80,56 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
80
  return vec
81
 
82
  ###############################################################################
83
- # 3. SHAP-VALUE (ABLATION) CALCULATION
84
  ###############################################################################
85
 
 
 
86
  def calculate_shap_values(model, x_tensor):
87
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  with torch.no_grad():
89
- baseline_output = model(x_tensor)
90
- baseline_probs = torch.softmax(baseline_output, dim=1)
91
- baseline_prob = baseline_probs[0, 1].item() # Prob of 'human'
92
- shap_values = []
93
- x_zeroed = x_tensor.clone()
94
- for i in range(x_tensor.shape[1]):
95
- original_val = x_zeroed[0, i].item()
96
- x_zeroed[0, i] = 0.0
97
- output = model(x_zeroed)
98
- probs = torch.softmax(output, dim=1)
99
- prob = probs[0, 1].item()
100
- shap_values.append(baseline_prob - prob)
101
- x_zeroed[0, i] = original_val
102
- return np.array(shap_values), baseline_prob
103
 
104
  ###############################################################################
105
  # 4. PER-BASE SHAP AGGREGATION
 
80
  return vec
81
 
82
  ###############################################################################
83
+ # 3. SHAP-VALUE CALCULATION
84
  ###############################################################################
85
 
86
+ import shap
87
+
88
  def calculate_shap_values(model, x_tensor):
89
  model.eval()
90
+ device = next(model.parameters()).device
91
+
92
+ # Create background dataset (baseline)
93
+ background = torch.zeros((10, x_tensor.shape[1]), device=device)
94
+
95
+ try:
96
+ # Try using DeepExplainer (efficient for neural networks)
97
+ explainer = shap.DeepExplainer(model, background)
98
+
99
+ # Calculate SHAP values
100
+ shap_values_all = explainer.shap_values(x_tensor)
101
+
102
+ # Get SHAP values for human class (index 1)
103
+ shap_values = shap_values_all[1][0]
104
+
105
+ except Exception as e:
106
+ print(f"DeepExplainer failed, falling back to KernelExplainer: {str(e)}")
107
+
108
+ # Create model wrapper function
109
+ def model_predict(x):
110
+ with torch.no_grad():
111
+ tensor_x = torch.FloatTensor(x).to(device)
112
+ output = model(tensor_x)
113
+ probs = torch.softmax(output, dim=1)[:, 1] # Human probability
114
+ return probs.cpu().numpy()
115
+
116
+ # Create baseline distribution
117
+ background = np.zeros((1, x_tensor.shape[1]))
118
+
119
+ # Use KernelExplainer as fallback
120
+ explainer = shap.KernelExplainer(model_predict, background)
121
+
122
+ # Calculate SHAP values
123
+ x_numpy = x_tensor.cpu().numpy()
124
+ shap_values = explainer.shap_values(x_numpy, nsamples=100)
125
+
126
+ # Get human probability
127
  with torch.no_grad():
128
+ output = model(x_tensor)
129
+ probs = torch.softmax(output, dim=1)
130
+ prob_human = probs[0, 1].item()
131
+
132
+ return np.array(shap_values), prob_human
 
 
 
 
 
 
 
 
 
133
 
134
  ###############################################################################
135
  # 4. PER-BASE SHAP AGGREGATION