HaileyStorm commited on
Commit
abd7c69
1 Parent(s): 4ca9231

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -309,14 +309,15 @@ class MambaPlayer:
309
  for bucket in self.move_buckets:
310
  if self.activations_count[layer_idx][bucket]['current'] > 0:
311
  X = self.activations_sum[layer_idx][bucket]['current'] / self.activations_count[layer_idx][bucket]['current']
312
- X = X.reshape(X.shape[0], -1) # Reshape X to have 2 dimensions
313
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
314
  y = np.array(self.linear_probe_targets[layer_idx][bucket][probe_type])
315
- if len(y) == X.shape[0]: # Check if the number of samples match
316
- self.linear_probes[layer_idx][probe_type].fit(X, y)
317
- print(f"Fit layer {layer_idx} type {probe_type}.")
 
318
  else:
319
- print(f"Skipping training for layer {layer_idx}, bucket {bucket}, probe type {probe_type} due to inconsistent number of samples. X shaspe {X.shape}, Y shape {y.shape}")
320
 
321
  # Reset linear_probe_targets after training
322
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
 
309
  for bucket in self.move_buckets:
310
  if self.activations_count[layer_idx][bucket]['current'] > 0:
311
  X = self.activations_sum[layer_idx][bucket]['current'] / self.activations_count[layer_idx][bucket]['current']
312
+ X = X.reshape(X.shape[1], -1) # Reshape X to have shape (sequence_length, flattened_features)
313
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
314
  y = np.array(self.linear_probe_targets[layer_idx][bucket][probe_type])
315
+ if len(y) > 0:
316
+ # Repeat X to match the number of target values
317
+ X_repeated = np.repeat(X, len(y), axis=0)
318
+ self.linear_probes[layer_idx][probe_type].fit(X_repeated, y)
319
  else:
320
+ print(f"Skipping training for layer {layer_idx}, bucket {bucket}, probe type {probe_type} due to empty target values.")
321
 
322
  # Reset linear_probe_targets after training
323
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}