HaileyStorm commited on
Commit
0eaef6c
1 Parent(s): 1dd6f7d

Update chess-gpt-eval-contrastive/main.py

Browse files
Files changed (1) hide show
  1. chess-gpt-eval-contrastive/main.py +60 -1
chess-gpt-eval-contrastive/main.py CHANGED
@@ -395,6 +395,49 @@ def add_random_moves(
395
  return game_state, board, num_moves
396
 
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  # Return is (move_san, move_uci, attempts, is_resignation, is_illegal_move)
399
  def get_legal_move(
400
  player: Player,
@@ -554,6 +597,8 @@ def play_games(
554
  #print(f"|{game_state}|")
555
  #print(f"{current_move_num}", end=" ")
556
 
 
 
557
  (
558
  game_state,
559
  player_one_resignation,
@@ -574,6 +619,12 @@ def play_games(
574
  ):
575
  break
576
 
 
 
 
 
 
 
577
  (
578
  game_state,
579
  player_two_resignation,
@@ -639,6 +690,9 @@ def play_games(
639
 
640
  if games_saved % save_activations_every == 0:
641
  player_one.save_activations(activations_path)
 
 
 
642
  else:
643
  print("Duplicate game; not saved.")
644
 
@@ -675,11 +729,16 @@ move_num_in_gamestate = False
675
  book_opening = False
676
  random_opening = True
677
  random_opening_moves = 10
 
678
  activations_path="activations_rdm.pkl"
679
- update_activations = True
680
  apply_activations = False
681
  save_activations_every = 25
682
  contrastive_weight = 0.8
 
 
 
 
683
  if __name__ == "__main__":
684
  for nanogpt_player in player_ones:
685
  i = 0
 
395
  return game_state, board, num_moves
396
 
397
 
398
+ def evaluate_position(fen, backend):
399
+ gamestate = GameState(fen=fen)
400
+ result = backend.evaluate(gamestate.as_input(backend))[0]
401
+ return result.q()
402
+
403
+
404
+ def material_balance(board):
405
+ PV = {
406
+ 'pawn': 1,
407
+ 'knight': 3,
408
+ 'bishop': 3,
409
+ 'rook': 5,
410
+ 'queen': 9,
411
+ 'king': 0
412
+ }
413
+
414
+ if board.is_insufficient_material():
415
+ return 0
416
+
417
+ wp = len(board.pieces(chess.PAWN, chess.WHITE))
418
+ bp = len(board.pieces(chess.PAWN, chess.BLACK))
419
+
420
+ wn = len(board.pieces(chess.KNIGHT, chess.WHITE))
421
+ bn = len(board.pieces(chess.KNIGHT, chess.BLACK))
422
+
423
+ wb = len(board.pieces(chess.BISHOP, chess.WHITE))
424
+ bb = len(board.pieces(chess.BISHOP, chess.BLACK))
425
+
426
+ wr = len(board.pieces(chess.ROOK, chess.WHITE))
427
+ br = len(board.pieces(chess.ROOK, chess.BLACK))
428
+
429
+ wq = len(board.pieces(chess.QUEEN, chess.WHITE))
430
+ bq = len(board.pieces(chess.QUEEN, chess.BLACK))
431
+
432
+ return (
433
+ PV['pawn'] * (wp - bp) +
434
+ PV['knight'] * (wn - bn) +
435
+ PV['bishop'] * (wb - bb) +
436
+ PV['rook'] * (wr - br) +
437
+ PV['queen'] * (wq - bq)
438
+ )
439
+
440
+
441
  # Return is (move_san, move_uci, attempts, is_resignation, is_illegal_move)
442
  def get_legal_move(
443
  player: Player,
 
597
  #print(f"|{game_state}|")
598
  #print(f"{current_move_num}", end=" ")
599
 
600
+ if update_linear:
601
+ prev_q_value = evaluate_position(board.fen(), player_two.backend)
602
  (
603
  game_state,
604
  player_one_resignation,
 
619
  ):
620
  break
621
 
622
+ if update_linear:
623
+ curr_q_value = evaluate_position(board.fen(), player_two.backend)
624
+ q_value_delta = curr_q_value - prev_q_value
625
+ material_bal = material_balance(board)
626
+ player_one.update_linear_probe_targets(curr_q_value, q_value_delta, material_bal)
627
+
628
  (
629
  game_state,
630
  player_two_resignation,
 
690
 
691
  if games_saved % save_activations_every == 0:
692
  player_one.save_activations(activations_path)
693
+
694
+ if update_linear and games_saved % save_activations_every == 0:
695
+ player_one.save_linear_probe_data(linear_path)
696
  else:
697
  print("Duplicate game; not saved.")
698
 
 
729
  book_opening = False
730
  random_opening = True
731
  random_opening_moves = 10
732
+
733
  activations_path="activations_rdm.pkl"
734
+ update_activations = False
735
  apply_activations = False
736
  save_activations_every = 25
737
  contrastive_weight = 0.8
738
+
739
+ linear_path="linear.pkl"
740
+ update_linear = True
741
+ eval_linear = False
742
  if __name__ == "__main__":
743
  for nanogpt_player in player_ones:
744
  i = 0