jrahn commited on
Commit
313445a
1 Parent(s): 23220ce

initial application

Browse files
Files changed (2) hide show
  1. app.py +81 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import random
4
+ import chess
5
+ import chess.svg
6
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, pipeline
7
+
8
+ token = os.environ['auth_token']
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv3', use_auth_token=token)
11
+ model = AutoModelForSequenceClassification.from_pretrained('jrahn/chessv3', use_auth_token=token)
12
+ pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
13
+
14
+ empty_field = '0'
15
+ board_split = ' | '
16
+ nums = {str(n): empty_field * n for n in range(1, 9)}
17
+ nums_rev = {v:k for k,v in reversed(nums.items())}
18
+
19
+ board = chess.Board()
20
+
21
+ def encode_fen(fen):
22
+ # decompress fen representation
23
+ # prepare for sub-word tokenization
24
+ fen_board, fen_rest = fen.split(' ', 1)
25
+ for n in nums:
26
+ fen_board = fen_board.replace(n, nums[n])
27
+ fen_board = '+' + fen_board
28
+ fen_board = fen_board.replace('/', ' +')
29
+ return board_split.join([fen_board, fen_rest])
30
+
31
+ def decode_fen_repr(fen_repr):
32
+ fen_board, fen_rest = fen_repr.split(board_split, 1)
33
+ for n in nums_rev:
34
+ fen_board = fen_board.replace(n, nums_rev[n])
35
+ fen_board = fen_board.replace(' +', '/')
36
+ fen_board = fen_board.replace('+', '')
37
+ return ' '.join([fen_board, fen_rest])
38
+
39
+ def predict_move(fen, top_k=3):
40
+ fen_prep = encode_fen(fen)
41
+ preds = pipe(fen_prep, top_k=top_k)
42
+ weights = [p['score'] for p in preds]
43
+ p = random.choices(preds, weights=weights)[0]
44
+ # discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k
45
+ return p['label']
46
+
47
+ def play_yolochess(inp_color, inp_notation, inp_move, inp_custom_fen, state):
48
+ global board
49
+ if inp_custom_fen:
50
+ board = chess.Board(fen=inp_custom_fen)
51
+ if (inp_color == 'white' and board.turn == chess.BLACK) or (inp_color == 'black' and board.turn == chess.WHITE):
52
+ move = predict_move(board.fen())
53
+ board.push_uci(move)
54
+ else:
55
+ if inp_move:
56
+ if inp_notation == 'UCI':
57
+ board.push_uci(inp_move)
58
+ if inp_notation == 'SAN':
59
+ board.push_san(inp_move)
60
+ with open('board.svg', 'w') as f:
61
+ f.write(str(chess.svg.board(board)))
62
+ print(state)
63
+ return 'board.svg', board.fen()
64
+
65
+ iface = gr.Interface(
66
+ fn=play_yolochess,
67
+ inputs=[
68
+ gr.Radio(["white", "black"], value="white", label='human player color'),
69
+ gr.Radio(["SAN", "UCI"], value="SAN", label='move notation'),
70
+ gr.Textbox(label='human player move'),
71
+ gr.Textbox(placeholder=board.fen(), label='starting position FEN'),
72
+ "state"
73
+ ],
74
+ outputs=[
75
+ gr.Image(label='position'),
76
+ "state"
77
+ ],
78
+ title='Play YoloChess - Policy Network v0.3',
79
+ allow_flagging="never",
80
+ )
81
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ chess==1.9.2