HaileyStorm commited on
Commit
454586f
1 Parent(s): 9c9bdee

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -203,9 +203,17 @@ class MambaPlayer:
203
 
204
  def update_activations(self, result):
205
  for layer_idx in self.activations_sum:
206
- for bucket in self.move_buckets:
207
- self.activations_sum[layer_idx][bucket][result] += self.activations_sum[layer_idx][bucket]["current"]
208
- self.activations_count[layer_idx][bucket][result] += 1
 
 
 
 
 
 
 
 
209
 
210
  def save_activations(self, path):
211
  if os.path.exists(path):
 
203
 
204
  def update_activations(self, result):
205
  for layer_idx in self.activations_sum:
206
+ if "result" == "reset":
207
+ self.activations_sum[layer_idx] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
208
+ "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
209
+ "current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
210
+ for bucket in self.move_buckets}
211
+ self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
212
+ for bucket in self.move_buckets}
213
+ else:
214
+ for bucket in self.move_buckets:
215
+ self.activations_sum[layer_idx][bucket][result] += self.activations_sum[layer_idx][bucket]["current"]
216
+ self.activations_count[layer_idx][bucket][result] += 1
217
 
218
  def save_activations(self, path):
219
  if os.path.exists(path):