HaileyStorm
commited on
Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -87,6 +87,17 @@ class MambaPlayer:
|
|
87 |
self.ctx = ctx
|
88 |
self.device = device
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
91 |
game_state = game_state.split("\n\n")[-1].strip()
|
92 |
#game_state = ";" + game_state
|
@@ -149,3 +160,28 @@ class MambaPlayer:
|
|
149 |
def get_config(self) -> dict:
|
150 |
return {"model": self.model_name}
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
self.ctx = ctx
|
88 |
self.device = device
|
89 |
|
90 |
+
self.activations = {}
|
91 |
+
self.hooks = []
|
92 |
+
|
93 |
+
for i, layer in enumerate(self.model.backbone.layers):
|
94 |
+
self.activations[i] = {"won": [], "lost": []}
|
95 |
+
|
96 |
+
def hook(module, input, output, layer_idx=i):
|
97 |
+
self.activations[layer_idx]["current"] = output.detach().cpu().numpy()
|
98 |
+
|
99 |
+
self.hooks.append(layer.register_forward_hook(hook))
|
100 |
+
|
101 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
102 |
game_state = game_state.split("\n\n")[-1].strip()
|
103 |
#game_state = ";" + game_state
|
|
|
160 |
def get_config(self) -> dict:
|
161 |
return {"model": self.model_name}
|
162 |
|
163 |
+
def update_activations(self, result):
|
164 |
+
for layer_idx in self.activations:
|
165 |
+
self.activations[layer_idx][result].append(self.activations[layer_idx]["current"])
|
166 |
+
|
167 |
+
def save_activations(self, path):
|
168 |
+
with open(path, "wb") as f:
|
169 |
+
pickle.dump(self.activations, f)
|
170 |
+
|
171 |
+
def load_activations(self, path):
|
172 |
+
if os.path.exists(path):
|
173 |
+
with open(path, "rb") as f:
|
174 |
+
self.activations = pickle.load(f)
|
175 |
+
|
176 |
+
def apply_contrastive_activations(self):
|
177 |
+
for layer_idx, layer_activations in self.activations.items():
|
178 |
+
if len(layer_activations["won"]) > 0 and len(layer_activations["lost"]) > 0:
|
179 |
+
won_activations = np.mean(layer_activations["won"], axis=0)
|
180 |
+
lost_activations = np.mean(layer_activations["lost"], axis=0)
|
181 |
+
contrastive_activations = won_activations - lost_activations
|
182 |
+
|
183 |
+
def hook(module, input, output):
|
184 |
+
return output + torch.from_numpy(contrastive_activations).to(output.device)
|
185 |
+
|
186 |
+
self.hooks[layer_idx] = self.model.backbone.layers[layer_idx].register_forward_hook(hook)
|
187 |
+
|