Spaces:
Running
Running
File size: 2,995 Bytes
3333fb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
"""
Gradio interface for plotting attention.
"""
import chess
import chess.pgn
import io
import gradio as gr
from lczerolens.board import LczeroBoard
from ..constants import FIGURE_DIRECTORY
def make_render(game_pgn:str, board_fen:str, plane_index:int):
if game_pgn:
try:
board = LczeroBoard()
pgn = io.StringIO(game_pgn)
game = chess.pgn.read_game(pgn)
for move in game.mainline_moves():
board.push(move)
except Exception as e:
print(e)
gr.Warning("Error parsing PGN, using starting position.")
board = LczeroBoard()
else:
try:
board = LczeroBoard(board_fen)
except Exception as e:
print(e)
gr.Warning("Invalid FEN, using starting position.")
board = LczeroBoard()
return board, *make_board_plot(board, plane_index)
def make_board_plot(board:LczeroBoard, plane_index:int):
input_tensor = board.to_input_tensor()
board.render_heatmap(
input_tensor[plane_index].view(64),
save_to=f"{FIGURE_DIRECTORY}/encodings.svg",
vmin=0,
vmax=1,
)
return f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg"
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
)
game_pgn = gr.Textbox(
label="Game PGN",
lines=1,
value="",
)
board_fen = gr.Textbox(
label="Board FEN",
lines=1,
max_lines=1,
value=chess.STARTING_FEN,
)
with gr.Group():
with gr.Row():
plane_index = gr.Slider(
label="Plane index",
minimum=0,
maximum=111,
step=1,
value=0,
)
with gr.Column():
image_board = gr.Image(label="Board", interactive=False)
colorbar = gr.Image(label="Colorbar", interactive=False)
state_board = gr.State(value=LczeroBoard())
render_inputs = [game_pgn, board_fen, plane_index]
render_outputs = [state_board, image_board, colorbar]
interface.load(
make_render,
inputs=render_inputs,
outputs=render_outputs,
)
game_pgn.submit(
make_render,
inputs=render_inputs,
outputs=render_outputs,
)
board_fen.submit(
make_render,
inputs=render_inputs,
outputs=render_outputs,
)
plane_index.change(
make_board_plot,
inputs=[state_board, plane_index],
outputs=[image_board, colorbar],
)
|