HaileyStorm
commited on
Commit
•
a7153f4
1
Parent(s):
6d717ee
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -358,12 +358,13 @@ class MambaPlayer:
|
|
358 |
}, step=self.wandb_step)
|
359 |
torch.save(self.linear_probes, path)
|
360 |
|
361 |
-
def evaluate_linear_probes(self, board: chess.Board
|
362 |
-
self.move_num =
|
363 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
364 |
for layer_idx in self.linear_probes:
|
365 |
-
X = torch.
|
366 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
|
|
367 |
probe = self.linear_probes[layer_idx][probe_type]
|
368 |
-
prediction = probe(X)
|
369 |
-
print(f"Layer {layer_idx}, {probe_type}: {prediction
|
|
|
358 |
}, step=self.wandb_step)
|
359 |
torch.save(self.linear_probes, path)
|
360 |
|
361 |
+
def evaluate_linear_probes(self, board: chess.Board):
|
362 |
+
self.move_num = board.fullmove_number
|
363 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
364 |
for layer_idx in self.linear_probes:
|
365 |
+
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float()
|
366 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
367 |
+
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item() #.unsqueeze(1)
|
368 |
probe = self.linear_probes[layer_idx][probe_type]
|
369 |
+
prediction = probe(X).item()
|
370 |
+
print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
|