HaileyStorm
commited on
Commit
•
4b17b1c
1
Parent(s):
be25621
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -96,7 +96,7 @@ class MambaPlayer:
|
|
96 |
self.activations[i] = {"won": [], "lost": []}
|
97 |
|
98 |
def hook(module, input, output, layer_idx=i):
|
99 |
-
self.activations[layer_idx]["current"] = output[0]
|
100 |
|
101 |
hook_function = partial(hook, layer_idx=i)
|
102 |
self.hooks.append(layer.register_forward_hook(hook_function))
|
@@ -170,30 +170,20 @@ class MambaPlayer:
|
|
170 |
def save_activations(self, path):
|
171 |
activations_sum = {}
|
172 |
activations_count = {}
|
173 |
-
|
174 |
for layer_idx, layer_activations in self.activations.items():
|
175 |
-
if layer_activations["won"]: # Check if not empty to avoid unnecessary operations
|
176 |
-
won_sum = torch.stack(layer_activations["won"]).sum(dim=0).cpu().numpy()
|
177 |
-
else:
|
178 |
-
return
|
179 |
-
|
180 |
-
if layer_activations["lost"]:
|
181 |
-
lost_sum = torch.stack(layer_activations["lost"]).sum(dim=0).cpu().numpy()
|
182 |
-
else:
|
183 |
-
return
|
184 |
-
|
185 |
activations_sum[layer_idx] = {
|
186 |
-
"won":
|
187 |
-
"lost":
|
188 |
}
|
189 |
activations_count[layer_idx] = {
|
190 |
"won": len(layer_activations["won"]),
|
191 |
"lost": len(layer_activations["lost"])
|
192 |
}
|
193 |
-
|
194 |
with open(path, "wb") as f:
|
195 |
pickle.dump((activations_sum, activations_count), f)
|
196 |
-
|
197 |
self.activations = {}
|
198 |
|
199 |
def apply_contrastive_activations(self, path):
|
|
|
96 |
self.activations[i] = {"won": [], "lost": []}
|
97 |
|
98 |
def hook(module, input, output, layer_idx=i):
|
99 |
+
self.activations[layer_idx]["current"] = output[0].detach().cpu().numpy()
|
100 |
|
101 |
hook_function = partial(hook, layer_idx=i)
|
102 |
self.hooks.append(layer.register_forward_hook(hook_function))
|
|
|
170 |
def save_activations(self, path):
|
171 |
activations_sum = {}
|
172 |
activations_count = {}
|
173 |
+
|
174 |
for layer_idx, layer_activations in self.activations.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
activations_sum[layer_idx] = {
|
176 |
+
"won": np.sum(layer_activations["won"], axis=0),
|
177 |
+
"lost": np.sum(layer_activations["lost"], axis=0)
|
178 |
}
|
179 |
activations_count[layer_idx] = {
|
180 |
"won": len(layer_activations["won"]),
|
181 |
"lost": len(layer_activations["lost"])
|
182 |
}
|
183 |
+
|
184 |
with open(path, "wb") as f:
|
185 |
pickle.dump((activations_sum, activations_count), f)
|
186 |
+
|
187 |
self.activations = {}
|
188 |
|
189 |
def apply_contrastive_activations(self, path):
|