HaileyStorm
commited on
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -337,7 +337,6 @@ class MambaPlayer:
|
|
337 |
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
|
338 |
if len(y) > 0:
|
339 |
y_pred = self.linear_probes[layer_idx][probe_type](X)
|
340 |
-
print(f'{y_pred}')
|
341 |
loss = criterion(y_pred, y)
|
342 |
for param_group in self.linear_optimizers[layer_idx][probe_type].param_groups:
|
343 |
param_group['lr'] = lr
|
@@ -364,11 +363,11 @@ class MambaPlayer:
|
|
364 |
self.move_num = board.fullmove_number
|
365 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
366 |
for layer_idx in self.linear_probes:
|
367 |
-
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float()
|
368 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
369 |
-
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().
|
370 |
probe = self.linear_probes[layer_idx][probe_type]
|
371 |
probe.eval()
|
372 |
-
prediction = probe(X)
|
373 |
print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
|
374 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
|
|
337 |
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
|
338 |
if len(y) > 0:
|
339 |
y_pred = self.linear_probes[layer_idx][probe_type](X)
|
|
|
340 |
loss = criterion(y_pred, y)
|
341 |
for param_group in self.linear_optimizers[layer_idx][probe_type].param_groups:
|
342 |
param_group['lr'] = lr
|
|
|
363 |
self.move_num = board.fullmove_number
|
364 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
365 |
for layer_idx in self.linear_probes:
|
366 |
+
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
|
367 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
368 |
+
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
|
369 |
probe = self.linear_probes[layer_idx][probe_type]
|
370 |
probe.eval()
|
371 |
+
prediction = probe(X).item()
|
372 |
print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
|
373 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|