Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -89,16 +89,20 @@ class MambaPlayer:
|
|
89 |
self.ctx = ctx
|
90 |
self.device = device
|
91 |
|
|
|
92 |
self.hooks = []
|
93 |
self.max_seq_len = 1536
|
|
|
94 |
|
95 |
self.activations_sum = {}
|
96 |
self.activations_count = {}
|
97 |
for i, layer in enumerate(self.model.backbone.layers):
|
98 |
-
self.activations_sum[i] = {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
102 |
|
103 |
def hook(module, input, output, layer_idx=i):
|
104 |
if isinstance(output, tuple):
|
@@ -106,8 +110,9 @@ class MambaPlayer:
|
|
106 |
else:
|
107 |
tensor_output = output
|
108 |
seq_len = tensor_output.shape[1]
|
109 |
-
self.
|
110 |
-
self.
|
|
|
111 |
|
112 |
self.hooks.append(layer.register_forward_hook(hook))
|
113 |
|
@@ -167,6 +172,7 @@ class MambaPlayer:
|
|
167 |
return None
|
168 |
|
169 |
def get_move(self, board: str, game_state: str, temperature: float) -> str:
|
|
|
170 |
completion = self.get_mamba_response(game_state, temperature, 8, self.vocab_size)
|
171 |
return self.get_move_from_response(completion)
|
172 |
|
@@ -175,9 +181,10 @@ class MambaPlayer:
|
|
175 |
|
176 |
def update_activations(self, result):
|
177 |
for layer_idx in self.activations_sum:
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
181 |
def save_activations(self, path):
|
182 |
if os.path.exists(path):
|
183 |
with open(path, "rb") as f:
|
@@ -187,38 +194,47 @@ class MambaPlayer:
|
|
187 |
activations_count = {}
|
188 |
|
189 |
for layer_idx in self.activations_sum:
|
190 |
-
for
|
|
|
|
|
191 |
if layer_idx not in activations_sum:
|
192 |
activations_sum[layer_idx] = {}
|
193 |
activations_count[layer_idx] = {}
|
194 |
-
if
|
195 |
-
activations_sum[layer_idx][
|
196 |
-
activations_count[layer_idx][
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
200 |
|
201 |
with open(path, "wb") as f:
|
202 |
pickle.dump((activations_sum, activations_count), f)
|
203 |
|
204 |
for layer_idx in self.activations_sum:
|
205 |
-
self.
|
206 |
-
|
|
|
207 |
|
208 |
def apply_contrastive_activations(self, path):
|
209 |
if os.path.exists(path):
|
210 |
with open(path, "rb") as f:
|
211 |
activations_sum, activations_count = pickle.load(f)
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
216 |
contrastive_activations = won_activations - lost_activations
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
223 |
|
224 |
|
|
|
89 |
self.ctx = ctx
|
90 |
self.device = device
|
91 |
|
92 |
+
self.move_num = 0
|
93 |
self.hooks = []
|
94 |
self.max_seq_len = 1536
|
95 |
+
self.move_buckets = [10, 20, 30, 40, float('inf')]
|
96 |
|
97 |
self.activations_sum = {}
|
98 |
self.activations_count = {}
|
99 |
for i, layer in enumerate(self.model.backbone.layers):
|
100 |
+
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
101 |
+
"lost": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
102 |
+
"current": np.zeros((1, self.max_seq_len, self.model.config.d_model))}
|
103 |
+
for bucket in self.move_buckets}
|
104 |
+
self.activations_count[i] = {bucket: {"won": 0, "lost": 0, "current": 0}
|
105 |
+
for bucket in self.move_buckets}
|
106 |
|
107 |
def hook(module, input, output, layer_idx=i):
|
108 |
if isinstance(output, tuple):
|
|
|
110 |
else:
|
111 |
tensor_output = output
|
112 |
seq_len = tensor_output.shape[1]
|
113 |
+
bucket = next(b for b in self.move_buckets if self.move_num <= b)
|
114 |
+
self.activations_sum[layer_idx][bucket]["current"][:, :seq_len, :] += tensor_output.detach().cpu().numpy()
|
115 |
+
self.activations_count[layer_idx][bucket]["current"] += 1
|
116 |
|
117 |
self.hooks.append(layer.register_forward_hook(hook))
|
118 |
|
|
|
172 |
return None
|
173 |
|
174 |
def get_move(self, board: str, game_state: str, temperature: float) -> str:
|
175 |
+
self.move_num = game_state.count('.')
|
176 |
completion = self.get_mamba_response(game_state, temperature, 8, self.vocab_size)
|
177 |
return self.get_move_from_response(completion)
|
178 |
|
|
|
181 |
|
182 |
def update_activations(self, result):
|
183 |
for layer_idx in self.activations_sum:
|
184 |
+
for bucket in self.move_buckets:
|
185 |
+
self.activations_sum[layer_idx][bucket][result] += self.activations_sum[layer_idx][bucket]["current"]
|
186 |
+
self.activations_count[layer_idx][bucket][result] += 1
|
187 |
+
|
188 |
def save_activations(self, path):
|
189 |
if os.path.exists(path):
|
190 |
with open(path, "rb") as f:
|
|
|
194 |
activations_count = {}
|
195 |
|
196 |
for layer_idx in self.activations_sum:
|
197 |
+
for bucket in self.move_buckets:
|
198 |
+
if self.activations_count[layer_idx][bucket]["current"] == 0:
|
199 |
+
continue
|
200 |
if layer_idx not in activations_sum:
|
201 |
activations_sum[layer_idx] = {}
|
202 |
activations_count[layer_idx] = {}
|
203 |
+
if bucket not in activations_sum[layer_idx]:
|
204 |
+
activations_sum[layer_idx][bucket] = {}
|
205 |
+
activations_count[layer_idx][bucket] = {}
|
206 |
+
for category in ["won", "lost"]:
|
207 |
+
if category not in activations_sum[layer_idx][bucket]:
|
208 |
+
activations_sum[layer_idx][bucket][category] = np.zeros((1, self.max_seq_len, self.model.config.d_model))
|
209 |
+
activations_count[layer_idx][bucket][category] = 0
|
210 |
+
|
211 |
+
activations_sum[layer_idx][bucket][category] += self.activations_sum[layer_idx][bucket][category]
|
212 |
+
activations_count[layer_idx][bucket][category] += self.activations_count[layer_idx][bucket][category]
|
213 |
|
214 |
with open(path, "wb") as f:
|
215 |
pickle.dump((activations_sum, activations_count), f)
|
216 |
|
217 |
for layer_idx in self.activations_sum:
|
218 |
+
for bucket in self.move_buckets:
|
219 |
+
self.activations_sum[layer_idx][bucket]["current"].fill(0)
|
220 |
+
self.activations_count[layer_idx][bucket]["current"] = 0
|
221 |
|
222 |
def apply_contrastive_activations(self, path):
|
223 |
if os.path.exists(path):
|
224 |
with open(path, "rb") as f:
|
225 |
activations_sum, activations_count = pickle.load(f)
|
226 |
|
227 |
+
def hook(module, input, output, layer_idx, bucket):
|
228 |
+
seq_len = output.shape[1]
|
229 |
+
won_activations = activations_sum[layer_idx][bucket]["won"] / activations_count[layer_idx][bucket]["won"]
|
230 |
+
lost_activations = activations_sum[layer_idx][bucket]["lost"] / activations_count[layer_idx][bucket]["lost"]
|
231 |
contrastive_activations = won_activations - lost_activations
|
232 |
+
return output + torch.from_numpy(contrastive_activations[:, :seq_len, :]).to(output.device)
|
233 |
+
|
234 |
+
for layer_idx in activations_sum:
|
235 |
+
for bucket in self.move_buckets:
|
236 |
+
self.hooks.append(self.model.backbone.layers[layer_idx].register_forward_hook(
|
237 |
+
lambda module, input, output, layer_idx=layer_idx, bucket=bucket: hook(module, input, output, layer_idx, bucket)
|
238 |
+
))
|
239 |
|
240 |
|