Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -223,9 +223,9 @@ class MambaPlayer:
|
|
223 |
def update_activations(self, result):
|
224 |
for layer_idx in self.activations_sum:
|
225 |
if result == "reset":
|
226 |
-
self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1,
|
227 |
-
"lost": np.zeros((1,
|
228 |
-
"current": np.zeros((1,
|
229 |
for bucket in self.move_buckets}
|
230 |
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
|
231 |
for bucket in self.move_buckets}
|
@@ -254,7 +254,7 @@ class MambaPlayer:
|
|
254 |
activations_count[layer_idx][bucket] = {}
|
255 |
for category in ["won", "lost"]:
|
256 |
if category not in activations_sum[layer_idx][bucket]:
|
257 |
-
activations_sum[layer_idx][bucket][category] = np.zeros((1,
|
258 |
activations_count[layer_idx][bucket][category] = 0
|
259 |
|
260 |
activations_sum[layer_idx][bucket][category] += self.activations_sum[layer_idx][bucket][category]
|
@@ -264,9 +264,9 @@ class MambaPlayer:
|
|
264 |
pickle.dump((activations_sum, activations_count), f)
|
265 |
|
266 |
for layer_idx in self.activations_sum:
|
267 |
-
self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1,
|
268 |
-
"lost": np.zeros((1,
|
269 |
-
"current": np.zeros((1,
|
270 |
for bucket in self.move_buckets}
|
271 |
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
|
272 |
for bucket in self.move_buckets}
|
|
|
223 |
def update_activations(self, result):
|
224 |
for layer_idx in self.activations_sum:
|
225 |
if result == "reset":
|
226 |
+
self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, 8, self.model.config.d_model)),
|
227 |
+
"lost": np.zeros((1, 8, self.model.config.d_model)),
|
228 |
+
"current": np.zeros((1, 8, self.model.config.d_model))}
|
229 |
for bucket in self.move_buckets}
|
230 |
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
|
231 |
for bucket in self.move_buckets}
|
|
|
254 |
activations_count[layer_idx][bucket] = {}
|
255 |
for category in ["won", "lost"]:
|
256 |
if category not in activations_sum[layer_idx][bucket]:
|
257 |
+
activations_sum[layer_idx][bucket][category] = np.zeros((1, 8, self.model.config.d_model))
|
258 |
activations_count[layer_idx][bucket][category] = 0
|
259 |
|
260 |
activations_sum[layer_idx][bucket][category] += self.activations_sum[layer_idx][bucket][category]
|
|
|
264 |
pickle.dump((activations_sum, activations_count), f)
|
265 |
|
266 |
for layer_idx in self.activations_sum:
|
267 |
+
self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, 8, self.model.config.d_model)),
|
268 |
+
"lost": np.zeros((1, 8, self.model.config.d_model)),
|
269 |
+
"current": np.zeros((1, 8, self.model.config.d_model))}
|
270 |
for bucket in self.move_buckets}
|
271 |
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
|
272 |
for bucket in self.move_buckets}
|