Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -320,14 +320,26 @@ class MambaPlayer:
|
|
320 |
self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
|
321 |
|
322 |
def train_linear_probes(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
criterion = nn.MSELoss()
|
324 |
self.wandb_step += 1
|
325 |
-
|
326 |
-
learning_rate = 0.00025
|
327 |
-
min_lr = 0.000025
|
328 |
-
|
329 |
-
coeff = 0.5 * (1.0 + math.cos(math.pi * min(self.wandb_step / decay_iters, 1.0)))
|
330 |
-
lr = min_lr + coeff * (learning_rate - min_lr)
|
331 |
|
332 |
for layer_idx in self.linear_probes:
|
333 |
for bucket in self.move_buckets:
|
@@ -367,7 +379,7 @@ class MambaPlayer:
|
|
367 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
368 |
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
|
369 |
probe = self.linear_probes[layer_idx][probe_type]
|
370 |
-
probe.eval()
|
371 |
prediction = probe(X).item()
|
372 |
print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
|
373 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
|
|
320 |
self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
|
321 |
|
322 |
def train_linear_probes(self):
|
323 |
+
def get_lr(it):
|
324 |
+
warmup_iters = 25 * 43
|
325 |
+
lr_decay_iters = 5000 * 43
|
326 |
+
learning_rate = 0.025
|
327 |
+
min_lr = 0.0001
|
328 |
+
# 1) linear warmup for warmup_iters steps
|
329 |
+
if it < warmup_iters:
|
330 |
+
return learning_rate * it / warmup_iters
|
331 |
+
# 2) if it > lr_decay_iters, return min learning rate
|
332 |
+
if it > lr_decay_iters:
|
333 |
+
return min_lr
|
334 |
+
# 3) in between, use cosine decay down to min learning rate
|
335 |
+
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
|
336 |
+
assert 0 <= decay_ratio <= 1
|
337 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
|
338 |
+
return min_lr + coeff * (learning_rate - min_lr)
|
339 |
+
|
340 |
criterion = nn.MSELoss()
|
341 |
self.wandb_step += 1
|
342 |
+
lr = get_lr(self.wandb_step)
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
for layer_idx in self.linear_probes:
|
345 |
for bucket in self.move_buckets:
|
|
|
379 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
380 |
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
|
381 |
probe = self.linear_probes[layer_idx][probe_type]
|
382 |
+
#probe.eval()
|
383 |
prediction = probe(X).item()
|
384 |
print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
|
385 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|