lczerolens-demo / app /encoding_interface.py
Xmaster6y's picture
removed chess.Board
5e4365f
"""
Gradio interface for plotting encodings.
"""
import chess
import gradio as gr
from demo import constants, visualisation
from lczerolens import board_encodings
from lczerolens.board import LczeroBoard
def make_encoding_plot(
board_fen,
action_seq,
plane_index,
color_flip,
):
try:
board = LczeroBoard(board_fen)
except ValueError:
board = LczeroBoard()
gr.Warning("Invalid FEN, using starting position.")
if action_seq:
try:
for action in action_seq.split():
board.push_uci(action)
except ValueError:
gr.Warning("Invalid action sequence, using starting position.")
board = LczeroBoard()
board_tensor = board_encodings.board_to_input_tensor(board)
heatmap = board_tensor[plane_index]
if color_flip and board.turn == chess.BLACK:
heatmap = heatmap.flip(0)
svg_board, fig = visualisation.render_heatmap(board, heatmap.view(64), vmin=0.0, vmax=1.0)
with open(f"{constants.FIGURE_DIRECTORY}/encoding.svg", "w") as f:
f.write(svg_board)
return f"{constants.FIGURE_DIRECTORY}/encoding.svg", fig
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
board_fen = gr.Textbox(
label="Board starting FEN",
lines=1,
max_lines=1,
value=chess.STARTING_FEN,
)
action_seq = gr.Textbox(
label="Action sequence",
lines=1,
max_lines=1,
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
)
with gr.Group():
with gr.Row():
plane_index = gr.Slider(
label="Plane index",
minimum=0,
maximum=111,
step=1,
value=0,
scale=3,
)
color_flip = gr.Checkbox(label="Color flip", value=True, scale=1)
colorbar = gr.Plot(label="Colorbar")
with gr.Column():
image = gr.Image(label="Board")
policy_inputs = [
board_fen,
action_seq,
plane_index,
color_flip,
]
policy_outputs = [image, colorbar]
board_fen.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
action_seq.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
plane_index.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
color_flip.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
interface.load(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)