HaileyStorm commited on
Commit
f4a6bfa
·
verified ·
1 Parent(s): 0955f14

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -87,6 +87,17 @@ class MambaPlayer:
87
  self.ctx = ctx
88
  self.device = device
89
 
 
 
 
 
 
 
 
 
 
 
 
90
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
91
  game_state = game_state.split("\n\n")[-1].strip()
92
  #game_state = ";" + game_state
@@ -149,3 +160,28 @@ class MambaPlayer:
149
  def get_config(self) -> dict:
150
  return {"model": self.model_name}
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  self.ctx = ctx
88
  self.device = device
89
 
90
+ self.activations = {}
91
+ self.hooks = []
92
+
93
+ for i, layer in enumerate(self.model.backbone.layers):
94
+ self.activations[i] = {"won": [], "lost": []}
95
+
96
+ def hook(module, input, output, layer_idx=i):
97
+ self.activations[layer_idx]["current"] = output.detach().cpu().numpy()
98
+
99
+ self.hooks.append(layer.register_forward_hook(hook))
100
+
101
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
102
  game_state = game_state.split("\n\n")[-1].strip()
103
  #game_state = ";" + game_state
 
160
  def get_config(self) -> dict:
161
  return {"model": self.model_name}
162
 
163
+ def update_activations(self, result):
164
+ for layer_idx in self.activations:
165
+ self.activations[layer_idx][result].append(self.activations[layer_idx]["current"])
166
+
167
+ def save_activations(self, path):
168
+ with open(path, "wb") as f:
169
+ pickle.dump(self.activations, f)
170
+
171
+ def load_activations(self, path):
172
+ if os.path.exists(path):
173
+ with open(path, "rb") as f:
174
+ self.activations = pickle.load(f)
175
+
176
+ def apply_contrastive_activations(self):
177
+ for layer_idx, layer_activations in self.activations.items():
178
+ if len(layer_activations["won"]) > 0 and len(layer_activations["lost"]) > 0:
179
+ won_activations = np.mean(layer_activations["won"], axis=0)
180
+ lost_activations = np.mean(layer_activations["lost"], axis=0)
181
+ contrastive_activations = won_activations - lost_activations
182
+
183
+ def hook(module, input, output):
184
+ return output + torch.from_numpy(contrastive_activations).to(output.device)
185
+
186
+ self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
187
+