""" Gradio interface for visualizing the policy of a model. """ import chess import chess.svg import gradio as gr import torch from lczero.backends import Backend, GameState, Weights from demo import constants, utils, visualisation from lczerolens import move_encodings from lczerolens.model import lczero as lczero_utils from lczerolens.xai import PolicyLens from lczerolens.board import LczeroBoard def list_models(): """ List the models in the model directory. """ models_info = utils.get_models_info(onnx=False) return sorted([[model_info[0]] for model_info in models_info]) def on_select_model_df( evt: gr.SelectData, ): """ When a model is selected, update the statement. """ return evt.value def make_policy_plot( board_fen, action_seq, view, model_name, depth, use_softmax, aggregate_topk, render_bestk, only_legal, ): if model_name == "": gr.Warning( "Please select a model.", ) return ( None, None, "", ) try: board = LczeroBoard(board_fen) except ValueError: board = LczeroBoard() gr.Warning("Invalid FEN, using starting position.") if action_seq: try: for action in action_seq.split(): board.push_uci(action) except ValueError: gr.Warning("Invalid action sequence, using starting position.") board = LczeroBoard() lczero_weights = Weights(f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}") lczero_backend = Backend(lczero_weights) uci_moves = [move.uci() for move in board.move_stack] lczero_game = GameState(moves=uci_moves) policy, value = lczero_utils.prediction_from_backend( lczero_backend, lczero_game, softmax=use_softmax, only_legal=only_legal, illegal_value=0, ) pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(policy, int(aggregate_topk)) if view == "from": if board.turn == chess.WHITE: heatmap = pickup_agg else: heatmap = pickup_agg.view(8, 8).flip(0).view(64) else: if board.turn == chess.WHITE: heatmap = dropoff_agg else: heatmap = dropoff_agg.view(8, 8).flip(0).view(64) us_them = (board.turn, not board.turn) if only_legal: legal_moves = [move_encodings.encode_move(move, us_them) for move in board.legal_moves] filtered_policy = torch.zeros(1858) filtered_policy[legal_moves] = policy[legal_moves] if (filtered_policy < 0).any(): gr.Warning("Some legal moves have negative policy.") topk_moves = torch.topk(filtered_policy, render_bestk) else: topk_moves = torch.topk(policy, render_bestk) arrows = [] for move_index in topk_moves.indices: move = move_encodings.decode_move(move_index, us_them) arrows.append((move.from_square, move.to_square)) svg_board, fig = visualisation.render_heatmap(board, heatmap, arrows=arrows) with open(f"{constants.FIGURE_DIRECTORY}/policy.svg", "w") as f: f.write(svg_board) raw_policy, _ = lczero_utils.prediction_from_backend( lczero_backend, lczero_game, softmax=False, only_legal=False, illegal_value=0, ) fig_dist = visualisation.render_policy_distribution( raw_policy, [move_encodings.encode_move(move, us_them) for move in board.legal_moves], ) return ( f"{constants.FIGURE_DIRECTORY}/policy.svg", fig, (f"Value: {value:.2f}"), fig_dist, ) with gr.Blocks() as interface: with gr.Row(): with gr.Column(scale=2): model_df = gr.Dataframe( headers=["Available models"], datatype=["str"], interactive=False, type="array", value=list_models, ) with gr.Column(scale=1): with gr.Row(): model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) model_df.select( on_select_model_df, None, model_name, ) with gr.Row(): with gr.Column(): board_fen = gr.Textbox( label="Board FEN", lines=1, max_lines=1, value=chess.STARTING_FEN, ) action_seq = gr.Textbox( label="Action sequence", lines=1, max_lines=1, value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), ) with gr.Group(): with gr.Row(): depth = gr.Radio(label="Depth", choices=[0], value=0) use_softmax = gr.Checkbox(label="Use softmax", value=True) with gr.Row(): aggregate_topk = gr.Slider( label="Aggregate top k", minimum=1, maximum=1858, step=1, value=1858, scale=3, ) view = gr.Radio( label="View", choices=["from", "to"], value="from", scale=1, ) with gr.Row(): render_bestk = gr.Slider( label="Render best k", minimum=1, maximum=5, step=1, value=5, scale=3, ) only_legal = gr.Checkbox(label="Only legal", value=True, scale=1) policy_button = gr.Button("Plot policy") colorbar = gr.Plot(label="Colorbar") game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="") with gr.Column(): image = gr.Image(label="Board") density_plot = gr.Plot(label="Density") policy_inputs = [ board_fen, action_seq, view, model_name, depth, use_softmax, aggregate_topk, render_bestk, only_legal, ] policy_outputs = [image, colorbar, game_info, density_plot] policy_button.click(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)