HaileyStorm commited on
Commit
9c9bdee
·
verified ·
1 Parent(s): de4b222

Update chess-gpt-eval-contrastive/main.py

Browse files
Files changed (1) hide show
  1. chess-gpt-eval-contrastive/main.py +7 -2
chess-gpt-eval-contrastive/main.py CHANGED
@@ -610,7 +610,7 @@ def play_games(
610
  if illegal_moves_one != 0:
611
  player_one_legal_moves -= 1
612
  illegal_move_numbers.append(board.fullmove_number)
613
- if update_activations:
614
  player_one.update_activations("current")
615
  if (
616
  board.is_game_over()
@@ -677,6 +677,9 @@ def play_games(
677
  )
678
  games_saved += 1
679
 
 
 
 
680
  if update_activations:
681
  if player_one_resignation or player_one_failed_to_find_legal_move:
682
  player_one.update_activations("lost")
@@ -690,6 +693,8 @@ def play_games(
690
 
691
  if games_saved % save_activations_every == 0:
692
  player_one.save_activations(activations_path)
 
 
693
 
694
  if update_linear and games_saved % save_activations_every == 0:
695
  player_one.save_linear_probe_data(linear_path)
@@ -756,7 +761,7 @@ if __name__ == "__main__":
756
  player_one_recording_name = nanogpt_player
757
  # player_one = NanoGptPlayer(model_name=player_one_recording_name, move_num_in_gamestate=move_num_in_gamestate)
758
  #player_one_recording_name = f"xformer_rdm_{rm}"
759
- player_one = MambaPlayer(model_name=player_one_recording_name, move_num_in_gamestate=move_num_in_gamestate)
760
  player_one_recording_name = f'contrastive_weights_rdm/mamba_rdm_wgt_{wgt}' #f'contrastive_rdm/mamba_rdm_{rm}'
761
  if apply_activations:
762
  player_one.apply_contrastive_activations(path=activations_path, weight=wgt)
 
610
  if illegal_moves_one != 0:
611
  player_one_legal_moves -= 1
612
  illegal_move_numbers.append(board.fullmove_number)
613
+ if update_activations or update_linear:
614
  player_one.update_activations("current")
615
  if (
616
  board.is_game_over()
 
677
  )
678
  games_saved += 1
679
 
680
+ if update_linear:
681
+ player_one.train_linear_probes()
682
+
683
  if update_activations:
684
  if player_one_resignation or player_one_failed_to_find_legal_move:
685
  player_one.update_activations("lost")
 
693
 
694
  if games_saved % save_activations_every == 0:
695
  player_one.save_activations(activations_path)
696
+ elif update_linear:
697
+ player_one.update_activations("reset")
698
 
699
  if update_linear and games_saved % save_activations_every == 0:
700
  player_one.save_linear_probe_data(linear_path)
 
761
  player_one_recording_name = nanogpt_player
762
  # player_one = NanoGptPlayer(model_name=player_one_recording_name, move_num_in_gamestate=move_num_in_gamestate)
763
  #player_one_recording_name = f"xformer_rdm_{rm}"
764
+ player_one = MambaPlayer(model_name=player_one_recording_name, move_num_in_gamestate=move_num_in_gamestate, update_contrastive=update_activations, update_linear=update_linear, linear_probe_path=linear_path)
765
  player_one_recording_name = f'contrastive_weights_rdm/mamba_rdm_wgt_{wgt}' #f'contrastive_rdm/mamba_rdm_{rm}'
766
  if apply_activations:
767
  player_one.apply_contrastive_activations(path=activations_path, weight=wgt)