HaileyStorm commited on
Commit
e8aba5c
·
verified ·
1 Parent(s): f2ce2e2

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -106,13 +106,6 @@ class MambaPlayer:
106
  self.linear_probes = torch.load(linear_probe_path)
107
  else:
108
  self.linear_probes = {}
109
- self.linear_optimizers = {
110
- layer_idx: {
111
- probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=lr)
112
- for probe_type in ['q_value', 'q_value_delta', 'material_balance']
113
- }
114
- for layer_idx in self.linear_probes
115
- }
116
  if update_contrastive or update_linear:
117
  for i, layer in enumerate(self.model.backbone.layers):
118
  self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
@@ -142,6 +135,13 @@ class MambaPlayer:
142
  }
143
  if update_linear:
144
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
 
 
 
 
 
 
 
145
 
146
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
147
  game_state = game_state.split("\n\n")[-1].strip()
 
106
  self.linear_probes = torch.load(linear_probe_path)
107
  else:
108
  self.linear_probes = {}
 
 
 
 
 
 
 
109
  if update_contrastive or update_linear:
110
  for i, layer in enumerate(self.model.backbone.layers):
111
  self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
 
135
  }
136
  if update_linear:
137
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
138
+ self.linear_optimizers = {
139
+ layer_idx: {
140
+ probe_type: optim.Adam(self.linear_probes[layer_idx][probe_type].parameters(), lr=lr)
141
+ for probe_type in ['q_value', 'q_value_delta', 'material_balance']
142
+ }
143
+ for layer_idx in self.linear_probes
144
+ }
145
 
146
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
147
  game_state = game_state.split("\n\n")[-1].strip()