""" Gradio interface for plotting attention. """ import chess import chess.pgn import io import gradio as gr import os from lczerolens import LczeroBoard, LczeroModel, Lens from demo import constants def get_model(model_name: str): return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name)) def get_gradients(model: LczeroModel, board: LczeroBoard, target: str): lens = Lens.from_name("gradient") def init_target(model): if target == "best_move": return getattr(model, "output/policy").output.max(dim=1).values else: wdl_index = {"win": 0, "draw": 1, "loss": 2}[target] return getattr(model, "output/wdl").output[:, wdl_index] results = lens.analyse(model, board, init_target=init_target) return results["input_grad"] def get_board(game_pgn:str, board_fen:str): if game_pgn: try: board = LczeroBoard() pgn = io.StringIO(game_pgn) game = chess.pgn.read_game(pgn) for move in game.mainline_moves(): board.push(move) except Exception as e: print(e) gr.Warning("Error parsing PGN, using starting position.") board = LczeroBoard() else: try: board = LczeroBoard(board_fen) except Exception as e: print(e) gr.Warning("Invalid FEN, using starting position.") board = LczeroBoard() return board def render_gradients(board: LczeroBoard, gradients, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index:int): if average_over_planes: heatmap = gradients[0, begin_average_index:end_average_index].mean(dim=0).view(64) else: heatmap = gradients[0, plane_index].view(64) board.render_heatmap( heatmap, save_to=f"{constants.FIGURE_DIRECTORY}/gradients.svg", ) return f"{constants.FIGURE_DIRECTORY}/gradients_board.svg", f"{constants.FIGURE_DIRECTORY}/gradients_colorbar.svg" def initial_load(model_name: str, board_fen: str, game_pgn: str, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): model = get_model(model_name) board = get_board(game_pgn, board_fen) gradients = get_gradients(model, board, target) plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) return model, board, gradients, *plots def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): board = get_board(game_pgn, board_fen) gradients = get_gradients(model, board, target) plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) return board, gradients, *plots def on_model_change(model_name: str, board: LczeroBoard, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): model = get_model(model_name) gradients = get_gradients(model, board, target) plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) return model, gradients, *plots def on_target_change(model: LczeroModel, board: LczeroBoard, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): gradients = get_gradients(model, board, target) plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) return gradients, *plots with gr.Blocks() as interface: with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown( "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." ) game_pgn = gr.Textbox( label="Game PGN", lines=1, value="", ) board_fen = gr.Textbox( label="Board FEN", lines=1, max_lines=1, value=chess.STARTING_FEN, ) with gr.Group(): model_name = gr.Dropdown( label="Model", choices=constants.ONNX_MODEL_NAMES, ) target = gr.Radio( ["win", "draw", "loss", "best_move"], label="Target", value="win", ) with gr.Group(): average_over_planes = gr.Checkbox(label="Average over Planes", value=False) with gr.Accordion("Average over planes", open=False): begin_average_index = gr.Slider( label="Begin average index", minimum=0, maximum=111, step=1, value=0, ) end_average_index = gr.Slider( label="End average index", minimum=0, maximum=111, step=1, value=111, ) plane_index = gr.Slider( label="Plane index", minimum=0, maximum=111, step=1, value=0, ) with gr.Column(): image_board = gr.Image(label="Board", interactive=False) colorbar = gr.Image(label="Colorbar", interactive=False) model = gr.State(value=None) board = gr.State(value=None) gradients = gr.State(value=None) interface.load( initial_load, inputs=[model_name, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[model, board, gradients, image_board, colorbar], concurrency_id="trace_queue" ) game_pgn.submit( on_board_change, inputs=[model, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[board, gradients, image_board, colorbar], concurrency_id="trace_queue" ) board_fen.submit( on_board_change, inputs=[model, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[board, gradients, image_board, colorbar], concurrency_id="trace_queue" ) model_name.change( on_model_change, inputs=[model_name, board, target, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[model, gradients, image_board, colorbar], concurrency_id="trace_queue" ) target.change( on_target_change, inputs=[model, board, target, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[gradients, image_board, colorbar], concurrency_id="trace_queue" ) for render_arg in [average_over_planes, begin_average_index, end_average_index, plane_index]: render_arg.change( render_gradients, inputs=[board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[image_board, colorbar], )