""" Gradio interface for plotting attention. """ import chess import chess.pgn import io import gradio as gr import os from lczerolens import LczeroBoard, LczeroModel, Lens from demo import constants from demo.utils import get_info def get_model(model_name: str): return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name)) def get_gradients(model: LczeroModel, board: LczeroBoard, target: str): lens = Lens.from_name("gradient") def init_target(model): if target == "best_move": return getattr(model, "output/policy").output.max(dim=1).values else: wdl_index = {"win": 0, "draw": 1, "loss": 2}[target] return getattr(model, "output/wdl").output[:, wdl_index] results = lens.analyse(model, board, init_target=init_target) return results["input_grad"] def get_board(game_pgn:str, board_fen:str): if game_pgn: try: board = LczeroBoard() pgn = io.StringIO(game_pgn) game = chess.pgn.read_game(pgn) for move in game.mainline_moves(): board.push(move) except Exception as e: print(e) gr.Warning("Error parsing PGN, using starting position.") board = LczeroBoard() else: try: board = LczeroBoard(board_fen) except Exception as e: print(e) gr.Warning("Invalid FEN, using starting position.") board = LczeroBoard() return board def render_gradients(board: LczeroBoard, gradients, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index:int): if average_over_planes: heatmap = gradients[0, begin_average_index:end_average_index].mean(dim=0).view(64) else: heatmap = gradients[0, plane_index].view(64) board.render_heatmap( heatmap, save_to=f"{constants.FIGURE_DIRECTORY}/gradients.svg", ) return f"{constants.FIGURE_DIRECTORY}/gradients_board.svg", f"{constants.FIGURE_DIRECTORY}/gradients_colorbar.svg" def initial_load(model_name: str, board_fen: str, game_pgn: str, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): model = get_model(model_name) board = get_board(game_pgn, board_fen) gradients = get_gradients(model, board, target) info = get_info(model, board) plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) return model, board, gradients, info, *plots def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): board = get_board(game_pgn, board_fen) gradients = get_gradients(model, board, target) info = get_info(model, board) plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) return board, gradients, info, *plots def on_model_change(model_name: str, board: LczeroBoard, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): model = get_model(model_name) gradients = get_gradients(model, board, target) info = get_info(model, board) plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) return model, gradients, info, *plots def on_target_change(model: LczeroModel, board: LczeroBoard, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): gradients = get_gradients(model, board, target) plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) return gradients, *plots with gr.Blocks() as interface: with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown( "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." ) game_pgn = gr.Textbox( label="Game PGN", lines=1, value="", ) board_fen = gr.Textbox( label="Board FEN", lines=1, max_lines=1, value=chess.STARTING_FEN, ) model_name = gr.Dropdown( label="Model", choices=constants.ONNX_MODEL_NAMES, ) with gr.Group(): info = gr.Textbox(label="Info", lines=1, value="") with gr.Group(): target = gr.Radio( ["win", "draw", "loss", "best_move"], label="Target", value="win", ) average_over_planes = gr.Checkbox(label="Average over Planes", value=False) with gr.Accordion("Average over planes", open=False): begin_average_index = gr.Slider( label="Begin average index", minimum=0, maximum=111, step=1, value=0, ) end_average_index = gr.Slider( label="End average index", minimum=0, maximum=111, step=1, value=111, ) plane_index = gr.Slider( label="Plane index", minimum=0, maximum=111, step=1, value=0, ) with gr.Column(): image_board = gr.Image(label="Board", interactive=False) colorbar = gr.Image(label="Colorbar", interactive=False) model = gr.State(value=None) board = gr.State(value=None) gradients = gr.State(value=None) interface.load( initial_load, inputs=[model_name, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[model, board, gradients, info, image_board, colorbar], concurrency_id="trace_queue" ) game_pgn.submit( on_board_change, inputs=[model, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[board, gradients, info, image_board, colorbar], concurrency_id="trace_queue" ) board_fen.submit( on_board_change, inputs=[model, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[board, gradients, info, image_board, colorbar], concurrency_id="trace_queue" ) model_name.change( on_model_change, inputs=[model_name, board, target, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[model, gradients, info, image_board, colorbar], concurrency_id="trace_queue" ) target.change( on_target_change, inputs=[model, board, target, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[gradients, image_board, colorbar], concurrency_id="trace_queue" ) for render_arg in [average_over_planes, begin_average_index, end_average_index, plane_index]: render_arg.change( render_gradients, inputs=[board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index], outputs=[image_board, colorbar], )