Spaces:
Running
Running
""" | |
Gradio interface for plotting attention. | |
""" | |
import chess | |
import chess.pgn | |
import io | |
import gradio as gr | |
import os | |
import torch | |
from lczerolens import LczeroBoard, LczeroModel, Lens | |
from .. import constants | |
def get_model(model_name: str): | |
return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name)) | |
def get_activations(model: LczeroModel, board: LczeroBoard): | |
lens = Lens.from_name("activation", "block\d/conv2/relu") | |
with torch.no_grad(): | |
results = lens.analyse(model, board) | |
return [results[f"block{i}/conv2/relu_output"][0] for i in range(len(results))] | |
def get_board(game_pgn:str, board_fen:str): | |
if game_pgn: | |
try: | |
board = LczeroBoard() | |
pgn = io.StringIO(game_pgn) | |
game = chess.pgn.read_game(pgn) | |
for move in game.mainline_moves(): | |
board.push(move) | |
except Exception as e: | |
print(e) | |
gr.Warning("Error parsing PGN, using starting position.") | |
board = LczeroBoard() | |
else: | |
try: | |
board = LczeroBoard(board_fen) | |
except Exception as e: | |
print(e) | |
gr.Warning("Invalid FEN, using starting position.") | |
board = LczeroBoard() | |
return board | |
def render_activations(board: LczeroBoard, activations, layer_index:int, channel_index:int): | |
if layer_index >= len(activations): | |
safe_layer_index = len(activations) - 1 | |
gr.Warning(f"Layer index {layer_index} out of range, using last layer ({safe_layer_index}).") | |
else: | |
safe_layer_index = layer_index | |
if channel_index >= activations[safe_layer_index].shape[0]: | |
safe_channel_index = activations[safe_layer_index].shape[0] - 1 | |
gr.Warning(f"Channel index {channel_index} out of range, using last channel ({safe_channel_index}).") | |
else: | |
safe_channel_index = channel_index | |
heatmap = activations[safe_layer_index][safe_channel_index].view(64) | |
board.render_heatmap( | |
heatmap, | |
save_to=f"{constants.FIGURE_DIRECTORY}/activations.svg", | |
) | |
return f"{constants.FIGURE_DIRECTORY}/activations_board.svg", f"{constants.FIGURE_DIRECTORY}/activations_colorbar.svg" | |
def initial_load(model_name: str, board_fen: str, game_pgn: str, layer_index: int, channel_index: int): | |
model = get_model(model_name) | |
board = get_board(game_pgn, board_fen) | |
activations = get_activations(model, board) | |
plots = render_activations(board, activations, layer_index, channel_index) | |
return model, board, activations, *plots | |
def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, layer_index: int, channel_index: int): | |
board = get_board(game_pgn, board_fen) | |
activations = get_activations(model, board) | |
plots = render_activations(board, activations, layer_index, channel_index) | |
return board, activations, *plots | |
def on_model_change(model_name: str, board: LczeroBoard, layer_index: int, channel_index: int): | |
model = get_model(model_name) | |
activations = get_activations(model, board) | |
plots = render_activations(board, activations, layer_index, channel_index) | |
return model, activations, *plots | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
gr.Markdown( | |
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." | |
) | |
game_pgn = gr.Textbox( | |
label="Game PGN", | |
lines=1, | |
value="", | |
) | |
board_fen = gr.Textbox( | |
label="Board FEN", | |
lines=1, | |
max_lines=1, | |
value=chess.STARTING_FEN, | |
) | |
with gr.Group(): | |
model_name = gr.Dropdown( | |
label="Model", | |
choices=constants.ONNX_MODEL_NAMES, | |
) | |
layer_index = gr.Slider( | |
label="Layer index", | |
minimum=0, | |
maximum=19, | |
step=1, | |
value=0, | |
) | |
channel_index = gr.Slider( | |
label="Channel index", | |
minimum=0, | |
maximum=200, | |
step=1, | |
value=0, | |
) | |
with gr.Column(): | |
image_board = gr.Image(label="Board", interactive=False) | |
colorbar = gr.Image(label="Colorbar", interactive=False) | |
model = gr.State(value=None) | |
board = gr.State(value=None) | |
activations = gr.State(value=None) | |
interface.load( | |
initial_load, | |
inputs=[model_name, game_pgn, board_fen, layer_index, channel_index], | |
outputs=[model, board, activations, image_board, colorbar], | |
) | |
game_pgn.submit( | |
on_board_change, | |
inputs=[model, game_pgn, board_fen, layer_index, channel_index], | |
outputs=[board, activations, image_board, colorbar], | |
) | |
board_fen.submit( | |
on_board_change, | |
inputs=[model, game_pgn, board_fen, layer_index, channel_index], | |
outputs=[board, activations, image_board, colorbar], | |
) | |
model_name.change( | |
on_model_change, | |
inputs=[model_name, board, layer_index, channel_index], | |
outputs=[model, activations, image_board, colorbar], | |
) | |
layer_index.change( | |
render_activations, | |
inputs=[board, activations, layer_index, channel_index], | |
outputs=[image_board, colorbar], | |
) | |
channel_index.change( | |
render_activations, | |
inputs=[board, activations, layer_index, channel_index], | |
outputs=[image_board, colorbar], | |
) | |