HaileyStorm
commited on
Commit
•
454586f
1
Parent(s):
9c9bdee
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -203,9 +203,17 @@ class MambaPlayer:
|
|
203 |
|
204 |
def update_activations(self, result):
|
205 |
for layer_idx in self.activations_sum:
|
206 |
-
|
207 |
-
self.activations_sum[layer_idx]
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
def save_activations(self, path):
|
211 |
if os.path.exists(path):
|
|
|
203 |
|
204 |
def update_activations(self, result):
|
205 |
for layer_idx in self.activations_sum:
|
206 |
+
if "result" == "reset":
|
207 |
+
self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
208 |
+
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
209 |
+
"current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
|
210 |
+
for bucket in self.move_buckets}
|
211 |
+
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
|
212 |
+
for bucket in self.move_buckets}
|
213 |
+
else:
|
214 |
+
for bucket in self.move_buckets:
|
215 |
+
self.activations_sum[layer_idx][bucket][result] += self.activations_sum[layer_idx][bucket]["current"]
|
216 |
+
self.activations_count[layer_idx][bucket][result] += 1
|
217 |
|
218 |
def save_activations(self, path):
|
219 |
if os.path.exists(path):
|