HaileyStorm commited on
Commit
f67254a
·
verified ·
1 Parent(s): 5238925

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -145,6 +145,8 @@ class MambaPlayer:
145
  for layer_idx in self.linear_probes
146
  }
147
  wandb.init(project="mamba_linear_probes", name=f"mamba_linear_probes")
 
 
148
 
149
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
150
  game_state = game_state.split("\n\n")[-1].strip()
@@ -317,6 +319,13 @@ class MambaPlayer:
317
 
318
  def train_linear_probes(self):
319
  criterion = nn.MSELoss()
 
 
 
 
 
 
 
320
 
321
  for layer_idx in self.linear_probes:
322
  for bucket in self.move_buckets:
@@ -327,16 +336,25 @@ class MambaPlayer:
327
  if len(y) > 0:
328
  y_pred = self.linear_probes[layer_idx][probe_type](X)
329
  loss = criterion(y_pred, y)
 
 
330
  self.linear_optimizers[layer_idx][probe_type].zero_grad()
331
  loss.backward()
332
  self.linear_optimizers[layer_idx][probe_type].step()
333
  #wandb.log({f"{probe_type}/layer_{layer_idx}_{bucket}_loss": loss.item()})
334
- wandb.log({f"{probe_type}/layer_{layer_idx}_loss": loss.item()})
 
 
 
335
 
336
  # Reset linear_probe_targets after training
337
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
338
 
339
  def save_linear_probe_data(self, path):
 
 
 
 
340
  torch.save(self.linear_probes, path)
341
 
342
  def evaluate_linear_probes(self, board: chess.Board, game_state: str):
 
145
  for layer_idx in self.linear_probes
146
  }
147
  wandb.init(project="mamba_linear_probes", name=f"mamba_linear_probes")
148
+ self.wandb_step = 0
149
+ self.linear_save_ct += 1
150
 
151
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
152
  game_state = game_state.split("\n\n")[-1].strip()
 
319
 
320
  def train_linear_probes(self):
321
  criterion = nn.MSELoss()
322
+ self.wandb_step += 1
323
+ decay_iters = 2000 * 40
324
+ learning_rate = 0.01
325
+ min_lr = 0.0001
326
+
327
+ coeff = 0.5 * (1.0 + math.cos(math.pi * min(self.wandb_step / decay_iters, 1.0)))
328
+ lr = min_lr + coeff * (learning_rate - min_lr)
329
 
330
  for layer_idx in self.linear_probes:
331
  for bucket in self.move_buckets:
 
336
  if len(y) > 0:
337
  y_pred = self.linear_probes[layer_idx][probe_type](X)
338
  loss = criterion(y_pred, y)
339
+ for param_group in self.linear_optimizers[layer_idx][probe_type].param_groups:
340
+ param_group['lr'] = lr
341
  self.linear_optimizers[layer_idx][probe_type].zero_grad()
342
  loss.backward()
343
  self.linear_optimizers[layer_idx][probe_type].step()
344
  #wandb.log({f"{probe_type}/layer_{layer_idx}_{bucket}_loss": loss.item()})
345
+ wandb.log({
346
+ "etc/lr": lr,
347
+ f"{probe_type}/layer_{layer_idx}_loss": loss.item()
348
+ }, step=self.wandb_step)
349
 
350
  # Reset linear_probe_targets after training
351
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
352
 
353
  def save_linear_probe_data(self, path):
354
+ self.linear_save_ct += 1
355
+ wandb.log({
356
+ "etc/games": self.linear_save_ct
357
+ }, step=self.wandb_step)
358
  torch.save(self.linear_probes, path)
359
 
360
  def evaluate_linear_probes(self, board: chess.Board, game_state: str):