HaileyStorm
commited on
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -126,7 +126,7 @@ class MambaPlayer:
|
|
126 |
tensor_output = output
|
127 |
seq_len = tensor_output.shape[1]
|
128 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
129 |
-
self.activations_sum[layer_idx][bucket]["current"][:, :8, :] += tensor_output.detach().cpu().numpy()[:self.seq_len][-8:]
|
130 |
self.activations_count[layer_idx][bucket]["current"] += 1
|
131 |
|
132 |
self.hooks.append(layer.register_forward_hook(hook))
|
@@ -377,7 +377,7 @@ class MambaPlayer:
|
|
377 |
self.move_num = board.fullmove_number
|
378 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
379 |
for layer_idx in self.linear_probes:
|
380 |
-
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
|
381 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
382 |
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
|
383 |
probe = self.linear_probes[layer_idx][probe_type]
|
|
|
126 |
tensor_output = output
|
127 |
seq_len = tensor_output.shape[1]
|
128 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
129 |
+
self.activations_sum[layer_idx][bucket]["current"][:, :8, :] += tensor_output.detach().cpu().numpy()[:, :self.seq_len, :][:, -8:, :]
|
130 |
self.activations_count[layer_idx][bucket]["current"] += 1
|
131 |
|
132 |
self.hooks.append(layer.register_forward_hook(hook))
|
|
|
377 |
self.move_num = board.fullmove_number
|
378 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
379 |
for layer_idx in self.linear_probes:
|
380 |
+
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
|
381 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
382 |
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
|
383 |
probe = self.linear_probes[layer_idx][probe_type]
|