HaileyStorm commited on
Commit
d90c994
1 Parent(s): 627fa95

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -89,18 +89,21 @@ class MambaPlayer:
89
  self.ctx = ctx
90
  self.device = device
91
 
92
- self.activations = {}
93
  self.hooks = []
 
 
94
 
95
  for i, layer in enumerate(self.model.backbone.layers):
96
- self.activations[i] = {"won": [], "lost": []}
 
97
 
98
  def hook(module, input, output, layer_idx=i):
99
  if isinstance(output, tuple):
100
  tensor_output = output[0]
101
  else:
102
  tensor_output = output
103
- self.activations[layer_idx]["current"] = tensor_output.detach().cpu().numpy()
 
104
 
105
  self.hooks.append(layer.register_forward_hook(hook))
106
 
@@ -167,27 +170,37 @@ class MambaPlayer:
167
  return {"model": self.model_name}
168
 
169
  def update_activations(self, result):
170
- for layer_idx in self.activations:
171
- self.activations[layer_idx][result].append(self.activations[layer_idx]["current"])
 
 
 
172
 
173
  def save_activations(self, path):
174
- activations_sum = {}
175
- activations_count = {}
 
 
 
 
176
 
177
- for layer_idx, layer_activations in self.activations.items():
178
- activations_sum[layer_idx] = {
179
- "won": np.sum(layer_activations["won"], axis=0),
180
- "lost": np.sum(layer_activations["lost"], axis=0)
181
- }
182
- activations_count[layer_idx] = {
183
- "won": len(layer_activations["won"]),
184
- "lost": len(layer_activations["lost"])
185
- }
 
 
186
 
187
  with open(path, "wb") as f:
188
  pickle.dump((activations_sum, activations_count), f)
189
-
190
- self.activations = {}
 
191
 
192
  def apply_contrastive_activations(self, path):
193
  if os.path.exists(path):
@@ -203,16 +216,5 @@ class MambaPlayer:
203
  return output + torch.from_numpy(contrastive_activations).to(output.device)
204
 
205
  self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
206
-
207
- def apply_contrastive_activations(self):
208
- for layer_idx, layer_activations in self.activations.items():
209
- if len(layer_activations["won"]) > 0 and len(layer_activations["lost"]) > 0:
210
- won_activations = np.mean(layer_activations["won"], axis=0)
211
- lost_activations = np.mean(layer_activations["lost"], axis=0)
212
- contrastive_activations = won_activations - lost_activations
213
-
214
- def hook(module, input, output):
215
- return output + torch.from_numpy(contrastive_activations).to(output.device)
216
-
217
- self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
218
 
 
89
  self.ctx = ctx
90
  self.device = device
91
 
 
92
  self.hooks = []
93
+ self.activations_sum = {}
94
+ self.activations_count = {}
95
 
96
  for i, layer in enumerate(self.model.backbone.layers):
97
+ self.activations_sum[i] = {"won": 0, "lost": 0}
98
+ self.activations_count[i] = {"won": 0, "lost": 0}
99
 
100
  def hook(module, input, output, layer_idx=i):
101
  if isinstance(output, tuple):
102
  tensor_output = output[0]
103
  else:
104
  tensor_output = output
105
+ self.activations_sum[layer_idx]["current"] += tensor_output.detach().cpu().numpy()
106
+ self.activations_count[layer_idx]["current"] += 1
107
 
108
  self.hooks.append(layer.register_forward_hook(hook))
109
 
 
170
  return {"model": self.model_name}
171
 
172
  def update_activations(self, result):
173
+ for layer_idx in self.activations_sum:
174
+ self.activations_sum[layer_idx][result] += self.activations_sum[layer_idx]["current"]
175
+ self.activations_count[layer_idx][result] += self.activations_count[layer_idx]["current"]
176
+ self.activations_sum[layer_idx]["current"] = 0
177
+ self.activations_count[layer_idx]["current"] = 0
178
 
179
  def save_activations(self, path):
180
+ if os.path.exists(path):
181
+ with open(path, "rb") as f:
182
+ activations_sum, activations_count = pickle.load(f)
183
+ else:
184
+ activations_sum = {}
185
+ activations_count = {}
186
 
187
+ for layer_idx in self.activations_sum:
188
+ for category in ["won", "lost"]:
189
+ if layer_idx not in activations_sum:
190
+ activations_sum[layer_idx] = {}
191
+ activations_count[layer_idx] = {}
192
+ if category not in activations_sum[layer_idx]:
193
+ activations_sum[layer_idx][category] = 0
194
+ activations_count[layer_idx][category] = 0
195
+
196
+ activations_sum[layer_idx][category] += self.activations_sum[layer_idx][category]
197
+ activations_count[layer_idx][category] += self.activations_count[layer_idx][category]
198
 
199
  with open(path, "wb") as f:
200
  pickle.dump((activations_sum, activations_count), f)
201
+
202
+ self.activations_sum = {}
203
+ self.activations_count = {}
204
 
205
  def apply_contrastive_activations(self, path):
206
  if os.path.exists(path):
 
216
  return output + torch.from_numpy(contrastive_activations).to(output.device)
217
 
218
  self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
219
+
 
 
 
 
 
 
 
 
 
 
 
220