HaileyStorm commited on
Commit
94c3d48
·
verified ·
1 Parent(s): cdde761

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -160,18 +160,39 @@ class MambaPlayer:
160
  def get_config(self) -> dict:
161
  return {"model": self.model_name}
162
 
163
- def update_activations(self, result):
164
- for layer_idx in self.activations:
165
- self.activations[layer_idx][result].append(self.activations[layer_idx]["current"])
166
-
167
  def save_activations(self, path):
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  with open(path, "wb") as f:
169
- pickle.dump(self.activations, f)
 
 
170
 
171
- def load_activations(self, path):
172
  if os.path.exists(path):
173
  with open(path, "rb") as f:
174
- self.activations = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
175
 
176
  def apply_contrastive_activations(self):
177
  for layer_idx, layer_activations in self.activations.items():
 
160
  def get_config(self) -> dict:
161
  return {"model": self.model_name}
162
 
 
 
 
 
163
  def save_activations(self, path):
164
+ activations_sum = {}
165
+ activations_count = {}
166
+
167
+ for layer_idx, layer_activations in self.activations.items():
168
+ activations_sum[layer_idx] = {
169
+ "won": np.sum(layer_activations["won"], axis=0),
170
+ "lost": np.sum(layer_activations["lost"], axis=0)
171
+ }
172
+ activations_count[layer_idx] = {
173
+ "won": len(layer_activations["won"]),
174
+ "lost": len(layer_activations["lost"])
175
+ }
176
+
177
  with open(path, "wb") as f:
178
+ pickle.dump((activations_sum, activations_count), f)
179
+
180
+ self.activations = {}
181
 
182
+ def apply_contrastive_activations(self, path):
183
  if os.path.exists(path):
184
  with open(path, "rb") as f:
185
+ activations_sum, activations_count = pickle.load(f)
186
+
187
+ for layer_idx in activations_sum:
188
+ won_activations = activations_sum[layer_idx]["won"] / activations_count[layer_idx]["won"]
189
+ lost_activations = activations_sum[layer_idx]["lost"] / activations_count[layer_idx]["lost"]
190
+ contrastive_activations = won_activations - lost_activations
191
+
192
+ def hook(module, input, output):
193
+ return output + torch.from_numpy(contrastive_activations).to(output.device)
194
+
195
+ self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
196
 
197
  def apply_contrastive_activations(self):
198
  for layer_idx, layer_activations in self.activations.items():