HaileyStorm
commited on
Update chess-gpt-eval-contrastive/main.py
Browse files
chess-gpt-eval-contrastive/main.py
CHANGED
@@ -610,7 +610,7 @@ def play_games(
|
|
610 |
if illegal_moves_one != 0:
|
611 |
player_one_legal_moves -= 1
|
612 |
illegal_move_numbers.append(board.fullmove_number)
|
613 |
-
if update_activations:
|
614 |
player_one.update_activations("current")
|
615 |
if (
|
616 |
board.is_game_over()
|
@@ -677,6 +677,9 @@ def play_games(
|
|
677 |
)
|
678 |
games_saved += 1
|
679 |
|
|
|
|
|
|
|
680 |
if update_activations:
|
681 |
if player_one_resignation or player_one_failed_to_find_legal_move:
|
682 |
player_one.update_activations("lost")
|
@@ -690,6 +693,8 @@ def play_games(
|
|
690 |
|
691 |
if games_saved % save_activations_every == 0:
|
692 |
player_one.save_activations(activations_path)
|
|
|
|
|
693 |
|
694 |
if update_linear and games_saved % save_activations_every == 0:
|
695 |
player_one.save_linear_probe_data(linear_path)
|
@@ -756,7 +761,7 @@ if __name__ == "__main__":
|
|
756 |
player_one_recording_name = nanogpt_player
|
757 |
# player_one = NanoGptPlayer(model_name=player_one_recording_name, move_num_in_gamestate=move_num_in_gamestate)
|
758 |
#player_one_recording_name = f"xformer_rdm_{rm}"
|
759 |
-
player_one = MambaPlayer(model_name=player_one_recording_name, move_num_in_gamestate=move_num_in_gamestate)
|
760 |
player_one_recording_name = f'contrastive_weights_rdm/mamba_rdm_wgt_{wgt}' #f'contrastive_rdm/mamba_rdm_{rm}'
|
761 |
if apply_activations:
|
762 |
player_one.apply_contrastive_activations(path=activations_path, weight=wgt)
|
|
|
610 |
if illegal_moves_one != 0:
|
611 |
player_one_legal_moves -= 1
|
612 |
illegal_move_numbers.append(board.fullmove_number)
|
613 |
+
if update_activations or update_linear:
|
614 |
player_one.update_activations("current")
|
615 |
if (
|
616 |
board.is_game_over()
|
|
|
677 |
)
|
678 |
games_saved += 1
|
679 |
|
680 |
+
if update_linear:
|
681 |
+
player_one.train_linear_probes()
|
682 |
+
|
683 |
if update_activations:
|
684 |
if player_one_resignation or player_one_failed_to_find_legal_move:
|
685 |
player_one.update_activations("lost")
|
|
|
693 |
|
694 |
if games_saved % save_activations_every == 0:
|
695 |
player_one.save_activations(activations_path)
|
696 |
+
elif update_linear:
|
697 |
+
player_one.update_activations("reset")
|
698 |
|
699 |
if update_linear and games_saved % save_activations_every == 0:
|
700 |
player_one.save_linear_probe_data(linear_path)
|
|
|
761 |
player_one_recording_name = nanogpt_player
|
762 |
# player_one = NanoGptPlayer(model_name=player_one_recording_name, move_num_in_gamestate=move_num_in_gamestate)
|
763 |
#player_one_recording_name = f"xformer_rdm_{rm}"
|
764 |
+
player_one = MambaPlayer(model_name=player_one_recording_name, move_num_in_gamestate=move_num_in_gamestate, update_contrastive=update_activations, update_linear=update_linear, linear_probe_path=linear_path)
|
765 |
player_one_recording_name = f'contrastive_weights_rdm/mamba_rdm_wgt_{wgt}' #f'contrastive_rdm/mamba_rdm_{rm}'
|
766 |
if apply_activations:
|
767 |
player_one.apply_contrastive_activations(path=activations_path, weight=wgt)
|