|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import chess |
|
import os |
|
import chess.engine as eng |
|
import torch.multiprocessing as mp |
|
import random |
|
from pathlib import Path |
|
|
|
|
|
CONFIG = { |
|
"stockfish_path": "/Users/aaronvattay/Downloads/stockfish/stockfish-macos-m1-apple-silicon", |
|
"model_path": "chessy_model.pth", |
|
"backup_model_path": "chessy_modelt-1.pth", |
|
"device": torch.device("mps"), |
|
"learning_rate": 1e-4, |
|
"num_games": 30, |
|
"num_epochs": 10, |
|
"stockfish_time_limit": 1.0, |
|
"search_depth": 1, |
|
"epsilon": 4 |
|
} |
|
|
|
device = CONFIG["device"] |
|
|
|
def board_to_tensor(board): |
|
piece_encoding = { |
|
'P': 1, 'N': 2, 'B': 3, 'R': 4, 'Q': 5, 'K': 6, |
|
'p': 7, 'n': 8, 'b': 9, 'r': 10, 'q': 11, 'k': 12 |
|
} |
|
|
|
tensor = torch.zeros(64, dtype=torch.long) |
|
for square in chess.SQUARES: |
|
piece = board.piece_at(square) |
|
if piece: |
|
tensor[square] = piece_encoding[piece.symbol()] |
|
else: |
|
tensor[square] = 0 |
|
|
|
return tensor.unsqueeze(0) |
|
|
|
class NN1(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.embedding = nn.Embedding(13, 64) |
|
self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=16) |
|
self.neu = 512 |
|
self.neurons = nn.Sequential( |
|
nn.Linear(4096, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, 64), |
|
nn.ReLU(), |
|
nn.Linear(64, 4) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.embedding(x) |
|
x = x.permute(1, 0, 2) |
|
attn_output, _ = self.attention(x, x, x) |
|
x = attn_output.permute(1, 0, 2).contiguous() |
|
x = x.view(x.size(0), -1) |
|
x = self.neurons(x) |
|
return x |
|
|
|
lass Policy(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.embedding = nn.Embedding(13, 32) |
|
self.attention = nn.MultiheadAttention(embed_dim=32, num_heads=16) |
|
self.neu = 256 |
|
self.neurons = nn.Sequential( |
|
nn.Linear(64*32, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, self.neu), |
|
nn.ReLU(), |
|
nn.Linear(self.neu, 128), |
|
nn.ReLU(), |
|
nn.Linear(128, 29275), |
|
) |
|
|
|
def forward(self, x): |
|
x = chess.Board(x) |
|
color = x.turn |
|
x = board_to_tensor(x) |
|
x = self.embedding(x) |
|
x = x.permute(1, 0, 2) |
|
attn_output, _ = self.attention(x, x, x) |
|
x = attn_output.permute(1, 0, 2).contiguous() |
|
x = x.view(x.size(0), -1) |
|
x = self.neurons(x) * color |
|
return x |
|
|
|
model = NN1().to(device) |
|
optimizer = optim.Adam(model.parameters(), lr=CONFIG["learning_rate"]) |
|
policy = Policy().to(device) |
|
polweight = torch.load("NeoChess/chessy_policy.pth",map_location=device,weights_only=False) |
|
policy.load_state_dict(polweight) |
|
|
|
try: |
|
model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device)) |
|
print(f"Loaded model from {CONFIG['model_path']}") |
|
except FileNotFoundError: |
|
try: |
|
model.load_state_dict(torch.load(CONFIG["backup_model_path"], map_location=device)) |
|
print(f"Loaded backup model from {CONFIG['backup_model_path']}") |
|
except FileNotFoundError: |
|
print("No model file found, starting from scratch.") |
|
|
|
model.train() |
|
criterion = nn.MSELoss() |
|
engine = eng.SimpleEngine.popen_uci(CONFIG["stockfish_path"]) |
|
lim = eng.Limit(time=CONFIG["stockfish_time_limit"]) |
|
|
|
def get_evaluation(board): |
|
""" |
|
Returns the evaluation of the board from the perspective of the current player. |
|
The model's output is from White's perspective. |
|
""" |
|
tensor = board_to_tensor(board).to(device) |
|
with torch.no_grad(): |
|
evaluation = model(tensor)[0][0].item() |
|
|
|
if board.turn == chess.WHITE: |
|
return evaluation |
|
else: |
|
return -evaluation |
|
|
|
with open("/usr/local/python/3.12.1/lib/python3.12/site-packages/torchrl/envs/custom/san_moves.txt", "r") as f: |
|
uci_to_index = {line.strip(): i for i, line in enumerate(f)} |
|
|
|
|
|
def search(board ,depth ,policy_net=policy, simulations=100, temperature=1.0, device="cpu"): |
|
""" |
|
Monte Carlo search using policy network for move selection |
|
and value network via get_evaluation(). |
|
""" |
|
|
|
depth |
|
with torch.no_grad(): |
|
fen_tensor = torch.tensor([board.fen()], device=device) |
|
logits = policy_net(fen_tensor)["logits"].squeeze(0) |
|
probs = torch.softmax(logits / temperature, dim=-1).cpu().numpy() |
|
|
|
move_scores = {move: 0 for move in board.legal_moves} |
|
|
|
for move in board.legal_moves: |
|
total_eval = 0 |
|
for _ in range(simulations): |
|
board.push(move) |
|
eval_score = get_evaluation(board) |
|
total_eval += eval_score |
|
board.pop() |
|
move_scores[move] = total_eval / simulations |
|
|
|
|
|
for move in move_scores: |
|
move_index = uci_to_index[str(move)] |
|
move_scores[move] *= probs[move_index] |
|
|
|
|
|
best_move = max(move_scores, key=move_scores.get) |
|
return best_move, move_scores |
|
|
|
|
|
|
|
def game_gen(engine_side): |
|
data = [] |
|
mc = 0 |
|
board = chess.Board() |
|
while not board.is_game_over(): |
|
is_bot_turn = board.turn != engine_side |
|
|
|
if is_bot_turn: |
|
evaling = {} |
|
for move in board.legal_moves: |
|
board.push(move) |
|
evaling[move] = -search(board, depth=CONFIG["search_depth"], alpha=float('-inf'), beta=float('inf')) |
|
board.pop() |
|
|
|
if not evaling: |
|
break |
|
|
|
keys = list(evaling.keys()) |
|
logits = torch.tensor(list(evaling.values())).to(device) |
|
probs = torch.softmax(logits,dim=0) |
|
epsilon = min(CONFIG["epsilon"],len(keys)) |
|
bests = torch.multinomial(probs,num_samples=epsilon,replacement=False) |
|
best_idx = bests[torch.argmax(logits[bests])] |
|
move = keys[best_idx.item()] |
|
|
|
else: |
|
result = engine.play(board, lim) |
|
move = result.move |
|
|
|
if is_bot_turn: |
|
data.append({ |
|
'fen': board.fen(), |
|
'move_number': mc, |
|
}) |
|
|
|
board.push(move) |
|
mc += 1 |
|
|
|
result = board.result() |
|
c = 0 |
|
if result == '1-0': |
|
c = 10.0 |
|
elif result == '0-1': |
|
c = -10.0 |
|
return data, c, mc |
|
def train(data, c, mc): |
|
for entry in data: |
|
tensor = board_to_tensor(chess.Board(entry['fen'])).to(device) |
|
target = torch.tensor(c * entry['move_number'] / mc, dtype=torch.float32).to(device) |
|
output = model(tensor)[0][0] |
|
loss = criterion(output, target) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
print(f"Saving model to {CONFIG['model_path']}") |
|
torch.save(model.state_dict(), CONFIG["model_path"]) |
|
return |
|
def main(): |
|
for i in range(CONFIG["num_epochs"]): |
|
mp.set_start_method('spawn', force=True) |
|
num_games = CONFIG['num_games'] |
|
num_instances = mp.cpu_count() |
|
print(f"Saving backup model to {CONFIG['backup_model_path']}") |
|
torch.save(model.state_dict(), CONFIG["backup_model_path"]) |
|
with mp.Pool(processes=num_instances) as pool: |
|
results_self = pool.starmap(game_gen, [(None,) for _ in range(num_games // 3)]) |
|
results_white = pool.starmap(game_gen, [(chess.WHITE,) for _ in range(num_games // 3)]) |
|
results_black = pool.starmap(game_gen, [(chess.BLACK,) for _ in range(num_games // 3)]) |
|
results = [] |
|
for s, w, b in zip(results_self, results_white, results_black): |
|
results.extend([s, w, b]) |
|
for batch in results: |
|
data, c, mc = batch |
|
print(f"Saving backup model to {CONFIG['backup_model_path']}") |
|
torch.save(model.state_dict(), CONFIG["backup_model_path"]) |
|
if data: |
|
train(data, c, mc) |
|
print("Training complete.") |
|
engine.quit() |
|
if __name__ == "__main__": |
|
main() |