yolochess / app.py
jrahn's picture
Update app.py
56f19a3 verified
import os
import random
from datetime import datetime
import gradio as gr
import chess
import chess.svg
from transformers import DebertaV2ForSequenceClassification, AutoTokenizer, pipeline
token = os.environ['auth_token']
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv6', use_auth_token=token)
model = DebertaV2ForSequenceClassification.from_pretrained('jrahn/chessv6', use_auth_token=token)
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
def predict_move(fen, top_k=3):
preds = pipe(fen, top_k=top_k)
weights = [p['score'] for p in preds]
p = random.choices(preds, weights=weights)[0]
return p['label']
def btn_load(inp_fen):
print(f'** log - load - ts {datetime.now().isoformat()}, fen: {inp_fen}')
board = chess.Board()
return chess.svg.board(board), board.fen(), ''
def btn_play(inp_fen, inp_move, inp_notation, inp_k):
print(f'** log - play - ts {datetime.now().isoformat()}, fen: {inp_fen}, move: {inp_move}, notation: {inp_notation}, top_k: {inp_k}')
board = chess.Board(inp_fen)
if inp_move:
if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move)
elif inp_notation == 'SAN': mv = board.parse_san(inp_move)
else:
mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
if mv in board.legal_moves:
board.push(mv)
else:
raise ValueError(f'Illegal Move: {str(mv)} @ {board.fen()}')
svg_str = chess.svg.board(board, lastmove=mv)
return svg_str, board.fen(), ''
with gr.Blocks() as block:
gr.Markdown(
'''
# Play YoloChess - Policy Network v0.6
87M Parameter Transformer (DeBERTaV2-base architecture)
- pre-trained (MLM) from scratch on chess positions in FEN notation
- fine-tuned for text classification (moves) on expert games.
'''
)
with gr.Row() as row:
with gr.Column():
with gr.Row():
move = gr.Textbox(label='human player move')
notation = gr.Radio(["SAN", "UCI"], value="SAN", label='move notation')
fen = gr.Textbox(value=chess.Board().fen(), label='FEN')
top_k = gr.Number(value=3, label='sample from top_k moves', precision=0)
with gr.Row():
load_btn = gr.Button("Load")
play_btn = gr.Button("Play")
gr.Markdown(
'''
- Click "Load" button to start and reset board.
- Click "Play" button to get Engine move.
- Enter a "human player move" in UCI or SAN notation and click "Play" to move a piece.
- Output "ERROR" generally occurs on illegal moves (Human or Engine).
- Enter "FEN" to start from a custom position.
'''
)
with gr.Column():
position_output = gr.HTML(label='board')
load_btn.click(fn=btn_load, inputs=fen, outputs=[position_output, fen, move])
play_btn.click(fn=btn_play, inputs=[fen, move, notation, top_k], outputs=[position_output, fen, move])
block.launch()