""" Gradio interface for plotting encodings. """ import chess import gradio as gr from demo import constants, visualisation from lczerolens import board_encodings from lczerolens.board import LczeroBoard def make_encoding_plot( board_fen, action_seq, plane_index, color_flip, ): try: board = LczeroBoard(board_fen) except ValueError: board = LczeroBoard() gr.Warning("Invalid FEN, using starting position.") if action_seq: try: for action in action_seq.split(): board.push_uci(action) except ValueError: gr.Warning("Invalid action sequence, using starting position.") board = LczeroBoard() board_tensor = board_encodings.board_to_input_tensor(board) heatmap = board_tensor[plane_index] if color_flip and board.turn == chess.BLACK: heatmap = heatmap.flip(0) svg_board, fig = visualisation.render_heatmap(board, heatmap.view(64), vmin=0.0, vmax=1.0) with open(f"{constants.FIGURE_DIRECTORY}/encoding.svg", "w") as f: f.write(svg_board) return f"{constants.FIGURE_DIRECTORY}/encoding.svg", fig with gr.Blocks() as interface: with gr.Row(): with gr.Column(): board_fen = gr.Textbox( label="Board starting FEN", lines=1, max_lines=1, value=chess.STARTING_FEN, ) action_seq = gr.Textbox( label="Action sequence", lines=1, max_lines=1, value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"), ) with gr.Group(): with gr.Row(): plane_index = gr.Slider( label="Plane index", minimum=0, maximum=111, step=1, value=0, scale=3, ) color_flip = gr.Checkbox(label="Color flip", value=True, scale=1) colorbar = gr.Plot(label="Colorbar") with gr.Column(): image = gr.Image(label="Board") policy_inputs = [ board_fen, action_seq, plane_index, color_flip, ] policy_outputs = [image, colorbar] board_fen.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs) action_seq.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs) plane_index.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs) color_flip.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs) interface.load(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)