HaileyStorm commited on
Commit
a7153f4
1 Parent(s): 6d717ee

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -358,12 +358,13 @@ class MambaPlayer:
358
  }, step=self.wandb_step)
359
  torch.save(self.linear_probes, path)
360
 
361
- def evaluate_linear_probes(self, board: chess.Board, game_state: str):
362
- self.move_num = game_state.count('.')
363
  bucket = next(b for b in self.move_buckets if self.move_num <= b)
364
  for layer_idx in self.linear_probes:
365
- X = torch.cat(self.activations_sum[layer_idx][bucket]['current'], dim=0)
366
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
 
367
  probe = self.linear_probes[layer_idx][probe_type]
368
- prediction = probe(X)
369
- print(f"Layer {layer_idx}, {probe_type}: {prediction.item()}")
 
358
  }, step=self.wandb_step)
359
  torch.save(self.linear_probes, path)
360
 
361
+ def evaluate_linear_probes(self, board: chess.Board):
362
+ self.move_num = board.fullmove_number
363
  bucket = next(b for b in self.move_buckets if self.move_num <= b)
364
  for layer_idx in self.linear_probes:
365
+ X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float()
366
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
367
+ target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item() #.unsqueeze(1)
368
  probe = self.linear_probes[layer_idx][probe_type]
369
+ prediction = probe(X).item()
370
+ print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")