HaileyStorm commited on
Commit
7136964
1 Parent(s): 5ca7ebb

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -110,6 +110,7 @@ class MambaPlayer:
110
  else:
111
  self.linear_probes = {}
112
  if update_contrastive or update_linear:
 
113
  for i, layer in enumerate(self.model.backbone.layers):
114
  self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
115
  "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
@@ -132,9 +133,9 @@ class MambaPlayer:
132
  if update_linear:
133
  if not linear_probe_path or not os.path.exists(linear_probe_path):
134
  self.linear_probes[i] = {
135
- 'q_value': nn.Linear(self.model.config.d_model, 1),
136
- 'q_value_delta': nn.Linear(self.model.config.d_model, 1),
137
- 'material_balance': nn.Linear(self.model.config.d_model, 1)
138
  }
139
  if update_linear:
140
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
@@ -331,7 +332,7 @@ class MambaPlayer:
331
  for layer_idx in self.linear_probes:
332
  for bucket in self.move_buckets:
333
  if self.activations_count[layer_idx][bucket]['current'] > 0:
334
- X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float() #/ self.activations_count[layer_idx][bucket]['current']).float()
335
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
336
  y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
337
  if len(y) > 0:
 
110
  else:
111
  self.linear_probes = {}
112
  if update_contrastive or update_linear:
113
+ linear_size = self.model.config.d_model * self.max_seq_len
114
  for i, layer in enumerate(self.model.backbone.layers):
115
  self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
116
  "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
 
133
  if update_linear:
134
  if not linear_probe_path or not os.path.exists(linear_probe_path):
135
  self.linear_probes[i] = {
136
+ 'q_value': nn.Linear(linear_size, 1),
137
+ 'q_value_delta': nn.Linear(linear_size, 1),
138
+ 'material_balance': nn.Linear(linear_size, 1)
139
  }
140
  if update_linear:
141
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
 
332
  for layer_idx in self.linear_probes:
333
  for bucket in self.move_buckets:
334
  if self.activations_count[layer_idx][bucket]['current'] > 0:
335
+ X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1) #/ self.activations_count[layer_idx][bucket]['current']).float()
336
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
337
  y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
338
  if len(y) > 0: