|
import openai |
|
import chess |
|
import chess.engine |
|
import os |
|
import csv |
|
import random |
|
import time |
|
import platform |
|
|
|
|
|
|
|
from nanogpt.nanogpt_module import NanoGptPlayer |
|
from mamba_module import MambaPlayer |
|
import gpt_query |
|
from lczero.backends import Weights, Backend, GameState |
|
import numpy as np |
|
|
|
from typing import Optional, Tuple |
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
class LegalMoveResponse: |
|
move_san: Optional[str] = None |
|
move_uci: Optional[chess.Move] = None |
|
attempts: int = 0 |
|
is_resignation: bool = False |
|
is_illegal_move: bool = False |
|
|
|
|
|
|
|
class Player: |
|
def get_move(self, board: chess.Board, game_state: str, temperature: float) -> str: |
|
raise NotImplementedError |
|
|
|
def get_config(self) -> dict: |
|
raise NotImplementedError |
|
|
|
|
|
class GPTPlayer(Player): |
|
def __init__(self, model: str): |
|
with open("gpt_inputs/api_key.txt", "r") as f: |
|
openai.api_key = f.read().strip() |
|
self.model = model |
|
|
|
def get_move( |
|
self, board: chess.Board, game_state: str, temperature: float |
|
) -> Optional[str]: |
|
response = get_gpt_response(game_state, self.model, temperature) |
|
return get_move_from_gpt_response(response) |
|
|
|
def get_config(self) -> dict: |
|
return {"model": self.model} |
|
|
|
|
|
class LC0PLayer(Player): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, skill): |
|
self.skill = skill |
|
network_paths = ["./lc0/build/release/11258-32x4-se.pb.gz", "./lc0/build/release/11258-48x5-se.pb.gz", "./lc0/build/release/11258-80x7-se.pb.gz", "./lc0/build/release/11258-104x9-se.pb.gz", "./lc0/build/release/TK-6430 aka 128x10-BPR-64M-6430000.pb.gz", "./lc0/build/release/00af53b081e80147172e6f281c01daf5ca19ada173321438914c730370aa4267", "./lc0/build/release/b2ec465d0fb5b5eb39d2e1e3f74041a5d2fc92d413b71aa7ea0b6fb082ccba9c"] |
|
print(f"\n\nLoading lc0 network: {network_paths[skill]}\n\n") |
|
self.weights = Weights(network_paths[skill]) |
|
self.backend = Backend(weights=self.weights) |
|
self.gamestate = GameState() |
|
|
|
def get_move(self, board: chess.Board, game_state: str, temperature: float): |
|
self.gamestate = GameState(fen=board.fen()) |
|
input_planes = self.gamestate.as_input(self.backend) |
|
result = self.backend.evaluate(input_planes)[0] |
|
moves = self.gamestate.moves() |
|
policy_indices = self.gamestate.policy_indices() |
|
move_probs = np.array(result.p_softmax(*policy_indices)) |
|
best_move_idx = move_probs.argmax() |
|
best_move = moves[best_move_idx] |
|
return board.san(chess.Move.from_uci(best_move)) |
|
|
|
def get_config(self) -> dict: |
|
return {"network": self.weights, "skill_level": self.skill, "play_time": 0} |
|
|
|
|
|
class StockfishPlayer(Player): |
|
|
|
@staticmethod |
|
def get_stockfish_path() -> str: |
|
""" |
|
Determines the operating system and returns the appropriate path for Stockfish. |
|
|
|
Returns: |
|
str: Path to the Stockfish executable based on the operating system. |
|
""" |
|
if platform.system() == 'Linux': |
|
return "/usr/games/stockfish" |
|
elif platform.system() == 'Darwin': |
|
return "stockfish" |
|
elif platform.system() == 'Windows': |
|
return r"C:\Users\Haile\Downloads\stockfish\stockfish-windows-x86-64-avx2.exe" |
|
else: |
|
raise OSError("Unsupported operating system") |
|
|
|
def __init__(self, skill_level: int, play_time: float): |
|
self._skill_level = skill_level |
|
self._play_time = play_time |
|
|
|
stockfish_path = StockfishPlayer.get_stockfish_path() |
|
self._engine = chess.engine.SimpleEngine.popen_uci(stockfish_path) |
|
|
|
def get_move( |
|
self, board: chess.Board, game_state: str, temperature: float |
|
) -> Optional[str]: |
|
if self._skill_level == -2: |
|
legal_moves = list(board.legal_moves) |
|
random_move = random.choice(legal_moves) |
|
return board.san(random_move) |
|
elif self._skill_level < 0: |
|
self._engine.configure({"Skill Level": 0}) |
|
result = self._engine.play( |
|
board, chess.engine.Limit(time=1e-8, depth=1, nodes=1) |
|
) |
|
|
|
else: |
|
self._engine.configure({"Skill Level": self._skill_level}) |
|
result = self._engine.play(board, chess.engine.Limit(time=self._play_time)) |
|
if result.move is None: |
|
return None |
|
return board.san(result.move) |
|
|
|
def get_config(self) -> dict: |
|
return {"skill_level": self._skill_level, "play_time": self._play_time} |
|
|
|
def close(self): |
|
self._engine.quit() |
|
|
|
|
|
class HumanPlayer(Player): |
|
def get_move(self, board: chess.Board, game_state: str, temperature: float) -> str: |
|
|
|
print(board) |
|
while True: |
|
move = input("Enter your move (SAN format): ") |
|
try: |
|
move_uci = board.parse_san(move) |
|
if move_uci in board.legal_moves: |
|
return move |
|
except: |
|
print("Illegal move, try again.") |
|
|
|
def get_config(self) -> dict: |
|
return {"player": "human"} |
|
|
|
|
|
def get_gpt_response(game_state: str, model: str, temperature: float) -> Optional[str]: |
|
|
|
if model == "gpt-4": |
|
time.sleep(0.4) |
|
response = gpt_query.get_gpt_response(game_state, model, temperature) |
|
return response |
|
|
|
|
|
def get_move_from_gpt_response(response: Optional[str]) -> Optional[str]: |
|
if response is None: |
|
return None |
|
|
|
|
|
moves = response.split() |
|
first_move = moves[0] if moves else None |
|
|
|
return first_move |
|
|
|
|
|
def record_results( |
|
board: chess.Board, |
|
player_one: Player, |
|
player_two: Player, |
|
game_state: str, |
|
player_one_illegal_moves: int, |
|
player_two_illegal_moves: int, |
|
player_one_legal_moves: int, |
|
player_two_legal_moves: int, |
|
total_time: float, |
|
player_one_resignation: bool, |
|
player_two_resignation: bool, |
|
player_one_failed_to_find_legal_move: bool, |
|
player_two_failed_to_find_legal_move: bool, |
|
total_moves: int, |
|
illegal_moves: int, |
|
): |
|
unique_game_id = generate_unique_game_id() |
|
|
|
( |
|
player_one_title, |
|
player_two_title, |
|
player_one_time, |
|
player_two_time, |
|
) = get_player_titles_and_time(player_one, player_two) |
|
|
|
if player_one_resignation or player_one_failed_to_find_legal_move: |
|
result = "0-1" |
|
player_one_score = 0 |
|
player_two_score = 1 |
|
elif player_two_resignation or player_two_failed_to_find_legal_move: |
|
result = "1-0" |
|
player_one_score = 1 |
|
player_two_score = 0 |
|
else: |
|
result = board.result() |
|
|
|
|
|
|
|
if "-" in result: |
|
player_one_score = result.split("-")[0] |
|
player_two_score = result.split("-")[1] |
|
elif result == "*": |
|
player_one_score = 0 |
|
player_two_score = 1 |
|
else: |
|
player_one_score = -1e10 |
|
player_two_score = -1e10 |
|
|
|
info_dict = { |
|
"game_id": unique_game_id, |
|
"transcript": game_state, |
|
"result": result, |
|
"player_one": player_one_title, |
|
"player_two": player_two_title, |
|
"player_one_time": player_one_time, |
|
"player_two_time": player_two_time, |
|
"player_one_score": player_one_score, |
|
"player_two_score": player_two_score, |
|
"player_one_illegal_moves": player_one_illegal_moves, |
|
"player_two_illegal_moves": player_two_illegal_moves, |
|
"player_one_legal_moves": player_one_legal_moves, |
|
"player_two_legal_moves": player_two_legal_moves, |
|
"player_one_resignation": player_one_resignation, |
|
"player_two_resignation": player_two_resignation, |
|
"player_one_failed_to_find_legal_move": player_one_failed_to_find_legal_move, |
|
"player_two_failed_to_find_legal_move": player_two_failed_to_find_legal_move, |
|
"game_title": f"{player_one_title} vs. {player_two_title}", |
|
"number_of_moves": board.fullmove_number, |
|
"time_taken": total_time, |
|
"total_moves": total_moves, |
|
"illegal_moves": illegal_moves, |
|
} |
|
|
|
if RUN_FOR_ANALYSIS: |
|
csv_file_path = f"logs/{player_one_recording_name}_vs_{player_two_recording_name}" |
|
csv_file_path = csv_file_path.replace(".", "_") |
|
csv_file_path += ".csv" |
|
else: |
|
csv_file_path = recording_file |
|
|
|
|
|
|
|
|
|
write_headers = not os.path.exists(csv_file_path) |
|
|
|
|
|
os.makedirs(os.path.dirname(csv_file_path), exist_ok=True) |
|
with open(csv_file_path, "a", newline="") as csv_file: |
|
writer = csv.DictWriter(csv_file, fieldnames=info_dict.keys()) |
|
if write_headers: |
|
writer.writeheader() |
|
writer.writerow(info_dict) |
|
|
|
with open("game.txt", "w") as f: |
|
f.write(game_state) |
|
|
|
|
|
def generate_unique_game_id() -> str: |
|
timestamp = int(time.time()) |
|
random_num = random.randint(1000, 9999) |
|
return f"{timestamp}-{random_num}" |
|
|
|
|
|
def get_player_titles_and_time( |
|
player_one: Player, player_two: Player |
|
) -> Tuple[str, str, Optional[float], Optional[float]]: |
|
player_one_config = player_one.get_config() |
|
player_two_config = player_two.get_config() |
|
|
|
|
|
if "model" in player_one_config: |
|
player_one_title = player_one_config["model"] |
|
player_one_time = None |
|
else: |
|
player_one_title = f"Stockfish {player_one_config['skill_level']}" |
|
player_one_time = player_one_config["play_time"] |
|
|
|
|
|
if "model" in player_two_config: |
|
player_two_title = player_two_config["model"] |
|
player_two_time = None |
|
else: |
|
player_two_title = f"Stockfish {player_two_config['skill_level']}" |
|
player_two_time = player_two_config["play_time"] |
|
|
|
return (player_one_title, player_two_title, player_one_time, player_two_time) |
|
|
|
|
|
used_openings = [] |
|
def initialize_game_with_opening( |
|
game_state: str, board: chess.Board |
|
) -> Tuple[str, chess.Board]: |
|
global used_openings |
|
with open("openings.csv", "r") as file: |
|
lines = file.readlines()[1:] |
|
moves_string = random.choice(lines) |
|
while moves_string in used_openings: |
|
moves_string = random.choice(lines) |
|
used_openings.append(moves_string) |
|
if move_num_in_gamestate: |
|
game_state = moves_string.rstrip() + " " |
|
else: |
|
game_state = ' '.join(['.' + m.split(".")[-1] if "." in m else m for m in moves_string.split()]) |
|
game_state = game_state.rstrip() + " " |
|
|
|
tokens = moves_string.split() |
|
|
|
for token in tokens: |
|
|
|
if "." in token: |
|
move = token.split(".")[-1] |
|
else: |
|
move = token |
|
|
|
board.push_san(move) |
|
return game_state.rstrip(), board |
|
|
|
|
|
|
|
def get_legal_move( |
|
player: Player, |
|
board: chess.Board, |
|
game_state: str, |
|
player_one: bool, |
|
max_attempts: int = 5, |
|
) -> LegalMoveResponse: |
|
"""Request a move from the player and ensure it's legal.""" |
|
move_san = None |
|
move_uci = None |
|
|
|
for attempt in range(max_attempts): |
|
|
|
move_san = player.get_move( |
|
board, game_state, min(((attempt / max_attempts) * 1) + 0.001, 0.75) |
|
) |
|
|
|
|
|
|
|
if move_san is not None: |
|
if move_san == "1-0" or move_san == "0-1" or move_san == "1/2-1/2": |
|
print(f"{move_san}, player has resigned") |
|
return LegalMoveResponse( |
|
move_san=None, |
|
move_uci=None, |
|
attempts=attempt, |
|
is_resignation=True, |
|
) |
|
|
|
try: |
|
move_uci = board.parse_san(move_san) |
|
except Exception as e: |
|
print(f"Error parsing move {move_san}: {e}") |
|
|
|
|
|
if player.get_config()["model"] == "gpt-3.5-turbo-instruct": |
|
with open("gpt-3.5-turbo-instruct-illegal-moves.txt", "a") as f: |
|
f.write(f"{game_state}\n{move_san}\n") |
|
continue |
|
|
|
if move_uci in board.legal_moves: |
|
if player_one == False: |
|
if not move_san.startswith(" "): |
|
move_san = " " + move_san |
|
else: |
|
if move_san.startswith(" "): |
|
move_san = move_san[1:] |
|
return LegalMoveResponse(move_san, move_uci, attempt) |
|
print(f"Illegal move: {move_san}") |
|
|
|
|
|
print(f"{player} provided illegal moves for {max_attempts} attempts.") |
|
return LegalMoveResponse( |
|
move_san=None, move_uci=None, attempts=max_attempts, is_illegal_move=True |
|
) |
|
|
|
|
|
def play_turn( |
|
player: Player, board: chess.Board, game_state: str, player_one: bool |
|
) -> Tuple[str, bool, bool, int]: |
|
result = get_legal_move(player, board, game_state, player_one, 5) |
|
illegal_moves = result.attempts |
|
move_san = result.move_san |
|
move_uci = result.move_uci |
|
resignation = result.is_resignation |
|
failed_to_find_legal_move = result.is_illegal_move |
|
|
|
if resignation: |
|
print(f"{player} resigned with result: {board.result()}") |
|
elif failed_to_find_legal_move: |
|
print(f"Game over: 5 consecutive illegal moves from {player}") |
|
elif move_san is None or move_uci is None: |
|
print(f"Game over: {player} failed to find a legal move") |
|
else: |
|
board.push(move_uci) |
|
game_state += move_san |
|
print(move_san, end=" ") |
|
|
|
return game_state, resignation, failed_to_find_legal_move, illegal_moves |
|
|
|
|
|
def play_game( |
|
player_one: Player, |
|
player_two: Player, |
|
max_games: int = 10, |
|
random_opening_seed: bool = False, |
|
): |
|
for z in range(max_games): |
|
print(f"\nGame {z} of {max_games}\n") |
|
|
|
with open("gpt_inputs/prompt.txt", "r") as f: |
|
game_state = f.read() |
|
board = chess.Board() |
|
|
|
if random_opening_seed: |
|
game_state, board = initialize_game_with_opening(game_state, board) |
|
|
|
player_one_illegal_moves = 0 |
|
player_two_illegal_moves = 0 |
|
player_one_legal_moves = 0 |
|
player_two_legal_moves = 0 |
|
player_one_resignation = False |
|
player_two_resignation = False |
|
player_one_failed_to_find_legal_move = False |
|
player_two_failed_to_find_legal_move = False |
|
start_time = time.time() |
|
|
|
total_moves = 0 |
|
illegal_moves = 0 |
|
print_for_human = isinstance(player_one, HumanPlayer) or isinstance(player_two, HumanPlayer) |
|
|
|
while not board.is_game_over(): |
|
if print_for_human: |
|
print(board) |
|
|
|
with open("game.txt", "w") as f: |
|
f.write(game_state) |
|
current_move_num = f"{board.fullmove_number if move_num_in_gamestate else ''}." |
|
total_moves += 1 |
|
|
|
player_one_legal_moves += 1 |
|
player_two_legal_moves += 1 |
|
|
|
|
|
if board.fullmove_number != 1: |
|
game_state += " " |
|
game_state += current_move_num |
|
|
|
|
|
|
|
( |
|
game_state, |
|
player_one_resignation, |
|
player_one_failed_to_find_legal_move, |
|
illegal_moves_one, |
|
) = play_turn(player_one, board, game_state, player_one=True) |
|
player_one_illegal_moves += illegal_moves_one |
|
if illegal_moves_one != 0: |
|
player_one_legal_moves -= 1 |
|
if ( |
|
board.is_game_over() |
|
or player_one_resignation |
|
or player_one_failed_to_find_legal_move |
|
): |
|
break |
|
|
|
( |
|
game_state, |
|
player_two_resignation, |
|
player_two_failed_to_find_legal_move, |
|
illegal_moves_two, |
|
) = play_turn(player_two, board, game_state, player_one=False) |
|
player_two_illegal_moves += illegal_moves_two |
|
if illegal_moves_two != 0: |
|
player_two_legal_moves -= 1 |
|
if ( |
|
board.is_game_over() |
|
or player_two_resignation |
|
or player_two_failed_to_find_legal_move |
|
): |
|
break |
|
|
|
print("\n", end="") |
|
|
|
if total_moves > MAX_MOVES: |
|
break |
|
|
|
end_time = time.time() |
|
total_time = end_time - start_time |
|
print(f"\nGame over. Total time: {total_time} seconds") |
|
print(f"Result: {board.result()}") |
|
print(board) |
|
print() |
|
record_results( |
|
board, |
|
player_one, |
|
player_two, |
|
game_state, |
|
player_one_illegal_moves, |
|
player_two_illegal_moves, |
|
player_one_legal_moves, |
|
player_two_legal_moves, |
|
total_time, |
|
player_one_resignation, |
|
player_two_resignation, |
|
player_one_failed_to_find_legal_move, |
|
player_two_failed_to_find_legal_move, |
|
total_moves, |
|
illegal_moves, |
|
) |
|
if isinstance(player_one, StockfishPlayer): |
|
player_one.close() |
|
if isinstance(player_two, StockfishPlayer): |
|
player_two.close() |
|
|
|
|
|
|
|
|
|
RUN_FOR_ANALYSIS = True |
|
MAX_MOVES = 999 |
|
recording_file = "logs/determine.csv" |
|
|
|
|
|
player_ones = ["50M/ckpt_9715500b.pt"] |
|
player_two_recording_name = "lc0_sweep" |
|
move_num_in_gamestate = False |
|
if __name__ == "__main__": |
|
for nanogpt_player in player_ones: |
|
player_one_recording_name = nanogpt_player |
|
for i in range(1): |
|
num_games = 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
player_one = MambaPlayer(model_name=player_one_recording_name, move_num_in_gamestate=False) |
|
|
|
player_two = LC0PLayer(skill=i) |
|
|
|
|
|
|
|
|
|
|
|
print(f"\n\nSTARTING GAMES AGAINST LC0 LEVEL {i}\n\n") |
|
|
|
play_game(player_one, player_two, num_games, random_opening_seed=True) |
|
|
|
print("\n\n\n********\nDONE!\n********\n\n\n") |
|
|