HaileyStorm
commited on
Commit
•
7136964
1
Parent(s):
5ca7ebb
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -110,6 +110,7 @@ class MambaPlayer:
|
|
110 |
else:
|
111 |
self.linear_probes = {}
|
112 |
if update_contrastive or update_linear:
|
|
|
113 |
for i, layer in enumerate(self.model.backbone.layers):
|
114 |
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
115 |
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
@@ -132,9 +133,9 @@ class MambaPlayer:
|
|
132 |
if update_linear:
|
133 |
if not linear_probe_path or not os.path.exists(linear_probe_path):
|
134 |
self.linear_probes[i] = {
|
135 |
-
'q_value': nn.Linear(
|
136 |
-
'q_value_delta': nn.Linear(
|
137 |
-
'material_balance': nn.Linear(
|
138 |
}
|
139 |
if update_linear:
|
140 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
@@ -331,7 +332,7 @@ class MambaPlayer:
|
|
331 |
for layer_idx in self.linear_probes:
|
332 |
for bucket in self.move_buckets:
|
333 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
334 |
-
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float() #/ self.activations_count[layer_idx][bucket]['current']).float()
|
335 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
336 |
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
|
337 |
if len(y) > 0:
|
|
|
110 |
else:
|
111 |
self.linear_probes = {}
|
112 |
if update_contrastive or update_linear:
|
113 |
+
linear_size = self.model.config.d_model * self.max_seq_len
|
114 |
for i, layer in enumerate(self.model.backbone.layers):
|
115 |
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
116 |
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
|
|
133 |
if update_linear:
|
134 |
if not linear_probe_path or not os.path.exists(linear_probe_path):
|
135 |
self.linear_probes[i] = {
|
136 |
+
'q_value': nn.Linear(linear_size, 1),
|
137 |
+
'q_value_delta': nn.Linear(linear_size, 1),
|
138 |
+
'material_balance': nn.Linear(linear_size, 1)
|
139 |
}
|
140 |
if update_linear:
|
141 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
|
|
332 |
for layer_idx in self.linear_probes:
|
333 |
for bucket in self.move_buckets:
|
334 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
335 |
+
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1) #/ self.activations_count[layer_idx][bucket]['current']).float()
|
336 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
337 |
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
|
338 |
if len(y) > 0:
|