HaileyStorm
commited on
Commit
•
f6ed371
1
Parent(s):
bf10919
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -12,6 +12,7 @@ import torch.nn as nn
|
|
12 |
import torch.optim as optim
|
13 |
import wandb
|
14 |
import math
|
|
|
15 |
|
16 |
BASE_DIR = "mamba/"
|
17 |
|
@@ -376,13 +377,33 @@ class MambaPlayer:
|
|
376 |
def evaluate_linear_probes(self, board: chess.Board):
|
377 |
self.move_num = board.fullmove_number
|
378 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
|
|
|
|
|
|
|
|
379 |
for layer_idx in self.linear_probes:
|
380 |
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
|
381 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
382 |
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
|
383 |
probe = self.linear_probes[layer_idx][probe_type]
|
384 |
-
#probe.eval()
|
385 |
prediction = probe(X).item()
|
386 |
-
|
387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import torch.optim as optim
|
13 |
import wandb
|
14 |
import math
|
15 |
+
import json
|
16 |
|
17 |
BASE_DIR = "mamba/"
|
18 |
|
|
|
377 |
def evaluate_linear_probes(self, board: chess.Board):
|
378 |
self.move_num = board.fullmove_number
|
379 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
380 |
+
|
381 |
+
# Create a dictionary to store the statistics for the current move
|
382 |
+
probe_stats = {probe_type: {layer_idx: {self.move_num: None} for layer_idx in self.linear_probes} for probe_type in ['q_value', 'q_value_delta', 'material_balance']}
|
383 |
+
|
384 |
for layer_idx in self.linear_probes:
|
385 |
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
|
386 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
387 |
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
|
388 |
probe = self.linear_probes[layer_idx][probe_type]
|
|
|
389 |
prediction = probe(X).item()
|
390 |
+
#print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
|
391 |
+
|
392 |
+
# Calculate the percentage accuracy based on the probe type
|
393 |
+
if probe_type == 'q_value':
|
394 |
+
accuracy = 1 - abs(prediction - target) / 2 # Q-value range: -1 to 1
|
395 |
+
elif probe_type == 'q_value_delta':
|
396 |
+
accuracy = 1 - abs(prediction - target) / 4 # Q-value delta range: -2 to 2
|
397 |
+
else: # material_balance
|
398 |
+
max_range = 35 # Adjust this value based on the expected range of material balance
|
399 |
+
accuracy = 1 - min(abs(prediction - target) / max_range, 1)
|
400 |
+
|
401 |
+
# Store the accuracy in the probe_stats dictionary for the current move
|
402 |
+
probe_stats[probe_type][layer_idx][self.move_num] = accuracy
|
403 |
+
|
404 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
405 |
+
|
406 |
+
# Append the probe_stats to the file
|
407 |
+
with open('probe_stats.json', 'a') as f:
|
408 |
+
json.dump(probe_stats, f)
|
409 |
+
f.write('\n') # Add a newline separator between moves
|