Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -112,9 +112,9 @@ class MambaPlayer:
|
|
112 |
if update_contrastive or update_linear:
|
113 |
linear_size = self.model.config.d_model * 8 #self.model.config.d_model * self.max_seq_len
|
114 |
for i, layer in enumerate(self.model.backbone.layers):
|
115 |
-
self.activations_sum[i] = {bucket: {"won": np.zeros((1,
|
116 |
-
"lost": np.zeros((1,
|
117 |
-
"current": np.zeros((1,
|
118 |
for bucket in self.move_buckets}
|
119 |
self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
|
120 |
for bucket in self.move_buckets}
|
@@ -157,6 +157,7 @@ class MambaPlayer:
|
|
157 |
# Tokenize the game state
|
158 |
encoded_prompt = self.encode(game_state)
|
159 |
input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)
|
|
|
160 |
|
161 |
self.model.eval() # Set the model to evaluation mode
|
162 |
with torch.no_grad():
|
@@ -183,8 +184,8 @@ class MambaPlayer:
|
|
183 |
else:
|
184 |
have_non_space = True
|
185 |
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
|
|
|
186 |
|
187 |
-
self.seq_len = input_ids[0].size(dim=0)
|
188 |
model_response = self.decode(input_ids[0].tolist())
|
189 |
model_response = model_response[len(game_state):].split(";")[0]
|
190 |
return model_response
|
|
|
112 |
if update_contrastive or update_linear:
|
113 |
linear_size = self.model.config.d_model * 8 #self.model.config.d_model * self.max_seq_len
|
114 |
for i, layer in enumerate(self.model.backbone.layers):
|
115 |
+
self.activations_sum[i] = {bucket: {"won": np.zeros((1, 8, self.model.config.d_model)),
|
116 |
+
"lost": np.zeros((1, 8, self.model.config.d_model)),
|
117 |
+
"current": np.zeros((1, 8, self.model.config.d_model))}
|
118 |
for bucket in self.move_buckets}
|
119 |
self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
|
120 |
for bucket in self.move_buckets}
|
|
|
157 |
# Tokenize the game state
|
158 |
encoded_prompt = self.encode(game_state)
|
159 |
input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)
|
160 |
+
self.seq_len = input_ids[0].size(dim=0)
|
161 |
|
162 |
self.model.eval() # Set the model to evaluation mode
|
163 |
with torch.no_grad():
|
|
|
184 |
else:
|
185 |
have_non_space = True
|
186 |
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
|
187 |
+
self.seq_len += 1
|
188 |
|
|
|
189 |
model_response = self.decode(input_ids[0].tolist())
|
190 |
model_response = model_response[len(game_state):].split(";")[0]
|
191 |
return model_response
|