Spaces:
Running
Running
""" | |
Gradio interface for plotting attention. | |
""" | |
import chess | |
import chess.pgn | |
import io | |
import gradio as gr | |
import os | |
from lczerolens import LczeroBoard, LczeroModel, Lens | |
from demo import constants | |
from demo.utils import get_info | |
def get_model(model_name: str): | |
return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name)) | |
def get_gradients(model: LczeroModel, board: LczeroBoard, target: str): | |
lens = Lens.from_name("gradient") | |
def init_target(model): | |
if target == "best_move": | |
return getattr(model, "output/policy").output.max(dim=1).values | |
else: | |
wdl_index = {"win": 0, "draw": 1, "loss": 2}[target] | |
return getattr(model, "output/wdl").output[:, wdl_index] | |
results = lens.analyse(model, board, init_target=init_target) | |
return results["input_grad"] | |
def get_board(game_pgn:str, board_fen:str): | |
if game_pgn: | |
try: | |
board = LczeroBoard() | |
pgn = io.StringIO(game_pgn) | |
game = chess.pgn.read_game(pgn) | |
for move in game.mainline_moves(): | |
board.push(move) | |
except Exception as e: | |
print(e) | |
gr.Warning("Error parsing PGN, using starting position.") | |
board = LczeroBoard() | |
else: | |
try: | |
board = LczeroBoard(board_fen) | |
except Exception as e: | |
print(e) | |
gr.Warning("Invalid FEN, using starting position.") | |
board = LczeroBoard() | |
return board | |
def render_gradients(board: LczeroBoard, gradients, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index:int): | |
if average_over_planes: | |
heatmap = gradients[0, begin_average_index:end_average_index].mean(dim=0).view(64) | |
else: | |
heatmap = gradients[0, plane_index].view(64) | |
board.render_heatmap( | |
heatmap, | |
save_to=f"{constants.FIGURE_DIRECTORY}/gradients.svg", | |
) | |
return f"{constants.FIGURE_DIRECTORY}/gradients_board.svg", f"{constants.FIGURE_DIRECTORY}/gradients_colorbar.svg" | |
def initial_load(model_name: str, board_fen: str, game_pgn: str, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): | |
model = get_model(model_name) | |
board = get_board(game_pgn, board_fen) | |
gradients = get_gradients(model, board, target) | |
info = get_info(model, board) | |
plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) | |
return model, board, gradients, info, *plots | |
def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): | |
board = get_board(game_pgn, board_fen) | |
gradients = get_gradients(model, board, target) | |
info = get_info(model, board) | |
plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) | |
return board, gradients, info, *plots | |
def on_model_change(model_name: str, board: LczeroBoard, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): | |
model = get_model(model_name) | |
gradients = get_gradients(model, board, target) | |
info = get_info(model, board) | |
plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) | |
return model, gradients, info, *plots | |
def on_target_change(model: LczeroModel, board: LczeroBoard, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int): | |
gradients = get_gradients(model, board, target) | |
plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index) | |
return gradients, *plots | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
gr.Markdown( | |
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." | |
) | |
game_pgn = gr.Textbox( | |
label="Game PGN", | |
lines=1, | |
value="", | |
) | |
board_fen = gr.Textbox( | |
label="Board FEN", | |
lines=1, | |
max_lines=1, | |
value=chess.STARTING_FEN, | |
) | |
model_name = gr.Dropdown( | |
label="Model", | |
choices=constants.ONNX_MODEL_NAMES, | |
) | |
with gr.Group(): | |
info = gr.Textbox(label="Info", lines=1, value="") | |
with gr.Group(): | |
target = gr.Radio( | |
["win", "draw", "loss", "best_move"], label="Target", | |
value="win", | |
) | |
average_over_planes = gr.Checkbox(label="Average over Planes", value=False) | |
with gr.Accordion("Average over planes", open=False): | |
begin_average_index = gr.Slider( | |
label="Begin average index", | |
minimum=0, | |
maximum=111, | |
step=1, | |
value=0, | |
) | |
end_average_index = gr.Slider( | |
label="End average index", | |
minimum=0, | |
maximum=111, | |
step=1, | |
value=111, | |
) | |
plane_index = gr.Slider( | |
label="Plane index", | |
minimum=0, | |
maximum=111, | |
step=1, | |
value=0, | |
) | |
with gr.Column(): | |
image_board = gr.Image(label="Board", interactive=False) | |
colorbar = gr.Image(label="Colorbar", interactive=False) | |
model = gr.State(value=None) | |
board = gr.State(value=None) | |
gradients = gr.State(value=None) | |
interface.load( | |
initial_load, | |
inputs=[model_name, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index], | |
outputs=[model, board, gradients, info, image_board, colorbar], | |
concurrency_id="trace_queue" | |
) | |
game_pgn.submit( | |
on_board_change, | |
inputs=[model, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index], | |
outputs=[board, gradients, info, image_board, colorbar], | |
concurrency_id="trace_queue" | |
) | |
board_fen.submit( | |
on_board_change, | |
inputs=[model, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index], | |
outputs=[board, gradients, info, image_board, colorbar], | |
concurrency_id="trace_queue" | |
) | |
model_name.change( | |
on_model_change, | |
inputs=[model_name, board, target, average_over_planes, begin_average_index, end_average_index, plane_index], | |
outputs=[model, gradients, info, image_board, colorbar], | |
concurrency_id="trace_queue" | |
) | |
target.change( | |
on_target_change, | |
inputs=[model, board, target, average_over_planes, begin_average_index, end_average_index, plane_index], | |
outputs=[gradients, image_board, colorbar], | |
concurrency_id="trace_queue" | |
) | |
for render_arg in [average_over_planes, begin_average_index, end_average_index, plane_index]: | |
render_arg.change( | |
render_gradients, | |
inputs=[board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index], | |
outputs=[image_board, colorbar], | |
) | |