Spaces:
Runtime error
Runtime error
""" | |
Gradio interface for plotting policy. | |
""" | |
import chess | |
import gradio as gr | |
import uuid | |
import torch | |
from lczerolens.encodings import encode_move | |
from src import constants, global_variables, visualisation | |
def render_feature_index( | |
file_id, | |
feature_index | |
): | |
if file_id is None: | |
file_id = str(uuid.uuid4()) | |
opt_features = global_variables.f_ds["opt_features"] | |
f_acts = opt_features[:, feature_index] | |
indices = f_acts.topk(16).indices | |
board_images = [] | |
colorbars = [] | |
for topi, idx in enumerate(indices): | |
s = global_variables.f_ds[idx.item()] | |
pixel_index = global_variables.f_ds["pixel_index"][idx] | |
features = [] | |
for i in range(64): | |
current_index = idx + i - pixel_index | |
features.append(opt_features[current_index.item(), feature_index]) | |
features = torch.stack(features) | |
fen = s["opt_fen"] | |
current_depth = s["current_depth"] | |
uci_move = s["moves_opt"][current_depth + 6] | |
move = chess.Move.from_uci(uci_move) | |
board = chess.Board(fen) | |
if board.turn: | |
heatmap = features.view(64) | |
else: | |
heatmap = features.view(8, 8).flip(0).view(64) | |
svg_board, fig = visualisation.render_heatmap( | |
board, | |
heatmap, | |
arrows=[(move.from_square, move.to_square)], | |
) | |
with open(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg", "w") as f: | |
f.write(svg_board) | |
board_images.append(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg") | |
colorbars.append(fig) | |
return file_id, *board_images, *colorbars | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
feature_index = gr.Slider( | |
label="Feature index", | |
minimum=0, | |
maximum=constants.DICTIONARY_SIZE-1, | |
step=1, | |
value=0, | |
) | |
board_images = [] | |
colorbars = [] | |
for i in range(4): | |
with gr.Row(): | |
for j in range(4): | |
with gr.Column(): | |
with gr.Group(): | |
idx = 4*i + j | |
with gr.Row(): | |
board_images.append(gr.Image(label=f"Board {idx}")) | |
with gr.Row(): | |
colorbars.append(gr.Plot(label=f"Colorbar {idx}")) | |
file_id = gr.State(None) | |
feature_index.change( | |
render_feature_index, | |
inputs=[file_id, feature_index], | |
outputs=[file_id, *board_images, *colorbars], | |
) |