File size: 3,560 Bytes
3333fb8
 
 
 
 
 
 
 
 
 
 
 
 
980eda6
 
3333fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980eda6
 
 
 
 
3333fb8
 
 
 
 
 
980eda6
 
3333fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Interface to play against the model.
"""

import os

import chess
import chess.pgn
import random
import gradio as gr

from lczerolens import LczeroBoard, LczeroModel
from lczerolens.play import PolicySampler

from demo import constants
from demo.utils import create_board_figure


def get_sampler(model_name: str):
    model = LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name))
    return PolicySampler(model)

def get_pgn(board: LczeroBoard):
    game = chess.pgn.Game()
    for move in board.move_stack:
        game.add_variation(move)
    return str(game).split("\n")[-1]

def render_board(
    board: LczeroBoard,
):
    player = board.turn
    if len(board.move_stack) > 0:
        last_move_uci = board.peek().uci()
    else:
        last_move_uci = None
    
    if board.is_check():
        check = board.king(board.turn)
    else:
        check = None
    filepath = create_board_figure(
        board, 
        orientation=player,
        arrows=last_move_uci,
        square=check,
        name="play_board",
    )
    return filepath

def gather_outputs(board: LczeroBoard, sampler: PolicySampler):
    return sampler, board, board.fen(), get_pgn(board), render_board(board), ""

def get_init(model_name: str):
    sampler = get_sampler(model_name)
    is_ai_white = random.choice([True, False])
    init_board = LczeroBoard()
    if is_ai_white:
        play_ai_move(init_board, sampler)
    return gather_outputs(init_board, sampler)

def play_user_move_then_ai_move(
    uci_move: str,
    board: LczeroBoard,
    sampler: PolicySampler,
):
    board.push_uci(uci_move)
    play_ai_move(board, sampler)
    return gather_outputs(board, sampler)
    

def play_ai_move(
    board: LczeroBoard,
    sampler: PolicySampler,
):
    move, _ = next(iter(sampler.get_next_moves([board])))
    board.push(move)

with gr.Blocks() as interface:
    with gr.Row():
        with gr.Column():
            current_fen = gr.Textbox(
                label="Board FEN",
                lines=1,
                max_lines=1,
                value=chess.STARTING_FEN,
            )
            current_pgn = gr.Textbox(
                label="Action sequence",
                lines=1,
                value="",
            )
            model_name = gr.Dropdown(
                label="Model",
                choices=constants.ONNX_MODEL_NAMES,
            )
            with gr.Column():
                move_to_play = gr.Textbox(
                    label="Move to play (UCI)",
                    lines=1,
                    max_lines=1,
                    value="",
                )
                play_button = gr.Button("Play")
                reset_button = gr.Button("Reset")
        with gr.Column():
            image_board = gr.Image(label="Board", interactive=False)

    sampler = gr.State(value=None)
    board = gr.State(value=None)

    outputs = [sampler, board, current_fen, current_pgn, image_board, move_to_play]

    play_button.click(
        play_user_move_then_ai_move,
        inputs=[move_to_play, board, sampler],
        outputs=outputs,
    )
    move_to_play.submit(
        play_user_move_then_ai_move,
        inputs=[move_to_play, board, sampler],
        outputs=outputs,
    )

    model_name.change(
        get_sampler,
        inputs=[model_name],
        outputs=[sampler],
    )

    reset_button.click(
        get_init,
        inputs=[model_name],
        outputs=outputs,
    )
    interface.load(
        get_init,
        inputs=[model_name],
        outputs=outputs,
    )