HaileyStorm commited on
Commit
0c33a38
·
verified ·
1 Parent(s): b2567ad

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -223,9 +223,9 @@ class MambaPlayer:
223
  def update_activations(self, result):
224
  for layer_idx in self.activations_sum:
225
  if result == "reset":
226
- self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
227
- "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
228
- "current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
229
  for bucket in self.move_buckets}
230
  self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
231
  for bucket in self.move_buckets}
@@ -254,7 +254,7 @@ class MambaPlayer:
254
  activations_count[layer_idx][bucket] = {}
255
  for category in ["won", "lost"]:
256
  if category not in activations_sum[layer_idx][bucket]:
257
- activations_sum[layer_idx][bucket][category] = np.zeros((1, self.max_seq_len, self.model.config.d_model))
258
  activations_count[layer_idx][bucket][category] = 0
259
 
260
  activations_sum[layer_idx][bucket][category] += self.activations_sum[layer_idx][bucket][category]
@@ -264,9 +264,9 @@ class MambaPlayer:
264
  pickle.dump((activations_sum, activations_count), f)
265
 
266
  for layer_idx in self.activations_sum:
267
- self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
268
- "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
269
- "current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
270
  for bucket in self.move_buckets}
271
  self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
272
  for bucket in self.move_buckets}
 
223
  def update_activations(self, result):
224
  for layer_idx in self.activations_sum:
225
  if result == "reset":
226
+ self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, 8, self.model.config.d_model)),
227
+ "lost": np.zeros((1, 8, self.model.config.d_model)),
228
+ "current": np.zeros((1, 8, self.model.config.d_model))}
229
  for bucket in self.move_buckets}
230
  self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
231
  for bucket in self.move_buckets}
 
254
  activations_count[layer_idx][bucket] = {}
255
  for category in ["won", "lost"]:
256
  if category not in activations_sum[layer_idx][bucket]:
257
+ activations_sum[layer_idx][bucket][category] = np.zeros((1, 8, self.model.config.d_model))
258
  activations_count[layer_idx][bucket][category] = 0
259
 
260
  activations_sum[layer_idx][bucket][category] += self.activations_sum[layer_idx][bucket][category]
 
264
  pickle.dump((activations_sum, activations_count), f)
265
 
266
  for layer_idx in self.activations_sum:
267
+ self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, 8, self.model.config.d_model)),
268
+ "lost": np.zeros((1, 8, self.model.config.d_model)),
269
+ "current": np.zeros((1, 8, self.model.config.d_model))}
270
  for bucket in self.move_buckets}
271
  self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
272
  for bucket in self.move_buckets}