""" Gradio interface for plotting policy. """ import chess import gradio as gr import uuid import torch from lczerolens.encodings import encode_move from src import constants, global_variables, visualisation def compute_features_fn( features, model_output, file_id, root_idx, traj_idx, start_fen, move_seq, feature_index ): error_return = [features, model_output, file_id, root_idx, traj_idx] + [None] * 5 root_board = None traj_board = None try: board = chess.Board(start_fen) except ValueError: board = chess.Board() gr.Warning("Invalid FEN, using starting position.") return error_return i = 0 if root_idx == 0: root_board = board.copy() if traj_idx == 0: traj_board = board.copy() if move_seq: try: if move_seq.startswith("1."): for move in move_seq.split(): if root_board is not None and traj_board is not None: break if move.endswith("."): continue board.push_san(move) i += 1 if i == root_idx: root_board = board.copy() if i == traj_idx: traj_board = board.copy() else: for move in move_seq.split(): if root_board is not None and traj_board is not None: break board.push_uci(move) i += 1 if i == root_idx: root_board = board.copy() if i == traj_idx: traj_board = board.copy() except ValueError: gr.Warning(f"Invalid move {move}.") return error_return if root_board is None or traj_board is None: gr.Warning("Invalid move sequence.") return error_return model_output, pixel_acts, sae_output = global_variables.generator.generate( root_board=root_board, traj_board=traj_board ) current_root_fen = root_board.fen() current_traj_fen = traj_board.fen() features = sae_output["features"] x_hat = sae_output["x_hat"] first_output = render_feature_index( features, model_output, file_id, root_idx, traj_idx, current_traj_fen, feature_index ) half_a_dim = constants.ACTIVATION_DIM // 2 half_f_dim = constants.DICTIONARY_SIZE // 2 pixel_f_avg = features.mean(dim=0) pixel_f_active = (features > 0).float().mean(dim=0) pixel_p_avg = features.mean(dim=1) pixel_p_active = (features > 0).float().mean(dim=1) if board.turn: most_avg_pixels = pixel_p_avg.topk(5).indices.tolist() most_active_pixels = pixel_p_active.topk(5).indices.tolist() else: most_avg_pixels = pixel_p_avg.view(8,8).flip(0).view(64).topk(5).indices.tolist() most_active_pixels = pixel_p_active.view(8,8).flip(0).view(64).topk(5).indices.tolist() info = f"Root WDL: {model_output['wdl'][0]}\n" info += f"Traj WDL: {model_output['wdl'][1]}\n" info += f"MSE loss: {torch.nn.functional.mse_loss(x_hat, pixel_acts, reduction='none').sum(dim=1).mean()}\n" info += f"MSE loss (root): {torch.nn.functional.mse_loss(x_hat[:,:half_a_dim], pixel_acts[:,:half_a_dim], reduction='none').sum(dim=1).mean()}\n" info += f"MSE loss (traj): {torch.nn.functional.mse_loss(x_hat[:,half_a_dim:], pixel_acts[:,half_a_dim:], reduction='none').sum(dim=1).mean()}\n" info += f"L0 loss: {(features>0).sum(dim=1).float().mean()}\n" info += f"L0 loss (c): {(features[:,:half_f_dim]>0).sum(dim=1).float().mean()}\n" info += f"L0 loss (d): {(features[:,half_f_dim:]>0).sum(dim=1).float().mean()}\n" info += f"Most active features (avg): {pixel_f_avg.topk(5).indices.tolist()}\n" info += f"Most active features (active): {pixel_f_active.topk(5).indices.tolist()}\n" info += f"Most active pixels (avg): {[chess.SQUARE_NAMES[p] for p in most_avg_pixels]}\n" info += f"Most active pixels (active): {[chess.SQUARE_NAMES[p] for p in most_active_pixels]}" return *first_output, current_root_fen, current_traj_fen, info def render_feature_index( features, model_output, file_id, root_idx, traj_idx, traj_fen, feature_index, ): if file_id is None: file_id = str(uuid.uuid4()) board = chess.Board(traj_fen) pixel_features = features[:,feature_index] if board.turn: heatmap = pixel_features.view(64) else: heatmap = pixel_features.view(8,8).flip(0).view(64) best_legal_logit = None best_legal_move = None for move in board.legal_moves: move_index = encode_move(move, (board.turn, not board.turn)) logit = model_output["policy"][1,move_index].item() if best_legal_logit is None: best_legal_logit = logit else: best_legal_move = move svg_board, fig = visualisation.render_heatmap( board, heatmap, arrows=[(best_legal_move.from_square, best_legal_move.to_square)], ) with open(f"{constants.FIGURES_FOLER}/{file_id}.svg", "w") as f: f.write(svg_board) return ( features, model_output, file_id, root_idx, traj_idx, f"{constants.FIGURES_FOLER}/{file_id}.svg", fig ) def make_features_fn(var, direction): def _make_features_fn( features, model_output, file_id, root_idx, traj_idx, start_fen, move_seq, feature_index ): move_count = len([mv for mv in move_seq.split() if not mv.endswith(".")]) if var == "root": root_idx += direction if root_idx < 0: gr.Warning("Already at first board.") root_idx = 0 elif root_idx >= move_count: gr.Warning("Already at last board.") root_idx = move_count - 1 elif root_idx > traj_idx: gr.Warning("Root should be before traj.") root_idx = traj_idx elif var == "traj": traj_idx += direction if traj_idx < 0: gr.Warning("Already at first board.") traj_idx = 0 elif traj_idx >= move_count: gr.Warning("Already at last board.") traj_idx = move_count - 1 elif traj_idx < root_idx: gr.Warning("Traj should be after root.") traj_idx = root_idx return compute_features_fn( features, model_output, file_id, root_idx, traj_idx, start_fen, move_seq, feature_index ) return _make_features_fn with gr.Blocks() as interface: with gr.Row(): with gr.Column(): start_fen = gr.Textbox( label="Starting FEN", lines=1, max_lines=1, value=chess.STARTING_FEN, ) move_seq = gr.Textbox( label="Move sequence", lines=1, max_lines=20, value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), ) with gr.Group(): with gr.Row(): previous_root_button = gr.Button("Previous root") next_root_button = gr.Button("Next root") with gr.Row(): previous_traj_button = gr.Button("Previous traj") next_traj_button = gr.Button("Next traj") with gr.Group(): with gr.Row(): current_root_fen = gr.Textbox( label="Root FEN", lines=1, max_lines=1, interactive=False ) with gr.Row(): current_traj_fen = gr.Textbox( label="Traj FEN", lines=1, max_lines=1, interactive=False ) with gr.Row(): feature_index = gr.Slider( label="Feature index", minimum=0, maximum=constants.DICTIONARY_SIZE-1, step=1, value=0, ) with gr.Group(): with gr.Row(): info = gr.Textbox(label="Info", lines=1, max_lines=20, value="") with gr.Row(): colorbar = gr.Plot(label="Colorbar") with gr.Column(): board_image = gr.Image(label="Board") features = gr.State(None) model_output = gr.State(None) file_id = gr.State(None) root_idx = gr.State(0) traj_idx = gr.State(0) state = [features, model_output, file_id, root_idx, traj_idx] base_inputs = [start_fen, move_seq, feature_index] base_outputs = [board_image, colorbar, current_root_fen, current_traj_fen, info] previous_root_button.click( make_features_fn(var="root", direction=-1), inputs=state + base_inputs, outputs=state + base_outputs, ) next_root_button.click( make_features_fn(var="root", direction=1), inputs=state + base_inputs, outputs=state + base_outputs, ) previous_traj_button.click( make_features_fn(var="traj", direction=-1), inputs=state + base_inputs, outputs=state + base_outputs, ) next_traj_button.click( make_features_fn(var="traj", direction=1), inputs=state + base_inputs, outputs=state + base_outputs, ) feature_index.change( render_feature_index, inputs=state + [current_traj_fen, feature_index], outputs=state + [board_image, colorbar], )