HaileyStorm
commited on
Upload chess-gpt-eval-contrastive/mamba_module.py with huggingface_hub
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -321,10 +321,10 @@ class MambaPlayer:
|
|
321 |
|
322 |
def train_linear_probes(self):
|
323 |
def get_lr(it):
|
324 |
-
warmup_iters =
|
325 |
lr_decay_iters = 5000 * 43
|
326 |
-
learning_rate = 0.
|
327 |
-
min_lr = 0.
|
328 |
# 1) linear warmup for warmup_iters steps
|
329 |
if it < warmup_iters:
|
330 |
return learning_rate * it / warmup_iters
|
@@ -365,7 +365,7 @@ class MambaPlayer:
|
|
365 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
366 |
|
367 |
def save_linear_probe_data(self, path):
|
368 |
-
self.linear_save_ct +=
|
369 |
wandb.log({
|
370 |
"etc/games": self.linear_save_ct
|
371 |
}, step=self.wandb_step)
|
@@ -382,4 +382,4 @@ class MambaPlayer:
|
|
382 |
#probe.eval()
|
383 |
prediction = probe(X).item()
|
384 |
print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
|
385 |
-
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
|
|
321 |
|
322 |
def train_linear_probes(self):
|
323 |
def get_lr(it):
|
324 |
+
warmup_iters = 150 * 43
|
325 |
lr_decay_iters = 5000 * 43
|
326 |
+
learning_rate = 0.000015
|
327 |
+
min_lr = 0.000001
|
328 |
# 1) linear warmup for warmup_iters steps
|
329 |
if it < warmup_iters:
|
330 |
return learning_rate * it / warmup_iters
|
|
|
365 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
366 |
|
367 |
def save_linear_probe_data(self, path):
|
368 |
+
self.linear_save_ct += 25
|
369 |
wandb.log({
|
370 |
"etc/games": self.linear_save_ct
|
371 |
}, step=self.wandb_step)
|
|
|
382 |
#probe.eval()
|
383 |
prediction = probe(X).item()
|
384 |
print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
|
385 |
+
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|