HaileyStorm commited on
Commit
164b5fe
·
verified ·
1 Parent(s): 9284512

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -110,7 +110,7 @@ class MambaPlayer:
110
  else:
111
  self.linear_probes = {}
112
  if update_contrastive or update_linear:
113
- linear_size = self.model.config.d_model * self.max_seq_len
114
  for i, layer in enumerate(self.model.backbone.layers):
115
  self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
116
  "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
@@ -184,6 +184,7 @@ class MambaPlayer:
184
  have_non_space = True
185
  input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
186
 
 
187
  model_response = self.decode(input_ids[0].tolist())
188
  model_response = model_response[len(game_state):].split(";")[0]
189
  return model_response
@@ -344,7 +345,7 @@ class MambaPlayer:
344
  for layer_idx in self.linear_probes:
345
  for bucket in self.move_buckets:
346
  if self.activations_count[layer_idx][bucket]['current'] > 0:
347
- X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1) #/ self.activations_count[layer_idx][bucket]['current']).float()
348
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
349
  y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
350
  if len(y) > 0:
@@ -375,7 +376,7 @@ class MambaPlayer:
375
  self.move_num = board.fullmove_number
376
  bucket = next(b for b in self.move_buckets if self.move_num <= b)
377
  for layer_idx in self.linear_probes:
378
- X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
379
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
380
  target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
381
  probe = self.linear_probes[layer_idx][probe_type]
 
110
  else:
111
  self.linear_probes = {}
112
  if update_contrastive or update_linear:
113
+ linear_size = self.model.config.d_model * 8 #self.model.config.d_model * self.max_seq_len
114
  for i, layer in enumerate(self.model.backbone.layers):
115
  self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
116
  "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
 
184
  have_non_space = True
185
  input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
186
 
187
+ self.seq_len = input_ids[0].size(dim=0)
188
  model_response = self.decode(input_ids[0].tolist())
189
  model_response = model_response[len(game_state):].split(";")[0]
190
  return model_response
 
345
  for layer_idx in self.linear_probes:
346
  for bucket in self.move_buckets:
347
  if self.activations_count[layer_idx][bucket]['current'] > 0:
348
+ X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)[:self.seq_len][-8:] #/ self.activations_count[layer_idx][bucket]['current']).float()
349
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
350
  y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
351
  if len(y) > 0:
 
376
  self.move_num = board.fullmove_number
377
  bucket = next(b for b in self.move_buckets if self.move_num <= b)
378
  for layer_idx in self.linear_probes:
379
+ X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)[:self.seq_len][-8:]
380
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
381
  target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
382
  probe = self.linear_probes[layer_idx][probe_type]