File size: 2,792 Bytes
343fa36
 
 
 
 
 
 
 
 
5e4365f
343fa36
 
 
 
 
 
 
 
 
5e4365f
343fa36
5e4365f
343fa36
 
 
 
 
 
 
5e4365f
343fa36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Gradio interface for plotting encodings.
"""

import chess
import gradio as gr

from demo import constants, visualisation
from lczerolens import board_encodings
from lczerolens.board import LczeroBoard


def make_encoding_plot(
    board_fen,
    action_seq,
    plane_index,
    color_flip,
):
    try:
        board = LczeroBoard(board_fen)
    except ValueError:
        board = LczeroBoard()
        gr.Warning("Invalid FEN, using starting position.")
    if action_seq:
        try:
            for action in action_seq.split():
                board.push_uci(action)
        except ValueError:
            gr.Warning("Invalid action sequence, using starting position.")
            board = LczeroBoard()
    board_tensor = board_encodings.board_to_input_tensor(board)
    heatmap = board_tensor[plane_index]
    if color_flip and board.turn == chess.BLACK:
        heatmap = heatmap.flip(0)
    svg_board, fig = visualisation.render_heatmap(board, heatmap.view(64), vmin=0.0, vmax=1.0)
    with open(f"{constants.FIGURE_DIRECTORY}/encoding.svg", "w") as f:
        f.write(svg_board)
    return f"{constants.FIGURE_DIRECTORY}/encoding.svg", fig


with gr.Blocks() as interface:
    with gr.Row():
        with gr.Column():
            board_fen = gr.Textbox(
                label="Board starting FEN",
                lines=1,
                max_lines=1,
                value=chess.STARTING_FEN,
            )
            action_seq = gr.Textbox(
                label="Action sequence",
                lines=1,
                max_lines=1,
                value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
            )
            with gr.Group():
                with gr.Row():
                    plane_index = gr.Slider(
                        label="Plane index",
                        minimum=0,
                        maximum=111,
                        step=1,
                        value=0,
                        scale=3,
                    )
                    color_flip = gr.Checkbox(label="Color flip", value=True, scale=1)

            colorbar = gr.Plot(label="Colorbar")
        with gr.Column():
            image = gr.Image(label="Board")

    policy_inputs = [
        board_fen,
        action_seq,
        plane_index,
        color_flip,
    ]
    policy_outputs = [image, colorbar]
    board_fen.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
    action_seq.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
    plane_index.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
    color_flip.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
    interface.load(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)