demo / src /interfaces /fen_feature_interface.py
Xmaster6y's picture
better demo
3b6ef01 unverified
"""
Gradio interface for plotting policy.
"""
import chess
import gradio as gr
import uuid
import torch
from lczerolens.encodings import encode_move
from src import constants, global_variables, visualisation
def compute_features_fn(
features,
model_output,
file_id,
root_fen,
traj_fen,
feature_index
):
model_output, pixel_acts, sae_output = global_variables.generator.generate(
root_fen=root_fen,
traj_fen=traj_fen
)
features = sae_output["features"]
x_hat = sae_output["x_hat"]
first_output = render_feature_index(
features,
model_output,
file_id,
traj_fen,
feature_index
)
half_a_dim = constants.ACTIVATION_DIM // 2
half_f_dim = constants.DICTIONARY_SIZE // 2
pixel_f_avg = features.mean(dim=0)
pixel_f_active = (features > 0).float().mean(dim=0)
pixel_p_avg = features.mean(dim=1)
pixel_p_active = (features > 0).float().mean(dim=1)
board = chess.Board(traj_fen)
if board.turn:
most_avg_pixels = pixel_p_avg.topk(5).indices.tolist()
most_active_pixels = pixel_p_active.topk(5).indices.tolist()
else:
most_avg_pixels = pixel_p_avg.view(8,8).flip(0).view(64).topk(5).indices.tolist()
most_active_pixels = pixel_p_active.view(8,8).flip(0).view(64).topk(5).indices.tolist()
info = f"Root WDL: {model_output['wdl'][0]}\n"
info += f"Traj WDL: {model_output['wdl'][1]}\n"
info += f"MSE loss: {torch.nn.functional.mse_loss(x_hat, pixel_acts, reduction='none').sum(dim=1).mean()}\n"
info += f"MSE loss (root): {torch.nn.functional.mse_loss(x_hat[:,:half_a_dim], pixel_acts[:,:half_a_dim], reduction='none').sum(dim=1).mean()}\n"
info += f"MSE loss (traj): {torch.nn.functional.mse_loss(x_hat[:,half_a_dim:], pixel_acts[:,half_a_dim:], reduction='none').sum(dim=1).mean()}\n"
info += f"L0 loss: {(features>0).sum(dim=1).float().mean()}\n"
info += f"L0 loss (c): {(features[:,:half_f_dim]>0).sum(dim=1).float().mean()}\n"
info += f"L0 loss (d): {(features[:,half_f_dim:]>0).sum(dim=1).float().mean()}\n"
info += f"Most active features (avg): {pixel_f_avg.topk(5).indices.tolist()}\n"
info += f"Most active features (active): {pixel_f_active.topk(5).indices.tolist()}\n"
info += f"Most active pixels (avg): {[chess.SQUARE_NAMES[p] for p in most_avg_pixels]}\n"
info += f"Most active pixels (active): {[chess.SQUARE_NAMES[p] for p in most_active_pixels]}"
return *first_output, info
def render_feature_index(
features,
model_output,
file_id,
traj_fen,
feature_index
):
if file_id is None:
file_id = str(uuid.uuid4())
board = chess.Board(traj_fen)
pixel_features = features[:,feature_index]
if board.turn:
heatmap = pixel_features.view(64)
else:
heatmap = pixel_features.view(8,8).flip(0).view(64)
best_legal_logit = None
best_legal_move = None
for move in board.legal_moves:
move_index = encode_move(move, (board.turn, not board.turn))
logit = model_output["policy"][1,move_index].item()
if best_legal_logit is None:
best_legal_logit = logit
else:
best_legal_move = move
svg_board, fig = visualisation.render_heatmap(
board,
heatmap,
arrows=[(best_legal_move.from_square, best_legal_move.to_square)],
)
with open(f"{constants.FIGURES_FOLER}/{file_id}.svg", "w") as f:
f.write(svg_board)
return (
features,
model_output,
file_id,
f"{constants.FIGURES_FOLER}/{file_id}.svg",
fig
)
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
root_fen = gr.Textbox(
label="Root FEN",
lines=1,
max_lines=1,
value=chess.STARTING_FEN,
)
traj_fen = gr.Textbox(
label="Trajectory FEN",
lines=1,
max_lines=1,
value="rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1",
)
compute_features = gr.Button("Compute features")
with gr.Group():
with gr.Row():
feature_index = gr.Slider(
label="Feature index",
minimum=0,
maximum=constants.DICTIONARY_SIZE-1,
step=1,
value=0,
)
with gr.Group():
with gr.Row():
info = gr.Textbox(label="Info", lines=1, max_lines=20, value="")
with gr.Row():
colorbar = gr.Plot(label="Colorbar")
with gr.Column():
board_image = gr.Image(label="Board")
features = gr.State(None)
model_output = gr.State(None)
file_id = gr.State(None)
compute_features.click(
compute_features_fn,
inputs=[features, model_output, file_id, root_fen, traj_fen, feature_index],
outputs=[features, model_output, file_id, board_image, colorbar, info],
)
feature_index.change(
render_feature_index,
inputs=[features, model_output, file_id, traj_fen, feature_index],
outputs=[features, model_output, file_id, board_image, colorbar],
)