HaileyStorm
commited on
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -110,7 +110,7 @@ class MambaPlayer:
|
|
110 |
else:
|
111 |
self.linear_probes = {}
|
112 |
if update_contrastive or update_linear:
|
113 |
-
linear_size = self.model.config.d_model * self.max_seq_len
|
114 |
for i, layer in enumerate(self.model.backbone.layers):
|
115 |
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
116 |
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
@@ -184,6 +184,7 @@ class MambaPlayer:
|
|
184 |
have_non_space = True
|
185 |
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
|
186 |
|
|
|
187 |
model_response = self.decode(input_ids[0].tolist())
|
188 |
model_response = model_response[len(game_state):].split(";")[0]
|
189 |
return model_response
|
@@ -344,7 +345,7 @@ class MambaPlayer:
|
|
344 |
for layer_idx in self.linear_probes:
|
345 |
for bucket in self.move_buckets:
|
346 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
347 |
-
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1) #/ self.activations_count[layer_idx][bucket]['current']).float()
|
348 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
349 |
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
|
350 |
if len(y) > 0:
|
@@ -375,7 +376,7 @@ class MambaPlayer:
|
|
375 |
self.move_num = board.fullmove_number
|
376 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
377 |
for layer_idx in self.linear_probes:
|
378 |
-
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
|
379 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
380 |
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
|
381 |
probe = self.linear_probes[layer_idx][probe_type]
|
|
|
110 |
else:
|
111 |
self.linear_probes = {}
|
112 |
if update_contrastive or update_linear:
|
113 |
+
linear_size = self.model.config.d_model * 8 #self.model.config.d_model * self.max_seq_len
|
114 |
for i, layer in enumerate(self.model.backbone.layers):
|
115 |
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
116 |
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
|
|
184 |
have_non_space = True
|
185 |
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
|
186 |
|
187 |
+
self.seq_len = input_ids[0].size(dim=0)
|
188 |
model_response = self.decode(input_ids[0].tolist())
|
189 |
model_response = model_response[len(game_state):].split(";")[0]
|
190 |
return model_response
|
|
|
345 |
for layer_idx in self.linear_probes:
|
346 |
for bucket in self.move_buckets:
|
347 |
if self.activations_count[layer_idx][bucket]['current'] > 0:
|
348 |
+
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)[:self.seq_len][-8:] #/ self.activations_count[layer_idx][bucket]['current']).float()
|
349 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
350 |
y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
|
351 |
if len(y) > 0:
|
|
|
376 |
self.move_num = board.fullmove_number
|
377 |
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
378 |
for layer_idx in self.linear_probes:
|
379 |
+
X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)[:self.seq_len][-8:]
|
380 |
for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
|
381 |
target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
|
382 |
probe = self.linear_probes[layer_idx][probe_type]
|