HaileyStorm commited on
Commit
de4b222
1 Parent(s): 0eaef6c

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -11,7 +11,7 @@ import chess
11
  BASE_DIR = "mamba/"
12
 
13
  class MambaPlayer:
14
- def __init__(self, model_name: str, move_num_in_gamestate: bool=False):
15
  self.model_name = model_name
16
  self.move_num_in_gamestate = move_num_in_gamestate
17
  # -----------------------------------------------------------------------------
@@ -95,27 +95,43 @@ class MambaPlayer:
95
  self.max_seq_len = 1536
96
  self.move_buckets = [10, 20, 30, 40, float('inf')]
97
 
98
- self.activations_sum = {}
99
- self.activations_count = {}
100
- for i, layer in enumerate(self.model.backbone.layers):
101
- self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
102
- "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
103
- "current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
104
- for bucket in self.move_buckets}
105
- self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
106
- for bucket in self.move_buckets}
107
-
108
- def hook(module, input, output, layer_idx=i):
109
- if isinstance(output, tuple):
110
- tensor_output = output[0]
111
- else:
112
- tensor_output = output
113
- seq_len = tensor_output.shape[1]
114
- bucket = next(b for b in self.move_buckets if self.move_num <= b)
115
- self.activations_sum[layer_idx][bucket]["current"][:, :seq_len, :] += tensor_output.detach().cpu().numpy()
116
- self.activations_count[layer_idx][bucket]["current"] += 1
117
-
118
- self.hooks.append(layer.register_forward_hook(hook))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
121
  game_state = game_state.split("\n\n")[-1].strip()
@@ -270,3 +286,36 @@ class MambaPlayer:
270
  self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
271
  lambda module, input, output, layer_idx=layer_idx: hook(module, input, output, layer_idx)
272
  ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  BASE_DIR = "mamba/"
12
 
13
  class MambaPlayer:
14
+ def __init__(self, model_name: str, move_num_in_gamestate: bool=False, update_contrastive: bool=False, update_linear: bool=False, linear_probe_path: str=None):
15
  self.model_name = model_name
16
  self.move_num_in_gamestate = move_num_in_gamestate
17
  # -----------------------------------------------------------------------------
 
95
  self.max_seq_len = 1536
96
  self.move_buckets = [10, 20, 30, 40, float('inf')]
97
 
98
+ if update_contrastive or update_linear:
99
+ self.activations_sum = {}
100
+ self.activations_count = {}
101
+ if update_linear:
102
+ if linear_probe_path and os.path.exists(linear_probe_path):
103
+ self.linear_probes = torch.load(linear_probe_data_path)
104
+ else:
105
+ self.linear_probes = {}
106
+ self.linear_probe_targets = {}
107
+ if update_contrastive or update_linear:
108
+ for i, layer in enumerate(self.model.backbone.layers):
109
+ self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
110
+ "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
111
+ "current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
112
+ for bucket in self.move_buckets}
113
+ self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
114
+ for bucket in self.move_buckets}
115
+
116
+ def hook(module, input, output, layer_idx=i):
117
+ if isinstance(output, tuple):
118
+ tensor_output = output[0]
119
+ else:
120
+ tensor_output = output
121
+ seq_len = tensor_output.shape[1]
122
+ bucket = next(b for b in self.move_buckets if self.move_num <= b)
123
+ self.activations_sum[layer_idx][bucket]["current"][:, :seq_len, :] += tensor_output.detach().cpu().numpy()
124
+ self.activations_count[layer_idx][bucket]["current"] += 1
125
+
126
+ self.hooks.append(layer.register_forward_hook(hook))
127
+ if update_linear:
128
+ if not linear_probe_path or not os.path.exists(linear_probe_path):
129
+ self.linear_probes[i] = {
130
+ 'q_value': torch.nn.Linear(self.model.config.d_model, 1),
131
+ 'q_value_delta': torch.nn.Linear(self.model.config.d_model, 1),
132
+ 'material_balance': torch.nn.Linear(self.model.config.d_model, 1)
133
+ }
134
+ self.linear_probe_targets[i] = {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets}
135
 
136
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
137
  game_state = game_state.split("\n\n")[-1].strip()
 
286
  self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
287
  lambda module, input, output, layer_idx=layer_idx: hook(module, input, output, layer_idx)
288
  ))
289
+
290
+ def update_linear_probe_targets(self, curr_q_value, q_value_delta, material_bal):
291
+ bucket = next(b for b in self.move_buckets if self.move_num <= b)
292
+ for layer_idx in self.linear_probe_targets:
293
+ self.linear_probe_targets[layer_idx][bucket]['q_value'].append(curr_q_value)
294
+ self.linear_probe_targets[layer_idx][bucket]['q_value_delta'].append(q_value_delta)
295
+ self.linear_probe_targets[layer_idx][bucket]['material_balance'].append(material_bal)
296
+
297
+ def train_linear_probes(self):
298
+ for layer_idx in self.linear_probes:
299
+ for bucket in self.move_buckets:
300
+ if self.activations_count[layer_idx][bucket]['current'] > 0:
301
+ X = self.activations_sum[layer_idx][bucket]['current'] / self.activations_count[layer_idx][bucket]['current']
302
+ X = torch.from_numpy(X).float()
303
+ for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
304
+ y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
305
+ self.linear_probes[layer_idx][probe_type].fit(X, y)
306
+
307
+ # Reset linear_probe_targets after training
308
+ self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
309
+
310
+ def save_linear_probe_data(self, path):
311
+ torch.save(self.linear_probes, path)
312
+
313
+ def evaluate_linear_probes(self, board: chess.Board, game_state: str):
314
+ self.move_num = game_state.count('.')
315
+ bucket = next(b for b in self.move_buckets if self.move_num <= b)
316
+ for layer_idx in self.linear_probes:
317
+ X = torch.cat(self.activations_sum[layer_idx][bucket]['current'], dim=0)
318
+ for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
319
+ probe = self.linear_probes[layer_idx][probe_type]
320
+ prediction = probe(X)
321
+ print(f"Layer {layer_idx}, {probe_type}: {prediction.item()}")