""" Gradio interface for plotting attention. """ import chess import chess.pgn import io import gradio as gr from lczerolens.board import LczeroBoard from ..constants import FIGURE_DIRECTORY def make_render(game_pgn:str, board_fen:str, plane_index:int): if game_pgn: try: board = LczeroBoard() pgn = io.StringIO(game_pgn) game = chess.pgn.read_game(pgn) for move in game.mainline_moves(): board.push(move) except Exception as e: print(e) gr.Warning("Error parsing PGN, using starting position.") board = LczeroBoard() else: try: board = LczeroBoard(board_fen) except Exception as e: print(e) gr.Warning("Invalid FEN, using starting position.") board = LczeroBoard() return board, *make_board_plot(board, plane_index) def make_board_plot(board:LczeroBoard, plane_index:int): input_tensor = board.to_input_tensor() board.render_heatmap( input_tensor[plane_index].view(64), save_to=f"{FIGURE_DIRECTORY}/encodings.svg", vmin=0, vmax=1, ) return f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg" with gr.Blocks() as interface: with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown( "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." ) game_pgn = gr.Textbox( label="Game PGN", lines=1, value="", ) board_fen = gr.Textbox( label="Board FEN", lines=1, max_lines=1, value=chess.STARTING_FEN, ) with gr.Group(): with gr.Row(): plane_index = gr.Slider( label="Plane index", minimum=0, maximum=111, step=1, value=0, ) with gr.Column(): image_board = gr.Image(label="Board", interactive=False) colorbar = gr.Image(label="Colorbar", interactive=False) state_board = gr.State(value=LczeroBoard()) render_inputs = [game_pgn, board_fen, plane_index] render_outputs = [state_board, image_board, colorbar] interface.load( make_render, inputs=render_inputs, outputs=render_outputs, ) game_pgn.submit( make_render, inputs=render_inputs, outputs=render_outputs, ) board_fen.submit( make_render, inputs=render_inputs, outputs=render_outputs, ) plane_index.change( make_board_plot, inputs=[state_board, plane_index], outputs=[image_board, colorbar], )