HaileyStorm commited on
Commit
7878a45
·
verified ·
1 Parent(s): 45d2b20

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -10,6 +10,7 @@ import chess
10
  from sklearn.linear_model import LinearRegression
11
  import torch.nn as nn
12
  import torch.optim as optim
 
13
 
14
  BASE_DIR = "mamba/"
15
 
@@ -142,6 +143,7 @@ class MambaPlayer:
142
  }
143
  for layer_idx in self.linear_probes
144
  }
 
145
 
146
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
147
  game_state = game_state.split("\n\n")[-1].strip()
@@ -327,6 +329,7 @@ class MambaPlayer:
327
  self.linear_optimizers[layer_idx][probe_type].zero_grad()
328
  loss.backward()
329
  self.linear_optimizers[layer_idx][probe_type].step()
 
330
 
331
  # Reset linear_probe_targets after training
332
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
 
10
  from sklearn.linear_model import LinearRegression
11
  import torch.nn as nn
12
  import torch.optim as optim
13
+ import wandb
14
 
15
  BASE_DIR = "mamba/"
16
 
 
143
  }
144
  for layer_idx in self.linear_probes
145
  }
146
+ wandb.init(project="mamba_linear_probes", name=f"mamba_linear_probes")
147
 
148
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
149
  game_state = game_state.split("\n\n")[-1].strip()
 
329
  self.linear_optimizers[layer_idx][probe_type].zero_grad()
330
  loss.backward()
331
  self.linear_optimizers[layer_idx][probe_type].step()
332
+ wandb.log({f"{probe_type}/layer_{layer_idx}_{bucket}_loss": loss.item()})
333
 
334
  # Reset linear_probe_targets after training
335
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}