HaileyStorm commited on
Commit
a73c8da
1 Parent(s): abd7c69

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -101,10 +101,16 @@ class MambaPlayer:
101
  self.activations_count = {}
102
  if update_linear:
103
  if linear_probe_path and os.path.exists(linear_probe_path):
104
- with open(linear_probe_path, 'rb') as f:
105
- self.linear_probes = pickle.load(f)
106
  else:
107
  self.linear_probes = {}
 
 
 
 
 
 
 
108
  if update_contrastive or update_linear:
109
  for i, layer in enumerate(self.model.backbone.layers):
110
  self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
@@ -128,9 +134,9 @@ class MambaPlayer:
128
  if update_linear:
129
  if not linear_probe_path or not os.path.exists(linear_probe_path):
130
  self.linear_probes[i] = {
131
- 'q_value': LinearRegression(),
132
- 'q_value_delta': LinearRegression(),
133
- 'material_balance': LinearRegression()
134
  }
135
  if update_linear:
136
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
@@ -304,27 +310,27 @@ class MambaPlayer:
304
  self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
305
  self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
306
 
307
- def train_linear_probes(self):
 
 
308
  for layer_idx in self.linear_probes:
309
  for bucket in self.move_buckets:
310
  if self.activations_count[layer_idx][bucket]['current'] > 0:
311
- X = self.activations_sum[layer_idx][bucket]['current'] / self.activations_count[layer_idx][bucket]['current']
312
- X = X.reshape(X.shape[1], -1) # Reshape X to have shape (sequence_length, flattened_features)
313
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
314
- y = np.array(self.linear_probe_targets[layer_idx][bucket][probe_type])
315
  if len(y) > 0:
316
- # Repeat X to match the number of target values
317
- X_repeated = np.repeat(X, len(y), axis=0)
318
- self.linear_probes[layer_idx][probe_type].fit(X_repeated, y)
319
- else:
320
- print(f"Skipping training for layer {layer_idx}, bucket {bucket}, probe type {probe_type} due to empty target values.")
321
-
322
  # Reset linear_probe_targets after training
323
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
324
 
325
  def save_linear_probe_data(self, path):
326
- with open(path, 'wb') as f:
327
- pickle.dump(self.linear_probes, f)
328
 
329
  def evaluate_linear_probes(self, board: chess.Board, game_state: str):
330
  self.move_num = game_state.count('.')
 
101
  self.activations_count = {}
102
  if update_linear:
103
  if linear_probe_path and os.path.exists(linear_probe_path):
104
+ self.linear_probes = torch.load(linear_probe_path)
 
105
  else:
106
  self.linear_probes = {}
107
+ slef.linear_optimizers = {
108
+ layer_idx: {
109
+ probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=lr)
110
+ for probe_type in ['q_value', 'q_value_delta', 'material_balance']
111
+ }
112
+ for layer_idx in self.linear_probes
113
+ }
114
  if update_contrastive or update_linear:
115
  for i, layer in enumerate(self.model.backbone.layers):
116
  self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
 
134
  if update_linear:
135
  if not linear_probe_path or not os.path.exists(linear_probe_path):
136
  self.linear_probes[i] = {
137
+ 'q_value': nn.Linear(self.model.config.d_model, 1),
138
+ 'q_value_delta': nn.Linear(self.model.config.d_model, 1),
139
+ 'material_balance': nn.Linear(self.model.config.d_model, 1)
140
  }
141
  if update_linear:
142
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
 
310
  self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
311
  self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
312
 
313
+ def train_linear_probes(self, lr=0.01):
314
+ criterion = nn.MSELoss()
315
+
316
  for layer_idx in self.linear_probes:
317
  for bucket in self.move_buckets:
318
  if self.activations_count[layer_idx][bucket]['current'] > 0:
319
+ X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current'] #/ self.activations_count[layer_idx][bucket]['current']).float()
 
320
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
321
+ y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
322
  if len(y) > 0:
323
+ y_pred = self.linear_probes[layer_idx][probe_type](X)
324
+ loss = criterion(y_pred, y)
325
+ self.linear_optimizers[layer_idx][probe_type].zero_grad()
326
+ loss.backward()
327
+ self.linear_optimizers[layer_idx][probe_type].step()
328
+
329
  # Reset linear_probe_targets after training
330
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
331
 
332
  def save_linear_probe_data(self, path):
333
+ torch.save(self.linear_probes, path)
 
334
 
335
  def evaluate_linear_probes(self, board: chess.Board, game_state: str):
336
  self.move_num = game_state.count('.')