HaileyStorm
commited on
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -160,18 +160,39 @@ class MambaPlayer:
|
|
160 |
def get_config(self) -> dict:
|
161 |
return {"model": self.model_name}
|
162 |
|
163 |
-
def update_activations(self, result):
|
164 |
-
for layer_idx in self.activations:
|
165 |
-
self.activations[layer_idx][result].append(self.activations[layer_idx]["current"])
|
166 |
-
|
167 |
def save_activations(self, path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
with open(path, "wb") as f:
|
169 |
-
pickle.dump(
|
|
|
|
|
170 |
|
171 |
-
def
|
172 |
if os.path.exists(path):
|
173 |
with open(path, "rb") as f:
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
def apply_contrastive_activations(self):
|
177 |
for layer_idx, layer_activations in self.activations.items():
|
|
|
160 |
def get_config(self) -> dict:
|
161 |
return {"model": self.model_name}
|
162 |
|
|
|
|
|
|
|
|
|
163 |
def save_activations(self, path):
|
164 |
+
activations_sum = {}
|
165 |
+
activations_count = {}
|
166 |
+
|
167 |
+
for layer_idx, layer_activations in self.activations.items():
|
168 |
+
activations_sum[layer_idx] = {
|
169 |
+
"won": np.sum(layer_activations["won"], axis=0),
|
170 |
+
"lost": np.sum(layer_activations["lost"], axis=0)
|
171 |
+
}
|
172 |
+
activations_count[layer_idx] = {
|
173 |
+
"won": len(layer_activations["won"]),
|
174 |
+
"lost": len(layer_activations["lost"])
|
175 |
+
}
|
176 |
+
|
177 |
with open(path, "wb") as f:
|
178 |
+
pickle.dump((activations_sum, activations_count), f)
|
179 |
+
|
180 |
+
self.activations = {}
|
181 |
|
182 |
+
def apply_contrastive_activations(self, path):
|
183 |
if os.path.exists(path):
|
184 |
with open(path, "rb") as f:
|
185 |
+
activations_sum, activations_count = pickle.load(f)
|
186 |
+
|
187 |
+
for layer_idx in activations_sum:
|
188 |
+
won_activations = activations_sum[layer_idx]["won"] / activations_count[layer_idx]["won"]
|
189 |
+
lost_activations = activations_sum[layer_idx]["lost"] / activations_count[layer_idx]["lost"]
|
190 |
+
contrastive_activations = won_activations - lost_activations
|
191 |
+
|
192 |
+
def hook(module, input, output):
|
193 |
+
return output + torch.from_numpy(contrastive_activations).to(output.device)
|
194 |
+
|
195 |
+
self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
|
196 |
|
197 |
def apply_contrastive_activations(self):
|
198 |
for layer_idx, layer_activations in self.activations.items():
|