HaileyStorm commited on
Commit
6d717ee
1 Parent(s): fbb8b5e

Update chess-gpt-eval-contrastive/main.py

Browse files
Files changed (1) hide show
  1. chess-gpt-eval-contrastive/main.py +10 -7
chess-gpt-eval-contrastive/main.py CHANGED
@@ -597,7 +597,7 @@ def play_games(
597
  #print(f"|{game_state}|")
598
  #print(f"{current_move_num}", end=" ")
599
 
600
- if update_linear:
601
  prev_q_value = evaluate_position(board.fen(), player_two.backend)
602
  (
603
  game_state,
@@ -610,7 +610,7 @@ def play_games(
610
  if illegal_moves_one != 0:
611
  player_one_legal_moves -= 1
612
  illegal_move_numbers.append(board.fullmove_number)
613
- if update_activations or update_linear:
614
  player_one.update_activations("current")
615
  if (
616
  board.is_game_over()
@@ -619,12 +619,15 @@ def play_games(
619
  ):
620
  break
621
 
622
- if update_linear:
623
  curr_q_value = evaluate_position(board.fen(), player_two.backend)
624
  q_value_delta = curr_q_value - prev_q_value
625
  material_bal = material_balance(board)
626
  player_one.update_linear_probe_targets(curr_q_value, q_value_delta, material_bal)
627
- player_one.train_linear_probes()
 
 
 
628
  player_one.update_activations("reset")
629
 
630
  (
@@ -744,8 +747,8 @@ save_activations_every = 25
744
  contrastive_weight = 0.8
745
 
746
  linear_path="linear.pkl"
747
- update_linear = True
748
- eval_linear = False
749
  if __name__ == "__main__":
750
  for nanogpt_player in player_ones:
751
  i = 0
@@ -753,7 +756,7 @@ if __name__ == "__main__":
753
  # for rm in range(5, 36, 5):
754
  for i in [0]: # [3] #range(11):
755
  # for wgt in [0.005, 0.01, 0.025, 0.05]:
756
- num_games = 5000
757
  # player_one = GPTPlayer(model="gpt-3.5-turbo-instruct")
758
  # player_one = LocalLlamaPlayer(model_name="meta-llama/Llama-2-7b-hf")
759
  # player_one = LocalLoraLlamaPlayer("meta-llama/Llama-2-7b-hf", "/workspace/axolotl/lora2-out")
 
597
  #print(f"|{game_state}|")
598
  #print(f"{current_move_num}", end=" ")
599
 
600
+ if update_linear or eval_linear:
601
  prev_q_value = evaluate_position(board.fen(), player_two.backend)
602
  (
603
  game_state,
 
610
  if illegal_moves_one != 0:
611
  player_one_legal_moves -= 1
612
  illegal_move_numbers.append(board.fullmove_number)
613
+ if update_activations or update_linear or eval_linear:
614
  player_one.update_activations("current")
615
  if (
616
  board.is_game_over()
 
619
  ):
620
  break
621
 
622
+ if update_linear or eval_linear:
623
  curr_q_value = evaluate_position(board.fen(), player_two.backend)
624
  q_value_delta = curr_q_value - prev_q_value
625
  material_bal = material_balance(board)
626
  player_one.update_linear_probe_targets(curr_q_value, q_value_delta, material_bal)
627
+ if update_linear:
628
+ player_one.train_linear_probes()
629
+ if eval_linear:
630
+ player_one.evaluate_linear_probes(board)
631
  player_one.update_activations("reset")
632
 
633
  (
 
747
  contrastive_weight = 0.8
748
 
749
  linear_path="linear.pkl"
750
+ update_linear = False
751
+ eval_linear = True
752
  if __name__ == "__main__":
753
  for nanogpt_player in player_ones:
754
  i = 0
 
756
  # for rm in range(5, 36, 5):
757
  for i in [0]: # [3] #range(11):
758
  # for wgt in [0.005, 0.01, 0.025, 0.05]:
759
+ num_games = 500
760
  # player_one = GPTPlayer(model="gpt-3.5-turbo-instruct")
761
  # player_one = LocalLlamaPlayer(model_name="meta-llama/Llama-2-7b-hf")
762
  # player_one = LocalLoraLlamaPlayer("meta-llama/Llama-2-7b-hf", "/workspace/axolotl/lora2-out")