""" Gradio interface for plotting policy. """ import copy import chess import gradio as gr from demo import constants, utils, visualisation from lczerolens.board import LczeroBoard cache = None boards = None board_index = 0 def list_models(): """ List the models in the model directory. """ models_info = utils.get_models_info(leela=False) return sorted([[model_info[0]] for model_info in models_info]) def on_select_model_df( evt: gr.SelectData, ): """ When a model is selected, update the statement. """ return evt.value def compute_cache( board_fen, action_seq, model_name, plane_index, history_index, ): global cache global boards if model_name == "": gr.Warning("No model selected.") return None, None, None, None, None try: board = LczeroBoard(board_fen) except ValueError: board = LczeroBoard() gr.Warning("Invalid FEN, using starting position.") boards = [board.copy()] if action_seq: try: if action_seq.startswith("1."): for action in action_seq.split(): if action.endswith("."): continue board.push_san(action) boards.append(board.copy()) else: for action in action_seq.split(): board.push_uci(action) boards.append(board.copy()) except ValueError: gr.Warning(f"Invalid action {action} stopping before it.") wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "crp") cache = [] for board in boards: relevance = lens.compute_heatmap(board, wrapper) cache.append(copy.deepcopy(relevance)) return ( *make_plot( plane_index, ), *make_history_plot( history_index, ), ) def make_plot( plane_index, ): global cache global boards global board_index if cache is None: gr.Warning("Cache not computed!") return None, None, None board = boards[board_index] relevance_tensor = cache[board_index] a_max = relevance_tensor.abs().max() if a_max != 0: relevance_tensor = relevance_tensor / a_max vmin = -1 vmax = 1 heatmap = relevance_tensor[plane_index - 1].view(64) if board.turn == chess.BLACK: heatmap = heatmap.view(8, 8).flip(0).view(64) svg_board, fig = visualisation.render_heatmap(board, heatmap, vmin=vmin, vmax=vmax) with open(f"{constants.FIGURE_DIRECTORY}/lrp.svg", "w") as f: f.write(svg_board) return f"{constants.FIGURE_DIRECTORY}/lrp.svg", board.fen(), fig def make_history_plot( history_index, ): global cache global boards global board_index if cache is None: gr.Warning("Cache not computed!") return None, None board = boards[board_index] relevance_tensor = cache[board_index] a_max = relevance_tensor.abs().max() if a_max != 0: relevance_tensor = relevance_tensor / a_max vmin = -1 vmax = 1 heatmap = relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1].sum(dim=0).view(64) if board.turn == chess.BLACK: heatmap = heatmap.view(8, 8).flip(0).view(64) if board_index - history_index + 1 < 0: history_board = LczeroBoard(fen=None) else: history_board = boards[board_index - history_index + 1] svg_board, fig = visualisation.render_heatmap(history_board, heatmap, vmin=vmin, vmax=vmax) with open(f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", "w") as f: f.write(svg_board) return f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", fig def previous_board( plane_index, history_index, ): global board_index board_index -= 1 if board_index < 0: gr.Warning("Already at first board.") board_index = 0 return ( *make_plot( plane_index, ), *make_history_plot( history_index, ), ) def next_board( plane_index, history_index, ): global board_index board_index += 1 if board_index >= len(boards): gr.Warning("Already at last board.") board_index = len(boards) - 1 return ( *make_plot( plane_index, ), *make_history_plot( history_index, ), ) with gr.Blocks() as interface: with gr.Row(): with gr.Column(scale=2): model_df = gr.Dataframe( headers=["Available models"], datatype=["str"], interactive=False, type="array", value=list_models, ) with gr.Column(scale=1): with gr.Row(): model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) model_df.select( on_select_model_df, None, model_name, ) with gr.Row(): with gr.Column(): board_fen = gr.Textbox( label="Board starting FEN", lines=1, max_lines=1, value=chess.STARTING_FEN, ) action_seq = gr.Textbox( label="Action sequence", lines=1, max_lines=1, value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), ) compute_cache_button = gr.Button("Compute heatmaps") with gr.Group(): with gr.Row(): plane_index = gr.Slider( label="Plane index", minimum=1, maximum=112, step=1, value=1, ) with gr.Row(): previous_board_button = gr.Button("Previous board") next_board_button = gr.Button("Next board") current_board_fen = gr.Textbox( label="Board FEN", lines=1, max_lines=1, ) colorbar = gr.Plot(label="Colorbar") with gr.Column(): image = gr.Image(label="Board") with gr.Row(): with gr.Column(): with gr.Group(): with gr.Row(): histroy_index = gr.Slider( label="History index", minimum=1, maximum=8, step=1, value=1, ) history_colorbar = gr.Plot(label="Colorbar") with gr.Column(): history_image = gr.Image(label="Board") base_inputs = [ plane_index, histroy_index, ] outputs = [ image, current_board_fen, colorbar, history_image, history_colorbar, ] compute_cache_button.click( compute_cache, inputs=[board_fen, action_seq, model_name] + base_inputs, outputs=outputs, ) previous_board_button.click(previous_board, inputs=base_inputs, outputs=outputs) next_board_button.click(next_board, inputs=base_inputs, outputs=outputs) plane_index.change( make_plot, inputs=plane_index, outputs=[image, current_board_fen, colorbar], ) histroy_index.change( make_history_plot, inputs=histroy_index, outputs=[history_image, history_colorbar], )