HaileyStorm
commited on
Commit
•
abd7c69
1
Parent(s):
4ca9231
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -309,14 +309,15 @@ class MambaPlayer:
|
|
309 |
for bucket in self.move_buckets:
|
310 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
311 |
X = self.activations_sum[layer_idx][bucket]['current'] / self.activations_count[layer_idx][bucket]['current']
|
312 |
-
X = X.reshape(X.shape[
|
313 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
314 |
y = np.array(self.linear_probe_targets[layer_idx][bucket][probe_type])
|
315 |
-
if len(y)
|
316 |
-
|
317 |
-
|
|
|
318 |
else:
|
319 |
-
print(f"Skipping training for layer {layer_idx}, bucket {bucket}, probe type {probe_type} due to
|
320 |
|
321 |
# Reset linear_probe_targets after training
|
322 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
|
|
309 |
for bucket in self.move_buckets:
|
310 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
311 |
X = self.activations_sum[layer_idx][bucket]['current'] / self.activations_count[layer_idx][bucket]['current']
|
312 |
+
X = X.reshape(X.shape[1], -1) # Reshape X to have shape (sequence_length, flattened_features)
|
313 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
314 |
y = np.array(self.linear_probe_targets[layer_idx][bucket][probe_type])
|
315 |
+
if len(y) > 0:
|
316 |
+
# Repeat X to match the number of target values
|
317 |
+
X_repeated = np.repeat(X, len(y), axis=0)
|
318 |
+
self.linear_probes[layer_idx][probe_type].fit(X_repeated, y)
|
319 |
else:
|
320 |
+
print(f"Skipping training for layer {layer_idx}, bucket {bucket}, probe type {probe_type} due to empty target values.")
|
321 |
|
322 |
# Reset linear_probe_targets after training
|
323 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|