Update chess-gpt-eval-contrastive/mamba_module.py
Browse files
chess-gpt-eval-contrastive/mamba_module.py
CHANGED
@@ -106,13 +106,6 @@ class MambaPlayer:
|
|
106 |
self.linear_probes = torch.load(linear_probe_path)
|
107 |
else:
|
108 |
self.linear_probes = {}
|
109 |
-
self.linear_optimizers = {
|
110 |
-
layer_idx: {
|
111 |
-
probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=lr)
|
112 |
-
for probe_type in ['q_value', 'q_value_delta', 'material_balance']
|
113 |
-
}
|
114 |
-
for layer_idx in self.linear_probes
|
115 |
-
}
|
116 |
if update_contrastive or update_linear:
|
117 |
for i, layer in enumerate(self.model.backbone.layers):
|
118 |
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
@@ -142,6 +135,13 @@ class MambaPlayer:
|
|
142 |
}
|
143 |
if update_linear:
|
144 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
147 |
game_state = game_state.split("\n\n")[-1].strip()
|
|
|
106 |
self.linear_probes = torch.load(linear_probe_path)
|
107 |
else:
|
108 |
self.linear_probes = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
if update_contrastive or update_linear:
|
110 |
for i, layer in enumerate(self.model.backbone.layers):
|
111 |
self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
|
|
|
135 |
}
|
136 |
if update_linear:
|
137 |
self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
|
138 |
+
self.linear_optimizers = {
|
139 |
+
layer_idx: {
|
140 |
+
probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=lr)
|
141 |
+
for probe_type in ['q_value', 'q_value_delta', 'material_balance']
|
142 |
+
}
|
143 |
+
for layer_idx in self.linear_probes
|
144 |
+
}
|
145 |
|
146 |
def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
|
147 |
game_state = game_state.split("\n\n")[-1].strip()
|