Spaces:
Runtime error
Runtime error
""" | |
Gradio interface for plotting policy. | |
""" | |
import chess | |
import gradio as gr | |
import uuid | |
import torch | |
from lczerolens.encodings import encode_move | |
from src import constants, global_variables, visualisation | |
def compute_features_fn( | |
features, | |
model_output, | |
file_id, | |
root_idx, | |
traj_idx, | |
start_fen, | |
move_seq, | |
feature_index | |
): | |
error_return = [features, model_output, file_id, root_idx, traj_idx] + [None] * 5 | |
root_board = None | |
traj_board = None | |
try: | |
board = chess.Board(start_fen) | |
except ValueError: | |
board = chess.Board() | |
gr.Warning("Invalid FEN, using starting position.") | |
return error_return | |
i = 0 | |
if root_idx == 0: | |
root_board = board.copy() | |
if traj_idx == 0: | |
traj_board = board.copy() | |
if move_seq: | |
try: | |
if move_seq.startswith("1."): | |
for move in move_seq.split(): | |
if root_board is not None and traj_board is not None: | |
break | |
if move.endswith("."): | |
continue | |
board.push_san(move) | |
i += 1 | |
if i == root_idx: | |
root_board = board.copy() | |
if i == traj_idx: | |
traj_board = board.copy() | |
else: | |
for move in move_seq.split(): | |
if root_board is not None and traj_board is not None: | |
break | |
board.push_uci(move) | |
i += 1 | |
if i == root_idx: | |
root_board = board.copy() | |
if i == traj_idx: | |
traj_board = board.copy() | |
except ValueError: | |
gr.Warning(f"Invalid move {move}.") | |
return error_return | |
if root_board is None or traj_board is None: | |
gr.Warning("Invalid move sequence.") | |
return error_return | |
model_output, pixel_acts, sae_output = global_variables.generator.generate( | |
root_board=root_board, | |
traj_board=traj_board | |
) | |
current_root_fen = root_board.fen() | |
current_traj_fen = traj_board.fen() | |
features = sae_output["features"] | |
x_hat = sae_output["x_hat"] | |
first_output = render_feature_index( | |
features, | |
model_output, | |
file_id, | |
root_idx, | |
traj_idx, | |
current_traj_fen, | |
feature_index | |
) | |
half_a_dim = constants.ACTIVATION_DIM // 2 | |
half_f_dim = constants.DICTIONARY_SIZE // 2 | |
pixel_f_avg = features.mean(dim=0) | |
pixel_f_active = (features > 0).float().mean(dim=0) | |
pixel_p_avg = features.mean(dim=1) | |
pixel_p_active = (features > 0).float().mean(dim=1) | |
if board.turn: | |
most_avg_pixels = pixel_p_avg.topk(5).indices.tolist() | |
most_active_pixels = pixel_p_active.topk(5).indices.tolist() | |
else: | |
most_avg_pixels = pixel_p_avg.view(8,8).flip(0).view(64).topk(5).indices.tolist() | |
most_active_pixels = pixel_p_active.view(8,8).flip(0).view(64).topk(5).indices.tolist() | |
info = f"Root WDL: {model_output['wdl'][0]}\n" | |
info += f"Traj WDL: {model_output['wdl'][1]}\n" | |
info += f"MSE loss: {torch.nn.functional.mse_loss(x_hat, pixel_acts, reduction='none').sum(dim=1).mean()}\n" | |
info += f"MSE loss (root): {torch.nn.functional.mse_loss(x_hat[:,:half_a_dim], pixel_acts[:,:half_a_dim], reduction='none').sum(dim=1).mean()}\n" | |
info += f"MSE loss (traj): {torch.nn.functional.mse_loss(x_hat[:,half_a_dim:], pixel_acts[:,half_a_dim:], reduction='none').sum(dim=1).mean()}\n" | |
info += f"L0 loss: {(features>0).sum(dim=1).float().mean()}\n" | |
info += f"L0 loss (c): {(features[:,:half_f_dim]>0).sum(dim=1).float().mean()}\n" | |
info += f"L0 loss (d): {(features[:,half_f_dim:]>0).sum(dim=1).float().mean()}\n" | |
info += f"Most active features (avg): {pixel_f_avg.topk(5).indices.tolist()}\n" | |
info += f"Most active features (active): {pixel_f_active.topk(5).indices.tolist()}\n" | |
info += f"Most active pixels (avg): {[chess.SQUARE_NAMES[p] for p in most_avg_pixels]}\n" | |
info += f"Most active pixels (active): {[chess.SQUARE_NAMES[p] for p in most_active_pixels]}" | |
return *first_output, current_root_fen, current_traj_fen, info | |
def render_feature_index( | |
features, | |
model_output, | |
file_id, | |
root_idx, | |
traj_idx, | |
traj_fen, | |
feature_index, | |
): | |
if file_id is None: | |
file_id = str(uuid.uuid4()) | |
board = chess.Board(traj_fen) | |
pixel_features = features[:,feature_index] | |
if board.turn: | |
heatmap = pixel_features.view(64) | |
else: | |
heatmap = pixel_features.view(8,8).flip(0).view(64) | |
best_legal_logit = None | |
best_legal_move = None | |
for move in board.legal_moves: | |
move_index = encode_move(move, (board.turn, not board.turn)) | |
logit = model_output["policy"][1,move_index].item() | |
if best_legal_logit is None: | |
best_legal_logit = logit | |
else: | |
best_legal_move = move | |
svg_board, fig = visualisation.render_heatmap( | |
board, | |
heatmap, | |
arrows=[(best_legal_move.from_square, best_legal_move.to_square)], | |
) | |
with open(f"{constants.FIGURES_FOLER}/{file_id}.svg", "w") as f: | |
f.write(svg_board) | |
return ( | |
features, | |
model_output, | |
file_id, | |
root_idx, | |
traj_idx, | |
f"{constants.FIGURES_FOLER}/{file_id}.svg", | |
fig | |
) | |
def make_features_fn(var, direction): | |
def _make_features_fn( | |
features, | |
model_output, | |
file_id, | |
root_idx, | |
traj_idx, | |
start_fen, | |
move_seq, | |
feature_index | |
): | |
move_count = len([mv for mv in move_seq.split() if not mv.endswith(".")]) | |
if var == "root": | |
root_idx += direction | |
if root_idx < 0: | |
gr.Warning("Already at first board.") | |
root_idx = 0 | |
elif root_idx >= move_count: | |
gr.Warning("Already at last board.") | |
root_idx = move_count - 1 | |
elif root_idx > traj_idx: | |
gr.Warning("Root should be before traj.") | |
root_idx = traj_idx | |
elif var == "traj": | |
traj_idx += direction | |
if traj_idx < 0: | |
gr.Warning("Already at first board.") | |
traj_idx = 0 | |
elif traj_idx >= move_count: | |
gr.Warning("Already at last board.") | |
traj_idx = move_count - 1 | |
elif traj_idx < root_idx: | |
gr.Warning("Traj should be after root.") | |
traj_idx = root_idx | |
return compute_features_fn( | |
features, | |
model_output, | |
file_id, | |
root_idx, | |
traj_idx, | |
start_fen, | |
move_seq, | |
feature_index | |
) | |
return _make_features_fn | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
with gr.Column(): | |
start_fen = gr.Textbox( | |
label="Starting FEN", | |
lines=1, | |
max_lines=1, | |
value=chess.STARTING_FEN, | |
) | |
move_seq = gr.Textbox( | |
label="Move sequence", | |
lines=1, | |
max_lines=20, | |
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), | |
) | |
with gr.Group(): | |
with gr.Row(): | |
previous_root_button = gr.Button("Previous root") | |
next_root_button = gr.Button("Next root") | |
with gr.Row(): | |
previous_traj_button = gr.Button("Previous traj") | |
next_traj_button = gr.Button("Next traj") | |
with gr.Group(): | |
with gr.Row(): | |
current_root_fen = gr.Textbox( | |
label="Root FEN", | |
lines=1, | |
max_lines=1, | |
interactive=False | |
) | |
with gr.Row(): | |
current_traj_fen = gr.Textbox( | |
label="Traj FEN", | |
lines=1, | |
max_lines=1, | |
interactive=False | |
) | |
with gr.Row(): | |
feature_index = gr.Slider( | |
label="Feature index", | |
minimum=0, | |
maximum=constants.DICTIONARY_SIZE-1, | |
step=1, | |
value=0, | |
) | |
with gr.Group(): | |
with gr.Row(): | |
info = gr.Textbox(label="Info", lines=1, max_lines=20, value="") | |
with gr.Row(): | |
colorbar = gr.Plot(label="Colorbar") | |
with gr.Column(): | |
board_image = gr.Image(label="Board") | |
features = gr.State(None) | |
model_output = gr.State(None) | |
file_id = gr.State(None) | |
root_idx = gr.State(0) | |
traj_idx = gr.State(0) | |
state = [features, model_output, file_id, root_idx, traj_idx] | |
base_inputs = [start_fen, move_seq, feature_index] | |
base_outputs = [board_image, colorbar, current_root_fen, current_traj_fen, info] | |
previous_root_button.click( | |
make_features_fn(var="root", direction=-1), | |
inputs=state + base_inputs, | |
outputs=state + base_outputs, | |
) | |
next_root_button.click( | |
make_features_fn(var="root", direction=1), | |
inputs=state + base_inputs, | |
outputs=state + base_outputs, | |
) | |
previous_traj_button.click( | |
make_features_fn(var="traj", direction=-1), | |
inputs=state + base_inputs, | |
outputs=state + base_outputs, | |
) | |
next_traj_button.click( | |
make_features_fn(var="traj", direction=1), | |
inputs=state + base_inputs, | |
outputs=state + base_outputs, | |
) | |
feature_index.change( | |
render_feature_index, | |
inputs=state + [current_traj_fen, feature_index], | |
outputs=state + [board_image, colorbar], | |
) | |