HaileyStorm commited on
Commit
bf84c14
·
verified ·
1 Parent(s): 498c07f

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -320,14 +320,26 @@ class MambaPlayer:
320
  self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
321
 
322
  def train_linear_probes(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  criterion = nn.MSELoss()
324
  self.wandb_step += 1
325
- decay_iters = 2000 * 43
326
- learning_rate = 0.00025
327
- min_lr = 0.000025
328
-
329
- coeff = 0.5 * (1.0 + math.cos(math.pi * min(self.wandb_step / decay_iters, 1.0)))
330
- lr = min_lr + coeff * (learning_rate - min_lr)
331
 
332
  for layer_idx in self.linear_probes:
333
  for bucket in self.move_buckets:
@@ -367,7 +379,7 @@ class MambaPlayer:
367
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
368
  target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
369
  probe = self.linear_probes[layer_idx][probe_type]
370
- probe.eval()
371
  prediction = probe(X).item()
372
  print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
373
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
 
320
  self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
321
 
322
  def train_linear_probes(self):
323
+ def get_lr(it):
324
+ warmup_iters = 25 * 43
325
+ lr_decay_iters = 5000 * 43
326
+ learning_rate = 0.025
327
+ min_lr = 0.0001
328
+ # 1) linear warmup for warmup_iters steps
329
+ if it < warmup_iters:
330
+ return learning_rate * it / warmup_iters
331
+ # 2) if it > lr_decay_iters, return min learning rate
332
+ if it > lr_decay_iters:
333
+ return min_lr
334
+ # 3) in between, use cosine decay down to min learning rate
335
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
336
+ assert 0 <= decay_ratio <= 1
337
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
338
+ return min_lr + coeff * (learning_rate - min_lr)
339
+
340
  criterion = nn.MSELoss()
341
  self.wandb_step += 1
342
+ lr = get_lr(self.wandb_step)
 
 
 
 
 
343
 
344
  for layer_idx in self.linear_probes:
345
  for bucket in self.move_buckets:
 
379
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
380
  target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
381
  probe = self.linear_probes[layer_idx][probe_type]
382
+ #probe.eval()
383
  prediction = probe(X).item()
384
  print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
385
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}