HaileyStorm
commited on
Commit
•
9082726
1
Parent(s):
ae3133e
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -222,7 +222,7 @@ class MambaPlayer:
|
|
222 |
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
|
223 |
for bucket in self.move_buckets}
|
224 |
|
225 |
-
def apply_contrastive_activations(self, path):
|
226 |
if os.path.exists(path):
|
227 |
with open(path, "rb") as f:
|
228 |
activations_sum, activations_count = pickle.load(f)
|
@@ -232,7 +232,7 @@ class MambaPlayer:
|
|
232 |
won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
|
233 |
lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
|
234 |
contrastive_activations = won_activations - lost_activations
|
235 |
-
return output + torch.from_numpy(contrastive_activations[:, :seq_len, :]).to(output.device)
|
236 |
|
237 |
for layer_idx in activations_sum:
|
238 |
for bucket in self.move_buckets:
|
|
|
222 |
self.activations_count[layer_idx] = {bucket: {"won": 0, "lost": 0, "current": 0}
|
223 |
for bucket in self.move_buckets}
|
224 |
|
225 |
+
def apply_contrastive_activations(self, path, weight=1.0):
|
226 |
if os.path.exists(path):
|
227 |
with open(path, "rb") as f:
|
228 |
activations_sum, activations_count = pickle.load(f)
|
|
|
232 |
won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
|
233 |
lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
|
234 |
contrastive_activations = won_activations - lost_activations
|
235 |
+
return output + torch.from_numpy(contrastive_activations[:, :seq_len, :]).to(output.device) * weight
|
236 |
|
237 |
for layer_idx in activations_sum:
|
238 |
for bucket in self.move_buckets:
|