Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -90,11 +90,14 @@ class MambaPlayer:
|
|
90 |
self.device = device
|
91 |
|
92 |
self.hooks = []
|
|
|
|
|
93 |
self.activations_sum = {}
|
94 |
self.activations_count = {}
|
95 |
-
|
96 |
for i, layer in enumerate(self.model.backbone.layers):
|
97 |
-
self.activations_sum[i] = {"won":
|
|
|
|
|
98 |
self.activations_count[i] = {"won": 0, "lost": 0, "current": 0}
|
99 |
|
100 |
def hook(module, input, output, layer_idx=i):
|
@@ -102,7 +105,8 @@ class MambaPlayer:
|
|
102 |
tensor_output = output[0]
|
103 |
else:
|
104 |
tensor_output = output
|
105 |
-
|
|
|
106 |
self.activations_count[layer_idx]["current"] += 1
|
107 |
|
108 |
self.hooks.append(layer.register_forward_hook(hook))
|
@@ -188,7 +192,7 @@ class MambaPlayer:
|
|
188 |
activations_sum[layer_idx] = {}
|
189 |
activations_count[layer_idx] = {}
|
190 |
if category not in activations_sum[layer_idx]:
|
191 |
-
activations_sum[layer_idx][category] =
|
192 |
activations_count[layer_idx][category] = 0
|
193 |
|
194 |
activations_sum[layer_idx][category] += self.activations_sum[layer_idx][category]
|
@@ -196,12 +200,10 @@ class MambaPlayer:
|
|
196 |
|
197 |
with open(path, "wb") as f:
|
198 |
pickle.dump((activations_sum, activations_count), f)
|
199 |
-
|
200 |
-
self.activations_sum
|
201 |
-
|
202 |
-
|
203 |
-
self.activations_sum[i] = {"won": 0, "lost": 0, "current": 0}
|
204 |
-
self.activations_count[i] = {"won": 0, "lost": 0, "current": 0}
|
205 |
|
206 |
def apply_contrastive_activations(self, path):
|
207 |
if os.path.exists(path):
|
@@ -214,7 +216,8 @@ class MambaPlayer:
|
|
214 |
contrastive_activations = won_activations - lost_activations
|
215 |
|
216 |
def hook(module, input, output):
|
217 |
-
|
|
|
218 |
|
219 |
self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
|
220 |
|
|
|
90 |
self.device = device
|
91 |
|
92 |
self.hooks = []
|
93 |
+
self.max_seq_len = 1536
|
94 |
+
|
95 |
self.activations_sum = {}
|
96 |
self.activations_count = {}
|
|
|
97 |
for i, layer in enumerate(self.model.backbone.layers):
|
98 |
+
self.activations_sum[i] = {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
99 |
+
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
100 |
+
"current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
|
101 |
self.activations_count[i] = {"won": 0, "lost": 0, "current": 0}
|
102 |
|
103 |
def hook(module, input, output, layer_idx=i):
|
|
|
105 |
tensor_output = output[0]
|
106 |
else:
|
107 |
tensor_output = output
|
108 |
+
seq_len = tensor_output.shape[1]
|
109 |
+
self.activations_sum[layer_idx]["current"][:, :seq_len, :] += tensor_output.detach().cpu().numpy()
|
110 |
self.activations_count[layer_idx]["current"] += 1
|
111 |
|
112 |
self.hooks.append(layer.register_forward_hook(hook))
|
|
|
192 |
activations_sum[layer_idx] = {}
|
193 |
activations_count[layer_idx] = {}
|
194 |
if category not in activations_sum[layer_idx]:
|
195 |
+
activations_sum[layer_idx][category] = np.zeros((1, self.max_seq_len, self.model.config.d_model))
|
196 |
activations_count[layer_idx][category] = 0
|
197 |
|
198 |
activations_sum[layer_idx][category] += self.activations_sum[layer_idx][category]
|
|
|
200 |
|
201 |
with open(path, "wb") as f:
|
202 |
pickle.dump((activations_sum, activations_count), f)
|
203 |
+
|
204 |
+
for layer_idx in self.activations_sum:
|
205 |
+
self.activations_sum[layer_idx]["current"].fill(0)
|
206 |
+
self.activations_count[layer_idx]["current"] = 0
|
|
|
|
|
207 |
|
208 |
def apply_contrastive_activations(self, path):
|
209 |
if os.path.exists(path):
|
|
|
216 |
contrastive_activations = won_activations - lost_activations
|
217 |
|
218 |
def hook(module, input, output):
|
219 |
+
seq_len = output.shape[1]
|
220 |
+
return output + torch.from_numpy(contrastive_activations[:, :seq_len, :]).to(output.device)
|
221 |
|
222 |
self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
|
223 |
|