HaileyStorm
commited on
Commit
•
de4b222
1
Parent(s):
0eaef6c
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -11,7 +11,7 @@ import chess
|
|
11 |
BASE_DIR = "mamba/"
|
12 |
|
13 |
class MambaPlayer:
|
14 |
-
def __init__(self, model_name: str, move_num_in_gamestate: bool=False):
|
15 |
self.model_name = model_name
|
16 |
self.move_num_in_gamestate = move_num_in_gamestate
|
17 |
# -----------------------------------------------------------------------------
|
@@ -95,27 +95,43 @@ class MambaPlayer:
|
|
95 |
self.max_seq_len = 1536
|
96 |
self.move_buckets = [10, 20, 30, 40, float('inf')]
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
121 |
game_state = game_state.split("\n\n")[-1].strip()
|
@@ -270,3 +286,36 @@ class MambaPlayer:
|
|
270 |
self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
|
271 |
lambda module, input, output, layer_idx=layer_idx: hook(module, input, output, layer_idx)
|
272 |
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
BASE_DIR = "mamba/"
|
12 |
|
13 |
class MambaPlayer:
|
14 |
+
def __init__(self, model_name: str, move_num_in_gamestate: bool=False, update_contrastive: bool=False, update_linear: bool=False, linear_probe_path: str=None):
|
15 |
self.model_name = model_name
|
16 |
self.move_num_in_gamestate = move_num_in_gamestate
|
17 |
# -----------------------------------------------------------------------------
|
|
|
95 |
self.max_seq_len = 1536
|
96 |
self.move_buckets = [10, 20, 30, 40, float('inf')]
|
97 |
|
98 |
+
if update_contrastive or update_linear:
|
99 |
+
self.activations_sum = {}
|
100 |
+
self.activations_count = {}
|
101 |
+
if update_linear:
|
102 |
+
if linear_probe_path and os.path.exists(linear_probe_path):
|
103 |
+
self.linear_probes = torch.load(linear_probe_data_path)
|
104 |
+
else:
|
105 |
+
self.linear_probes = {}
|
106 |
+
self.linear_probe_targets = {}
|
107 |
+
if update_contrastive or update_linear:
|
108 |
+
for i, layer in enumerate(self.model.backbone.layers):
|
109 |
+
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
110 |
+
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
111 |
+
"current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
|
112 |
+
for bucket in self.move_buckets}
|
113 |
+
self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
|
114 |
+
for bucket in self.move_buckets}
|
115 |
+
|
116 |
+
def hook(module, input, output, layer_idx=i):
|
117 |
+
if isinstance(output, tuple):
|
118 |
+
tensor_output = output[0]
|
119 |
+
else:
|
120 |
+
tensor_output = output
|
121 |
+
seq_len = tensor_output.shape[1]
|
122 |
+
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
123 |
+
self.activations_sum[layer_idx][bucket]["current"][:, :seq_len, :] += tensor_output.detach().cpu().numpy()
|
124 |
+
self.activations_count[layer_idx][bucket]["current"] += 1
|
125 |
+
|
126 |
+
self.hooks.append(layer.register_forward_hook(hook))
|
127 |
+
if update_linear:
|
128 |
+
if not linear_probe_path or not os.path.exists(linear_probe_path):
|
129 |
+
self.linear_probes[i] = {
|
130 |
+
'q_value': torch.nn.Linear(self.model.config.d_model, 1),
|
131 |
+
'q_value_delta': torch.nn.Linear(self.model.config.d_model, 1),
|
132 |
+
'material_balance': torch.nn.Linear(self.model.config.d_model, 1)
|
133 |
+
}
|
134 |
+
self.linear_probe_targets[i] = {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets}
|
135 |
|
136 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
137 |
game_state = game_state.split("\n\n")[-1].strip()
|
|
|
286 |
self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
|
287 |
lambda module, input, output, layer_idx=layer_idx: hook(module, input, output, layer_idx)
|
288 |
))
|
289 |
+
|
290 |
+
def update_linear_probe_targets(self, curr_q_value, q_value_delta, material_bal):
|
291 |
+
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
292 |
+
for layer_idx in self.linear_probe_targets:
|
293 |
+
self.linear_probe_targets[layer_idx][bucket]['q_value'].append(curr_q_value)
|
294 |
+
self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
|
295 |
+
self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
|
296 |
+
|
297 |
+
def train_linear_probes(self):
|
298 |
+
for layer_idx in self.linear_probes:
|
299 |
+
for bucket in self.move_buckets:
|
300 |
+
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
301 |
+
X = self.activations_sum[layer_idx][bucket]['current'] / self.activations_count[layer_idx][bucket]['current']
|
302 |
+
X = torch.from_numpy(X).float()
|
303 |
+
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
304 |
+
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
|
305 |
+
self.linear_probes[layer_idx][probe_type].fit(X, y)
|
306 |
+
|
307 |
+
# Reset linear_probe_targets after training
|
308 |
+
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
309 |
+
|
310 |
+
def save_linear_probe_data(self, path):
|
311 |
+
torch.save(self.linear_probes, path)
|
312 |
+
|
313 |
+
def evaluate_linear_probes(self, board: chess.Board, game_state: str):
|
314 |
+
self.move_num = game_state.count('.')
|
315 |
+
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
316 |
+
for layer_idx in self.linear_probes:
|
317 |
+
X = torch.cat(self.activations_sum[layer_idx][bucket]['current'], dim=0)
|
318 |
+
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
319 |
+
probe = self.linear_probes[layer_idx][probe_type]
|
320 |
+
prediction = probe(X)
|
321 |
+
print(f"Layer {layer_idx}, {probe_type}: {prediction.item()}")
|