File size: 1,822 Bytes
9dcb348
ee2b517
3479f48
3a76146
f8bdf54
4249eba
 
9ede36f
3a76146
 
 
 
 
 
 
 
 
c51abf0
3479f48
3a76146
 
 
3479f48
3a76146
 
aa43f32
3479f48
 
3a76146
796a2f3
3479f48
4249eba
22bfac3
 
4249eba
22bfac3
 
 
 
3479f48
 
 
 
c51abf0
3479f48
 
c51abf0
3479f48
 
 
 
 
 
 
796a2f3
4249eba
796a2f3
 
 
 
 
 
4249eba
796a2f3
 
9dcb348
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
from io import StringIO
from model import DecoderTransformer, Tokenizer
from huggingface_hub import hf_hub_download
import torch
import chess
import chess.svg
import chess.pgn


vocab_size=33
n_embed=384
context_size=256
n_layer=6
n_head=6
dropout=0.2

device = 'cpu'

model_id = "philipp-zettl/chessPT"

model_path = hf_hub_download(repo_id=model_id, filename="chessPT.pkl")
tokenizer_path = hf_hub_download(repo_id=model_id, filename="tokenizer.json")

model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.to(device)
tokenizer = Tokenizer.from_pretrained(tokenizer_path)

def generate(prompt):
    model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt)))
    pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size)[0].tolist())
    pgn_str = StringIO(pgn)
    game = chess.pgn.read_game(pgn_str)
    img = chess.svg.board(game.board())
    filename = f'moves-{pgn}'
    with open(filename, 'w') as f:
        f.write(img)
    return pgn, filename


with gr.Blocks() as demo:
    gr.Markdown("""
    # ChessPT
    Welcome to ChessPT.

    The **C**hess-**P**re-trained-**T**ransformer.

    The rules are simple: provide a PGN string of your current game, the engine will predict the next token!
    """)
    prompt = gr.Text(label="PGN")
    output = gr.Text(label="Next turn", interactive=False)

    submit = gr.Button("Submit")
    submit.click(generate, [prompt], [output])
    img = gr.Image()
    gr.Examples(
        [
            ["1. e4", ],
            ["1. e4 g6 2."],
        ],
        inputs=[prompt],
        outputs=[output, img],
        fn=generate
    )
demo.launch()