HaileyStorm commited on
Commit
f728647
·
verified ·
1 Parent(s): 00aad6e

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -337,7 +337,6 @@ class MambaPlayer:
337
  y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
338
  if len(y) > 0:
339
  y_pred = self.linear_probes[layer_idx][probe_type](X)
340
- print(f'{y_pred}')
341
  loss = criterion(y_pred, y)
342
  for param_group in self.linear_optimizers[layer_idx][probe_type].param_groups:
343
  param_group['lr'] = lr
@@ -364,11 +363,11 @@ class MambaPlayer:
364
  self.move_num = board.fullmove_number
365
  bucket = next(b for b in self.move_buckets if self.move_num <= b)
366
  for layer_idx in self.linear_probes:
367
- X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float()
368
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
369
- target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
370
  probe = self.linear_probes[layer_idx][probe_type]
371
  probe.eval()
372
- prediction = probe(X)#.item()
373
  print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
374
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}
 
337
  y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
338
  if len(y) > 0:
339
  y_pred = self.linear_probes[layer_idx][probe_type](X)
 
340
  loss = criterion(y_pred, y)
341
  for param_group in self.linear_optimizers[layer_idx][probe_type].param_groups:
342
  param_group['lr'] = lr
 
363
  self.move_num = board.fullmove_number
364
  bucket = next(b for b in self.move_buckets if self.move_num <= b)
365
  for layer_idx in self.linear_probes:
366
+ X = torch.from_numpy(self.activations_sum[layer_idx][bucket]['current']).float().flatten(1)
367
  for probe_type in ['q_value', 'q_value_delta', 'material_balance']:
368
+ target = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().item()
369
  probe = self.linear_probes[layer_idx][probe_type]
370
  probe.eval()
371
+ prediction = probe(X).item()
372
  print(f"Layer {layer_idx}, {probe_type}: {prediction} vs {target}")
373
  self.linear_probe_targets = {i: {bucket: {'q_value': [], 'q_value_delta': [], 'material_balance': []} for bucket in self.move_buckets} for i in self.linear_probes}