HaileyStorm commited on
Commit
4560751
·
verified ·
1 Parent(s): d29de63

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -112,9 +112,9 @@ class MambaPlayer:
112
  if update_contrastive or update_linear:
113
  linear_size = self.model.config.d_model * 8 #self.model.config.d_model * self.max_seq_len
114
  for i, layer in enumerate(self.model.backbone.layers):
115
- self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
116
- "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
117
- "current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
118
  for bucket in self.move_buckets}
119
  self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
120
  for bucket in self.move_buckets}
@@ -157,6 +157,7 @@ class MambaPlayer:
157
  # Tokenize the game state
158
  encoded_prompt = self.encode(game_state)
159
  input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)
 
160
 
161
  self.model.eval() # Set the model to evaluation mode
162
  with torch.no_grad():
@@ -183,8 +184,8 @@ class MambaPlayer:
183
  else:
184
  have_non_space = True
185
  input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
 
186
 
187
- self.seq_len = input_ids[0].size(dim=0)
188
  model_response = self.decode(input_ids[0].tolist())
189
  model_response = model_response[len(game_state):].split(";")[0]
190
  return model_response
 
112
  if update_contrastive or update_linear:
113
  linear_size = self.model.config.d_model * 8 #self.model.config.d_model * self.max_seq_len
114
  for i, layer in enumerate(self.model.backbone.layers):
115
+ self.activations_sum[i] = {bucket: {"won": np.zeros((1, 8, self.model.config.d_model)),
116
+ "lost": np.zeros((1, 8, self.model.config.d_model)),
117
+ "current": np.zeros((1, 8, self.model.config.d_model))}
118
  for bucket in self.move_buckets}
119
  self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
120
  for bucket in self.move_buckets}
 
157
  # Tokenize the game state
158
  encoded_prompt = self.encode(game_state)
159
  input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=self.device)
160
+ self.seq_len = input_ids[0].size(dim=0)
161
 
162
  self.model.eval() # Set the model to evaluation mode
163
  with torch.no_grad():
 
184
  else:
185
  have_non_space = True
186
  input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)
187
+ self.seq_len += 1
188
 
 
189
  model_response = self.decode(input_ids[0].tolist())
190
  model_response = model_response[len(game_state):].split(";")[0]
191
  return model_response