Spaces:
Sleeping
Sleeping
File size: 1,822 Bytes
9dcb348 ee2b517 3479f48 3a76146 f8bdf54 4249eba 9ede36f 3a76146 c51abf0 3479f48 3a76146 3479f48 3a76146 aa43f32 3479f48 3a76146 796a2f3 3479f48 4249eba 22bfac3 4249eba 22bfac3 3479f48 c51abf0 3479f48 c51abf0 3479f48 796a2f3 4249eba 796a2f3 4249eba 796a2f3 9dcb348 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import gradio as gr
from io import StringIO
from model import DecoderTransformer, Tokenizer
from huggingface_hub import hf_hub_download
import torch
import chess
import chess.svg
import chess.pgn
vocab_size=33
n_embed=384
context_size=256
n_layer=6
n_head=6
dropout=0.2
device = 'cpu'
model_id = "philipp-zettl/chessPT"
model_path = hf_hub_download(repo_id=model_id, filename="chessPT.pkl")
tokenizer_path = hf_hub_download(repo_id=model_id, filename="tokenizer.json")
model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.to(device)
tokenizer = Tokenizer.from_pretrained(tokenizer_path)
def generate(prompt):
model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt)))
pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size)[0].tolist())
pgn_str = StringIO(pgn)
game = chess.pgn.read_game(pgn_str)
img = chess.svg.board(game.board())
filename = f'moves-{pgn}'
with open(filename, 'w') as f:
f.write(img)
return pgn, filename
with gr.Blocks() as demo:
gr.Markdown("""
# ChessPT
Welcome to ChessPT.
The **C**hess-**P**re-trained-**T**ransformer.
The rules are simple: provide a PGN string of your current game, the engine will predict the next token!
""")
prompt = gr.Text(label="PGN")
output = gr.Text(label="Next turn", interactive=False)
submit = gr.Button("Submit")
submit.click(generate, [prompt], [output])
img = gr.Image()
gr.Examples(
[
["1. e4", ],
["1. e4 g6 2."],
],
inputs=[prompt],
outputs=[output, img],
fn=generate
)
demo.launch()
|