HaileyStorm commited on
Commit
9ae57a2
1 Parent(s): f207752

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -7,6 +7,7 @@ from contextlib import nullcontext
7
  import numpy as np
8
  from functools import partial
9
  import chess
 
10
 
11
  BASE_DIR = "mamba/"
12
 
@@ -100,10 +101,10 @@ class MambaPlayer:
100
  self.activations_count = {}
101
  if update_linear:
102
  if linear_probe_path and os.path.exists(linear_probe_path):
103
- self.linear_probes = torch.load(linear_probe_data_path)
 
104
  else:
105
  self.linear_probes = {}
106
- self.linear_probe_targets = {}
107
  if update_contrastive or update_linear:
108
  for i, layer in enumerate(self.model.backbone.layers):
109
  self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
@@ -131,7 +132,8 @@ class MambaPlayer:
131
  'q_value_delta': torch.nn.Linear(self.model.config.d_model, 1),
132
  'material_balance': torch.nn.Linear(self.model.config.d_model, 1)
133
  }
134
- self.linear_probe_targets[i] = {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets}
 
135
 
136
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
137
  game_state = game_state.split("\n\n")[-1].strip()
@@ -307,16 +309,16 @@ class MambaPlayer:
307
  for bucket in self.move_buckets:
308
  if self.activations_count[layer_idx][bucket]['current'] > 0:
309
  X = self.activations_sum[layer_idx][bucket]['current'] / self.activations_count[layer_idx][bucket]['current']
310
- X = torch.from_numpy(X).float()
311
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
312
- y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
313
  self.linear_probes[layer_idx][probe_type].fit(X, y)
314
-
315
  # Reset linear_probe_targets after training
316
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
317
-
318
  def save_linear_probe_data(self, path):
319
- torch.save(self.linear_probes, path)
 
320
 
321
  def evaluate_linear_probes(self, board: chess.Board, game_state: str):
322
  self.move_num = game_state.count('.')
 
7
  import numpy as np
8
  from functools import partial
9
  import chess
10
+ from sklearn.linear_model import LinearRegression
11
 
12
  BASE_DIR = "mamba/"
13
 
 
101
  self.activations_count = {}
102
  if update_linear:
103
  if linear_probe_path and os.path.exists(linear_probe_path):
104
+ with open(linear_probe_path, 'rb') as f:
105
+ self.linear_probes = pickle.load(f)
106
  else:
107
  self.linear_probes = {}
 
108
  if update_contrastive or update_linear:
109
  for i, layer in enumerate(self.model.backbone.layers):
110
  self.activations_sum[i] = {bucket: {"won": np.zeros((1, self.max_seq_len, self.model.config.d_model)),
 
132
  'q_value_delta': torch.nn.Linear(self.model.config.d_model, 1),
133
  'material_balance': torch.nn.Linear(self.model.config.d_model, 1)
134
  }
135
+ if update_linear:
136
+ self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
137
 
138
  def get_mamba_response(self, game_state: str, temperature: float, max_new_tokens: int, top_k: int):
139
  game_state = game_state.split("\n\n")[-1].strip()
 
309
  for bucket in self.move_buckets:
310
  if self.activations_count[layer_idx][bucket]['current'] > 0:
311
  X = self.activations_sum[layer_idx][bucket]['current'] / self.activations_count[layer_idx][bucket]['current']
 
312
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
313
+ y = self.linear_probe_targets[layer_idx][bucket][probe_type]
314
  self.linear_probes[layer_idx][probe_type].fit(X, y)
315
+
316
  # Reset linear_probe_targets after training
317
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
318
+
319
  def save_linear_probe_data(self, path):
320
+ with open(path, 'wb') as f:
321
+ pickle.dump(self.linear_probes, f)
322
 
323
  def evaluate_linear_probes(self, board: chess.Board, game_state: str):
324
  self.move_num = game_state.count('.')