Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -10,6 +10,7 @@ import chess
|
|
10 |
from sklearn.linear_model import LinearRegression
|
11 |
import torch.nn as nn
|
12 |
import torch.optim as optim
|
|
|
13 |
|
14 |
BASE_DIR = "mamba/"
|
15 |
|
@@ -142,6 +143,7 @@ class MambaPlayer:
|
|
142 |
}
|
143 |
for layer_idx in self.linear_probes
|
144 |
}
|
|
|
145 |
|
146 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
147 |
game_state = game_state.split("\n\n")[-1].strip()
|
@@ -327,6 +329,7 @@ class MambaPlayer:
|
|
327 |
self.linear_optimizers[layer_idx][probe_type].zero_grad()
|
328 |
loss.backward()
|
329 |
self.linear_optimizers[layer_idx][probe_type].step()
|
|
|
330 |
|
331 |
# Reset linear_probe_targets after training
|
332 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
|
|
10 |
from sklearn.linear_model import LinearRegression
|
11 |
import torch.nn as nn
|
12 |
import torch.optim as optim
|
13 |
+
import wandb
|
14 |
|
15 |
BASE_DIR = "mamba/"
|
16 |
|
|
|
143 |
}
|
144 |
for layer_idx in self.linear_probes
|
145 |
}
|
146 |
+
wandb.init(project="mamba_linear_probes", name=f"mamba_linear_probes")
|
147 |
|
148 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
149 |
game_state = game_state.split("\n\n")[-1].strip()
|
|
|
329 |
self.linear_optimizers[layer_idx][probe_type].zero_grad()
|
330 |
loss.backward()
|
331 |
self.linear_optimizers[layer_idx][probe_type].step()
|
332 |
+
wandb.log({f"{probe_type}/layer_{layer_idx}_{bucket}_loss": loss.item()})
|
333 |
|
334 |
# Reset linear_probe_targets after training
|
335 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|