HaileyStorm commited on
Commit
72c8584
·
verified ·
1 Parent(s): e8c8242

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -89,16 +89,20 @@ class MambaPlayer:
89
  self.ctx = ctx
90
  self.device = device
91
 
 
92
  self.hooks = []
93
  self.max_seq_len = 1536
 
94
 
95
  self.activations_sum = {}
96
  self.activations_count = {}
97
  for i, layer in enumerate(self.model.backbone.layers):
98
- self.activations_sum[i] = {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
99
- "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
100
- "current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
101
- self.activations_count[i] = {"won": 0, "lost": 0, "current": 0}
 
 
102
 
103
  def hook(module, input, output, layer_idx=i):
104
  if isinstance(output, tuple):
@@ -106,8 +110,9 @@ class MambaPlayer:
106
  else:
107
  tensor_output = output
108
  seq_len = tensor_output.shape[1]
109
- self.activations_sum[layer_idx]["current"][:, :seq_len, :] += tensor_output.detach().cpu().numpy()
110
- self.activations_count[layer_idx]["current"] += 1
 
111
 
112
  self.hooks.append(layer.register_forward_hook(hook))
113
 
@@ -167,6 +172,7 @@ class MambaPlayer:
167
  return None
168
 
169
  def get_move(self, board: str, game_state: str, temperature: float) -> str:
 
170
  completion = self.get_mamba_response(game_state, temperature, 8, self.vocab_size)
171
  return self.get_move_from_response(completion)
172
 
@@ -175,9 +181,10 @@ class MambaPlayer:
175
 
176
  def update_activations(self, result):
177
  for layer_idx in self.activations_sum:
178
- self.activations_sum[layer_idx][result] += self.activations_sum[layer_idx]["current"]
179
- self.activations_count[layer_idx][result] += 1
180
-
 
181
  def save_activations(self, path):
182
  if os.path.exists(path):
183
  with open(path, "rb") as f:
@@ -187,38 +194,47 @@ class MambaPlayer:
187
  activations_count = {}
188
 
189
  for layer_idx in self.activations_sum:
190
- for category in ["won", "lost"]:
 
 
191
  if layer_idx not in activations_sum:
192
  activations_sum[layer_idx] = {}
193
  activations_count[layer_idx] = {}
194
- if category not in activations_sum[layer_idx]:
195
- activations_sum[layer_idx][category] = np.zeros((1, self.max_seq_len, self.model.config.d_model))
196
- activations_count[layer_idx][category] = 0
197
-
198
- activations_sum[layer_idx][category] += self.activations_sum[layer_idx][category]
199
- activations_count[layer_idx][category] += self.activations_count[layer_idx][category]
 
 
 
 
200
 
201
  with open(path, "wb") as f:
202
  pickle.dump((activations_sum, activations_count), f)
203
 
204
  for layer_idx in self.activations_sum:
205
- self.activations_sum[layer_idx]["current"].fill(0)
206
- self.activations_count[layer_idx]["current"] = 0
 
207
 
208
  def apply_contrastive_activations(self, path):
209
  if os.path.exists(path):
210
  with open(path, "rb") as f:
211
  activations_sum, activations_count = pickle.load(f)
212
 
213
- for layer_idx in activations_sum:
214
- won_activations = activations_sum[layer_idx]["won"] / activations_count[layer_idx]["won"]
215
- lost_activations = activations_sum[layer_idx]["lost"] / activations_count[layer_idx]["lost"]
 
216
  contrastive_activations = won_activations - lost_activations
217
-
218
- def hook(module, input, output):
219
- seq_len = output.shape[1]
220
- return output + torch.from_numpy(contrastive_activations[:, :seq_len, :]).to(output.device)
221
-
222
- self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
 
223
 
224
 
 
89
  self.ctx = ctx
90
  self.device = device
91
 
92
+ self.move_num = 0
93
  self.hooks = []
94
  self.max_seq_len = 1536
95
+ self.move_buckets = [10, 20, 30, 40, float('inf')]
96
 
97
  self.activations_sum = {}
98
  self.activations_count = {}
99
  for i, layer in enumerate(self.model.backbone.layers):
100
+ self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
101
+ "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
102
+ "current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
103
+ for bucket in self.move_buckets}
104
+ self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
105
+ for bucket in self.move_buckets}
106
 
