demo / src /interfaces /act_max_interface.py
Xmaster6y's picture
better demo
3b6ef01 unverified
raw
history blame
2.55 kB
"""
Gradio interface for plotting policy.
"""
import chess
import gradio as gr
import uuid
import torch
from lczerolens.encodings import encode_move
from src import constants, global_variables, visualisation
def render_feature_index(
file_id,
feature_index
):
if file_id is None:
file_id = str(uuid.uuid4())
opt_features = global_variables.f_ds["opt_features"]
f_acts = opt_features[:, feature_index]
indices = f_acts.topk(16).indices
board_images = []
colorbars = []
for topi, idx in enumerate(indices):
s = global_variables.f_ds[idx.item()]
pixel_index = global_variables.f_ds["pixel_index"][idx]
features = []
for i in range(64):
current_index = idx + i - pixel_index
features.append(opt_features[current_index.item(), feature_index])
features = torch.stack(features)
fen = s["opt_fen"]
current_depth = s["current_depth"]
uci_move = s["moves_opt"][current_depth + 6]
move = chess.Move.from_uci(uci_move)
board = chess.Board(fen)
if board.turn:
heatmap = features.view(64)
else:
heatmap = features.view(8, 8).flip(0).view(64)
svg_board, fig = visualisation.render_heatmap(
board,
heatmap,
arrows=[(move.from_square, move.to_square)],
)
with open(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg", "w") as f:
f.write(svg_board)
board_images.append(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg")
colorbars.append(fig)
return file_id, *board_images, *colorbars
with gr.Blocks() as interface:
with gr.Row():
feature_index = gr.Slider(
label="Feature index",
minimum=0,
maximum=constants.DICTIONARY_SIZE-1,
step=1,
value=0,
)
board_images = []
colorbars = []
for i in range(4):
with gr.Row():
for j in range(4):
with gr.Column():
with gr.Group():
idx = 4*i + j
with gr.Row():
board_images.append(gr.Image(label=f"Board {idx}"))
with gr.Row():
colorbars.append(gr.Plot(label=f"Colorbar {idx}"))
file_id = gr.State(None)
feature_index.change(
render_feature_index,
inputs=[file_id, feature_index],
outputs=[file_id, *board_images, *colorbars],
)