File size: 3,833 Bytes
313445a b02b814 313445a b02b814 1e7240f b02b814 313445a 0671ff1 799a947 b02b814 0671ff1 b02b814 799a947 0671ff1 799a947 b02b814 0671ff1 799a947 0671ff1 b02b814 1e7240f b02b814 0671ff1 b02b814 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os
import gradio as gr
import random
import chess
import chess.svg
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
token = os.environ['auth_token']
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
model = AutoModelForSequenceClassification.from_pretrained('jrahn/chessv3', use_auth_token=token)
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
empty_field = '0'
board_split = ' | '
nums = {str(n): empty_field * n for n in range(1, 9)}
nums_rev = {v:k for k,v in reversed(nums.items())}
def encode_fen(fen):
# decompress fen representation
# prepare for sub-word tokenization
fen_board, fen_rest = fen.split(' ', 1)
for n in nums:
fen_board = fen_board.replace(n, nums[n])
fen_board = '+' + fen_board
fen_board = fen_board.replace('/', ' +')
return board_split.join([fen_board, fen_rest])
def decode_fen_repr(fen_repr):
fen_board, fen_rest = fen_repr.split(board_split, 1)
for n in nums_rev:
fen_board = fen_board.replace(n, nums_rev[n])
fen_board = fen_board.replace(' +', '/')
fen_board = fen_board.replace('+', '')
return ' '.join([fen_board, fen_rest])
def predict_move(fen, top_k=3):
fen_prep = encode_fen(fen)
preds = pipe(fen_prep, top_k=top_k)
weights = [p['score'] for p in preds]
p = random.choices(preds, weights=weights)[0]
# discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k
return p['label']
def btn_load(inp_fen):
board = chess.Board()
with open('board.svg', 'w') as f:
f.write(str(chess.svg.board(board)))
return 'board.svg', board.fen()
def btn_play(inp_fen, inp_move, inp_notation, inp_k):
board = chess.Board(inp_fen)
if inp_move:
if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move) #board.push_uci(inp_move)
elif inp_notation == 'SAN': mv = board.parse_san(inp_move) #chess.Move.from_san(inp_move) #board.push_san(inp_move)
else:
mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
board.push(mv)
with open('board.svg', 'w') as f:
f.write(str(chess.svg.board(board, lastmove=mv)))
return 'board.svg', board.fen(), ''
with gr.Blocks() as block:
gr.Markdown(
'''
# Play YoloChess - Policy Network v0.3
110M Parameter Transformer (BERT-base architecture) trained for text classification from scratch on expert games in modified FEN notation.
'''
)
with gr.Row() as row:
with gr.Column():
with gr.Row():
move = gr.Textbox(label='human player move')
notation = gr.Radio(["SAN", "UCI"], value="SAN", label='move notation')
fen = gr.Textbox(value=chess.Board().fen(), label='FEN')
top_k = gr.Number(value=3, label='pick from top_k moves', precision=0)
with gr.Row():
load_btn = gr.Button("Load")
play_btn = gr.Button("Play")
gr.Markdown(
'''
- Click "Load" button to start and reset board.
- Click "Play" button to get Engine move.
- Enter a "human player move" in UCI or SAN notation and click "Play" to move a piece.
- Output "ERROR" generally occurs on illegal moves (Human or Engine).
- Enter "FEN" to start from a custom position.
'''
)
with gr.Column():
position_output = gr.Image(label='board')
load_btn.click(fn=btn_load, inputs=fen, outputs=[position_output, fen])
play_btn.click(fn=btn_play, inputs=[fen, move, notation, top_k], outputs=[position_output, fen, move])
block.launch() |