""" Gradio interface for plotting attention. """ import chess import chess.pgn import io import gradio as gr import os import torch from lczerolens import LczeroBoard, LczeroModel, Lens from demo import constants from demo.utils import get_info def get_model(model_name: str): return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name)) def get_activations(model: LczeroModel, board: LczeroBoard): lens = Lens.from_name("activation", "block\d/conv2/relu") with torch.no_grad(): results = lens.analyse(model, board) return [results[f"block{i}/conv2/relu_output"][0] for i in range(len(results))] def get_board(game_pgn:str, board_fen:str): if game_pgn: try: board = LczeroBoard() pgn = io.StringIO(game_pgn) game = chess.pgn.read_game(pgn) for move in game.mainline_moves(): board.push(move) except Exception as e: print(e) gr.Warning("Error parsing PGN, using starting position.") board = LczeroBoard() else: try: board = LczeroBoard(board_fen) except Exception as e: print(e) gr.Warning("Invalid FEN, using starting position.") board = LczeroBoard() return board def render_activations(board: LczeroBoard, activations, layer_index:int, channel_index:int): if layer_index >= len(activations): safe_layer_index = len(activations) - 1 gr.Warning(f"Layer index {layer_index} out of range, using last layer ({safe_layer_index}).") else: safe_layer_index = layer_index if channel_index >= activations[safe_layer_index].shape[0]: safe_channel_index = activations[safe_layer_index].shape[0] - 1 gr.Warning(f"Channel index {channel_index} out of range, using last channel ({safe_channel_index}).") else: safe_channel_index = channel_index heatmap = activations[safe_layer_index][safe_channel_index].view(64) board.render_heatmap( heatmap, save_to=f"{constants.FIGURE_DIRECTORY}/activations.svg", ) return f"{constants.FIGURE_DIRECTORY}/activations_board.svg", f"{constants.FIGURE_DIRECTORY}/activations_colorbar.svg" def initial_load(model_name: str, board_fen: str, game_pgn: str, layer_index: int, channel_index: int): model = get_model(model_name) board = get_board(game_pgn, board_fen) activations = get_activations(model, board) info = get_info(model, board) plots = render_activations(board, activations, layer_index, channel_index) return model, board, activations, info, *plots def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, layer_index: int, channel_index: int): board = get_board(game_pgn, board_fen) activations = get_activations(model, board) info = get_info(model, board) plots = render_activations(board, activations, layer_index, channel_index) return board, activations, info, *plots def on_model_change(model_name: str, board: LczeroBoard, layer_index: int, channel_index: int): model = get_model(model_name) activations = get_activations(model, board) info = get_info(model, board) plots = render_activations(board, activations, layer_index, channel_index) return model, activations, info, *plots with gr.Blocks() as interface: with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown( "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." ) game_pgn = gr.Textbox( label="Game PGN", lines=1, value="", ) board_fen = gr.Textbox( label="Board FEN", lines=1, max_lines=1, value=chess.STARTING_FEN, ) model_name = gr.Dropdown( label="Model", choices=constants.ONNX_MODEL_NAMES, ) with gr.Group(): info = gr.Textbox(label="Info", lines=1, value="") with gr.Group(): layer_index = gr.Slider( label="Layer index", minimum=0, maximum=19, step=1, value=0, ) channel_index = gr.Slider( label="Channel index", minimum=0, maximum=200, step=1, value=0, ) with gr.Column(): image_board = gr.Image(label="Board", interactive=False) colorbar = gr.Image(label="Colorbar", interactive=False) model = gr.State(value=None) board = gr.State(value=None) activations = gr.State(value=None) interface.load( initial_load, inputs=[model_name, game_pgn, board_fen, layer_index, channel_index], outputs=[model, board, activations, info, image_board, colorbar], concurrency_limit=1, concurrency_id="trace_queue" ) game_pgn.submit( on_board_change, inputs=[model, game_pgn, board_fen, layer_index, channel_index], outputs=[board, activations, info, image_board, colorbar], concurrency_id="trace_queue" ) board_fen.submit( on_board_change, inputs=[model, game_pgn, board_fen, layer_index, channel_index], outputs=[board, activations, info, image_board, colorbar], concurrency_id="trace_queue" ) model_name.change( on_model_change, inputs=[model_name, board, layer_index, channel_index], outputs=[model, activations, info, image_board, colorbar], concurrency_id="trace_queue" ) layer_index.change( render_activations, inputs=[board, activations, layer_index, channel_index], outputs=[image_board, colorbar], ) channel_index.change( render_activations, inputs=[board, activations, layer_index, channel_index], outputs=[image_board, colorbar], )