|
import os |
|
import gradio as gr |
|
import random |
|
import chess |
|
import chess.svg |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline |
|
|
|
token = os.environ['auth_token'] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token) |
|
model = AutoModelForSequenceClassification.from_pretrained('jrahn/chessv3', use_auth_token=token) |
|
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer) |
|
|
|
empty_field = '0' |
|
board_split = ' | ' |
|
nums = {str(n): empty_field * n for n in range(1, 9)} |
|
nums_rev = {v:k for k,v in reversed(nums.items())} |
|
|
|
|
|
def encode_fen(fen): |
|
|
|
|
|
fen_board, fen_rest = fen.split(' ', 1) |
|
for n in nums: |
|
fen_board = fen_board.replace(n, nums[n]) |
|
fen_board = '+' + fen_board |
|
fen_board = fen_board.replace('/', ' +') |
|
return board_split.join([fen_board, fen_rest]) |
|
|
|
def decode_fen_repr(fen_repr): |
|
fen_board, fen_rest = fen_repr.split(board_split, 1) |
|
for n in nums_rev: |
|
fen_board = fen_board.replace(n, nums_rev[n]) |
|
fen_board = fen_board.replace(' +', '/') |
|
fen_board = fen_board.replace('+', '') |
|
return ' '.join([fen_board, fen_rest]) |
|
|
|
def predict_move(fen, top_k=3): |
|
fen_prep = encode_fen(fen) |
|
preds = pipe(fen_prep, top_k=top_k) |
|
weights = [p['score'] for p in preds] |
|
p = random.choices(preds, weights=weights)[0] |
|
|
|
return p['label'] |
|
|
|
def btn_load(inp_fen): |
|
board = chess.Board() |
|
|
|
with open('board.svg', 'w') as f: |
|
f.write(str(chess.svg.board(board))) |
|
return 'board.svg', board.fen() |
|
|
|
def btn_play(inp_fen, inp_move, inp_notation, inp_k): |
|
board = chess.Board(inp_fen) |
|
|
|
if inp_move: |
|
if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move) |
|
elif inp_notation == 'SAN': mv = board.parse_san(inp_move) |
|
else: |
|
mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k)) |
|
|
|
board.push(mv) |
|
|
|
with open('board.svg', 'w') as f: |
|
f.write(str(chess.svg.board(board, lastmove=mv))) |
|
|
|
return 'board.svg', board.fen(), '' |
|
|
|
with gr.Blocks() as block: |
|
gr.Markdown( |
|
''' |
|
# Play YoloChess - Policy Network v0.3 |
|
110M Parameter Transformer (BERT-base architecture) trained for text classification from scratch on expert games in modified FEN notation. |
|
''' |
|
) |
|
with gr.Row() as row: |
|
with gr.Column(): |
|
with gr.Row(): |
|
move = gr.Textbox(label='human player move') |
|
notation = gr.Radio(["SAN", "UCI"], value="SAN", label='move notation') |
|
fen = gr.Textbox(value=chess.Board().fen(), label='FEN') |
|
top_k = gr.Number(value=3, label='pick from top_k moves', precision=0) |
|
with gr.Row(): |
|
load_btn = gr.Button("Load") |
|
play_btn = gr.Button("Play") |
|
gr.Markdown( |
|
''' |
|
- Click "Load" button to start and reset board. |
|
- Click "Play" button to get Engine move. |
|
- Enter a "human player move" in UCI or SAN notation and click "Play" to move a piece. |
|
- Output "ERROR" generally occurs on illegal moves (Human or Engine). |
|
- Enter "FEN" to start from a custom position. |
|
''' |
|
) |
|
with gr.Column(): |
|
position_output = gr.Image(label='board') |
|
|
|
load_btn.click(fn=btn_load, inputs=fen, outputs=[position_output, fen]) |
|
play_btn.click(fn=btn_play, inputs=[fen, move, notation, top_k], outputs=[position_output, fen, move]) |
|
|
|
|
|
block.launch() |