File size: 3,236 Bytes
313445a 72998f1 313445a 4a50f91 313445a 4a02647 313445a 4a02647 313445a b02b814 72998f1 1e7240f b02b814 313445a 3efeb97 313445a 0671ff1 7763e11 799a947 b02b814 4a50f91 b02b814 0671ff1 b02b814 0c203c2 b02b814 799a947 4a02647 b836b8f 799a947 b02b814 0671ff1 799a947 ded9643 b02b814 1e7240f b02b814 3efeb97 0671ff1 b02b814 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import random
from datetime import datetime
import gradio as gr
import chess
import chess.svg
from transformers import DebertaV2ForSequenceClassification, AutoTokenizer, pipeline
token = os.environ['auth_token']
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv6', use_auth_token=token)
model = DebertaV2ForSequenceClassification.from_pretrained('jrahn/chessv6', use_auth_token=token)
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
def predict_move(fen, top_k=3):
preds = pipe(fen, top_k=top_k)
weights = [p['score'] for p in preds]
p = random.choices(preds, weights=weights)[0]
return p['label']
def btn_load(inp_fen):
print(f'** log - load - ts {datetime.now().isoformat()}, fen: {inp_fen}')
board = chess.Board()
with open('board.svg', 'w') as f:
f.write(str(chess.svg.board(board)))
return 'board.svg', board.fen(), ''
def btn_play(inp_fen, inp_move, inp_notation, inp_k):
print(f'** log - play - ts {datetime.now().isoformat()}, fen: {inp_fen}, move: {inp_move}, notation: {inp_notation}, top_k: {inp_k}')
board = chess.Board(inp_fen)
if inp_move:
if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move)
elif inp_notation == 'SAN': mv = board.parse_san(inp_move)
else:
mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
if mv in board.legal_moves:
board.push(mv)
else:
raise ValueError(f'Illegal Move: {str(mv)} @ {board.fen()}')
with open('board.svg', 'w') as f:
f.write(str(chess.svg.board(board, lastmove=mv)))
return 'board.svg', board.fen(), ''
with gr.Blocks() as block:
gr.Markdown(
'''
# Play YoloChess - Policy Network v0.6
87M Parameter Transformer (DeBERTaV2-base architecture)
- pre-trained (MLM) from scratch on chess positions in FEN notation
- fine-tuned for text classification (moves) on expert games.
'''
)
with gr.Row() as row:
with gr.Column():
with gr.Row():
move = gr.Textbox(label='human player move')
notation = gr.Radio(["SAN", "UCI"], value="SAN", label='move notation')
fen = gr.Textbox(value=chess.Board().fen(), label='FEN')
top_k = gr.Number(value=3, label='sample from top_k moves', precision=0)
with gr.Row():
load_btn = gr.Button("Load")
play_btn = gr.Button("Play")
gr.Markdown(
'''
- Click "Load" button to start and reset board.
- Click "Play" button to get Engine move.
- Enter a "human player move" in UCI or SAN notation and click "Play" to move a piece.
- Output "ERROR" generally occurs on illegal moves (Human or Engine).
- Enter "FEN" to start from a custom position.
'''
)
with gr.Column():
position_output = gr.Image(label='board')
load_btn.click(fn=btn_load, inputs=fen, outputs=[position_output, fen, move])
play_btn.click(fn=btn_play, inputs=[fen, move, notation, top_k], outputs=[position_output, fen, move])
block.launch() |