HaileyStorm
commited on
Commit
•
c72c7b7
1
Parent(s):
b3c0ce4
Update chess-gpt-eval/main.py
Browse files- chess-gpt-eval/main.py +52 -3
chess-gpt-eval/main.py
CHANGED
@@ -173,6 +173,31 @@ def get_move_from_gpt_response(response: Optional[str]) -> Optional[str]:
|
|
173 |
return first_move
|
174 |
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
def record_results(
|
177 |
board: chess.Board,
|
178 |
player_one: Player,
|
@@ -216,8 +241,8 @@ def record_results(
|
|
216 |
# resignation / failed move situation I didn't think of
|
217 |
# -1e10 at least ensures it doesn't fail silently
|
218 |
if "-" in result:
|
219 |
-
player_one_score = result.split("-")[0]
|
220 |
-
player_two_score = result.split("-")[1]
|
221 |
elif result == "*": # Loss due to hitting max moves
|
222 |
player_one_score = 0
|
223 |
player_two_score = 1
|
@@ -553,6 +578,7 @@ def play_games(
|
|
553 |
print(f"Result: {board.result()}")
|
554 |
print(board)
|
555 |
print()
|
|
|
556 |
if game_transcript not in unique_games:
|
557 |
unique_games.add(game_transcript)
|
558 |
record_results(
|
@@ -575,6 +601,7 @@ def play_games(
|
|
575 |
opening_moves,
|
576 |
illegal_move_numbers
|
577 |
)
|
|
|
578 |
else:
|
579 |
print("Duplicate game; not saved.")
|
580 |
if isinstance(player_one, StockfishPlayer):
|
@@ -582,7 +609,17 @@ def play_games(
|
|
582 |
if isinstance(player_two, StockfishPlayer):
|
583 |
player_two.close()
|
584 |
|
585 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
|
587 |
|
588 |
RUN_FOR_ANALYSIS = True
|
@@ -615,4 +652,16 @@ if __name__ == "__main__":
|
|
615 |
print(f"\n\nSTARTING GAMES AGAINST LC0 LEVEL {i}\n\n")
|
616 |
|
617 |
play_games(player_one, player_two, num_games, book_opening=book_opening, random_opening=random_opening, random_opening_moves=random_opening_moves)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
618 |
print("\n\n\n********\nDONE!\n********\n\n\n")
|
|
|
173 |
return first_move
|
174 |
|
175 |
|
176 |
+
def calculate_stats(csv_file_path):
|
177 |
+
data = []
|
178 |
+
with open(csv_file_path, "r") as csv_file:
|
179 |
+
reader = csv.DictReader(csv_file)
|
180 |
+
data = list(reader)
|
181 |
+
|
182 |
+
if not data:
|
183 |
+
return None
|
184 |
+
|
185 |
+
stats = {
|
186 |
+
"wins": sum(float(row["player_one_score"]) for row in data if row["player_one_score"] == 1),
|
187 |
+
"draws": len(data) - sum(float(row["player_two_score"]) for row in data if row["player_one_score"] == 1) - sum(float(row["player_one_score"]) for row in data if row["player_one_score"] == 1),
|
188 |
+
"illegal_attempts_ratio": sum(float(row["p1_illegal_attempts"]) for row in data) / (sum(float(row["p1_illegal_attempts"]) for row in data) + sum(float(row["player_one_legal_moves"]) for row in data)),
|
189 |
+
"illegal_moves_ratio": sum(float(row["player_one_illegal_moves"]) for row in data) / sum(float(row["player_one_illegal_moves"]) + float(row["player_one_legal_moves"]) for row in data),
|
190 |
+
"avg_attempts_per_illegal": sum(float(row["p1_avg_attempts_per_illegal"]) for row in data) / len(data),
|
191 |
+
"avg_first_illegal_move": sum(float(row["p1_first_illegal_move_num"]) for row in data if float(row["p1_first_illegal_move_num"]) > 0) / len([row for row in data if float(row["p1_first_illegal_move_num"]) > 0]),
|
192 |
+
"avg_illegal_move_num": sum(float(row["p1_avg_illegal_move_num"]) for row in data if float(row["p1_avg_illegal_move_num"]) > 0) / len([row for row in data if float(row["p1_avg_illegal_move_num"]) > 0]),
|
193 |
+
"lost_to_illegal_ratio": len([row for row in data if row["player_one_failed_to_find_legal_move"] == "True"]) / len([row for row in data if float(row["number_of_moves"]) > 0]),
|
194 |
+
"avg_game_length": sum(float(row["number_of_moves"]) for row in data) / len(data),
|
195 |
+
"max_game_length": max(float(row["number_of_moves"]) for row in data),
|
196 |
+
}
|
197 |
+
|
198 |
+
return stats
|
199 |
+
|
200 |
+
|
201 |
def record_results(
|
202 |
board: chess.Board,
|
203 |
player_one: Player,
|
|
|
241 |
# resignation / failed move situation I didn't think of
|
242 |
# -1e10 at least ensures it doesn't fail silently
|
243 |
if "-" in result:
|
244 |
+
player_one_score = float(result.split("-")[0])
|
245 |
+
player_two_score = float(result.split("-")[1])
|
246 |
elif result == "*": # Loss due to hitting max moves
|
247 |
player_one_score = 0
|
248 |
player_two_score = 1
|
|
|
578 |
print(f"Result: {board.result()}")
|
579 |
print(board)
|
580 |
print()
|
581 |
+
game_transcript = game_state.strip()
|
582 |
if game_transcript not in unique_games:
|
583 |
unique_games.add(game_transcript)
|
584 |
record_results(
|
|
|
601 |
opening_moves,
|
602 |
illegal_move_numbers
|
603 |
)
|
604 |
+
games_saved += 1
|
605 |
else:
|
606 |
print("Duplicate game; not saved.")
|
607 |
if isinstance(player_one, StockfishPlayer):
|
|
|
609 |
if isinstance(player_two, StockfishPlayer):
|
610 |
player_two.close()
|
611 |
|
612 |
+
stats = calculate_stats(csv_file_path)
|
613 |
+
if stats:
|
614 |
+
print("\nStatistics:")
|
615 |
+
for key, value in stats.items():
|
616 |
+
print(f"{key}: {value}")
|
617 |
+
|
618 |
+
with open(csv_file_path, "a") as csv_file:
|
619 |
+
writer = csv.writer(csv_file)
|
620 |
+
writer.writerow([""] * len(info_dict)) # Add empty cells for existing columns
|
621 |
+
writer.writerow(list(stats.keys()))
|
622 |
+
writer.writerow(list(stats.values()))
|
623 |
|
624 |
|
625 |
RUN_FOR_ANALYSIS = True
|
|
|
652 |
print(f"\n\nSTARTING GAMES AGAINST LC0 LEVEL {i}\n\n")
|
653 |
|
654 |
play_games(player_one, player_two, num_games, book_opening=book_opening, random_opening=random_opening, random_opening_moves=random_opening_moves)
|
655 |
+
|
656 |
+
print("\n\n\n********\nFinal Statistics:\n********\n")
|
657 |
+
for nanogpt_player in player_ones:
|
658 |
+
csv_file_path = f"logs/{player_one_recording_name}_vs_{player_two_recording_name}"
|
659 |
+
csv_file_path = csv_file_path.replace(".", "_") # Because I'm using ckpt filenames for nanogpt models
|
660 |
+
csv_file_path += ".csv"
|
661 |
+
|
662 |
+
stats = calculate_stats(csv_file_path)
|
663 |
+
if stats:
|
664 |
+
print(f"\nStatistics for {nanogpt_player}:")
|
665 |
+
for key, value in stats.items():
|
666 |
+
print(f"{key}: {value}")
|
667 |
print("\n\n\n********\nDONE!\n********\n\n\n")
|