HaileyStorm commited on
Commit
9082726
1 Parent(s): ae3133e

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -222,7 +222,7 @@ class MambaPlayer:
222
  self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
223
  for bucket in self.move_buckets}
224
 
225
- def apply_contrastive_activations(self, path):
226
  if os.path.exists(path):
227
  with open(path, "rb") as f:
228
  activations_sum, activations_count = pickle.load(f)
@@ -232,7 +232,7 @@ class MambaPlayer:
232
  won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
233
  lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
234
  contrastive_activations = won_activations - lost_activations
235
- return output + torch.from_numpy(contrastive_activations[:, :seq_len, :]).to(output.device)
236
 
237
  for layer_idx in activations_sum:
238
  for bucket in self.move_buckets:
 
222
  self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
223
  for bucket in self.move_buckets}
224
 
225
+ def apply_contrastive_activations(self, path, weight=1.0):
226
  if os.path.exists(path):
227
  with open(path, "rb") as f:
228
  activations_sum, activations_count = pickle.load(f)
 
232
  won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
233
  lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
234
  contrastive_activations = won_activations - lost_activations
235
+ return output + torch.from_numpy(contrastive_activations[:, :seq_len, :]).to(output.device) * weight
236
 
237
  for layer_idx in activations_sum:
238
  for bucket in self.move_buckets: