File size: 2,817 Bytes
313445a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
import gradio as gr
import random
import chess
import chess.svg
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, pipeline
token = os.environ['auth_token']
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
model = AutoModelForSequenceClassification.from_pretrained('jrahn/chessv3', use_auth_token=token)
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
empty_field = '0'
board_split = ' | '
nums = {str(n): empty_field * n for n in range(1, 9)}
nums_rev = {v:k for k,v in reversed(nums.items())}
board = chess.Board()
def encode_fen(fen):
# decompress fen representation
# prepare for sub-word tokenization
fen_board, fen_rest = fen.split(' ', 1)
for n in nums:
fen_board = fen_board.replace(n, nums[n])
fen_board = '+' + fen_board
fen_board = fen_board.replace('/', ' +')
return board_split.join([fen_board, fen_rest])
def decode_fen_repr(fen_repr):
fen_board, fen_rest = fen_repr.split(board_split, 1)
for n in nums_rev:
fen_board = fen_board.replace(n, nums_rev[n])
fen_board = fen_board.replace(' +', '/')
fen_board = fen_board.replace('+', '')
return ' '.join([fen_board, fen_rest])
def predict_move(fen, top_k=3):
fen_prep = encode_fen(fen)
preds = pipe(fen_prep, top_k=top_k)
weights = [p['score'] for p in preds]
p = random.choices(preds, weights=weights)[0]
# discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k
return p['label']
def play_yolochess(inp_color, inp_notation, inp_move, inp_custom_fen, state):
global board
if inp_custom_fen:
board = chess.Board(fen=inp_custom_fen)
if (inp_color == 'white' and board.turn == chess.BLACK) or (inp_color == 'black' and board.turn == chess.WHITE):
move = predict_move(board.fen())
board.push_uci(move)
else:
if inp_move:
if inp_notation == 'UCI':
board.push_uci(inp_move)
if inp_notation == 'SAN':
board.push_san(inp_move)
with open('board.svg', 'w') as f:
f.write(str(chess.svg.board(board)))
print(state)
return 'board.svg', board.fen()
iface = gr.Interface(
fn=play_yolochess,
inputs=[
gr.Radio(["white", "black"], value="white", label='human player color'),
gr.Radio(["SAN", "UCI"], value="SAN", label='move notation'),
gr.Textbox(label='human player move'),
gr.Textbox(placeholder=board.fen(), label='starting position FEN'),
"state"
],
outputs=[
gr.Image(label='position'),
"state"
],
title='Play YoloChess - Policy Network v0.3',
allow_flagging="never",
)
iface.launch() |