HaileyStorm commited on
Commit
9284512
·
verified ·
1 Parent(s): bf84c14

Upload chess-gpt-eval-contrastive/mamba_module.py with huggingface_hub

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -321,10 +321,10 @@ class MambaPlayer:
321
 
322
  def train_linear_probes(self):
323
  def get_lr(it):
324
- warmup_iters = 25 * 43
325
  lr_decay_iters = 5000 * 43
326
- learning_rate = 0.025
327
- min_lr = 0.0001
328
  # 1) linear warmup for warmup_iters steps
329
  if it < warmup_iters:
330
  return learning_rate * it / warmup_iters
@@ -365,7 +365,7 @@ class MambaPlayer:
365
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
366
 
367
  def save_linear_probe_data(self, path):
368
- self.linear_save_ct += 1
369
  wandb.log({
370
  "etc/games": self.linear_save_ct
371
  }, step=self.wandb_step)
@@ -382,4 +382,4 @@ class MambaPlayer:
382
  #probe.eval()
383
  prediction = probe(X).item()
384
  print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
385
- self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
 
321
 
322
  def train_linear_probes(self):
323
  def get_lr(it):
324
+ warmup_iters = 150 * 43
325
  lr_decay_iters = 5000 * 43
326
+ learning_rate = 0.000015
327
+ min_lr = 0.000001
328
  # 1) linear warmup for warmup_iters steps
329
  if it < warmup_iters:
330
  return learning_rate * it / warmup_iters
 
365
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
366
 
367
  def save_linear_probe_data(self, path):
368
+ self.linear_save_ct += 25
369
  wandb.log({
370
  "etc/games": self.linear_save_ct
371
  }, step=self.wandb_step)
 
382
  #probe.eval()
383
  prediction = probe(X).item()
384
  print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
385
+ self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}