lczerolens-demo / app /crp_interface.py
Xmaster6y's picture
removed chess.Board
5e4365f
"""
Gradio interface for plotting policy.
"""
import copy
import chess
import gradio as gr
from demo import constants, utils, visualisation
from lczerolens.board import LczeroBoard
cache = None
boards = None
board_index = 0
def list_models():
"""
List the models in the model directory.
"""
models_info = utils.get_models_info(leela=False)
return sorted([[model_info[0]] for model_info in models_info])
def on_select_model_df(
evt: gr.SelectData,
):
"""
When a model is selected, update the statement.
"""
return evt.value
def compute_cache(
board_fen,
action_seq,
model_name,
plane_index,
history_index,
):
global cache
global boards
if model_name == "":
gr.Warning("No model selected.")
return None, None, None, None, None
try:
board = LczeroBoard(board_fen)
except ValueError:
board = LczeroBoard()
gr.Warning("Invalid FEN, using starting position.")
boards = [board.copy()]
if action_seq:
try:
if action_seq.startswith("1."):
for action in action_seq.split():
if action.endswith("."):
continue
board.push_san(action)
boards.append(board.copy())
else:
for action in action_seq.split():
board.push_uci(action)
boards.append(board.copy())
except ValueError:
gr.Warning(f"Invalid action {action} stopping before it.")
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "crp")
cache = []
for board in boards:
relevance = lens.compute_heatmap(board, wrapper)
cache.append(copy.deepcopy(relevance))
return (
*make_plot(
plane_index,
),
*make_history_plot(
history_index,
),
)
def make_plot(
plane_index,
):
global cache
global boards
global board_index
if cache is None:
gr.Warning("Cache not computed!")
return None, None, None
board = boards[board_index]
relevance_tensor = cache[board_index]
a_max = relevance_tensor.abs().max()
if a_max != 0:
relevance_tensor = relevance_tensor / a_max
vmin = -1
vmax = 1
heatmap = relevance_tensor[plane_index - 1].view(64)
if board.turn == chess.BLACK:
heatmap = heatmap.view(8, 8).flip(0).view(64)
svg_board, fig = visualisation.render_heatmap(board, heatmap, vmin=vmin, vmax=vmax)
with open(f"{constants.FIGURE_DIRECTORY}/lrp.svg", "w") as f:
f.write(svg_board)
return f"{constants.FIGURE_DIRECTORY}/lrp.svg", board.fen(), fig
def make_history_plot(
history_index,
):
global cache
global boards
global board_index
if cache is None:
gr.Warning("Cache not computed!")
return None, None
board = boards[board_index]
relevance_tensor = cache[board_index]
a_max = relevance_tensor.abs().max()
if a_max != 0:
relevance_tensor = relevance_tensor / a_max
vmin = -1
vmax = 1
heatmap = relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1].sum(dim=0).view(64)
if board.turn == chess.BLACK:
heatmap = heatmap.view(8, 8).flip(0).view(64)
if board_index - history_index + 1 < 0:
history_board = LczeroBoard(fen=None)
else:
history_board = boards[board_index - history_index + 1]
svg_board, fig = visualisation.render_heatmap(history_board, heatmap, vmin=vmin, vmax=vmax)
with open(f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", "w") as f:
f.write(svg_board)
return f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", fig
def previous_board(
plane_index,
history_index,
):
global board_index
board_index -= 1
if board_index < 0:
gr.Warning("Already at first board.")
board_index = 0
return (
*make_plot(
plane_index,
),
*make_history_plot(
history_index,
),
)
def next_board(
plane_index,
history_index,
):
global board_index
board_index += 1
if board_index >= len(boards):
gr.Warning("Already at last board.")
board_index = len(boards) - 1
return (
*make_plot(
plane_index,
),
*make_history_plot(
history_index,
),
)
with gr.Blocks() as interface:
with gr.Row():
with gr.Column(scale=2):
model_df = gr.Dataframe(
headers=["Available models"],
datatype=["str"],
interactive=False,
type="array",
value=list_models,
)
with gr.Column(scale=1):
with gr.Row():
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
model_df.select(
on_select_model_df,
None,
model_name,
)
with gr.Row():
with gr.Column():
board_fen = gr.Textbox(
label="Board starting FEN",
lines=1,
max_lines=1,
value=chess.STARTING_FEN,
)
action_seq = gr.Textbox(
label="Action sequence",
lines=1,
max_lines=1,
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
)
compute_cache_button = gr.Button("Compute heatmaps")
with gr.Group():
with gr.Row():
plane_index = gr.Slider(
label="Plane index",
minimum=1,
maximum=112,
step=1,
value=1,
)
with gr.Row():
previous_board_button = gr.Button("Previous board")
next_board_button = gr.Button("Next board")
current_board_fen = gr.Textbox(
label="Board FEN",
lines=1,
max_lines=1,
)
colorbar = gr.Plot(label="Colorbar")
with gr.Column():
image = gr.Image(label="Board")
with gr.Row():
with gr.Column():
with gr.Group():
with gr.Row():
histroy_index = gr.Slider(
label="History index",
minimum=1,
maximum=8,
step=1,
value=1,
)
history_colorbar = gr.Plot(label="Colorbar")
with gr.Column():
history_image = gr.Image(label="Board")
base_inputs = [
plane_index,
histroy_index,
]
outputs = [
image,
current_board_fen,
colorbar,
history_image,
history_colorbar,
]
compute_cache_button.click(
compute_cache,
inputs=[board_fen, action_seq, model_name] + base_inputs,
outputs=outputs,
)
previous_board_button.click(previous_board, inputs=base_inputs, outputs=outputs)
next_board_button.click(next_board, inputs=base_inputs, outputs=outputs)
plane_index.change(
make_plot,
inputs=plane_index,
outputs=[image, current_board_fen, colorbar],
)
histroy_index.change(
make_history_plot,
inputs=histroy_index,
outputs=[history_image, history_colorbar],
)