hiyata commited on
Commit
cbacd3e
·
verified ·
1 Parent(s): 2ed8007

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -87,17 +87,18 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
87
  def calculate_shap_values(model, x_tensor):
88
  model.eval()
89
  device = next(model.parameters()).device
90
-
91
  try:
92
- # Create background as a torch tensor
93
  background = torch.zeros((300, x_tensor.shape[1]), device=device)
94
  explainer = shap.DeepExplainer(model, background)
95
  shap_values_all = explainer.shap_values(x_tensor)
96
- # For binary classification, use the second output and then the first sample
97
  shap_values = shap_values_all[1][0]
98
  except Exception as e:
99
  print(f"DeepExplainer failed, falling back to KernelExplainer: {str(e)}")
100
-
 
101
  def model_predict(x):
102
  if not isinstance(x, np.ndarray):
103
  x = np.array(x)
@@ -106,24 +107,26 @@ def calculate_shap_values(model, x_tensor):
106
  with torch.no_grad():
107
  tensor_x = torch.tensor(x, dtype=torch.float, device=device)
108
  output = model(tensor_x)
109
- probs = torch.softmax(output, dim=1)[:, 1]
110
  return probs.cpu().numpy()
111
-
112
- # Use a numpy background for KernelExplainer
113
- background = np.zeros((300, x_tensor.shape[1]))
114
- explainer = shap.KernelExplainer(model_predict, background)
115
  x_numpy = x_tensor.cpu().numpy()
 
 
 
 
116
  shap_values = explainer.shap_values(x_numpy, nsamples=1000)
117
- # If KernelExplainer returns a list, take its first element.
118
  if isinstance(shap_values, list):
119
  shap_values = shap_values[0]
120
-
121
- # Get human probability from model prediction
122
  with torch.no_grad():
123
  output = model(x_tensor)
124
  probs = torch.softmax(output, dim=1)
125
  prob_human = probs[0, 1].item()
126
-
127
  return np.array(shap_values), prob_human
128
 
129
 
 
87
  def calculate_shap_values(model, x_tensor):
88
  model.eval()
89
  device = next(model.parameters()).device
90
+
91
  try:
92
+ # Create background as a torch tensor (using zeros may be acceptable for DeepExplainer)
93
  background = torch.zeros((300, x_tensor.shape[1]), device=device)
94
  explainer = shap.DeepExplainer(model, background)
95
  shap_values_all = explainer.shap_values(x_tensor)
96
+ # For binary classification, get SHAP for class 1 and first sample
97
  shap_values = shap_values_all[1][0]
98
  except Exception as e:
99
  print(f"DeepExplainer failed, falling back to KernelExplainer: {str(e)}")
100
+
101
+ # Define a wrapper that ensures proper input shape and conversion to tensor
102
  def model_predict(x):
103
  if not isinstance(x, np.ndarray):
104
  x = np.array(x)
 
107
  with torch.no_grad():
108
  tensor_x = torch.tensor(x, dtype=torch.float, device=device)
109
  output = model(tensor_x)
110
+ probs = torch.softmax(output, dim=1)[:, 1] # human probability
111
  return probs.cpu().numpy()
112
+
113
+ # Instead of using zeros as background, use the input sample repeated 300 times.
 
 
114
  x_numpy = x_tensor.cpu().numpy()
115
+ background = np.repeat(x_numpy, 300, axis=0)
116
+
117
+ explainer = shap.KernelExplainer(model_predict, background)
118
+ # Increase nsamples for a more robust estimate.
119
  shap_values = explainer.shap_values(x_numpy, nsamples=1000)
120
+ # If a list is returned, select the first element.
121
  if isinstance(shap_values, list):
122
  shap_values = shap_values[0]
123
+
124
+ # Get the human probability from the model output.
125
  with torch.no_grad():
126
  output = model(x_tensor)
127
  probs = torch.softmax(output, dim=1)
128
  prob_human = probs[0, 1].item()
129
+
130
  return np.array(shap_values), prob_human
131
 
132