HaileyStorm commited on
Commit
d0c6814
1 Parent(s): e993689

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -215,9 +215,12 @@ class MambaPlayer:
215
  pickle.dump((activations_sum, activations_count), f)
216
 
217
  for layer_idx in self.activations_sum:
218
- for bucket in self.move_buckets:
219
- self.activations_sum[layer_idx][bucket]["current"].fill(0)
220
- self.activations_count[layer_idx][bucket]["current"] = 0
 
 
 
221
 
222
  def apply_contrastive_activations(self, path):
223
  if os.path.exists(path):
 
215
  pickle.dump((activations_sum, activations_count), f)
216
 
217
  for layer_idx in self.activations_sum:
218
+ self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
219
+ "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
220
+ "current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
221
+ for bucket in self.move_buckets}
222
+ self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
223
+ for bucket in self.move_buckets}
224
 
225
  def apply_contrastive_activations(self, path):
226
  if os.path.exists(path):