update to chessv6 sl policy model
Browse files
app.py
CHANGED
@@ -9,40 +9,14 @@ from transformers import DebertaV2ForSequenceClassification, AutoTokenizer, pipe
|
|
9 |
|
10 |
token = os.environ['auth_token']
|
11 |
|
12 |
-
tokenizer = AutoTokenizer.from_pretrained('jrahn/
|
13 |
-
model = DebertaV2ForSequenceClassification.from_pretrained('jrahn/
|
14 |
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
|
15 |
|
16 |
-
empty_field = '0'
|
17 |
-
board_split = ' | '
|
18 |
-
nums = {str(n): empty_field * n for n in range(1, 9)}
|
19 |
-
nums_rev = {v:k for k,v in reversed(nums.items())}
|
20 |
-
|
21 |
-
|
22 |
-
def encode_fen(fen):
|
23 |
-
# decompress fen representation
|
24 |
-
# prepare for sub-word tokenization
|
25 |
-
fen_board, fen_rest = fen.split(' ', 1)
|
26 |
-
for n in nums:
|
27 |
-
fen_board = fen_board.replace(n, nums[n])
|
28 |
-
fen_board = '+' + fen_board
|
29 |
-
fen_board = fen_board.replace('/', ' +')
|
30 |
-
return board_split.join([fen_board, fen_rest])
|
31 |
-
|
32 |
-
def decode_fen_repr(fen_repr):
|
33 |
-
fen_board, fen_rest = fen_repr.split(board_split, 1)
|
34 |
-
for n in nums_rev:
|
35 |
-
fen_board = fen_board.replace(n, nums_rev[n])
|
36 |
-
fen_board = fen_board.replace(' +', '/')
|
37 |
-
fen_board = fen_board.replace('+', '')
|
38 |
-
return ' '.join([fen_board, fen_rest])
|
39 |
-
|
40 |
def predict_move(fen, top_k=3):
|
41 |
-
|
42 |
-
preds = pipe(fen_prep, top_k=top_k)
|
43 |
weights = [p['score'] for p in preds]
|
44 |
p = random.choices(preds, weights=weights)[0]
|
45 |
-
# discard illegal moves (https://python-chess.readthedocs.io/en/latest/core.html#chess.Board.legal_moves), then select top_k
|
46 |
return p['label']
|
47 |
|
48 |
def btn_load(inp_fen):
|
@@ -76,9 +50,9 @@ def btn_play(inp_fen, inp_move, inp_notation, inp_k):
|
|
76 |
with gr.Blocks() as block:
|
77 |
gr.Markdown(
|
78 |
'''
|
79 |
-
# Play YoloChess - Policy Network v0.
|
80 |
-
|
81 |
-
- pre-trained (MLM) from scratch on chess positions in
|
82 |
- fine-tuned for text classification (moves) on expert games.
|
83 |
'''
|
84 |
)
|
|
|
9 |
|
10 |
token = os.environ['auth_token']
|
11 |
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained('jrahn/chessv6', use_auth_token=token)
|
13 |
+
model = DebertaV2ForSequenceClassification.from_pretrained('jrahn/chessv6', use_auth_token=token)
|
14 |
pipe = pipeline(task="text-classification", model=model, tokenizer=tokenizer)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def predict_move(fen, top_k=3):
|
17 |
+
preds = pipe(fen, top_k=top_k)
|
|
|
18 |
weights = [p['score'] for p in preds]
|
19 |
p = random.choices(preds, weights=weights)[0]
|
|
|
20 |
return p['label']
|
21 |
|
22 |
def btn_load(inp_fen):
|
|
|
50 |
with gr.Blocks() as block:
|
51 |
gr.Markdown(
|
52 |
'''
|
53 |
+
# Play YoloChess - Policy Network v0.6
|
54 |
+
87M Parameter Transformer (DeBERTaV2-base architecture)
|
55 |
+
- pre-trained (MLM) from scratch on chess positions in FEN notation
|
56 |
- fine-tuned for text classification (moves) on expert games.
|
57 |
'''
|
58 |
)
|