HaileyStorm
commited on
Commit
•
d90c994
1
Parent(s):
627fa95
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -89,18 +89,21 @@ class MambaPlayer:
|
|
89 |
self.ctx = ctx
|
90 |
self.device = device
|
91 |
|
92 |
-
self.activations = {}
|
93 |
self.hooks = []
|
|
|
|
|
94 |
|
95 |
for i, layer in enumerate(self.model.backbone.layers):
|
96 |
-
self.
|
|
|
97 |
|
98 |
def hook(module, input, output, layer_idx=i):
|
99 |
if isinstance(output, tuple):
|
100 |
tensor_output = output[0]
|
101 |
else:
|
102 |
tensor_output = output
|
103 |
-
self.
|
|
|
104 |
|
105 |
self.hooks.append(layer.register_forward_hook(hook))
|
106 |
|
@@ -167,27 +170,37 @@ class MambaPlayer:
|
|
167 |
return {"model": self.model_name}
|
168 |
|
169 |
def update_activations(self, result):
|
170 |
-
for layer_idx in self.
|
171 |
-
self.
|
|
|
|
|
|
|
172 |
|
173 |
def save_activations(self, path):
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
176 |
|
177 |
-
for layer_idx
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
186 |
|
187 |
with open(path, "wb") as f:
|
188 |
pickle.dump((activations_sum, activations_count), f)
|
189 |
-
|
190 |
-
self.
|
|
|
191 |
|
192 |
def apply_contrastive_activations(self, path):
|
193 |
if os.path.exists(path):
|
@@ -203,16 +216,5 @@ class MambaPlayer:
|
|
203 |
return output + torch.from_numpy(contrastive_activations).to(output.device)
|
204 |
|
205 |
self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
|
206 |
-
|
207 |
-
def apply_contrastive_activations(self):
|
208 |
-
for layer_idx, layer_activations in self.activations.items():
|
209 |
-
if len(layer_activations["won"]) > 0 and len(layer_activations["lost"]) > 0:
|
210 |
-
won_activations = np.mean(layer_activations["won"], axis=0)
|
211 |
-
lost_activations = np.mean(layer_activations["lost"], axis=0)
|
212 |
-
contrastive_activations = won_activations - lost_activations
|
213 |
-
|
214 |
-
def hook(module, input, output):
|
215 |
-
return output + torch.from_numpy(contrastive_activations).to(output.device)
|
216 |
-
|
217 |
-
self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
|
218 |
|
|
|
89 |
self.ctx = ctx
|
90 |
self.device = device
|
91 |
|
|
|
92 |
self.hooks = []
|
93 |
+
self.activations_sum = {}
|
94 |
+
self.activations_count = {}
|
95 |
|
96 |
for i, layer in enumerate(self.model.backbone.layers):
|
97 |
+
self.activations_sum[i] = {"won": 0, "lost": 0}
|
98 |
+
self.activations_count[i] = {"won": 0, "lost": 0}
|
99 |
|
100 |
def hook(module, input, output, layer_idx=i):
|
101 |
if isinstance(output, tuple):
|
102 |
tensor_output = output[0]
|
103 |
else:
|
104 |
tensor_output = output
|
105 |
+
self.activations_sum[layer_idx]["current"] += tensor_output.detach().cpu().numpy()
|
106 |
+
self.activations_count[layer_idx]["current"] += 1
|
107 |
|
108 |
self.hooks.append(layer.register_forward_hook(hook))
|
109 |
|
|
|
170 |
return {"model": self.model_name}
|
171 |
|
172 |
def update_activations(self, result):
|
173 |
+
for layer_idx in self.activations_sum:
|
174 |
+
self.activations_sum[layer_idx][result] += self.activations_sum[layer_idx]["current"]
|
175 |
+
self.activations_count[layer_idx][result] += self.activations_count[layer_idx]["current"]
|
176 |
+
self.activations_sum[layer_idx]["current"] = 0
|
177 |
+
self.activations_count[layer_idx]["current"] = 0
|
178 |
|
179 |
def save_activations(self, path):
|
180 |
+
if os.path.exists(path):
|
181 |
+
with open(path, "rb") as f:
|
182 |
+
activations_sum, activations_count = pickle.load(f)
|
183 |
+
else:
|
184 |
+
activations_sum = {}
|
185 |
+
activations_count = {}
|
186 |
|
187 |
+
for layer_idx in self.activations_sum:
|
188 |
+
for category in ["won", "lost"]:
|
189 |
+
if layer_idx not in activations_sum:
|
190 |
+
activations_sum[layer_idx] = {}
|
191 |
+
activations_count[layer_idx] = {}
|
192 |
+
if category not in activations_sum[layer_idx]:
|
193 |
+
activations_sum[layer_idx][category] = 0
|
194 |
+
activations_count[layer_idx][category] = 0
|
195 |
+
|
196 |
+
activations_sum[layer_idx][category] += self.activations_sum[layer_idx][category]
|
197 |
+
activations_count[layer_idx][category] += self.activations_count[layer_idx][category]
|
198 |
|
199 |
with open(path, "wb") as f:
|
200 |
pickle.dump((activations_sum, activations_count), f)
|
201 |
+
|
202 |
+
self.activations_sum = {}
|
203 |
+
self.activations_count = {}
|
204 |
|
205 |
def apply_contrastive_activations(self, path):
|
206 |
if os.path.exists(path):
|
|
|
216 |
return output + torch.from_numpy(contrastive_activations).to(output.device)
|
217 |
|
218 |
self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
|
219 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|