Spaces:
Sleeping
Sleeping
""" | |
Gradio interface for plotting attention. | |
""" | |
import chess | |
import chess.pgn | |
import io | |
import gradio as gr | |
from lczerolens.board import LczeroBoard, InputEncoding | |
from ..constants import FIGURE_DIRECTORY | |
def make_render(game_pgn:str, board_fen:str, input_encoding:InputEncoding, plane_index:int): | |
if game_pgn: | |
try: | |
board = LczeroBoard() | |
pgn = io.StringIO(game_pgn) | |
game = chess.pgn.read_game(pgn) | |
for move in game.mainline_moves(): | |
board.push(move) | |
except Exception as e: | |
print(e) | |
gr.Warning("Error parsing PGN, using starting position.") | |
board = LczeroBoard() | |
else: | |
try: | |
board = LczeroBoard(board_fen) | |
except Exception as e: | |
print(e) | |
gr.Warning("Invalid FEN, using starting position.") | |
board = LczeroBoard() | |
return board, *make_board_plot(board, input_encoding, plane_index) | |
def make_board_plot(board:LczeroBoard, input_encoding:InputEncoding, plane_index:int): | |
input_tensor = board.to_input_tensor(input_encoding=input_encoding) | |
board.render_heatmap( | |
input_tensor[plane_index].view(64), | |
save_to=f"{FIGURE_DIRECTORY}/encodings.svg", | |
vmin=0, | |
vmax=1, | |
) | |
return f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg" | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
gr.Markdown( | |
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." | |
) | |
game_pgn = gr.Textbox( | |
label="Game PGN", | |
lines=1, | |
value="", | |
) | |
board_fen = gr.Textbox( | |
label="Board FEN", | |
lines=1, | |
max_lines=1, | |
value=chess.STARTING_FEN, | |
) | |
input_encoding = gr.Radio( | |
label="Input encoding", | |
choices=[ | |
("classical", InputEncoding.INPUT_CLASSICAL_112_PLANE), | |
("repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED), | |
("no history repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED), | |
("no history zeros", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS) | |
], | |
value=InputEncoding.INPUT_CLASSICAL_112_PLANE, | |
) | |
with gr.Group(): | |
with gr.Row(): | |
plane_index = gr.Slider( | |
label="Plane index", | |
minimum=0, | |
maximum=111, | |
step=1, | |
value=0, | |
) | |
with gr.Column(): | |
image_board = gr.Image(label="Board", interactive=False) | |
colorbar = gr.Image(label="Colorbar", interactive=False) | |
state_board = gr.State(value=LczeroBoard()) | |
render_inputs = [game_pgn, board_fen, input_encoding, plane_index] | |
render_outputs = [state_board, image_board, colorbar] | |
interface.load( | |
make_render, | |
inputs=render_inputs, | |
outputs=render_outputs, | |
) | |
game_pgn.submit( | |
make_render, | |
inputs=render_inputs, | |
outputs=render_outputs, | |
) | |
board_fen.submit( | |
make_render, | |
inputs=render_inputs, | |
outputs=render_outputs, | |
) | |
input_encoding.change( | |
make_board_plot, | |
inputs=[state_board, input_encoding, plane_index], | |
outputs=[image_board, colorbar], | |
) | |
plane_index.change( | |
make_board_plot, | |
inputs=[state_board, input_encoding, plane_index], | |
outputs=[image_board, colorbar], | |
) | |