Spaces:
Runtime error
Runtime error
File size: 2,546 Bytes
3b6ef01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
"""
Gradio interface for plotting policy.
"""
import chess
import gradio as gr
import uuid
import torch
from lczerolens.encodings import encode_move
from src import constants, global_variables, visualisation
def render_feature_index(
file_id,
feature_index
):
if file_id is None:
file_id = str(uuid.uuid4())
opt_features = global_variables.f_ds["opt_features"]
f_acts = opt_features[:, feature_index]
indices = f_acts.topk(16).indices
board_images = []
colorbars = []
for topi, idx in enumerate(indices):
s = global_variables.f_ds[idx.item()]
pixel_index = global_variables.f_ds["pixel_index"][idx]
features = []
for i in range(64):
current_index = idx + i - pixel_index
features.append(opt_features[current_index.item(), feature_index])
features = torch.stack(features)
fen = s["opt_fen"]
current_depth = s["current_depth"]
uci_move = s["moves_opt"][current_depth + 6]
move = chess.Move.from_uci(uci_move)
board = chess.Board(fen)
if board.turn:
heatmap = features.view(64)
else:
heatmap = features.view(8, 8).flip(0).view(64)
svg_board, fig = visualisation.render_heatmap(
board,
heatmap,
arrows=[(move.from_square, move.to_square)],
)
with open(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg", "w") as f:
f.write(svg_board)
board_images.append(f"{constants.FIGURES_FOLER}/{file_id}_{topi}.svg")
colorbars.append(fig)
return file_id, *board_images, *colorbars
with gr.Blocks() as interface:
with gr.Row():
feature_index = gr.Slider(
label="Feature index",
minimum=0,
maximum=constants.DICTIONARY_SIZE-1,
step=1,
value=0,
)
board_images = []
colorbars = []
for i in range(4):
with gr.Row():
for j in range(4):
with gr.Column():
with gr.Group():
idx = 4*i + j
with gr.Row():
board_images.append(gr.Image(label=f"Board {idx}"))
with gr.Row():
colorbars.append(gr.Plot(label=f"Colorbar {idx}"))
file_id = gr.State(None)
feature_index.change(
render_feature_index,
inputs=[file_id, feature_index],
outputs=[file_id, *board_images, *colorbars],
) |