Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -145,6 +145,8 @@ class MambaPlayer:
|
|
145 |
for layer_idx in self.linear_probes
|
146 |
}
|
147 |
wandb.init(project="mamba_linear_probes", name=f"mamba_linear_probes")
|
|
|
|
|
148 |
|
149 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
150 |
game_state = game_state.split("\n\n")[-1].strip()
|
@@ -317,6 +319,13 @@ class MambaPlayer:
|
|
317 |
|
318 |
def train_linear_probes(self):
|
319 |
criterion = nn.MSELoss()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
for layer_idx in self.linear_probes:
|
322 |
for bucket in self.move_buckets:
|
@@ -327,16 +336,25 @@ class MambaPlayer:
|
|
327 |
if len(y) > 0:
|
328 |
y_pred = self.linear_probes[layer_idx][probe_type](X)
|
329 |
loss = criterion(y_pred, y)
|
|
|
|
|
330 |
self.linear_optimizers[layer_idx][probe_type].zero_grad()
|
331 |
loss.backward()
|
332 |
self.linear_optimizers[layer_idx][probe_type].step()
|
333 |
#wandb.log({f"{probe_type}/layer_{layer_idx}_{bucket}_loss": loss.item()})
|
334 |
-
wandb.log({
|
|
|
|
|
|
|
335 |
|
336 |
# Reset linear_probe_targets after training
|
337 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
338 |
|
339 |
def save_linear_probe_data(self, path):
|
|
|
|
|
|
|
|
|
340 |
torch.save(self.linear_probes, path)
|
341 |
|
342 |
def evaluate_linear_probes(self, board: chess.Board, game_state: str):
|
|
|
145 |
for layer_idx in self.linear_probes
|
146 |
}
|
147 |
wandb.init(project="mamba_linear_probes", name=f"mamba_linear_probes")
|
148 |
+
self.wandb_step = 0
|
149 |
+
self.linear_save_ct += 1
|
150 |
|
151 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
152 |
game_state = game_state.split("\n\n")[-1].strip()
|
|
|
319 |
|
320 |
def train_linear_probes(self):
|
321 |
criterion = nn.MSELoss()
|
322 |
+
self.wandb_step += 1
|
323 |
+
decay_iters = 2000 * 40
|
324 |
+
learning_rate = 0.01
|
325 |
+
min_lr = 0.0001
|
326 |
+
|
327 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * min(self.wandb_step / decay_iters, 1.0)))
|
328 |
+
lr = min_lr + coeff * (learning_rate - min_lr)
|
329 |
|
330 |
for layer_idx in self.linear_probes:
|
331 |
for bucket in self.move_buckets:
|
|
|
336 |
if len(y) > 0:
|
337 |
y_pred = self.linear_probes[layer_idx][probe_type](X)
|
338 |
loss = criterion(y_pred, y)
|
339 |
+
for param_group in self.linear_optimizers[layer_idx][probe_type].param_groups:
|
340 |
+
param_group['lr'] = lr
|
341 |
self.linear_optimizers[layer_idx][probe_type].zero_grad()
|
342 |
loss.backward()
|
343 |
self.linear_optimizers[layer_idx][probe_type].step()
|
344 |
#wandb.log({f"{probe_type}/layer_{layer_idx}_{bucket}_loss": loss.item()})
|
345 |
+
wandb.log({
|
346 |
+
"etc/lr": lr,
|
347 |
+
f"{probe_type}/layer_{layer_idx}_loss": loss.item()
|
348 |
+
}, step=self.wandb_step)
|
349 |
|
350 |
# Reset linear_probe_targets after training
|
351 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
352 |
|
353 |
def save_linear_probe_data(self, path):
|
354 |
+
self.linear_save_ct += 1
|
355 |
+
wandb.log({
|
356 |
+
"etc/games": self.linear_save_ct
|
357 |
+
}, step=self.wandb_step)
|
358 |
torch.save(self.linear_probes, path)
|
359 |
|
360 |
def evaluate_linear_probes(self, board: chess.Board, game_state: str):
|