File size: 2,817 Bytes
313445a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import gradio as gr
import random
import chess
import chess.svg
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, pipeline

token = os.environ['auth_token']

tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
model = AutoModelForSequenceClassification.from_pretrained('jrahn/chessv3', use_auth_token=token)
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)

empty_field = '0'
board_split = ' | '
nums = {str(n): empty_field * n for n in range(1, 9)}
nums_rev = {v:k for k,v in reversed(nums.items())}

board = chess.Board()

def encode_fen(fen):
    # decompress fen representation
    # prepare for sub-word tokenization
    fen_board, fen_rest = fen.split(' ', 1)
    for n in nums:
        fen_board = fen_board.replace(n, nums[n])
    fen_board = '+' + fen_board
    fen_board = fen_board.replace('/', ' +')
    return board_split.join([fen_board, fen_rest])

def decode_fen_repr(fen_repr):
    fen_board, fen_rest = fen_repr.split(board_split, 1)
    for n in nums_rev:
        fen_board = fen_board.replace(n, nums_rev[n])
    fen_board = fen_board.replace(' +', '/')
    fen_board = fen_board.replace('+', '')
    return ' '.join([fen_board, fen_rest])

def predict_move(fen, top_k=3):
    fen_prep = encode_fen(fen)
    preds = pipe(fen_prep, top_k=top_k)
    weights = [p['score'] for p in preds]
    p = random.choices(preds, weights=weights)[0]
    # discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k
    return p['label']

def play_yolochess(inp_color, inp_notation, inp_move, inp_custom_fen, state):
    global board
    if inp_custom_fen:
        board = chess.Board(fen=inp_custom_fen)
    if (inp_color == 'white' and board.turn == chess.BLACK) or (inp_color == 'black' and board.turn == chess.WHITE):
        move = predict_move(board.fen())
        board.push_uci(move)
    else:
        if inp_move:
            if inp_notation == 'UCI':
                board.push_uci(inp_move)
            if inp_notation == 'SAN':
                board.push_san(inp_move)
    with open('board.svg', 'w') as f:
        f.write(str(chess.svg.board(board)))
    print(state)
    return 'board.svg', board.fen()

iface = gr.Interface(
    fn=play_yolochess, 
    inputs=[
        gr.Radio(["white", "black"], value="white", label='human player color'),
        gr.Radio(["SAN", "UCI"], value="SAN", label='move notation'),
        gr.Textbox(label='human player move'),
        gr.Textbox(placeholder=board.fen(), label='starting position FEN'),
        "state"
    ], 
    outputs=[
        gr.Image(label='position'),
        "state"
    ],
    title='Play YoloChess - Policy Network v0.3',
    allow_flagging="never",
)
iface.launch()