File size: 2,995 Bytes
3333fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Gradio interface for plotting attention.
"""

import chess
import chess.pgn
import io
import gradio as gr

from lczerolens.board import LczeroBoard

from ..constants import FIGURE_DIRECTORY

def make_render(game_pgn:str, board_fen:str, plane_index:int):
    if game_pgn:
        try:
            board = LczeroBoard()
            pgn = io.StringIO(game_pgn)
            game = chess.pgn.read_game(pgn)
            for move in game.mainline_moves():
                board.push(move)
        except Exception as e:
            print(e)
            gr.Warning("Error parsing PGN, using starting position.")
            board = LczeroBoard()
    else:
        try:
            board = LczeroBoard(board_fen)
        except Exception as e:
            print(e)
            gr.Warning("Invalid FEN, using starting position.")
            board = LczeroBoard()
    return board, *make_board_plot(board, plane_index)

def make_board_plot(board:LczeroBoard, plane_index:int):
    input_tensor = board.to_input_tensor()
    board.render_heatmap(
        input_tensor[plane_index].view(64),
        save_to=f"{FIGURE_DIRECTORY}/encodings.svg",
        vmin=0,
        vmax=1,
    )
    return  f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg"

with gr.Blocks() as interface:
    with gr.Row():
        with gr.Column():
            with gr.Group():
                gr.Markdown(
                    "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
                )
                game_pgn = gr.Textbox(
                    label="Game PGN",
                    lines=1,
                    value="",
                )
                board_fen = gr.Textbox(
                    label="Board FEN",
                    lines=1,
                    max_lines=1,
                    value=chess.STARTING_FEN,
                )
            with gr.Group():
                with gr.Row():
                    plane_index = gr.Slider(
                        label="Plane index",
                        minimum=0,
                        maximum=111,
                        step=1,
                        value=0,
                    )
        with gr.Column():
            image_board = gr.Image(label="Board", interactive=False)
            colorbar = gr.Image(label="Colorbar", interactive=False)

    state_board = gr.State(value=LczeroBoard())

    render_inputs = [game_pgn, board_fen, plane_index]
    render_outputs = [state_board, image_board, colorbar]
    interface.load(
        make_render,
        inputs=render_inputs,
        outputs=render_outputs,
    )
    game_pgn.submit(
        make_render,
        inputs=render_inputs,
        outputs=render_outputs,
    )
    board_fen.submit(
        make_render,
        inputs=render_inputs,
        outputs=render_outputs,
    )
    plane_index.change(
        make_board_plot,
        inputs=[state_board, plane_index],
        outputs=[image_board, colorbar],
    )