""" Gradio interface for plotting policy. """ import chess import gradio as gr import uuid import torch from lczerolens.encodings import encode_move from src import constants, global_variables, visualisation def compute_features_fn( features, model_output, file_id, root_fen, traj_fen, feature_index ): model_output, pixel_acts, sae_output = global_variables.generator.generate( root_fen=root_fen, traj_fen=traj_fen ) features = sae_output["features"] x_hat = sae_output["x_hat"] first_output = render_feature_index( features, model_output, file_id, traj_fen, feature_index ) half_a_dim = constants.ACTIVATION_DIM // 2 half_f_dim = constants.DICTIONARY_SIZE // 2 pixel_f_avg = features.mean(dim=0) pixel_f_active = (features > 0).float().mean(dim=0) pixel_p_avg = features.mean(dim=1) pixel_p_active = (features > 0).float().mean(dim=1) board = chess.Board(traj_fen) if board.turn: most_avg_pixels = pixel_p_avg.topk(5).indices.tolist() most_active_pixels = pixel_p_active.topk(5).indices.tolist() else: most_avg_pixels = pixel_p_avg.view(8,8).flip(0).view(64).topk(5).indices.tolist() most_active_pixels = pixel_p_active.view(8,8).flip(0).view(64).topk(5).indices.tolist() info = f"Root WDL: {model_output['wdl'][0]}\n" info += f"Traj WDL: {model_output['wdl'][1]}\n" info += f"MSE loss: {torch.nn.functional.mse_loss(x_hat, pixel_acts, reduction='none').sum(dim=1).mean()}\n" info += f"MSE loss (root): {torch.nn.functional.mse_loss(x_hat[:,:half_a_dim], pixel_acts[:,:half_a_dim], reduction='none').sum(dim=1).mean()}\n" info += f"MSE loss (traj): {torch.nn.functional.mse_loss(x_hat[:,half_a_dim:], pixel_acts[:,half_a_dim:], reduction='none').sum(dim=1).mean()}\n" info += f"L0 loss: {(features>0).sum(dim=1).float().mean()}\n" info += f"L0 loss (c): {(features[:,:half_f_dim]>0).sum(dim=1).float().mean()}\n" info += f"L0 loss (d): {(features[:,half_f_dim:]>0).sum(dim=1).float().mean()}\n" info += f"Most active features (avg): {pixel_f_avg.topk(5).indices.tolist()}\n" info += f"Most active features (active): {pixel_f_active.topk(5).indices.tolist()}\n" info += f"Most active pixels (avg): {[chess.SQUARE_NAMES[p] for p in most_avg_pixels]}\n" info += f"Most active pixels (active): {[chess.SQUARE_NAMES[p] for p in most_active_pixels]}" return *first_output, info def render_feature_index( features, model_output, file_id, traj_fen, feature_index ): if file_id is None: file_id = str(uuid.uuid4()) board = chess.Board(traj_fen) pixel_features = features[:,feature_index] if board.turn: heatmap = pixel_features.view(64) else: heatmap = pixel_features.view(8,8).flip(0).view(64) best_legal_logit = None best_legal_move = None for move in board.legal_moves: move_index = encode_move(move, (board.turn, not board.turn)) logit = model_output["policy"][1,move_index].item() if best_legal_logit is None: best_legal_logit = logit else: best_legal_move = move svg_board, fig = visualisation.render_heatmap( board, heatmap, arrows=[(best_legal_move.from_square, best_legal_move.to_square)], ) with open(f"{constants.FIGURES_FOLER}/{file_id}.svg", "w") as f: f.write(svg_board) return ( features, model_output, file_id, f"{constants.FIGURES_FOLER}/{file_id}.svg", fig ) with gr.Blocks() as interface: with gr.Row(): with gr.Column(): root_fen = gr.Textbox( label="Root FEN", lines=1, max_lines=1, value=chess.STARTING_FEN, ) traj_fen = gr.Textbox( label="Trajectory FEN", lines=1, max_lines=1, value="rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1", ) compute_features = gr.Button("Compute features") with gr.Group(): with gr.Row(): feature_index = gr.Slider( label="Feature index", minimum=0, maximum=constants.DICTIONARY_SIZE-1, step=1, value=0, ) with gr.Group(): with gr.Row(): info = gr.Textbox(label="Info", lines=1, max_lines=20, value="") with gr.Row(): colorbar = gr.Plot(label="Colorbar") with gr.Column(): board_image = gr.Image(label="Board") features = gr.State(None) model_output = gr.State(None) file_id = gr.State(None) compute_features.click( compute_features_fn, inputs=[features, model_output, file_id, root_fen, traj_fen, feature_index], outputs=[features, model_output, file_id, board_image, colorbar, info], ) feature_index.change( render_feature_index, inputs=[features, model_output, file_id, traj_fen, feature_index], outputs=[features, model_output, file_id, board_image, colorbar], )