initial application
Browse files- app.py +81 -0
- requirements.txt +1 -0
app.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import random
|
4 |
+
import chess
|
5 |
+
import chess.svg
|
6 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, pipeline
|
7 |
+
|
8 |
+
token = os.environ['auth_token']
|
9 |
+
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
|
11 |
+
model = AutoModelForSequenceClassification.from_pretrained('jrahn/chessv3', use_auth_token=token)
|
12 |
+
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
|
13 |
+
|
14 |
+
empty_field = '0'
|
15 |
+
board_split = ' | '
|
16 |
+
nums = {str(n): empty_field * n for n in range(1, 9)}
|
17 |
+
nums_rev = {v:k for k,v in reversed(nums.items())}
|
18 |
+
|
19 |
+
board = chess.Board()
|
20 |
+
|
21 |
+
def encode_fen(fen):
|
22 |
+
# decompress fen representation
|
23 |
+
# prepare for sub-word tokenization
|
24 |
+
fen_board, fen_rest = fen.split(' ', 1)
|
25 |
+
for n in nums:
|
26 |
+
fen_board = fen_board.replace(n, nums[n])
|
27 |
+
fen_board = '+' + fen_board
|
28 |
+
fen_board = fen_board.replace('/', ' +')
|
29 |
+
return board_split.join([fen_board, fen_rest])
|
30 |
+
|
31 |
+
def decode_fen_repr(fen_repr):
|
32 |
+
fen_board, fen_rest = fen_repr.split(board_split, 1)
|
33 |
+
for n in nums_rev:
|
34 |
+
fen_board = fen_board.replace(n, nums_rev[n])
|
35 |
+
fen_board = fen_board.replace(' +', '/')
|
36 |
+
fen_board = fen_board.replace('+', '')
|
37 |
+
return ' '.join([fen_board, fen_rest])
|
38 |
+
|
39 |
+
def predict_move(fen, top_k=3):
|
40 |
+
fen_prep = encode_fen(fen)
|
41 |
+
preds = pipe(fen_prep, top_k=top_k)
|
42 |
+
weights = [p['score'] for p in preds]
|
43 |
+
p = random.choices(preds, weights=weights)[0]
|
44 |
+
# discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k
|
45 |
+
return p['label']
|
46 |
+
|
47 |
+
def play_yolochess(inp_color, inp_notation, inp_move, inp_custom_fen, state):
|
48 |
+
global board
|
49 |
+
if inp_custom_fen:
|
50 |
+
board = chess.Board(fen=inp_custom_fen)
|
51 |
+
if (inp_color == 'white' and board.turn == chess.BLACK) or (inp_color == 'black' and board.turn == chess.WHITE):
|
52 |
+
move = predict_move(board.fen())
|
53 |
+
board.push_uci(move)
|
54 |
+
else:
|
55 |
+
if inp_move:
|
56 |
+
if inp_notation == 'UCI':
|
57 |
+
board.push_uci(inp_move)
|
58 |
+
if inp_notation == 'SAN':
|
59 |
+
board.push_san(inp_move)
|
60 |
+
with open('board.svg', 'w') as f:
|
61 |
+
f.write(str(chess.svg.board(board)))
|
62 |
+
print(state)
|
63 |
+
return 'board.svg', board.fen()
|
64 |
+
|
65 |
+
iface = gr.Interface(
|
66 |
+
fn=play_yolochess,
|
67 |
+
inputs=[
|
68 |
+
gr.Radio(["white", "black"], value="white", label='human player color'),
|
69 |
+
gr.Radio(["SAN", "UCI"], value="SAN", label='move notation'),
|
70 |
+
gr.Textbox(label='human player move'),
|
71 |
+
gr.Textbox(placeholder=board.fen(), label='starting position FEN'),
|
72 |
+
"state"
|
73 |
+
],
|
74 |
+
outputs=[
|
75 |
+
gr.Image(label='position'),
|
76 |
+
"state"
|
77 |
+
],
|
78 |
+
title='Play YoloChess - Policy Network v0.3',
|
79 |
+
allow_flagging="never",
|
80 |
+
)
|
81 |
+
iface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
chess==1.9.2
|