""" Gradio interface for plotting attention. """ import chess import chess.pgn import io import gradio as gr from lczerolens.board import LczeroBoard, InputEncoding from ..constants import FIGURE_DIRECTORY def make_render(game_pgn:str, board_fen:str, input_encoding:InputEncoding, plane_index:int): if game_pgn: try: board = LczeroBoard() pgn = io.StringIO(game_pgn) game = chess.pgn.read_game(pgn) for move in game.mainline_moves(): board.push(move) except Exception as e: print(e) gr.Warning("Error parsing PGN, using starting position.") board = LczeroBoard() else: try: board = LczeroBoard(board_fen) except Exception as e: print(e) gr.Warning("Invalid FEN, using starting position.") board = LczeroBoard() return board, *make_board_plot(board, input_encoding, plane_index) def make_board_plot(board:LczeroBoard, input_encoding:InputEncoding, plane_index:int): input_tensor = board.to_input_tensor(input_encoding=input_encoding) board.render_heatmap( input_tensor[plane_index].view(64), save_to=f"{FIGURE_DIRECTORY}/encodings.svg", vmin=0, vmax=1, ) return f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg" with gr.Blocks() as interface: with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown( "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." ) game_pgn = gr.Textbox( label="Game PGN", lines=1, value="", ) board_fen = gr.Textbox( label="Board FEN", lines=1, max_lines=1, value=chess.STARTING_FEN, ) input_encoding = gr.Radio( label="Input encoding", choices=[ ("classical", InputEncoding.INPUT_CLASSICAL_112_PLANE), ("repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED), ("no history repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED), ("no history zeros", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS) ], value=InputEncoding.INPUT_CLASSICAL_112_PLANE, ) with gr.Group(): with gr.Row(): plane_index = gr.Slider( label="Plane index", minimum=0, maximum=111, step=1, value=0, ) with gr.Column(): image_board = gr.Image(label="Board", interactive=False) colorbar = gr.Image(label="Colorbar", interactive=False) state_board = gr.State(value=LczeroBoard()) render_inputs = [game_pgn, board_fen, input_encoding, plane_index] render_outputs = [state_board, image_board, colorbar] interface.load( make_render, inputs=render_inputs, outputs=render_outputs, ) game_pgn.submit( make_render, inputs=render_inputs, outputs=render_outputs, ) board_fen.submit( make_render, inputs=render_inputs, outputs=render_outputs, ) input_encoding.change( make_board_plot, inputs=[state_board, input_encoding, plane_index], outputs=[image_board, colorbar], ) plane_index.change( make_board_plot, inputs=[state_board, input_encoding, plane_index], outputs=[image_board, colorbar], )