Xmaster6y's picture
input encoding
273af2d unverified
"""
Gradio interface for plotting attention.
"""
import chess
import chess.pgn
import io
import gradio as gr
from lczerolens.board import LczeroBoard, InputEncoding
from ..constants import FIGURE_DIRECTORY
def make_render(game_pgn:str, board_fen:str, input_encoding:InputEncoding, plane_index:int):
if game_pgn:
try:
board = LczeroBoard()
pgn = io.StringIO(game_pgn)
game = chess.pgn.read_game(pgn)
for move in game.mainline_moves():
board.push(move)
except Exception as e:
print(e)
gr.Warning("Error parsing PGN, using starting position.")
board = LczeroBoard()
else:
try:
board = LczeroBoard(board_fen)
except Exception as e:
print(e)
gr.Warning("Invalid FEN, using starting position.")
board = LczeroBoard()
return board, *make_board_plot(board, input_encoding, plane_index)
def make_board_plot(board:LczeroBoard, input_encoding:InputEncoding, plane_index:int):
input_tensor = board.to_input_tensor(input_encoding=input_encoding)
board.render_heatmap(
input_tensor[plane_index].view(64),
save_to=f"{FIGURE_DIRECTORY}/encodings.svg",
vmin=0,
vmax=1,
)
return f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg"
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
)
game_pgn = gr.Textbox(
label="Game PGN",
lines=1,
value="",
)
board_fen = gr.Textbox(
label="Board FEN",
lines=1,
max_lines=1,
value=chess.STARTING_FEN,
)
input_encoding = gr.Radio(
label="Input encoding",
choices=[
("classical", InputEncoding.INPUT_CLASSICAL_112_PLANE),
("repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED),
("no history repeated", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED),
("no history zeros", InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS)
],
value=InputEncoding.INPUT_CLASSICAL_112_PLANE,
)
with gr.Group():
with gr.Row():
plane_index = gr.Slider(
label="Plane index",
minimum=0,
maximum=111,
step=1,
value=0,
)
with gr.Column():
image_board = gr.Image(label="Board", interactive=False)
colorbar = gr.Image(label="Colorbar", interactive=False)
state_board = gr.State(value=LczeroBoard())
render_inputs = [game_pgn, board_fen, input_encoding, plane_index]
render_outputs = [state_board, image_board, colorbar]
interface.load(
make_render,
inputs=render_inputs,
outputs=render_outputs,
)
game_pgn.submit(
make_render,
inputs=render_inputs,
outputs=render_outputs,
)
board_fen.submit(
make_render,
inputs=render_inputs,
outputs=render_outputs,
)
input_encoding.change(
make_board_plot,
inputs=[state_board, input_encoding, plane_index],
outputs=[image_board, colorbar],
)
plane_index.change(
make_board_plot,
inputs=[state_board, input_encoding, plane_index],
outputs=[image_board, colorbar],
)