|
import os |
|
import gradio as gr |
|
import random |
|
import chess |
|
import chess.svg |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, pipeline |
|
|
|
token = os.environ['auth_token'] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token) |
|
model = AutoModelForSequenceClassification.from_pretrained('jrahn/chessv3', use_auth_token=token) |
|
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer) |
|
|
|
empty_field = '0' |
|
board_split = ' | ' |
|
nums = {str(n): empty_field * n for n in range(1, 9)} |
|
nums_rev = {v:k for k,v in reversed(nums.items())} |
|
|
|
board = chess.Board() |
|
|
|
def encode_fen(fen): |
|
|
|
|
|
fen_board, fen_rest = fen.split(' ', 1) |
|
for n in nums: |
|
fen_board = fen_board.replace(n, nums[n]) |
|
fen_board = '+' + fen_board |
|
fen_board = fen_board.replace('/', ' +') |
|
return board_split.join([fen_board, fen_rest]) |
|
|
|
def decode_fen_repr(fen_repr): |
|
fen_board, fen_rest = fen_repr.split(board_split, 1) |
|
for n in nums_rev: |
|
fen_board = fen_board.replace(n, nums_rev[n]) |
|
fen_board = fen_board.replace(' +', '/') |
|
fen_board = fen_board.replace('+', '') |
|
return ' '.join([fen_board, fen_rest]) |
|
|
|
def predict_move(fen, top_k=3): |
|
fen_prep = encode_fen(fen) |
|
preds = pipe(fen_prep, top_k=top_k) |
|
weights = [p['score'] for p in preds] |
|
p = random.choices(preds, weights=weights)[0] |
|
|
|
return p['label'] |
|
|
|
def play_yolochess(inp_color, inp_notation, inp_move, inp_custom_fen, state): |
|
global board |
|
if inp_custom_fen: |
|
board = chess.Board(fen=inp_custom_fen) |
|
if (inp_color == 'white' and board.turn == chess.BLACK) or (inp_color == 'black' and board.turn == chess.WHITE): |
|
move = predict_move(board.fen()) |
|
board.push_uci(move) |
|
else: |
|
if inp_move: |
|
if inp_notation == 'UCI': |
|
board.push_uci(inp_move) |
|
if inp_notation == 'SAN': |
|
board.push_san(inp_move) |
|
with open('board.svg', 'w') as f: |
|
f.write(str(chess.svg.board(board))) |
|
print(state) |
|
return 'board.svg', board.fen() |
|
|
|
iface = gr.Interface( |
|
fn=play_yolochess, |
|
inputs=[ |
|
gr.Radio(["white", "black"], value="white", label='human player color'), |
|
gr.Radio(["SAN", "UCI"], value="SAN", label='move notation'), |
|
gr.Textbox(label='human player move'), |
|
gr.Textbox(placeholder=board.fen(), label='starting position FEN'), |
|
"state" |
|
], |
|
outputs=[ |
|
gr.Image(label='position'), |
|
"state" |
|
], |
|
title='Play YoloChess - Policy Network v0.3', |
|
allow_flagging="never", |
|
) |
|
iface.launch() |