HaileyStorm
commited on
Commit
•
9ae57a2
1
Parent(s):
f207752
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -7,6 +7,7 @@ from contextlib import nullcontext
|
|
7 |
import numpy as np
|
8 |
from functools import partial
|
9 |
import chess
|
|
|
10 |
|
11 |
BASE_DIR = "mamba/"
|
12 |
|
@@ -100,10 +101,10 @@ class MambaPlayer:
|
|
100 |
self.activations_count = {}
|
101 |
if update_linear:
|
102 |
if linear_probe_path and os.path.exists(linear_probe_path):
|
103 |
-
|
|
|
104 |
else:
|
105 |
self.linear_probes = {}
|
106 |
-
self.linear_probe_targets = {}
|
107 |
if update_contrastive or update_linear:
|
108 |
for i, layer in enumerate(self.model.backbone.layers):
|
109 |
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
@@ -131,7 +132,8 @@ class MambaPlayer:
|
|
131 |
'q_value_delta': torch.nn.Linear(self.model.config.d_model, 1),
|
132 |
'material_balance': torch.nn.Linear(self.model.config.d_model, 1)
|
133 |
}
|
134 |
-
|
|
|
135 |
|
136 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
137 |
game_state = game_state.split("\n\n")[-1].strip()
|
@@ -307,16 +309,16 @@ class MambaPlayer:
|
|
307 |
for bucket in self.move_buckets:
|
308 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
309 |
X = self.activations_sum[layer_idx][bucket]['current'] / self.activations_count[layer_idx][bucket]['current']
|
310 |
-
X = torch.from_numpy(X).float()
|
311 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
312 |
-
y =
|
313 |
self.linear_probes[layer_idx][probe_type].fit(X, y)
|
314 |
-
|
315 |
# Reset linear_probe_targets after training
|
316 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
317 |
-
|
318 |
def save_linear_probe_data(self, path):
|
319 |
-
|
|
|
320 |
|
321 |
def evaluate_linear_probes(self, board: chess.Board, game_state: str):
|
322 |
self.move_num = game_state.count('.')
|
|
|
7 |
import numpy as np
|
8 |
from functools import partial
|
9 |
import chess
|
10 |
+
from sklearn.linear_model import LinearRegression
|
11 |
|
12 |
BASE_DIR = "mamba/"
|
13 |
|
|
|
101 |
self.activations_count = {}
|
102 |
if update_linear:
|
103 |
if linear_probe_path and os.path.exists(linear_probe_path):
|
104 |
+
with open(linear_probe_path, 'rb') as f:
|
105 |
+
self.linear_probes = pickle.load(f)
|
106 |
else:
|
107 |
self.linear_probes = {}
|
|
|
108 |
if update_contrastive or update_linear:
|
109 |
for i, layer in enumerate(self.model.backbone.layers):
|
110 |
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
|
|
132 |
'q_value_delta': torch.nn.Linear(self.model.config.d_model, 1),
|
133 |
'material_balance': torch.nn.Linear(self.model.config.d_model, 1)
|
134 |
}
|
135 |
+
if update_linear:
|
136 |
+
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
137 |
|
138 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
139 |
game_state = game_state.split("\n\n")[-1].strip()
|
|
|
309 |
for bucket in self.move_buckets:
|
310 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
311 |
X = self.activations_sum[layer_idx][bucket]['current'] / self.activations_count[layer_idx][bucket]['current']
|
|
|
312 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
313 |
+
y = self.linear_probe_targets[layer_idx][bucket][probe_type]
|
314 |
self.linear_probes[layer_idx][probe_type].fit(X, y)
|
315 |
+
|
316 |
# Reset linear_probe_targets after training
|
317 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
318 |
+
|
319 |
def save_linear_probe_data(self, path):
|
320 |
+
with open(path, 'wb') as f:
|
321 |
+
pickle.dump(self.linear_probes, f)
|
322 |
|
323 |
def evaluate_linear_probes(self, board: chess.Board, game_state: str):
|
324 |
self.move_num = game_state.count('.')
|