HaileyStorm
commited on
Commit
•
a73c8da
1
Parent(s):
abd7c69
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -101,10 +101,16 @@ class MambaPlayer:
|
|
101 |
self.activations_count = {}
|
102 |
if update_linear:
|
103 |
if linear_probe_path and os.path.exists(linear_probe_path):
|
104 |
-
|
105 |
-
self.linear_probes = pickle.load(f)
|
106 |
else:
|
107 |
self.linear_probes = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
if update_contrastive or update_linear:
|
109 |
for i, layer in enumerate(self.model.backbone.layers):
|
110 |
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
@@ -128,9 +134,9 @@ class MambaPlayer:
|
|
128 |
if update_linear:
|
129 |
if not linear_probe_path or not os.path.exists(linear_probe_path):
|
130 |
self.linear_probes[i] = {
|
131 |
-
'q_value':
|
132 |
-
'q_value_delta':
|
133 |
-
'material_balance':
|
134 |
}
|
135 |
if update_linear:
|
136 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
@@ -304,27 +310,27 @@ class MambaPlayer:
|
|
304 |
self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
|
305 |
self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
|
306 |
|
307 |
-
def train_linear_probes(self):
|
|
|
|
|
308 |
for layer_idx in self.linear_probes:
|
309 |
for bucket in self.move_buckets:
|
310 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
311 |
-
X = self.activations_sum[layer_idx][bucket]['current']
|
312 |
-
X = X.reshape(X.shape[1], -1) # Reshape X to have shape (sequence_length, flattened_features)
|
313 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
314 |
-
y =
|
315 |
if len(y) > 0:
|
316 |
-
|
317 |
-
|
318 |
-
self.
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
# Reset linear_probe_targets after training
|
323 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
324 |
|
325 |
def save_linear_probe_data(self, path):
|
326 |
-
|
327 |
-
pickle.dump(self.linear_probes, f)
|
328 |
|
329 |
def evaluate_linear_probes(self, board: chess.Board, game_state: str):
|
330 |
self.move_num = game_state.count('.')
|
|
|
101 |
self.activations_count = {}
|
102 |
if update_linear:
|
103 |
if linear_probe_path and os.path.exists(linear_probe_path):
|
104 |
+
self.linear_probes = torch.load(linear_probe_path)
|
|
|
105 |
else:
|
106 |
self.linear_probes = {}
|
107 |
+
slef.linear_optimizers = {
|
108 |
+
layer_idx: {
|
109 |
+
probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=lr)
|
110 |
+
for probe_type in ['q_value', 'q_value_delta', 'material_balance']
|
111 |
+
}
|
112 |
+
for layer_idx in self.linear_probes
|
113 |
+
}
|
114 |
if update_contrastive or update_linear:
|
115 |
for i, layer in enumerate(self.model.backbone.layers):
|
116 |
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
|
|
134 |
if update_linear:
|
135 |
if not linear_probe_path or not os.path.exists(linear_probe_path):
|
136 |
self.linear_probes[i] = {
|
137 |
+
'q_value': nn.Linear(self.model.config.d_model, 1),
|
138 |
+
'q_value_delta': nn.Linear(self.model.config.d_model, 1),
|
139 |
+
'material_balance': nn.Linear(self.model.config.d_model, 1)
|
140 |
}
|
141 |
if update_linear:
|
142 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
|
|
310 |
self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
|
311 |
self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
|
312 |
|
313 |
+
def train_linear_probes(self, lr=0.01):
|
314 |
+
criterion = nn.MSELoss()
|
315 |
+
|
316 |
for layer_idx in self.linear_probes:
|
317 |
for bucket in self.move_buckets:
|
318 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
319 |
+
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current'] #/ self.activations_count[layer_idx][bucket]['current']).float()
|
|
|
320 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
321 |
+
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
|
322 |
if len(y) > 0:
|
323 |
+
y_pred = self.linear_probes[layer_idx][probe_type](X)
|
324 |
+
loss = criterion(y_pred, y)
|
325 |
+
self.linear_optimizers[layer_idx][probe_type].zero_grad()
|
326 |
+
loss.backward()
|
327 |
+
self.linear_optimizers[layer_idx][probe_type].step()
|
328 |
+
|
329 |
# Reset linear_probe_targets after training
|
330 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
331 |
|
332 |
def save_linear_probe_data(self, path):
|
333 |
+
torch.save(self.linear_probes, path)
|
|
|
334 |
|
335 |
def evaluate_linear_probes(self, board: chess.Board, game_state: str):
|
336 |
self.move_num = game_state.count('.')
|