HaileyStorm commited on
Commit
1fa91dd
·
verified ·
1 Parent(s): 842bafd

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -90,11 +90,14 @@ class MambaPlayer:
90
  self.device = device
91
 
92
  self.hooks = []
 
 
93
  self.activations_sum = {}
94
  self.activations_count = {}
95
-
96
  for i, layer in enumerate(self.model.backbone.layers):
97
- self.activations_sum[i] = {"won": 0, "lost": 0, "current": 0}
 
 
98
  self.activations_count[i] = {"won": 0, "lost": 0, "current": 0}
99
 
100
  def hook(module, input, output, layer_idx=i):
@@ -102,7 +105,8 @@ class MambaPlayer:
102
  tensor_output = output[0]
103
  else:
104
  tensor_output = output
105
- self.activations_sum[layer_idx]["current"] += tensor_output.detach().cpu().numpy()
 
106
  self.activations_count[layer_idx]["current"] += 1
107
 
108
  self.hooks.append(layer.register_forward_hook(hook))
@@ -188,7 +192,7 @@ class MambaPlayer:
188
  activations_sum[layer_idx] = {}
189
  activations_count[layer_idx] = {}
190
  if category not in activations_sum[layer_idx]:
191
- activations_sum[layer_idx][category] = 0
192
  activations_count[layer_idx][category] = 0
193
 
194
  activations_sum[layer_idx][category] += self.activations_sum[layer_idx][category]
@@ -196,12 +200,10 @@ class MambaPlayer:
196
 
197
  with open(path, "wb") as f:
198
  pickle.dump((activations_sum, activations_count), f)
199
-
200
- self.activations_sum = {}
201
- self.activations_count = {}
202
- for i, layer in enumerate(self.model.backbone.layers):
203
- self.activations_sum[i] = {"won": 0, "lost": 0, "current": 0}
204
- self.activations_count[i] = {"won": 0, "lost": 0, "current": 0}
205
 
206
  def apply_contrastive_activations(self, path):
207
  if os.path.exists(path):
@@ -214,7 +216,8 @@ class MambaPlayer:
214
  contrastive_activations = won_activations - lost_activations
215
 
216
  def hook(module, input, output):
217
- return output + torch.from_numpy(contrastive_activations).to(output.device)
 
218
 
219
  self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
220
 
 
90
  self.device = device
91
 
92
  self.hooks = []
93
+ self.max_seq_len = 1536
94
+
95
  self.activations_sum = {}
96
  self.activations_count = {}
 
97
  for i, layer in enumerate(self.model.backbone.layers):
98
+ self.activations_sum[i] = {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
99
+ "lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
100
+ "current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
101
  self.activations_count[i] = {"won": 0, "lost": 0, "current": 0}
102
 
103
  def hook(module, input, output, layer_idx=i):
 
105
  tensor_output = output[0]
106
  else:
107
  tensor_output = output
108
+ seq_len = tensor_output.shape[1]
109
+ self.activations_sum[layer_idx]["current"][:, :seq_len, :] += tensor_output.detach().cpu().numpy()
110
  self.activations_count[layer_idx]["current"] += 1
111
 
112
  self.hooks.append(layer.register_forward_hook(hook))
 
192
  activations_sum[layer_idx] = {}
193
  activations_count[layer_idx] = {}
194
  if category not in activations_sum[layer_idx]:
195
+ activations_sum[layer_idx][category] = np.zeros((1, self.max_seq_len, self.model.config.d_model))
196
  activations_count[layer_idx][category] = 0
197
 
198
  activations_sum[layer_idx][category] += self.activations_sum[layer_idx][category]
 
200
 
201
  with open(path, "wb") as f:
202
  pickle.dump((activations_sum, activations_count), f)
203
+
204
+ for layer_idx in self.activations_sum:
205
+ self.activations_sum[layer_idx]["current"].fill(0)
206
+ self.activations_count[layer_idx]["current"] = 0
 
 
207
 
208
  def apply_contrastive_activations(self, path):
209
  if os.path.exists(path):
 
216
  contrastive_activations = won_activations - lost_activations
217
 
218
  def hook(module, input, output):
219
+ seq_len = output.shape[1]
220
+ return output + torch.from_numpy(contrastive_activations[:, :seq_len, :]).to(output.device)
221
 
222
  self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
223