File size: 3,833 Bytes
313445a
 
 
 
 
b02b814
313445a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b02b814
1e7240f
b02b814
313445a
 
 
 
0671ff1
799a947
b02b814
 
 
 
 
0671ff1
b02b814
 
 
 
 
 
 
 
 
799a947
 
 
0671ff1
799a947
 
b02b814
 
0671ff1
 
 
799a947
0671ff1
b02b814
 
 
1e7240f
 
 
 
 
 
 
 
 
b02b814
 
 
 
0671ff1
b02b814
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import gradio as gr
import random
import chess
import chess.svg
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

token = os.environ['auth_token']

tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
model = AutoModelForSequenceClassification.from_pretrained('jrahn/chessv3', use_auth_token=token)
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)

empty_field = '0'
board_split = ' | '
nums = {str(n): empty_field * n for n in range(1, 9)}
nums_rev = {v:k for k,v in reversed(nums.items())}


def encode_fen(fen):
    # decompress fen representation
    # prepare for sub-word tokenization
    fen_board, fen_rest = fen.split(' ', 1)
    for n in nums:
        fen_board = fen_board.replace(n, nums[n])
    fen_board = '+' + fen_board
    fen_board = fen_board.replace('/', ' +')
    return board_split.join([fen_board, fen_rest])

def decode_fen_repr(fen_repr):
    fen_board, fen_rest = fen_repr.split(board_split, 1)
    for n in nums_rev:
        fen_board = fen_board.replace(n, nums_rev[n])
    fen_board = fen_board.replace(' +', '/')
    fen_board = fen_board.replace('+', '')
    return ' '.join([fen_board, fen_rest])

def predict_move(fen, top_k=3):
    fen_prep = encode_fen(fen)
    preds = pipe(fen_prep, top_k=top_k)
    weights = [p['score'] for p in preds]
    p = random.choices(preds, weights=weights)[0]
    # discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k
    return p['label']

def btn_load(inp_fen):
    board = chess.Board()
    
    with open('board.svg', 'w') as f:
        f.write(str(chess.svg.board(board)))
    return 'board.svg', board.fen()

def btn_play(inp_fen, inp_move, inp_notation, inp_k):
    board = chess.Board(inp_fen)
    
    if inp_move:
        if inp_notation == 'UCI': mv = chess.Move.from_uci(inp_move) #board.push_uci(inp_move)
        elif inp_notation == 'SAN': mv = board.parse_san(inp_move) #chess.Move.from_san(inp_move) #board.push_san(inp_move)
    else:
        mv = chess.Move.from_uci(predict_move(board.fen(), top_k=inp_k))
    
    board.push(mv)
        
    with open('board.svg', 'w') as f:
        f.write(str(chess.svg.board(board, lastmove=mv)))
    
    return 'board.svg', board.fen(), ''

with gr.Blocks() as block:
    gr.Markdown(
    '''
    # Play YoloChess - Policy Network v0.3
    110M Parameter Transformer (BERT-base architecture) trained for text classification from scratch on expert games in modified FEN notation. 
    '''
    )
    with gr.Row() as row:
        with gr.Column():
            with gr.Row():
                move = gr.Textbox(label='human player move')
                notation = gr.Radio(["SAN", "UCI"], value="SAN", label='move notation')
            fen = gr.Textbox(value=chess.Board().fen(), label='FEN')
            top_k = gr.Number(value=3, label='pick from top_k moves', precision=0)
            with gr.Row():
                load_btn = gr.Button("Load")
                play_btn = gr.Button("Play")
            gr.Markdown(
                '''
                - Click "Load" button to start and reset board.  
                - Click "Play" button to get Engine move.  
                - Enter a "human player move" in UCI or SAN notation and click "Play" to move a piece.  
                - Output "ERROR" generally occurs on illegal moves (Human or Engine).
                - Enter "FEN" to start from a custom position.
                '''
            )
        with gr.Column():
            position_output = gr.Image(label='board')

    load_btn.click(fn=btn_load, inputs=fen, outputs=[position_output, fen])
    play_btn.click(fn=btn_play, inputs=[fen, move, notation, top_k], outputs=[position_output, fen, move])
    

block.launch()