HaileyStorm commited on
Commit
4b17b1c
1 Parent(s): be25621

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -96,7 +96,7 @@ class MambaPlayer:
96
  self.activations[i] = {"won": [], "lost": []}
97
 
98
  def hook(module, input, output, layer_idx=i):
99
- self.activations[layer_idx]["current"] = output[0]#.detach().cpu().numpy()
100
 
101
  hook_function = partial(hook, layer_idx=i)
102
  self.hooks.append(layer.register_forward_hook(hook_function))
@@ -170,30 +170,20 @@ class MambaPlayer:
170
  def save_activations(self, path):
171
  activations_sum = {}
172
  activations_count = {}
173
-
174
  for layer_idx, layer_activations in self.activations.items():
175
- if layer_activations["won"]: # Check if not empty to avoid unnecessary operations
176
- won_sum = torch.stack(layer_activations["won"]).sum(dim=0).cpu().numpy()
177
- else:
178
- return
179
-
180
- if layer_activations["lost"]:
181
- lost_sum = torch.stack(layer_activations["lost"]).sum(dim=0).cpu().numpy()
182
- else:
183
- return
184
-
185
  activations_sum[layer_idx] = {
186
- "won": won_sum,
187
- "lost": lost_sum
188
  }
189
  activations_count[layer_idx] = {
190
  "won": len(layer_activations["won"]),
191
  "lost": len(layer_activations["lost"])
192
  }
193
-
194
  with open(path, "wb") as f:
195
  pickle.dump((activations_sum, activations_count), f)
196
-
197
  self.activations = {}
198
 
199
  def apply_contrastive_activations(self, path):
 
96
  self.activations[i] = {"won": [], "lost": []}
97
 
98
  def hook(module, input, output, layer_idx=i):
99
+ self.activations[layer_idx]["current"] = output[0].detach().cpu().numpy()
100
 
101
  hook_function = partial(hook, layer_idx=i)
102
  self.hooks.append(layer.register_forward_hook(hook_function))
 
170
  def save_activations(self, path):
171
  activations_sum = {}
172
  activations_count = {}
173
+
174
  for layer_idx, layer_activations in self.activations.items():
 
 
 
 
 
 
 
 
 
 
175
  activations_sum[layer_idx] = {
176
+ "won": np.sum(layer_activations["won"], axis=0),
177
+ "lost": np.sum(layer_activations["lost"], axis=0)
178
  }
179
  activations_count[layer_idx] = {
180
  "won": len(layer_activations["won"]),
181
  "lost": len(layer_activations["lost"])
182
  }
183
+
184
  with open(path, "wb") as f:
185
  pickle.dump((activations_sum, activations_count), f)
186
+
187
  self.activations = {}
188
 
189
  def apply_contrastive_activations(self, path):