107
  def hook(module, input, output, layer_idx=i):
108
  if isinstance(output, tuple):
 
110
  else:
111
  tensor_output = output
112
  seq_len = tensor_output.shape[1]
113
+ bucket = next(b for b in self.move_buckets if self.move_num <= b)
114
+ self.activations_sum[layer_idx][bucket]["current"][:, :seq_len, :] += tensor_output.detach().cpu().numpy()
115
+ self.activations_count[layer_idx][bucket]["current"] += 1
116
 
117
  self.hooks.append(layer.register_forward_hook(hook))
118
 
 
172
  return None
173
 
174
  def get_move(self, board: str, game_state: str, temperature: float) -> str:
175
+ self.move_num = game_state.count('.')
176
  completion = self.get_mamba_response(game_state, temperature, 8, self.vocab_size)
177
  return self.get_move_from_response(completion)
178
 
 
181
 
182
  def update_activations(self, result):
183
  for layer_idx in self.activations_sum:
184
+ for bucket in self.move_buckets:
185
+ self.activations_sum[layer_idx][bucket][result] += self.activations_sum[layer_idx][bucket]["current"]
186
+ self.activations_count[layer_idx][bucket][result] += 1
187
+
188
  def save_activations(self, path):
189
  if os.path.exists(path):
190
  with open(path, "rb") as f:
 
194
  activations_count = {}
195
 
196
  for layer_idx in self.activations_sum:
197
+ for bucket in self.move_buckets:
198
+ if self.activations_count[layer_idx][bucket]["current"] == 0:
199
+ continue
200
  if layer_idx not in activations_sum:
201
  activations_sum[layer_idx] = {}
202
  activations_count[layer_idx] = {}
203
+ if bucket not in activations_sum[layer_idx]:
204
+ activations_sum[layer_idx][bucket] = {}
205
+ activations_count[layer_idx][bucket] = {}
206
+ for category in ["won", "lost"]:
207
+ if category not in activations_sum[layer_idx][bucket]:
208
+ activations_sum[layer_idx][bucket][category] = np.zeros((1, self.max_seq_len, self.model.config.d_model))
209
+ activations_count[layer_idx][bucket][category] = 0
210
+
211
+ activations_sum[layer_idx][bucket][category] += self.activations_sum[layer_idx][bucket][category]
212
+ activations_count[layer_idx][bucket][category] += self.activations_count[layer_idx][bucket][category]
213
 
214
  with open(path, "wb") as f:
215
  pickle.dump((activations_sum, activations_count), f)
216
 
217
  for layer_idx in self.activations_sum:
218
+ for bucket in self.move_buckets:
219
+ self.activations_sum[layer_idx][bucket]["current"].fill(0)
220
+ self.activations_count[layer_idx][bucket]["current"] = 0
221
 
222
  def apply_contrastive_activations(self, path):
223
  if os.path.exists(path):
224
  with open(path, "rb") as f:
225
  activations_sum, activations_count = pickle.load(f)
226
 
227
+ def hook(module, input, output, layer_idx, bucket):
228
+ seq_len = output.shape[1]
229
+ won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
230
+ lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
231
  contrastive_activations = won_activations - lost_activations
232
+ return output + torch.from_numpy(contrastive_activations[:, :seq_len, :]).to(output.device)
233
+
234
+ for layer_idx in activations_sum:
235
+ for bucket in self.move_buckets:
236
+ self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
237
+ lambda module, input, output, layer_idx=layer_idx, bucket=bucket: hook(module, input, output, layer_idx, bucket)
238
+ ))
239
 
240