HaileyStorm commited on
Commit
00aad6e
·
verified ·
1 Parent(s): 7136964

Update chess-gpt-eval-contrastive/mamba_module.py

Browse files
chess-gpt-eval-contrastive/mamba_module.py CHANGED
@@ -337,8 +337,7 @@ class MambaPlayer:
337
  y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
338
  if len(y) > 0:
339
  y_pred = self.linear_probes[layer_idx][probe_type](X)
340
- print(f'{X.shape} vs {y_pred.shape}')
341
- exit()
342
  loss = criterion(y_pred, y)
343
  for param_group in self.linear_optimizers[layer_idx][probe_type].param_groups:
344
  param_group['lr'] = lr
 
337
  y = torch.tensor(self.linear_probe_targets[layer_idx][bucket][probe_type]).float().unsqueeze(1)
338
  if len(y) > 0:
339
  y_pred = self.linear_probes[layer_idx][probe_type](X)
340
+ print(f'{y_pred}')
 
341
  loss = criterion(y_pred, y)
342
  for param_group in self.linear_optimizers[layer_idx][probe_type].param_groups:
343
  param_group['lr'] = lr