""" Gradio interface for plotting attention. """ import copy import chess import gradio as gr from lczerolens.board import LczeroBoard from demo import constants, utils, visualisation def list_models(): """ List the models in the model directory. """ models_info = utils.get_models_info(leela=False) return sorted([[model_info[0]] for model_info in models_info]) def on_select_model_df( evt: gr.SelectData, ): """ When a model is selected, update the statement. """ return evt.value def compute_cache( board_fen, action_seq, model_name, attention_layer, attention_head, square, state_board_index, state_boards, state_cache, ): if model_name == "": gr.Warning("No model selected.") return None, None, None, state_boards, state_cache try: board = LczeroBoard(board_fen) except ValueError: board = LczeroBoard() gr.Warning("Invalid FEN, using starting position.") state_boards = [board.copy()] if action_seq: try: if action_seq.startswith("1."): for action in action_seq.split(): if action.endswith("."): continue board.push_san(action) state_boards.append(board.copy()) else: for action in action_seq.split(): board.push_uci(action) state_boards.append(board.copy()) except ValueError: gr.Warning(f"Invalid action {action} stopping before it.") try: wrapper, lens = utils.get_wrapper_lens_from_state( model_name, "activation", lens_name="attention", module_exp=r"encoder\d+/mha/QK/softmax", ) except ValueError: gr.Warning("Could not load model.") return None, None, None, state_boards, state_cache state_cache = [] for board in state_boards: attention_cache = copy.deepcopy(lens.analyse_board(board, wrapper)) state_cache.append(attention_cache) return ( *make_plot( attention_layer, attention_head, square, state_board_index, state_boards, state_cache, ), state_boards, state_cache, ) def make_plot( attention_layer, attention_head, square, state_board_index, state_boards, state_cache, ): if state_cache == []: gr.Warning("No cache available.") return None, None, None board = state_boards[state_board_index] num_attention_layers = len(state_cache[state_board_index]) if attention_layer > num_attention_layers: gr.Warning( f"Attention layer {attention_layer} does not exist, " f"using layer {num_attention_layers} instead." ) attention_layer = num_attention_layers key = f"encoder{attention_layer-1}/mha/QK/softmax" try: attention_tensor = state_cache[state_board_index][key] except KeyError: gr.Warning(f"Combination {key} does not exist.") return None, None, None if attention_head > attention_tensor.shape[1]: gr.Warning( f"Attention head {attention_head} does not exist, " f"using head {attention_tensor.shape[1]+1} instead." ) attention_head = attention_tensor.shape[1] try: square_index = chess.SQUARE_NAMES.index(square) except ValueError: gr.Warning(f"Invalid square {square}, using a1 instead.") square_index = 0 square = "a1" if board.turn == chess.BLACK: square_index = chess.square_mirror(square_index) heatmap = attention_tensor[0, attention_head - 1, square_index] if board.turn == chess.BLACK: heatmap = heatmap.view(8, 8).flip(0).view(64) svg_board, fig = visualisation.render_heatmap(board, heatmap, square=square) with open(f"{constants.FIGURE_DIRECTORY}/attention.svg", "w") as f: f.write(svg_board) return f"{constants.FIGURE_DIRECTORY}/attention.svg", board.fen(), fig def previous_board( attention_layer, attention_head, square, state_board_index, state_boards, state_cache, ): state_board_index -= 1 if state_board_index < 0: gr.Warning("Already at first board.") state_board_index = 0 return ( *make_plot( attention_layer, attention_head, square, state_board_index, state_boards, state_cache, ), state_board_index, ) def next_board( attention_layer, attention_head, square, state_board_index, state_boards, state_cache, ): state_board_index += 1 if state_board_index >= len(state_boards): gr.Warning("Already at last board.") state_board_index = len(state_boards) - 1 return ( *make_plot( attention_layer, attention_head, square, state_board_index, state_boards, state_cache, ), state_board_index, ) with gr.Blocks() as interface: with gr.Row(): with gr.Column(scale=2): model_df = gr.Dataframe( headers=["Available models"], datatype=["str"], interactive=False, type="array", value=list_models, ) with gr.Column(scale=1): with gr.Row(): model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) model_df.select( on_select_model_df, None, model_name, ) with gr.Row(): with gr.Column(): board_fen = gr.Textbox( label="Board starting FEN", lines=1, max_lines=1, value=chess.STARTING_FEN, ) action_seq = gr.Textbox( label="Action sequence", lines=1, max_lines=1, value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), ) compute_cache_button = gr.Button("Compute cache") with gr.Group(): with gr.Row(): attention_layer = gr.Slider( label="Attention layer", minimum=1, maximum=24, step=1, value=1, ) attention_head = gr.Slider( label="Attention head", minimum=1, maximum=24, step=1, value=1, ) with gr.Row(): square = gr.Textbox( label="Square", lines=1, max_lines=1, value="a1", scale=1, ) with gr.Row(): previous_board_button = gr.Button("Previous board") next_board_button = gr.Button("Next board") current_board_fen = gr.Textbox( label="Board FEN", lines=1, max_lines=1, ) colorbar = gr.Plot(label="Colorbar") with gr.Column(): image = gr.Image(label="Board") state_board_index = gr.State(0) state_boards = gr.State([]) state_cache = gr.State([]) base_inputs = [ attention_layer, attention_head, square, state_board_index, state_boards, state_cache, ] outputs = [image, current_board_fen, colorbar] compute_cache_button.click( compute_cache, inputs=[board_fen, action_seq, model_name] + base_inputs, outputs=outputs + [state_boards, state_cache], ) previous_board_button.click( previous_board, inputs=base_inputs, outputs=outputs + [state_board_index], ) next_board_button.click(next_board, inputs=base_inputs, outputs=outputs + [state_board_index]) attention_layer.change(make_plot, inputs=base_inputs, outputs=outputs) attention_head.change(make_plot, inputs=base_inputs, outputs=outputs) square.submit(make_plot, inputs=base_inputs, outputs=outputs)