HaileyStorm
commited on
Commit
•
6d717ee
1
Parent(s):
fbb8b5e
Update chess-gpt-eval-contrastive/main.py
Browse files
chess-gpt-eval-contrastive/main.py
CHANGED
@@ -597,7 +597,7 @@ def play_games(
|
|
597 |
#print(f"|{game_state}|")
|
598 |
#print(f"{current_move_num}", end=" ")
|
599 |
|
600 |
-
if update_linear:
|
601 |
prev_q_value = evaluate_position(board.fen(), player_two.backend)
|
602 |
(
|
603 |
game_state,
|
@@ -610,7 +610,7 @@ def play_games(
|
|
610 |
if illegal_moves_one != 0:
|
611 |
player_one_legal_moves -= 1
|
612 |
illegal_move_numbers.append(board.fullmove_number)
|
613 |
-
if update_activations or update_linear:
|
614 |
player_one.update_activations("current")
|
615 |
if (
|
616 |
board.is_game_over()
|
@@ -619,12 +619,15 @@ def play_games(
|
|
619 |
):
|
620 |
break
|
621 |
|
622 |
-
if update_linear:
|
623 |
curr_q_value = evaluate_position(board.fen(), player_two.backend)
|
624 |
q_value_delta = curr_q_value - prev_q_value
|
625 |
material_bal = material_balance(board)
|
626 |
player_one.update_linear_probe_targets(curr_q_value, q_value_delta, material_bal)
|
627 |
-
|
|
|
|
|
|
|
628 |
player_one.update_activations("reset")
|
629 |
|
630 |
(
|
@@ -744,8 +747,8 @@ save_activations_every = 25
|
|
744 |
contrastive_weight = 0.8
|
745 |
|
746 |
linear_path="linear.pkl"
|
747 |
-
update_linear =
|
748 |
-
eval_linear =
|
749 |
if __name__ == "__main__":
|
750 |
for nanogpt_player in player_ones:
|
751 |
i = 0
|
@@ -753,7 +756,7 @@ if __name__ == "__main__":
|
|
753 |
# for rm in range(5, 36, 5):
|
754 |
for i in [0]: # [3] #range(11):
|
755 |
# for wgt in [0.005, 0.01, 0.025, 0.05]:
|
756 |
-
num_games =
|
757 |
# player_one = GPTPlayer(model="gpt-3.5-turbo-instruct")
|
758 |
# player_one = LocalLlamaPlayer(model_name="meta-llama/Llama-2-7b-hf")
|
759 |
# player_one = LocalLoraLlamaPlayer("meta-llama/Llama-2-7b-hf", "/workspace/axolotl/lora2-out")
|
|
|
597 |
#print(f"|{game_state}|")
|
598 |
#print(f"{current_move_num}", end=" ")
|
599 |
|
600 |
+
if update_linear or eval_linear:
|
601 |
prev_q_value = evaluate_position(board.fen(), player_two.backend)
|
602 |
(
|
603 |
game_state,
|
|
|
610 |
if illegal_moves_one != 0:
|
611 |
player_one_legal_moves -= 1
|
612 |
illegal_move_numbers.append(board.fullmove_number)
|
613 |
+
if update_activations or update_linear or eval_linear:
|
614 |
player_one.update_activations("current")
|
615 |
if (
|
616 |
board.is_game_over()
|
|
|
619 |
):
|
620 |
break
|
621 |
|
622 |
+
if update_linear or eval_linear:
|
623 |
curr_q_value = evaluate_position(board.fen(), player_two.backend)
|
624 |
q_value_delta = curr_q_value - prev_q_value
|
625 |
material_bal = material_balance(board)
|
626 |
player_one.update_linear_probe_targets(curr_q_value, q_value_delta, material_bal)
|
627 |
+
if update_linear:
|
628 |
+
player_one.train_linear_probes()
|
629 |
+
if eval_linear:
|
630 |
+
player_one.evaluate_linear_probes(board)
|
631 |
player_one.update_activations("reset")
|
632 |
|
633 |
(
|
|
|
747 |
contrastive_weight = 0.8
|
748 |
|
749 |
linear_path="linear.pkl"
|
750 |
+
update_linear = False
|
751 |
+
eval_linear = True
|
752 |
if __name__ == "__main__":
|
753 |
for nanogpt_player in player_ones:
|
754 |
i = 0
|
|
|
756 |
# for rm in range(5, 36, 5):
|
757 |
for i in [0]: # [3] #range(11):
|
758 |
# for wgt in [0.005, 0.01, 0.025, 0.05]:
|
759 |
+
num_games = 500
|
760 |
# player_one = GPTPlayer(model="gpt-3.5-turbo-instruct")
|
761 |
# player_one = LocalLlamaPlayer(model_name="meta-llama/Llama-2-7b-hf")
|
762 |
# player_one = LocalLoraLlamaPlayer("meta-llama/Llama-2-7b-hf", "/workspace/axolotl/lora2-out")
|