HaileyStorm commited on
Commit
7b1a352
1 Parent(s): f4a6bfa

Update chess-gpt-eval-contrastive/main.py

Browse files
Files changed (1) hide show
  1. chess-gpt-eval-contrastive/main.py +30 -4
chess-gpt-eval-contrastive/main.py CHANGED
@@ -625,6 +625,29 @@ def play_games(
625
  games_saved += 1
626
  else:
627
  print("Duplicate game; not saved.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
  if isinstance(player_one, StockfishPlayer):
629
  player_one.close()
630
  if isinstance(player_two, StockfishPlayer):
@@ -658,11 +681,12 @@ move_num_in_gamestate = False
658
  book_opening = True
659
  random_opening = True
660
  random_opening_moves = 10
 
 
 
661
  if __name__ == "__main__":
662
  for nanogpt_player in player_ones:
663
- i = 0
664
- for rm in [25]: #range(5, 25, 5):
665
- #for i in [0]: # [3] #range(11):
666
  num_games = 500
667
  # player_one = GPTPlayer(model="gpt-3.5-turbo-instruct")
668
  # player_one = LocalLlamaPlayer(model_name="meta-llama/Llama-2-7b-hf")
@@ -675,6 +699,8 @@ if __name__ == "__main__":
675
  #player_one_recording_name = f"xformer_rdm_{rm}"
676
  player_one = MambaPlayer(model_name=player_one_recording_name, move_num_in_gamestate=move_num_in_gamestate)
677
  player_one_recording_name = f"random_mamba_start/mamba_rdmstart_{rm}"
 
 
678
 
679
  #player_two = StockfishPlayer(skill_level=i, play_time=0.1)
680
  player_two = LC0PLayer(skill=i)
@@ -685,5 +711,5 @@ if __name__ == "__main__":
685
  #print(f"\n\nSTARTING GAMES AGAINST STOCKFISH LEVEL {i}\n\n")
686
  print(f"\n\nSTARTING GAMES AGAINST LC0 LEVEL {i}\n\n")
687
 
688
- play_games(player_one, player_two, num_games, book_opening=book_opening, random_opening=random_opening, random_opening_moves=random_opening_moves, random_move_start=rm)
689
 
 
625
  games_saved += 1
626
  else:
627
  print("Duplicate game; not saved.")
628
+
629
+ if update_activations:
630
+ if player_one_resignation or player_one_failed_to_find_legal_move:
631
+ player_one.update_activations("lost")
632
+ player_two.update_activations("won")
633
+ elif player_two_resignation or player_two_failed_to_find_legal_move:
634
+ player_one.update_activations("won")
635
+ player_two.update_activations("lost")
636
+ else:
637
+ if board.result() == "1-0":
638
+ player_one.update_activations("won")
639
+ player_two.update_activations("lost")
640
+ elif board.result() == "0-1":
641
+ player_one.update_activations("lost")
642
+ player_two.update_activations("won")
643
+ else: # Draw
644
+ player_one.update_activations("draw")
645
+ player_two.update_activations("draw")
646
+
647
+ if games_saved % contrastive_activation_save_interval == 0:
648
+ player_one.save_activations(activations_path)
649
+ player_two.save_activations(activations_path)
650
+
651
  if isinstance(player_one, StockfishPlayer):
652
  player_one.close()
653
  if isinstance(player_two, StockfishPlayer):
 
681
  book_opening = True
682
  random_opening = True
683
  random_opening_moves = 10
684
+ contrastive_activation_save_interval = 10
685
+ activations_path="activations.pkl"
686
+ update_activations = True # False = use them
687
  if __name__ == "__main__":
688
  for nanogpt_player in player_ones:
689
+ for i in [0]: # [3] #range(11):
 
 
690
  num_games = 500
691
  # player_one = GPTPlayer(model="gpt-3.5-turbo-instruct")
692
  # player_one = LocalLlamaPlayer(model_name="meta-llama/Llama-2-7b-hf")
 
699
  #player_one_recording_name = f"xformer_rdm_{rm}"
700
  player_one = MambaPlayer(model_name=player_one_recording_name, move_num_in_gamestate=move_num_in_gamestate)
701
  player_one_recording_name = f"random_mamba_start/mamba_rdmstart_{rm}"
702
+ if not update_activations:
703
+ player_one.apply_contrastive_activations()
704
 
705
  #player_two = StockfishPlayer(skill_level=i, play_time=0.1)
706
  player_two = LC0PLayer(skill=i)
 
711
  #print(f"\n\nSTARTING GAMES AGAINST STOCKFISH LEVEL {i}\n\n")
712
  print(f"\n\nSTARTING GAMES AGAINST LC0 LEVEL {i}\n\n")
713
 
714
+ play_games(player_one, player_two, num_games, book_opening=book_opening, random_opening=random_opening, random_opening_moves=random_opening_moves)
715