Spaces:
Starting
Starting
old files
Browse files- app/__init__.py +0 -0
- app/attention_interface.py +292 -0
- app/backend_interface.py +207 -0
- app/board_interface.py +80 -0
- app/constants.py +7 -0
- app/convert_interface.py +201 -0
- app/crp_interface.py +279 -0
- app/encoding_interface.py +82 -0
- app/figures/.gitignore +2 -0
- app/leela_models/.gitignore +2 -0
- app/lrp_interface.py +279 -0
- app/main.py +50 -0
- app/onnx_models/.gitignore +2 -0
- app/policy_interface.py +276 -0
- app/state.py +18 -0
- app/statistics_interface.py +189 -0
- app/utils.py +121 -0
- app/visualisation.py +303 -0
app/__init__.py
ADDED
File without changes
|
app/attention_interface.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for plotting attention.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import copy
|
6 |
+
|
7 |
+
import chess
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from demo import constants, utils, visualisation
|
11 |
+
|
12 |
+
|
13 |
+
def list_models():
|
14 |
+
"""
|
15 |
+
List the models in the model directory.
|
16 |
+
"""
|
17 |
+
models_info = utils.get_models_info(leela=False)
|
18 |
+
return sorted([[model_info[0]] for model_info in models_info])
|
19 |
+
|
20 |
+
|
21 |
+
def on_select_model_df(
|
22 |
+
evt: gr.SelectData,
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
When a model is selected, update the statement.
|
26 |
+
"""
|
27 |
+
return evt.value
|
28 |
+
|
29 |
+
|
30 |
+
def compute_cache(
|
31 |
+
board_fen,
|
32 |
+
action_seq,
|
33 |
+
model_name,
|
34 |
+
attention_layer,
|
35 |
+
attention_head,
|
36 |
+
square,
|
37 |
+
state_board_index,
|
38 |
+
state_boards,
|
39 |
+
state_cache,
|
40 |
+
):
|
41 |
+
if model_name == "":
|
42 |
+
gr.Warning("No model selected.")
|
43 |
+
return None, None, None, state_boards, state_cache
|
44 |
+
|
45 |
+
try:
|
46 |
+
board = chess.Board(board_fen)
|
47 |
+
except ValueError:
|
48 |
+
board = chess.Board()
|
49 |
+
gr.Warning("Invalid FEN, using starting position.")
|
50 |
+
state_boards = [board.copy()]
|
51 |
+
if action_seq:
|
52 |
+
try:
|
53 |
+
if action_seq.startswith("1."):
|
54 |
+
for action in action_seq.split():
|
55 |
+
if action.endswith("."):
|
56 |
+
continue
|
57 |
+
board.push_san(action)
|
58 |
+
state_boards.append(board.copy())
|
59 |
+
else:
|
60 |
+
for action in action_seq.split():
|
61 |
+
board.push_uci(action)
|
62 |
+
state_boards.append(board.copy())
|
63 |
+
except ValueError:
|
64 |
+
gr.Warning(f"Invalid action {action} stopping before it.")
|
65 |
+
try:
|
66 |
+
wrapper, lens = utils.get_wrapper_lens_from_state(
|
67 |
+
model_name,
|
68 |
+
"activation",
|
69 |
+
lens_name="attention",
|
70 |
+
module_exp=r"encoder\d+/mha/QK/softmax",
|
71 |
+
)
|
72 |
+
except ValueError:
|
73 |
+
gr.Warning("Could not load model.")
|
74 |
+
return None, None, None, state_boards, state_cache
|
75 |
+
state_cache = []
|
76 |
+
for board in state_boards:
|
77 |
+
attention_cache = copy.deepcopy(lens.analyse_board(board, wrapper))
|
78 |
+
state_cache.append(attention_cache)
|
79 |
+
return (
|
80 |
+
*make_plot(
|
81 |
+
attention_layer,
|
82 |
+
attention_head,
|
83 |
+
square,
|
84 |
+
state_board_index,
|
85 |
+
state_boards,
|
86 |
+
state_cache,
|
87 |
+
),
|
88 |
+
state_boards,
|
89 |
+
state_cache,
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
def make_plot(
|
94 |
+
attention_layer,
|
95 |
+
attention_head,
|
96 |
+
square,
|
97 |
+
state_board_index,
|
98 |
+
state_boards,
|
99 |
+
state_cache,
|
100 |
+
):
|
101 |
+
if state_cache == []:
|
102 |
+
gr.Warning("No cache available.")
|
103 |
+
return None, None, None
|
104 |
+
|
105 |
+
board = state_boards[state_board_index]
|
106 |
+
num_attention_layers = len(state_cache[state_board_index])
|
107 |
+
if attention_layer > num_attention_layers:
|
108 |
+
gr.Warning(
|
109 |
+
f"Attention layer {attention_layer} does not exist, " f"using layer {num_attention_layers} instead."
|
110 |
+
)
|
111 |
+
attention_layer = num_attention_layers
|
112 |
+
|
113 |
+
key = f"encoder{attention_layer-1}/mha/QK/softmax"
|
114 |
+
try:
|
115 |
+
attention_tensor = state_cache[state_board_index][key]
|
116 |
+
except KeyError:
|
117 |
+
gr.Warning(f"Combination {key} does not exist.")
|
118 |
+
return None, None, None
|
119 |
+
if attention_head > attention_tensor.shape[1]:
|
120 |
+
gr.Warning(
|
121 |
+
f"Attention head {attention_head} does not exist, " f"using head {attention_tensor.shape[1]+1} instead."
|
122 |
+
)
|
123 |
+
attention_head = attention_tensor.shape[1]
|
124 |
+
try:
|
125 |
+
square_index = chess.SQUARE_NAMES.index(square)
|
126 |
+
except ValueError:
|
127 |
+
gr.Warning(f"Invalid square {square}, using a1 instead.")
|
128 |
+
square_index = 0
|
129 |
+
square = "a1"
|
130 |
+
if board.turn == chess.BLACK:
|
131 |
+
square_index = chess.square_mirror(square_index)
|
132 |
+
|
133 |
+
heatmap = attention_tensor[0, attention_head - 1, square_index]
|
134 |
+
if board.turn == chess.BLACK:
|
135 |
+
heatmap = heatmap.view(8, 8).flip(0).view(64)
|
136 |
+
svg_board, fig = visualisation.render_heatmap(board, heatmap, square=square)
|
137 |
+
with open(f"{constants.FIGURE_DIRECTORY}/attention.svg", "w") as f:
|
138 |
+
f.write(svg_board)
|
139 |
+
return f"{constants.FIGURE_DIRECTORY}/attention.svg", board.fen(), fig
|
140 |
+
|
141 |
+
|
142 |
+
def previous_board(
|
143 |
+
attention_layer,
|
144 |
+
attention_head,
|
145 |
+
square,
|
146 |
+
state_board_index,
|
147 |
+
state_boards,
|
148 |
+
state_cache,
|
149 |
+
):
|
150 |
+
state_board_index -= 1
|
151 |
+
if state_board_index < 0:
|
152 |
+
gr.Warning("Already at first board.")
|
153 |
+
state_board_index = 0
|
154 |
+
return (
|
155 |
+
*make_plot(
|
156 |
+
attention_layer,
|
157 |
+
attention_head,
|
158 |
+
square,
|
159 |
+
state_board_index,
|
160 |
+
state_boards,
|
161 |
+
state_cache,
|
162 |
+
),
|
163 |
+
state_board_index,
|
164 |
+
)
|
165 |
+
|
166 |
+
|
167 |
+
def next_board(
|
168 |
+
attention_layer,
|
169 |
+
attention_head,
|
170 |
+
square,
|
171 |
+
state_board_index,
|
172 |
+
state_boards,
|
173 |
+
state_cache,
|
174 |
+
):
|
175 |
+
state_board_index += 1
|
176 |
+
if state_board_index >= len(state_boards):
|
177 |
+
gr.Warning("Already at last board.")
|
178 |
+
state_board_index = len(state_boards) - 1
|
179 |
+
return (
|
180 |
+
*make_plot(
|
181 |
+
attention_layer,
|
182 |
+
attention_head,
|
183 |
+
square,
|
184 |
+
state_board_index,
|
185 |
+
state_boards,
|
186 |
+
state_cache,
|
187 |
+
),
|
188 |
+
state_board_index,
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
with gr.Blocks() as interface:
|
193 |
+
with gr.Row():
|
194 |
+
with gr.Column(scale=2):
|
195 |
+
model_df = gr.Dataframe(
|
196 |
+
headers=["Available models"],
|
197 |
+
datatype=["str"],
|
198 |
+
interactive=False,
|
199 |
+
type="array",
|
200 |
+
value=list_models,
|
201 |
+
)
|
202 |
+
with gr.Column(scale=1):
|
203 |
+
with gr.Row():
|
204 |
+
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
|
205 |
+
|
206 |
+
model_df.select(
|
207 |
+
on_select_model_df,
|
208 |
+
None,
|
209 |
+
model_name,
|
210 |
+
)
|
211 |
+
|
212 |
+
with gr.Row():
|
213 |
+
with gr.Column():
|
214 |
+
board_fen = gr.Textbox(
|
215 |
+
label="Board starting FEN",
|
216 |
+
lines=1,
|
217 |
+
max_lines=1,
|
218 |
+
value=chess.STARTING_FEN,
|
219 |
+
)
|
220 |
+
action_seq = gr.Textbox(
|
221 |
+
label="Action sequence",
|
222 |
+
lines=1,
|
223 |
+
max_lines=1,
|
224 |
+
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
|
225 |
+
)
|
226 |
+
compute_cache_button = gr.Button("Compute cache")
|
227 |
+
|
228 |
+
with gr.Group():
|
229 |
+
with gr.Row():
|
230 |
+
attention_layer = gr.Slider(
|
231 |
+
label="Attention layer",
|
232 |
+
minimum=1,
|
233 |
+
maximum=24,
|
234 |
+
step=1,
|
235 |
+
value=1,
|
236 |
+
)
|
237 |
+
attention_head = gr.Slider(
|
238 |
+
label="Attention head",
|
239 |
+
minimum=1,
|
240 |
+
maximum=24,
|
241 |
+
step=1,
|
242 |
+
value=1,
|
243 |
+
)
|
244 |
+
with gr.Row():
|
245 |
+
square = gr.Textbox(
|
246 |
+
label="Square",
|
247 |
+
lines=1,
|
248 |
+
max_lines=1,
|
249 |
+
value="a1",
|
250 |
+
scale=1,
|
251 |
+
)
|
252 |
+
with gr.Row():
|
253 |
+
previous_board_button = gr.Button("Previous board")
|
254 |
+
next_board_button = gr.Button("Next board")
|
255 |
+
current_board_fen = gr.Textbox(
|
256 |
+
label="Board FEN",
|
257 |
+
lines=1,
|
258 |
+
max_lines=1,
|
259 |
+
)
|
260 |
+
colorbar = gr.Plot(label="Colorbar")
|
261 |
+
with gr.Column():
|
262 |
+
image = gr.Image(label="Board")
|
263 |
+
|
264 |
+
state_board_index = gr.State(0)
|
265 |
+
state_boards = gr.State([])
|
266 |
+
state_cache = gr.State([])
|
267 |
+
base_inputs = [
|
268 |
+
attention_layer,
|
269 |
+
attention_head,
|
270 |
+
square,
|
271 |
+
state_board_index,
|
272 |
+
state_boards,
|
273 |
+
state_cache,
|
274 |
+
]
|
275 |
+
outputs = [image, current_board_fen, colorbar]
|
276 |
+
|
277 |
+
compute_cache_button.click(
|
278 |
+
compute_cache,
|
279 |
+
inputs=[board_fen, action_seq, model_name] + base_inputs,
|
280 |
+
outputs=outputs + [state_boards, state_cache],
|
281 |
+
)
|
282 |
+
|
283 |
+
previous_board_button.click(
|
284 |
+
previous_board,
|
285 |
+
inputs=base_inputs,
|
286 |
+
outputs=outputs + [state_board_index],
|
287 |
+
)
|
288 |
+
next_board_button.click(next_board, inputs=base_inputs, outputs=outputs + [state_board_index])
|
289 |
+
|
290 |
+
attention_layer.change(make_plot, inputs=base_inputs, outputs=outputs)
|
291 |
+
attention_head.change(make_plot, inputs=base_inputs, outputs=outputs)
|
292 |
+
square.submit(make_plot, inputs=base_inputs, outputs=outputs)
|
app/backend_interface.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for visualizing the policy of a model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import chess
|
6 |
+
import chess.svg
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
from lczero.backends import Backend, GameState, Weights
|
10 |
+
|
11 |
+
from demo import constants, utils, visualisation
|
12 |
+
from lczerolens import move_encodings
|
13 |
+
from lczerolens.model import lczero as lczero_utils
|
14 |
+
from lczerolens.xai import PolicyLens
|
15 |
+
|
16 |
+
|
17 |
+
def list_models():
|
18 |
+
"""
|
19 |
+
List the models in the model directory.
|
20 |
+
"""
|
21 |
+
models_info = utils.get_models_info(onnx=False)
|
22 |
+
return sorted([[model_info[0]] for model_info in models_info])
|
23 |
+
|
24 |
+
|
25 |
+
def on_select_model_df(
|
26 |
+
evt: gr.SelectData,
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
When a model is selected, update the statement.
|
30 |
+
"""
|
31 |
+
return evt.value
|
32 |
+
|
33 |
+
|
34 |
+
def make_policy_plot(
|
35 |
+
board_fen,
|
36 |
+
action_seq,
|
37 |
+
view,
|
38 |
+
model_name,
|
39 |
+
depth,
|
40 |
+
use_softmax,
|
41 |
+
aggregate_topk,
|
42 |
+
render_bestk,
|
43 |
+
only_legal,
|
44 |
+
):
|
45 |
+
if model_name == "":
|
46 |
+
gr.Warning(
|
47 |
+
"Please select a model.",
|
48 |
+
)
|
49 |
+
return (
|
50 |
+
None,
|
51 |
+
None,
|
52 |
+
"",
|
53 |
+
)
|
54 |
+
try:
|
55 |
+
board = chess.Board(board_fen)
|
56 |
+
except ValueError:
|
57 |
+
board = chess.Board()
|
58 |
+
gr.Warning("Invalid FEN, using starting position.")
|
59 |
+
if action_seq:
|
60 |
+
try:
|
61 |
+
for action in action_seq.split():
|
62 |
+
board.push_uci(action)
|
63 |
+
except ValueError:
|
64 |
+
gr.Warning("Invalid action sequence, using starting position.")
|
65 |
+
board = chess.Board()
|
66 |
+
lczero_weights = Weights(f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}")
|
67 |
+
lczero_backend = Backend(lczero_weights)
|
68 |
+
uci_moves = [move.uci() for move in board.move_stack]
|
69 |
+
lczero_game = GameState(moves=uci_moves)
|
70 |
+
policy, value = lczero_utils.prediction_from_backend(
|
71 |
+
lczero_backend,
|
72 |
+
lczero_game,
|
73 |
+
softmax=use_softmax,
|
74 |
+
only_legal=only_legal,
|
75 |
+
illegal_value=0,
|
76 |
+
)
|
77 |
+
pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(policy, int(aggregate_topk))
|
78 |
+
|
79 |
+
if view == "from":
|
80 |
+
if board.turn == chess.WHITE:
|
81 |
+
heatmap = pickup_agg
|
82 |
+
else:
|
83 |
+
heatmap = pickup_agg.view(8, 8).flip(0).view(64)
|
84 |
+
else:
|
85 |
+
if board.turn == chess.WHITE:
|
86 |
+
heatmap = dropoff_agg
|
87 |
+
else:
|
88 |
+
heatmap = dropoff_agg.view(8, 8).flip(0).view(64)
|
89 |
+
us_them = (board.turn, not board.turn)
|
90 |
+
if only_legal:
|
91 |
+
legal_moves = [move_encodings.encode_move(move, us_them) for move in board.legal_moves]
|
92 |
+
filtered_policy = torch.zeros(1858)
|
93 |
+
filtered_policy[legal_moves] = policy[legal_moves]
|
94 |
+
if (filtered_policy < 0).any():
|
95 |
+
gr.Warning("Some legal moves have negative policy.")
|
96 |
+
topk_moves = torch.topk(filtered_policy, render_bestk)
|
97 |
+
else:
|
98 |
+
topk_moves = torch.topk(policy, render_bestk)
|
99 |
+
arrows = []
|
100 |
+
for move_index in topk_moves.indices:
|
101 |
+
move = move_encodings.decode_move(move_index, us_them)
|
102 |
+
arrows.append((move.from_square, move.to_square))
|
103 |
+
svg_board, fig = visualisation.render_heatmap(board, heatmap, arrows=arrows)
|
104 |
+
with open(f"{constants.FIGURE_DIRECTORY}/policy.svg", "w") as f:
|
105 |
+
f.write(svg_board)
|
106 |
+
raw_policy, _ = lczero_utils.prediction_from_backend(
|
107 |
+
lczero_backend,
|
108 |
+
lczero_game,
|
109 |
+
softmax=False,
|
110 |
+
only_legal=False,
|
111 |
+
illegal_value=0,
|
112 |
+
)
|
113 |
+
fig_dist = visualisation.render_policy_distribution(
|
114 |
+
raw_policy,
|
115 |
+
[move_encodings.encode_move(move, us_them) for move in board.legal_moves],
|
116 |
+
)
|
117 |
+
return (
|
118 |
+
f"{constants.FIGURE_DIRECTORY}/policy.svg",
|
119 |
+
fig,
|
120 |
+
(f"Value: {value:.2f}"),
|
121 |
+
fig_dist,
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
with gr.Blocks() as interface:
|
126 |
+
with gr.Row():
|
127 |
+
with gr.Column(scale=2):
|
128 |
+
model_df = gr.Dataframe(
|
129 |
+
headers=["Available models"],
|
130 |
+
datatype=["str"],
|
131 |
+
interactive=False,
|
132 |
+
type="array",
|
133 |
+
value=list_models,
|
134 |
+
)
|
135 |
+
with gr.Column(scale=1):
|
136 |
+
with gr.Row():
|
137 |
+
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
|
138 |
+
|
139 |
+
model_df.select(
|
140 |
+
on_select_model_df,
|
141 |
+
None,
|
142 |
+
model_name,
|
143 |
+
)
|
144 |
+
with gr.Row():
|
145 |
+
with gr.Column():
|
146 |
+
board_fen = gr.Textbox(
|
147 |
+
label="Board FEN",
|
148 |
+
lines=1,
|
149 |
+
max_lines=1,
|
150 |
+
value=chess.STARTING_FEN,
|
151 |
+
)
|
152 |
+
action_seq = gr.Textbox(
|
153 |
+
label="Action sequence",
|
154 |
+
lines=1,
|
155 |
+
max_lines=1,
|
156 |
+
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
|
157 |
+
)
|
158 |
+
with gr.Group():
|
159 |
+
with gr.Row():
|
160 |
+
depth = gr.Radio(label="Depth", choices=[0], value=0)
|
161 |
+
use_softmax = gr.Checkbox(label="Use softmax", value=True)
|
162 |
+
with gr.Row():
|
163 |
+
aggregate_topk = gr.Slider(
|
164 |
+
label="Aggregate top k",
|
165 |
+
minimum=1,
|
166 |
+
maximum=1858,
|
167 |
+
step=1,
|
168 |
+
value=1858,
|
169 |
+
scale=3,
|
170 |
+
)
|
171 |
+
view = gr.Radio(
|
172 |
+
label="View",
|
173 |
+
choices=["from", "to"],
|
174 |
+
value="from",
|
175 |
+
scale=1,
|
176 |
+
)
|
177 |
+
with gr.Row():
|
178 |
+
render_bestk = gr.Slider(
|
179 |
+
label="Render best k",
|
180 |
+
minimum=1,
|
181 |
+
maximum=5,
|
182 |
+
step=1,
|
183 |
+
value=5,
|
184 |
+
scale=3,
|
185 |
+
)
|
186 |
+
only_legal = gr.Checkbox(label="Only legal", value=True, scale=1)
|
187 |
+
|
188 |
+
policy_button = gr.Button("Plot policy")
|
189 |
+
colorbar = gr.Plot(label="Colorbar")
|
190 |
+
game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="")
|
191 |
+
with gr.Column():
|
192 |
+
image = gr.Image(label="Board")
|
193 |
+
density_plot = gr.Plot(label="Density")
|
194 |
+
|
195 |
+
policy_inputs = [
|
196 |
+
board_fen,
|
197 |
+
action_seq,
|
198 |
+
view,
|
199 |
+
model_name,
|
200 |
+
depth,
|
201 |
+
use_softmax,
|
202 |
+
aggregate_topk,
|
203 |
+
render_bestk,
|
204 |
+
only_legal,
|
205 |
+
]
|
206 |
+
policy_outputs = [image, colorbar, game_info, density_plot]
|
207 |
+
policy_button.click(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
|
app/board_interface.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for plotting a board.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import chess
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from demo import constants
|
9 |
+
|
10 |
+
|
11 |
+
def make_board_plot(board_fen, arrows, square):
|
12 |
+
try:
|
13 |
+
board = chess.Board(board_fen)
|
14 |
+
except ValueError:
|
15 |
+
board = chess.Board()
|
16 |
+
gr.Warning("Invalid FEN, using starting position.")
|
17 |
+
try:
|
18 |
+
if arrows:
|
19 |
+
arrows_list = arrows.split(" ")
|
20 |
+
chess_arrows = []
|
21 |
+
for arrow in arrows_list:
|
22 |
+
from_square, to_square = arrow[:2], arrow[2:]
|
23 |
+
chess_arrows.append(
|
24 |
+
(
|
25 |
+
chess.parse_square(from_square),
|
26 |
+
chess.parse_square(to_square),
|
27 |
+
)
|
28 |
+
)
|
29 |
+
else:
|
30 |
+
chess_arrows = []
|
31 |
+
except ValueError:
|
32 |
+
chess_arrows = []
|
33 |
+
gr.Warning("Invalid arrows, using none.")
|
34 |
+
|
35 |
+
color_dict = {chess.parse_square(square): "#FF0000"} if square else {}
|
36 |
+
svg_board = chess.svg.board(
|
37 |
+
board,
|
38 |
+
size=350,
|
39 |
+
arrows=chess_arrows,
|
40 |
+
fill=color_dict,
|
41 |
+
)
|
42 |
+
with open(f"{constants.FIGURE_DIRECTORY}/board.svg", "w") as f:
|
43 |
+
f.write(svg_board)
|
44 |
+
return f"{constants.FIGURE_DIRECTORY}/board.svg"
|
45 |
+
|
46 |
+
|
47 |
+
with gr.Blocks() as interface:
|
48 |
+
with gr.Row():
|
49 |
+
with gr.Column():
|
50 |
+
board_fen = gr.Textbox(
|
51 |
+
label="Board starting FEN",
|
52 |
+
lines=1,
|
53 |
+
max_lines=1,
|
54 |
+
value=chess.STARTING_FEN,
|
55 |
+
)
|
56 |
+
arrows = gr.Textbox(
|
57 |
+
label="Arrows",
|
58 |
+
lines=1,
|
59 |
+
max_lines=1,
|
60 |
+
value="",
|
61 |
+
placeholder="e2e4 e7e5",
|
62 |
+
)
|
63 |
+
square = gr.Textbox(
|
64 |
+
label="Square",
|
65 |
+
lines=1,
|
66 |
+
max_lines=1,
|
67 |
+
value="",
|
68 |
+
placeholder="e4",
|
69 |
+
)
|
70 |
+
with gr.Column():
|
71 |
+
image = gr.Image(label="Board", interactive=False)
|
72 |
+
|
73 |
+
inputs = [
|
74 |
+
board_fen,
|
75 |
+
arrows,
|
76 |
+
square,
|
77 |
+
]
|
78 |
+
board_fen.submit(make_board_plot, inputs=inputs, outputs=image)
|
79 |
+
arrows.submit(make_board_plot, inputs=inputs, outputs=image)
|
80 |
+
interface.load(make_board_plot, inputs=inputs, outputs=image)
|
app/constants.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Constants for the demo.
|
3 |
+
"""
|
4 |
+
|
5 |
+
MODEL_DIRECTORY = "demo/onnx_models"
|
6 |
+
LEELA_MODEL_DIRECTORY = "demo/leela_models"
|
7 |
+
FIGURE_DIRECTORY = "demo/figures"
|
app/convert_interface.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for converting models.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import uuid
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from demo import constants, utils
|
11 |
+
from lczerolens.model import lczero as lczero_utils
|
12 |
+
|
13 |
+
|
14 |
+
def list_models():
|
15 |
+
"""
|
16 |
+
List the models in the model directory.
|
17 |
+
"""
|
18 |
+
models_info = utils.get_models_info()
|
19 |
+
return sorted([[model_info[0]] for model_info in models_info])
|
20 |
+
|
21 |
+
|
22 |
+
def on_select_model_df(
|
23 |
+
evt: gr.SelectData,
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
When a model is selected, update the statement.
|
27 |
+
"""
|
28 |
+
return evt.value
|
29 |
+
|
30 |
+
|
31 |
+
def convert_model(
|
32 |
+
model_name: str,
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Convert the model.
|
36 |
+
"""
|
37 |
+
if model_name == "":
|
38 |
+
gr.Warning(
|
39 |
+
"Please select a model.",
|
40 |
+
)
|
41 |
+
return list_models(), ""
|
42 |
+
if model_name.endswith(".onnx"):
|
43 |
+
gr.Warning(
|
44 |
+
"ONNX conversion not implemented.",
|
45 |
+
)
|
46 |
+
return list_models(), ""
|
47 |
+
try:
|
48 |
+
lczero_utils.convert_to_onnx(
|
49 |
+
f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}",
|
50 |
+
f"{constants.MODEL_DIRECTORY}/{model_name[:-6]}.onnx",
|
51 |
+
)
|
52 |
+
except RuntimeError:
|
53 |
+
gr.Warning(
|
54 |
+
f"Could not convert net at `{model_name}`.",
|
55 |
+
)
|
56 |
+
return list_models(), "Conversion failed"
|
57 |
+
return list_models(), "Conversion successful"
|
58 |
+
|
59 |
+
|
60 |
+
def upload_model(
|
61 |
+
model_file: gr.File,
|
62 |
+
):
|
63 |
+
"""
|
64 |
+
Convert the model.
|
65 |
+
"""
|
66 |
+
if model_file is None:
|
67 |
+
gr.Warning(
|
68 |
+
"File not uploaded.",
|
69 |
+
)
|
70 |
+
return list_models()
|
71 |
+
try:
|
72 |
+
id = uuid.uuid4()
|
73 |
+
tmp_file_path = f"{constants.LEELA_MODEL_DIRECTORY}/{id}"
|
74 |
+
with open(
|
75 |
+
tmp_file_path,
|
76 |
+
"wb",
|
77 |
+
) as f:
|
78 |
+
f.write(model_file)
|
79 |
+
utils.save_model(tmp_file_path)
|
80 |
+
except RuntimeError:
|
81 |
+
gr.Warning(
|
82 |
+
"Invalid file type.",
|
83 |
+
)
|
84 |
+
finally:
|
85 |
+
if os.path.exists(tmp_file_path):
|
86 |
+
os.remove(tmp_file_path)
|
87 |
+
return list_models()
|
88 |
+
|
89 |
+
|
90 |
+
def get_model_description(
|
91 |
+
model_name: str,
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
Get the model description.
|
95 |
+
"""
|
96 |
+
if model_name == "":
|
97 |
+
gr.Warning(
|
98 |
+
"Please select a model.",
|
99 |
+
)
|
100 |
+
return ""
|
101 |
+
if model_name.endswith(".onnx"):
|
102 |
+
gr.Warning(
|
103 |
+
"ONNX description not implemented.",
|
104 |
+
)
|
105 |
+
return ""
|
106 |
+
try:
|
107 |
+
description = lczero_utils.describenet(
|
108 |
+
f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}",
|
109 |
+
)
|
110 |
+
except RuntimeError:
|
111 |
+
raise gr.Error(
|
112 |
+
f"Could not describe net at `{model_name}`.",
|
113 |
+
)
|
114 |
+
return description
|
115 |
+
|
116 |
+
|
117 |
+
def get_model_path(
|
118 |
+
model_name: str,
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
Get the model path.
|
122 |
+
"""
|
123 |
+
if model_name == "":
|
124 |
+
gr.Warning(
|
125 |
+
"Please select a model.",
|
126 |
+
)
|
127 |
+
return None
|
128 |
+
if model_name.endswith(".onnx"):
|
129 |
+
return f"{constants.MODEL_DIRECTORY}/{model_name}"
|
130 |
+
else:
|
131 |
+
return f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}"
|
132 |
+
|
133 |
+
|
134 |
+
with gr.Blocks() as interface:
|
135 |
+
model_file = gr.File(type="binary")
|
136 |
+
upload_button = gr.Button(
|
137 |
+
value="Upload",
|
138 |
+
)
|
139 |
+
with gr.Row():
|
140 |
+
with gr.Column(scale=2):
|
141 |
+
model_df = gr.Dataframe(
|
142 |
+
headers=["Available models"],
|
143 |
+
datatype=["str"],
|
144 |
+
interactive=False,
|
145 |
+
type="array",
|
146 |
+
value=list_models,
|
147 |
+
)
|
148 |
+
with gr.Column(scale=1):
|
149 |
+
with gr.Row():
|
150 |
+
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
|
151 |
+
conversion_status = gr.Textbox(
|
152 |
+
label="Conversion status",
|
153 |
+
lines=1,
|
154 |
+
interactive=False,
|
155 |
+
)
|
156 |
+
|
157 |
+
convert_button = gr.Button(
|
158 |
+
value="Convert",
|
159 |
+
)
|
160 |
+
describe_button = gr.Button(
|
161 |
+
value="Describe model",
|
162 |
+
)
|
163 |
+
model_description = gr.Textbox(
|
164 |
+
label="Model description",
|
165 |
+
lines=1,
|
166 |
+
interactive=False,
|
167 |
+
)
|
168 |
+
download_button = gr.Button(
|
169 |
+
value="Get download link",
|
170 |
+
)
|
171 |
+
download_file = gr.File(
|
172 |
+
type="filepath",
|
173 |
+
label="Download link",
|
174 |
+
interactive=False,
|
175 |
+
)
|
176 |
+
|
177 |
+
model_df.select(
|
178 |
+
on_select_model_df,
|
179 |
+
None,
|
180 |
+
model_name,
|
181 |
+
)
|
182 |
+
upload_button.click(
|
183 |
+
upload_model,
|
184 |
+
model_file,
|
185 |
+
model_df,
|
186 |
+
)
|
187 |
+
convert_button.click(
|
188 |
+
convert_model,
|
189 |
+
model_name,
|
190 |
+
[model_df, conversion_status],
|
191 |
+
)
|
192 |
+
describe_button.click(
|
193 |
+
get_model_description,
|
194 |
+
model_name,
|
195 |
+
model_description,
|
196 |
+
)
|
197 |
+
download_button.click(
|
198 |
+
get_model_path,
|
199 |
+
model_name,
|
200 |
+
download_file,
|
201 |
+
)
|
app/crp_interface.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for plotting policy.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import copy
|
6 |
+
|
7 |
+
import chess
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from demo import constants, utils, visualisation
|
11 |
+
|
12 |
+
cache = None
|
13 |
+
boards = None
|
14 |
+
board_index = 0
|
15 |
+
|
16 |
+
|
17 |
+
def list_models():
|
18 |
+
"""
|
19 |
+
List the models in the model directory.
|
20 |
+
"""
|
21 |
+
models_info = utils.get_models_info(leela=False)
|
22 |
+
return sorted([[model_info[0]] for model_info in models_info])
|
23 |
+
|
24 |
+
|
25 |
+
def on_select_model_df(
|
26 |
+
evt: gr.SelectData,
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
When a model is selected, update the statement.
|
30 |
+
"""
|
31 |
+
return evt.value
|
32 |
+
|
33 |
+
|
34 |
+
def compute_cache(
|
35 |
+
board_fen,
|
36 |
+
action_seq,
|
37 |
+
model_name,
|
38 |
+
plane_index,
|
39 |
+
history_index,
|
40 |
+
):
|
41 |
+
global cache
|
42 |
+
global boards
|
43 |
+
if model_name == "":
|
44 |
+
gr.Warning("No model selected.")
|
45 |
+
return None, None, None, None, None
|
46 |
+
try:
|
47 |
+
board = chess.Board(board_fen)
|
48 |
+
except ValueError:
|
49 |
+
board = chess.Board()
|
50 |
+
gr.Warning("Invalid FEN, using starting position.")
|
51 |
+
boards = [board.copy()]
|
52 |
+
if action_seq:
|
53 |
+
try:
|
54 |
+
if action_seq.startswith("1."):
|
55 |
+
for action in action_seq.split():
|
56 |
+
if action.endswith("."):
|
57 |
+
continue
|
58 |
+
board.push_san(action)
|
59 |
+
boards.append(board.copy())
|
60 |
+
else:
|
61 |
+
for action in action_seq.split():
|
62 |
+
board.push_uci(action)
|
63 |
+
boards.append(board.copy())
|
64 |
+
except ValueError:
|
65 |
+
gr.Warning(f"Invalid action {action} stopping before it.")
|
66 |
+
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "crp")
|
67 |
+
cache = []
|
68 |
+
for board in boards:
|
69 |
+
relevance = lens.compute_heatmap(board, wrapper)
|
70 |
+
cache.append(copy.deepcopy(relevance))
|
71 |
+
return (
|
72 |
+
*make_plot(
|
73 |
+
plane_index,
|
74 |
+
),
|
75 |
+
*make_history_plot(
|
76 |
+
history_index,
|
77 |
+
),
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
def make_plot(
|
82 |
+
plane_index,
|
83 |
+
):
|
84 |
+
global cache
|
85 |
+
global boards
|
86 |
+
global board_index
|
87 |
+
|
88 |
+
if cache is None:
|
89 |
+
gr.Warning("Cache not computed!")
|
90 |
+
return None, None, None
|
91 |
+
|
92 |
+
board = boards[board_index]
|
93 |
+
relevance_tensor = cache[board_index]
|
94 |
+
a_max = relevance_tensor.abs().max()
|
95 |
+
if a_max != 0:
|
96 |
+
relevance_tensor = relevance_tensor / a_max
|
97 |
+
vmin = -1
|
98 |
+
vmax = 1
|
99 |
+
heatmap = relevance_tensor[plane_index - 1].view(64)
|
100 |
+
if board.turn == chess.BLACK:
|
101 |
+
heatmap = heatmap.view(8, 8).flip(0).view(64)
|
102 |
+
svg_board, fig = visualisation.render_heatmap(board, heatmap, vmin=vmin, vmax=vmax)
|
103 |
+
with open(f"{constants.FIGURE_DIRECTORY}/lrp.svg", "w") as f:
|
104 |
+
f.write(svg_board)
|
105 |
+
return f"{constants.FIGURE_DIRECTORY}/lrp.svg", board.fen(), fig
|
106 |
+
|
107 |
+
|
108 |
+
def make_history_plot(
|
109 |
+
history_index,
|
110 |
+
):
|
111 |
+
global cache
|
112 |
+
global boards
|
113 |
+
global board_index
|
114 |
+
|
115 |
+
if cache is None:
|
116 |
+
gr.Warning("Cache not computed!")
|
117 |
+
return None, None
|
118 |
+
|
119 |
+
board = boards[board_index]
|
120 |
+
relevance_tensor = cache[board_index]
|
121 |
+
a_max = relevance_tensor.abs().max()
|
122 |
+
if a_max != 0:
|
123 |
+
relevance_tensor = relevance_tensor / a_max
|
124 |
+
vmin = -1
|
125 |
+
vmax = 1
|
126 |
+
heatmap = relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1].sum(dim=0).view(64)
|
127 |
+
if board.turn == chess.BLACK:
|
128 |
+
heatmap = heatmap.view(8, 8).flip(0).view(64)
|
129 |
+
if board_index - history_index + 1 < 0:
|
130 |
+
history_board = chess.Board(fen=None)
|
131 |
+
else:
|
132 |
+
history_board = boards[board_index - history_index + 1]
|
133 |
+
svg_board, fig = visualisation.render_heatmap(history_board, heatmap, vmin=vmin, vmax=vmax)
|
134 |
+
with open(f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", "w") as f:
|
135 |
+
f.write(svg_board)
|
136 |
+
return f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", fig
|
137 |
+
|
138 |
+
|
139 |
+
def previous_board(
|
140 |
+
plane_index,
|
141 |
+
history_index,
|
142 |
+
):
|
143 |
+
global board_index
|
144 |
+
board_index -= 1
|
145 |
+
if board_index < 0:
|
146 |
+
gr.Warning("Already at first board.")
|
147 |
+
board_index = 0
|
148 |
+
return (
|
149 |
+
*make_plot(
|
150 |
+
plane_index,
|
151 |
+
),
|
152 |
+
*make_history_plot(
|
153 |
+
history_index,
|
154 |
+
),
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
def next_board(
|
159 |
+
plane_index,
|
160 |
+
history_index,
|
161 |
+
):
|
162 |
+
global board_index
|
163 |
+
board_index += 1
|
164 |
+
if board_index >= len(boards):
|
165 |
+
gr.Warning("Already at last board.")
|
166 |
+
board_index = len(boards) - 1
|
167 |
+
return (
|
168 |
+
*make_plot(
|
169 |
+
plane_index,
|
170 |
+
),
|
171 |
+
*make_history_plot(
|
172 |
+
history_index,
|
173 |
+
),
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
with gr.Blocks() as interface:
|
178 |
+
with gr.Row():
|
179 |
+
with gr.Column(scale=2):
|
180 |
+
model_df = gr.Dataframe(
|
181 |
+
headers=["Available models"],
|
182 |
+
datatype=["str"],
|
183 |
+
interactive=False,
|
184 |
+
type="array",
|
185 |
+
value=list_models,
|
186 |
+
)
|
187 |
+
with gr.Column(scale=1):
|
188 |
+
with gr.Row():
|
189 |
+
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
|
190 |
+
|
191 |
+
model_df.select(
|
192 |
+
on_select_model_df,
|
193 |
+
None,
|
194 |
+
model_name,
|
195 |
+
)
|
196 |
+
|
197 |
+
with gr.Row():
|
198 |
+
with gr.Column():
|
199 |
+
board_fen = gr.Textbox(
|
200 |
+
label="Board starting FEN",
|
201 |
+
lines=1,
|
202 |
+
max_lines=1,
|
203 |
+
value=chess.STARTING_FEN,
|
204 |
+
)
|
205 |
+
action_seq = gr.Textbox(
|
206 |
+
label="Action sequence",
|
207 |
+
lines=1,
|
208 |
+
max_lines=1,
|
209 |
+
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
|
210 |
+
)
|
211 |
+
compute_cache_button = gr.Button("Compute heatmaps")
|
212 |
+
|
213 |
+
with gr.Group():
|
214 |
+
with gr.Row():
|
215 |
+
plane_index = gr.Slider(
|
216 |
+
label="Plane index",
|
217 |
+
minimum=1,
|
218 |
+
maximum=112,
|
219 |
+
step=1,
|
220 |
+
value=1,
|
221 |
+
)
|
222 |
+
with gr.Row():
|
223 |
+
previous_board_button = gr.Button("Previous board")
|
224 |
+
next_board_button = gr.Button("Next board")
|
225 |
+
current_board_fen = gr.Textbox(
|
226 |
+
label="Board FEN",
|
227 |
+
lines=1,
|
228 |
+
max_lines=1,
|
229 |
+
)
|
230 |
+
colorbar = gr.Plot(label="Colorbar")
|
231 |
+
with gr.Column():
|
232 |
+
image = gr.Image(label="Board")
|
233 |
+
|
234 |
+
with gr.Row():
|
235 |
+
with gr.Column():
|
236 |
+
with gr.Group():
|
237 |
+
with gr.Row():
|
238 |
+
histroy_index = gr.Slider(
|
239 |
+
label="History index",
|
240 |
+
minimum=1,
|
241 |
+
maximum=8,
|
242 |
+
step=1,
|
243 |
+
value=1,
|
244 |
+
)
|
245 |
+
history_colorbar = gr.Plot(label="Colorbar")
|
246 |
+
with gr.Column():
|
247 |
+
history_image = gr.Image(label="Board")
|
248 |
+
|
249 |
+
base_inputs = [
|
250 |
+
plane_index,
|
251 |
+
histroy_index,
|
252 |
+
]
|
253 |
+
outputs = [
|
254 |
+
image,
|
255 |
+
current_board_fen,
|
256 |
+
colorbar,
|
257 |
+
history_image,
|
258 |
+
history_colorbar,
|
259 |
+
]
|
260 |
+
|
261 |
+
compute_cache_button.click(
|
262 |
+
compute_cache,
|
263 |
+
inputs=[board_fen, action_seq, model_name] + base_inputs,
|
264 |
+
outputs=outputs,
|
265 |
+
)
|
266 |
+
|
267 |
+
previous_board_button.click(previous_board, inputs=base_inputs, outputs=outputs)
|
268 |
+
next_board_button.click(next_board, inputs=base_inputs, outputs=outputs)
|
269 |
+
|
270 |
+
plane_index.change(
|
271 |
+
make_plot,
|
272 |
+
inputs=plane_index,
|
273 |
+
outputs=[image, current_board_fen, colorbar],
|
274 |
+
)
|
275 |
+
histroy_index.change(
|
276 |
+
make_history_plot,
|
277 |
+
inputs=histroy_index,
|
278 |
+
outputs=[history_image, history_colorbar],
|
279 |
+
)
|
app/encoding_interface.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for plotting encodings.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import chess
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from demo import constants, visualisation
|
9 |
+
from lczerolens import board_encodings
|
10 |
+
|
11 |
+
|
12 |
+
def make_encoding_plot(
|
13 |
+
board_fen,
|
14 |
+
action_seq,
|
15 |
+
plane_index,
|
16 |
+
color_flip,
|
17 |
+
):
|
18 |
+
try:
|
19 |
+
board = chess.Board(board_fen)
|
20 |
+
except ValueError:
|
21 |
+
board = chess.Board()
|
22 |
+
gr.Warning("Invalid FEN, using starting position.")
|
23 |
+
if action_seq:
|
24 |
+
try:
|
25 |
+
for action in action_seq.split():
|
26 |
+
board.push_uci(action)
|
27 |
+
except ValueError:
|
28 |
+
gr.Warning("Invalid action sequence, using starting position.")
|
29 |
+
board = chess.Board()
|
30 |
+
board_tensor = board_encodings.board_to_input_tensor(board)
|
31 |
+
heatmap = board_tensor[plane_index]
|
32 |
+
if color_flip and board.turn == chess.BLACK:
|
33 |
+
heatmap = heatmap.flip(0)
|
34 |
+
svg_board, fig = visualisation.render_heatmap(board, heatmap.view(64), vmin=0.0, vmax=1.0)
|
35 |
+
with open(f"{constants.FIGURE_DIRECTORY}/encoding.svg", "w") as f:
|
36 |
+
f.write(svg_board)
|
37 |
+
return f"{constants.FIGURE_DIRECTORY}/encoding.svg", fig
|
38 |
+
|
39 |
+
|
40 |
+
with gr.Blocks() as interface:
|
41 |
+
with gr.Row():
|
42 |
+
with gr.Column():
|
43 |
+
board_fen = gr.Textbox(
|
44 |
+
label="Board starting FEN",
|
45 |
+
lines=1,
|
46 |
+
max_lines=1,
|
47 |
+
value=chess.STARTING_FEN,
|
48 |
+
)
|
49 |
+
action_seq = gr.Textbox(
|
50 |
+
label="Action sequence",
|
51 |
+
lines=1,
|
52 |
+
max_lines=1,
|
53 |
+
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
|
54 |
+
)
|
55 |
+
with gr.Group():
|
56 |
+
with gr.Row():
|
57 |
+
plane_index = gr.Slider(
|
58 |
+
label="Plane index",
|
59 |
+
minimum=0,
|
60 |
+
maximum=111,
|
61 |
+
step=1,
|
62 |
+
value=0,
|
63 |
+
scale=3,
|
64 |
+
)
|
65 |
+
color_flip = gr.Checkbox(label="Color flip", value=True, scale=1)
|
66 |
+
|
67 |
+
colorbar = gr.Plot(label="Colorbar")
|
68 |
+
with gr.Column():
|
69 |
+
image = gr.Image(label="Board")
|
70 |
+
|
71 |
+
policy_inputs = [
|
72 |
+
board_fen,
|
73 |
+
action_seq,
|
74 |
+
plane_index,
|
75 |
+
color_flip,
|
76 |
+
]
|
77 |
+
policy_outputs = [image, colorbar]
|
78 |
+
board_fen.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
|
79 |
+
action_seq.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
|
80 |
+
plane_index.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
|
81 |
+
color_flip.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
|
82 |
+
interface.load(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
|
app/figures/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
app/leela_models/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
app/lrp_interface.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for plotting policy.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import copy
|
6 |
+
|
7 |
+
import chess
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from demo import constants, utils, visualisation
|
11 |
+
|
12 |
+
cache = None
|
13 |
+
boards = None
|
14 |
+
board_index = 0
|
15 |
+
|
16 |
+
|
17 |
+
def list_models():
|
18 |
+
"""
|
19 |
+
List the models in the model directory.
|
20 |
+
"""
|
21 |
+
models_info = utils.get_models_info(leela=False)
|
22 |
+
return sorted([[model_info[0]] for model_info in models_info])
|
23 |
+
|
24 |
+
|
25 |
+
def on_select_model_df(
|
26 |
+
evt: gr.SelectData,
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
When a model is selected, update the statement.
|
30 |
+
"""
|
31 |
+
return evt.value
|
32 |
+
|
33 |
+
|
34 |
+
def compute_cache(
|
35 |
+
board_fen,
|
36 |
+
action_seq,
|
37 |
+
model_name,
|
38 |
+
plane_index,
|
39 |
+
history_index,
|
40 |
+
):
|
41 |
+
global cache
|
42 |
+
global boards
|
43 |
+
if model_name == "":
|
44 |
+
gr.Warning("No model selected.")
|
45 |
+
return None, None, None, None, None
|
46 |
+
try:
|
47 |
+
board = chess.Board(board_fen)
|
48 |
+
except ValueError:
|
49 |
+
board = chess.Board()
|
50 |
+
gr.Warning("Invalid FEN, using starting position.")
|
51 |
+
boards = [board.copy()]
|
52 |
+
if action_seq:
|
53 |
+
try:
|
54 |
+
if action_seq.startswith("1."):
|
55 |
+
for action in action_seq.split():
|
56 |
+
if action.endswith("."):
|
57 |
+
continue
|
58 |
+
board.push_san(action)
|
59 |
+
boards.append(board.copy())
|
60 |
+
else:
|
61 |
+
for action in action_seq.split():
|
62 |
+
board.push_uci(action)
|
63 |
+
boards.append(board.copy())
|
64 |
+
except ValueError:
|
65 |
+
gr.Warning(f"Invalid action {action} stopping before it.")
|
66 |
+
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "lrp")
|
67 |
+
cache = []
|
68 |
+
for board in boards:
|
69 |
+
relevance = lens.compute_heatmap(board, wrapper)
|
70 |
+
cache.append(copy.deepcopy(relevance))
|
71 |
+
return (
|
72 |
+
*make_plot(
|
73 |
+
plane_index,
|
74 |
+
),
|
75 |
+
*make_history_plot(
|
76 |
+
history_index,
|
77 |
+
),
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
def make_plot(
|
82 |
+
plane_index,
|
83 |
+
):
|
84 |
+
global cache
|
85 |
+
global boards
|
86 |
+
global board_index
|
87 |
+
|
88 |
+
if cache is None:
|
89 |
+
gr.Warning("Cache not computed!")
|
90 |
+
return None, None, None
|
91 |
+
|
92 |
+
board = boards[board_index]
|
93 |
+
relevance_tensor = cache[board_index]
|
94 |
+
a_max = relevance_tensor.abs().max()
|
95 |
+
if a_max != 0:
|
96 |
+
relevance_tensor = relevance_tensor / a_max
|
97 |
+
vmin = -1
|
98 |
+
vmax = 1
|
99 |
+
heatmap = relevance_tensor[plane_index - 1].view(64)
|
100 |
+
if board.turn == chess.BLACK:
|
101 |
+
heatmap = heatmap.view(8, 8).flip(0).view(64)
|
102 |
+
svg_board, fig = visualisation.render_heatmap(board, heatmap, vmin=vmin, vmax=vmax)
|
103 |
+
with open(f"{constants.FIGURE_DIRECTORY}/lrp.svg", "w") as f:
|
104 |
+
f.write(svg_board)
|
105 |
+
return f"{constants.FIGURE_DIRECTORY}/lrp.svg", board.fen(), fig
|
106 |
+
|
107 |
+
|
108 |
+
def make_history_plot(
|
109 |
+
history_index,
|
110 |
+
):
|
111 |
+
global cache
|
112 |
+
global boards
|
113 |
+
global board_index
|
114 |
+
|
115 |
+
if cache is None:
|
116 |
+
gr.Warning("Cache not computed!")
|
117 |
+
return None, None
|
118 |
+
|
119 |
+
board = boards[board_index]
|
120 |
+
relevance_tensor = cache[board_index]
|
121 |
+
a_max = relevance_tensor.abs().max()
|
122 |
+
if a_max != 0:
|
123 |
+
relevance_tensor = relevance_tensor / a_max
|
124 |
+
vmin = -1
|
125 |
+
vmax = 1
|
126 |
+
heatmap = relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1].sum(dim=0).view(64)
|
127 |
+
if board.turn == chess.BLACK:
|
128 |
+
heatmap = heatmap.view(8, 8).flip(0).view(64)
|
129 |
+
if board_index - history_index + 1 < 0:
|
130 |
+
history_board = chess.Board(fen=None)
|
131 |
+
else:
|
132 |
+
history_board = boards[board_index - history_index + 1]
|
133 |
+
svg_board, fig = visualisation.render_heatmap(history_board, heatmap, vmin=vmin, vmax=vmax)
|
134 |
+
with open(f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", "w") as f:
|
135 |
+
f.write(svg_board)
|
136 |
+
return f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", fig
|
137 |
+
|
138 |
+
|
139 |
+
def previous_board(
|
140 |
+
plane_index,
|
141 |
+
history_index,
|
142 |
+
):
|
143 |
+
global board_index
|
144 |
+
board_index -= 1
|
145 |
+
if board_index < 0:
|
146 |
+
gr.Warning("Already at first board.")
|
147 |
+
board_index = 0
|
148 |
+
return (
|
149 |
+
*make_plot(
|
150 |
+
plane_index,
|
151 |
+
),
|
152 |
+
*make_history_plot(
|
153 |
+
history_index,
|
154 |
+
),
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
def next_board(
|
159 |
+
plane_index,
|
160 |
+
history_index,
|
161 |
+
):
|
162 |
+
global board_index
|
163 |
+
board_index += 1
|
164 |
+
if board_index >= len(boards):
|
165 |
+
gr.Warning("Already at last board.")
|
166 |
+
board_index = len(boards) - 1
|
167 |
+
return (
|
168 |
+
*make_plot(
|
169 |
+
plane_index,
|
170 |
+
),
|
171 |
+
*make_history_plot(
|
172 |
+
history_index,
|
173 |
+
),
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
with gr.Blocks() as interface:
|
178 |
+
with gr.Row():
|
179 |
+
with gr.Column(scale=2):
|
180 |
+
model_df = gr.Dataframe(
|
181 |
+
headers=["Available models"],
|
182 |
+
datatype=["str"],
|
183 |
+
interactive=False,
|
184 |
+
type="array",
|
185 |
+
value=list_models,
|
186 |
+
)
|
187 |
+
with gr.Column(scale=1):
|
188 |
+
with gr.Row():
|
189 |
+
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
|
190 |
+
|
191 |
+
model_df.select(
|
192 |
+
on_select_model_df,
|
193 |
+
None,
|
194 |
+
model_name,
|
195 |
+
)
|
196 |
+
|
197 |
+
with gr.Row():
|
198 |
+
with gr.Column():
|
199 |
+
board_fen = gr.Textbox(
|
200 |
+
label="Board starting FEN",
|
201 |
+
lines=1,
|
202 |
+
max_lines=1,
|
203 |
+
value=chess.STARTING_FEN,
|
204 |
+
)
|
205 |
+
action_seq = gr.Textbox(
|
206 |
+
label="Action sequence",
|
207 |
+
lines=1,
|
208 |
+
max_lines=1,
|
209 |
+
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
|
210 |
+
)
|
211 |
+
compute_cache_button = gr.Button("Compute heatmaps")
|
212 |
+
|
213 |
+
with gr.Group():
|
214 |
+
with gr.Row():
|
215 |
+
plane_index = gr.Slider(
|
216 |
+
label="Plane index",
|
217 |
+
minimum=1,
|
218 |
+
maximum=112,
|
219 |
+
step=1,
|
220 |
+
value=1,
|
221 |
+
)
|
222 |
+
with gr.Row():
|
223 |
+
previous_board_button = gr.Button("Previous board")
|
224 |
+
next_board_button = gr.Button("Next board")
|
225 |
+
current_board_fen = gr.Textbox(
|
226 |
+
label="Board FEN",
|
227 |
+
lines=1,
|
228 |
+
max_lines=1,
|
229 |
+
)
|
230 |
+
colorbar = gr.Plot(label="Colorbar")
|
231 |
+
with gr.Column():
|
232 |
+
image = gr.Image(label="Board")
|
233 |
+
|
234 |
+
with gr.Row():
|
235 |
+
with gr.Column():
|
236 |
+
with gr.Group():
|
237 |
+
with gr.Row():
|
238 |
+
histroy_index = gr.Slider(
|
239 |
+
label="History index",
|
240 |
+
minimum=1,
|
241 |
+
maximum=8,
|
242 |
+
step=1,
|
243 |
+
value=1,
|
244 |
+
)
|
245 |
+
history_colorbar = gr.Plot(label="Colorbar")
|
246 |
+
with gr.Column():
|
247 |
+
history_image = gr.Image(label="Board")
|
248 |
+
|
249 |
+
base_inputs = [
|
250 |
+
plane_index,
|
251 |
+
histroy_index,
|
252 |
+
]
|
253 |
+
outputs = [
|
254 |
+
image,
|
255 |
+
current_board_fen,
|
256 |
+
colorbar,
|
257 |
+
history_image,
|
258 |
+
history_colorbar,
|
259 |
+
]
|
260 |
+
|
261 |
+
compute_cache_button.click(
|
262 |
+
compute_cache,
|
263 |
+
inputs=[board_fen, action_seq, model_name] + base_inputs,
|
264 |
+
outputs=outputs,
|
265 |
+
)
|
266 |
+
|
267 |
+
previous_board_button.click(previous_board, inputs=base_inputs, outputs=outputs)
|
268 |
+
next_board_button.click(next_board, inputs=base_inputs, outputs=outputs)
|
269 |
+
|
270 |
+
plane_index.change(
|
271 |
+
make_plot,
|
272 |
+
inputs=plane_index,
|
273 |
+
outputs=[image, current_board_fen, colorbar],
|
274 |
+
)
|
275 |
+
histroy_index.change(
|
276 |
+
make_history_plot,
|
277 |
+
inputs=histroy_index,
|
278 |
+
outputs=[history_image, history_colorbar],
|
279 |
+
)
|
app/main.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio demo for lczero-easy.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
from . import (
|
8 |
+
attention_interface,
|
9 |
+
backend_interface,
|
10 |
+
board_interface,
|
11 |
+
convert_interface,
|
12 |
+
crp_interface,
|
13 |
+
encoding_interface,
|
14 |
+
lrp_interface,
|
15 |
+
policy_interface,
|
16 |
+
statistics_interface,
|
17 |
+
)
|
18 |
+
|
19 |
+
demo = gr.TabbedInterface(
|
20 |
+
[
|
21 |
+
crp_interface.interface,
|
22 |
+
statistics_interface.interface,
|
23 |
+
lrp_interface.interface,
|
24 |
+
attention_interface.interface,
|
25 |
+
policy_interface.interface,
|
26 |
+
backend_interface.interface,
|
27 |
+
encoding_interface.interface,
|
28 |
+
board_interface.interface,
|
29 |
+
convert_interface.interface,
|
30 |
+
],
|
31 |
+
[
|
32 |
+
"CRP",
|
33 |
+
"Statistics",
|
34 |
+
"LRP",
|
35 |
+
"Attention",
|
36 |
+
"Policy",
|
37 |
+
"Backend",
|
38 |
+
"Encoding",
|
39 |
+
"Board",
|
40 |
+
"Convert",
|
41 |
+
],
|
42 |
+
title="LczeroLens Demo",
|
43 |
+
analytics_enabled=False,
|
44 |
+
)
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
demo.launch(
|
48 |
+
server_port=8000,
|
49 |
+
server_name="0.0.0.0",
|
50 |
+
)
|
app/onnx_models/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
app/policy_interface.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for visualizing the policy of a model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import chess
|
6 |
+
import chess.svg
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from demo import constants, utils, visualisation
|
11 |
+
from lczerolens import move_encodings
|
12 |
+
from lczerolens.xai import PolicyLens
|
13 |
+
|
14 |
+
current_board = None
|
15 |
+
current_raw_policy = None
|
16 |
+
current_policy = None
|
17 |
+
current_value = None
|
18 |
+
current_outcome = None
|
19 |
+
|
20 |
+
|
21 |
+
def list_models():
|
22 |
+
"""
|
23 |
+
List the models in the model directory.
|
24 |
+
"""
|
25 |
+
models_info = utils.get_models_info(leela=False)
|
26 |
+
return sorted([[model_info[0]] for model_info in models_info])
|
27 |
+
|
28 |
+
|
29 |
+
def on_select_model_df(
|
30 |
+
evt: gr.SelectData,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
When a model is selected, update the statement.
|
34 |
+
"""
|
35 |
+
return evt.value
|
36 |
+
|
37 |
+
|
38 |
+
def compute_policy(
|
39 |
+
board_fen,
|
40 |
+
action_seq,
|
41 |
+
model_name,
|
42 |
+
):
|
43 |
+
global current_board
|
44 |
+
global current_policy
|
45 |
+
global current_raw_policy
|
46 |
+
global current_value
|
47 |
+
global current_outcome
|
48 |
+
if model_name == "":
|
49 |
+
gr.Warning(
|
50 |
+
"Please select a model.",
|
51 |
+
)
|
52 |
+
return (
|
53 |
+
None,
|
54 |
+
None,
|
55 |
+
"",
|
56 |
+
)
|
57 |
+
try:
|
58 |
+
board = chess.Board(board_fen)
|
59 |
+
except ValueError:
|
60 |
+
gr.Warning("Invalid FEN.")
|
61 |
+
return (None, None, "", None)
|
62 |
+
if action_seq:
|
63 |
+
try:
|
64 |
+
for action in action_seq.split():
|
65 |
+
board.push_uci(action)
|
66 |
+
except ValueError:
|
67 |
+
gr.Warning("Invalid action sequence.")
|
68 |
+
return (None, None, "", None)
|
69 |
+
wrapper = utils.get_wrapper_from_state(model_name)
|
70 |
+
(output,) = wrapper.predict(board)
|
71 |
+
current_raw_policy = output["policy"][0]
|
72 |
+
policy = torch.softmax(output["policy"][0], dim=-1)
|
73 |
+
|
74 |
+
filtered_policy = torch.full((1858,), 0.0)
|
75 |
+
legal_moves = [move_encodings.encode_move(move, (board.turn, not board.turn)) for move in board.legal_moves]
|
76 |
+
filtered_policy[legal_moves] = policy[legal_moves]
|
77 |
+
policy = filtered_policy
|
78 |
+
|
79 |
+
current_board = board
|
80 |
+
current_policy = policy
|
81 |
+
current_value = output.get("value", None)
|
82 |
+
current_outcome = output.get("wdl", None)
|
83 |
+
|
84 |
+
|
85 |
+
def make_plot(
|
86 |
+
view,
|
87 |
+
aggregate_topk,
|
88 |
+
move_to_play,
|
89 |
+
):
|
90 |
+
global current_board
|
91 |
+
global current_policy
|
92 |
+
global current_raw_policy
|
93 |
+
global current_value
|
94 |
+
global current_outcome
|
95 |
+
|
96 |
+
if current_board is None or current_policy is None:
|
97 |
+
gr.Warning("Please compute a policy first.")
|
98 |
+
return (None, None, "", None)
|
99 |
+
|
100 |
+
pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(current_policy, int(aggregate_topk))
|
101 |
+
|
102 |
+
if view == "from":
|
103 |
+
if current_board.turn == chess.WHITE:
|
104 |
+
heatmap = pickup_agg
|
105 |
+
else:
|
106 |
+
heatmap = pickup_agg.view(8, 8).flip(0).view(64)
|
107 |
+
else:
|
108 |
+
if current_board.turn == chess.WHITE:
|
109 |
+
heatmap = dropoff_agg
|
110 |
+
else:
|
111 |
+
heatmap = dropoff_agg.view(8, 8).flip(0).view(64)
|
112 |
+
us_them = (current_board.turn, not current_board.turn)
|
113 |
+
topk_moves = torch.topk(current_policy, 50)
|
114 |
+
move = move_encodings.decode_move(topk_moves.indices[move_to_play - 1], us_them)
|
115 |
+
arrows = [(move.from_square, move.to_square)]
|
116 |
+
svg_board, fig = visualisation.render_heatmap(current_board, heatmap, arrows=arrows)
|
117 |
+
with open(f"{constants.FIGURE_DIRECTORY}/policy.svg", "w") as f:
|
118 |
+
f.write(svg_board)
|
119 |
+
fig_dist = visualisation.render_policy_distribution(
|
120 |
+
current_raw_policy,
|
121 |
+
[move_encodings.encode_move(move, us_them) for move in current_board.legal_moves],
|
122 |
+
)
|
123 |
+
return (
|
124 |
+
f"{constants.FIGURE_DIRECTORY}/policy.svg",
|
125 |
+
fig,
|
126 |
+
(f"Value: {current_value} - WDL: {current_outcome}"),
|
127 |
+
fig_dist,
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
def make_policy_plot(
|
132 |
+
board_fen,
|
133 |
+
action_seq,
|
134 |
+
view,
|
135 |
+
model_name,
|
136 |
+
aggregate_topk,
|
137 |
+
move_to_play,
|
138 |
+
):
|
139 |
+
compute_policy(
|
140 |
+
board_fen,
|
141 |
+
action_seq,
|
142 |
+
model_name,
|
143 |
+
)
|
144 |
+
return make_plot(
|
145 |
+
view,
|
146 |
+
aggregate_topk,
|
147 |
+
move_to_play,
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
def play_move(
|
152 |
+
board_fen,
|
153 |
+
action_seq,
|
154 |
+
view,
|
155 |
+
model_name,
|
156 |
+
aggregate_topk,
|
157 |
+
move_to_play,
|
158 |
+
):
|
159 |
+
global current_board
|
160 |
+
global current_policy
|
161 |
+
|
162 |
+
move = move_encodings.decode_move(
|
163 |
+
current_policy.topk(50).indices[move_to_play - 1],
|
164 |
+
(current_board.turn, not current_board.turn),
|
165 |
+
)
|
166 |
+
current_board.push(move)
|
167 |
+
action_seq = f"{action_seq} {move.uci()}"
|
168 |
+
compute_policy(
|
169 |
+
board_fen,
|
170 |
+
action_seq,
|
171 |
+
model_name,
|
172 |
+
)
|
173 |
+
return [
|
174 |
+
*make_plot(
|
175 |
+
view,
|
176 |
+
aggregate_topk,
|
177 |
+
1,
|
178 |
+
),
|
179 |
+
action_seq,
|
180 |
+
1,
|
181 |
+
]
|
182 |
+
|
183 |
+
|
184 |
+
with gr.Blocks() as interface:
|
185 |
+
with gr.Row():
|
186 |
+
with gr.Column(scale=2):
|
187 |
+
model_df = gr.Dataframe(
|
188 |
+
headers=["Available models"],
|
189 |
+
datatype=["str"],
|
190 |
+
interactive=False,
|
191 |
+
type="array",
|
192 |
+
value=list_models,
|
193 |
+
)
|
194 |
+
with gr.Column(scale=1):
|
195 |
+
with gr.Row():
|
196 |
+
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
|
197 |
+
model_df.select(
|
198 |
+
on_select_model_df,
|
199 |
+
None,
|
200 |
+
model_name,
|
201 |
+
)
|
202 |
+
|
203 |
+
with gr.Row():
|
204 |
+
with gr.Column():
|
205 |
+
board_fen = gr.Textbox(
|
206 |
+
label="Board FEN",
|
207 |
+
lines=1,
|
208 |
+
max_lines=1,
|
209 |
+
value=chess.STARTING_FEN,
|
210 |
+
)
|
211 |
+
action_seq = gr.Textbox(
|
212 |
+
label="Action sequence",
|
213 |
+
lines=1,
|
214 |
+
value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
|
215 |
+
)
|
216 |
+
with gr.Group():
|
217 |
+
with gr.Row():
|
218 |
+
aggregate_topk = gr.Slider(
|
219 |
+
label="Aggregate top k",
|
220 |
+
minimum=1,
|
221 |
+
maximum=1858,
|
222 |
+
step=1,
|
223 |
+
value=1858,
|
224 |
+
scale=3,
|
225 |
+
)
|
226 |
+
view = gr.Radio(
|
227 |
+
label="View",
|
228 |
+
choices=["from", "to"],
|
229 |
+
value="from",
|
230 |
+
scale=1,
|
231 |
+
)
|
232 |
+
with gr.Row():
|
233 |
+
move_to_play = gr.Slider(
|
234 |
+
label="Move to play",
|
235 |
+
minimum=1,
|
236 |
+
maximum=50,
|
237 |
+
step=1,
|
238 |
+
value=1,
|
239 |
+
scale=3,
|
240 |
+
)
|
241 |
+
play_button = gr.Button("Play")
|
242 |
+
|
243 |
+
policy_button = gr.Button("Compute policy")
|
244 |
+
colorbar = gr.Plot(label="Colorbar")
|
245 |
+
game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="")
|
246 |
+
with gr.Column():
|
247 |
+
image = gr.Image(label="Board")
|
248 |
+
density_plot = gr.Plot(label="Density")
|
249 |
+
|
250 |
+
policy_inputs = [
|
251 |
+
board_fen,
|
252 |
+
action_seq,
|
253 |
+
view,
|
254 |
+
model_name,
|
255 |
+
aggregate_topk,
|
256 |
+
move_to_play,
|
257 |
+
]
|
258 |
+
policy_outputs = [image, colorbar, game_info, density_plot]
|
259 |
+
policy_button.click(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
|
260 |
+
board_fen.submit(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
|
261 |
+
action_seq.submit(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
|
262 |
+
|
263 |
+
fast_inputs = [
|
264 |
+
view,
|
265 |
+
aggregate_topk,
|
266 |
+
move_to_play,
|
267 |
+
]
|
268 |
+
aggregate_topk.change(make_plot, inputs=fast_inputs, outputs=policy_outputs)
|
269 |
+
view.change(make_plot, inputs=fast_inputs, outputs=policy_outputs)
|
270 |
+
move_to_play.change(make_plot, inputs=fast_inputs, outputs=policy_outputs)
|
271 |
+
|
272 |
+
play_button.click(
|
273 |
+
play_move,
|
274 |
+
inputs=policy_inputs,
|
275 |
+
outputs=policy_outputs + [action_seq, move_to_play],
|
276 |
+
)
|
app/state.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Global state for the demo application.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Dict
|
6 |
+
|
7 |
+
from lczerolens import Lens, ModelWrapper
|
8 |
+
|
9 |
+
wrappers: Dict[str, ModelWrapper] = {}
|
10 |
+
|
11 |
+
lenses: Dict[str, Dict[str, Lens]] = {
|
12 |
+
"activation": {},
|
13 |
+
"lrp": {},
|
14 |
+
"crp": {},
|
15 |
+
"policy": {},
|
16 |
+
"probing": {},
|
17 |
+
"patching": {},
|
18 |
+
}
|
app/statistics_interface.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for visualizing the policy of a model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
from demo import utils, visualisation
|
8 |
+
from lczerolens import GameDataset
|
9 |
+
from lczerolens.xai import ConceptDataset, HasThreatConcept
|
10 |
+
|
11 |
+
current_policy_statistics = None
|
12 |
+
current_lrp_statistics = None
|
13 |
+
current_probing_statistics = None
|
14 |
+
dataset = GameDataset("assets/test_stockfish_10.jsonl")
|
15 |
+
check_concept = HasThreatConcept("K", relative=True)
|
16 |
+
unique_check_dataset = ConceptDataset.from_game_dataset(dataset)
|
17 |
+
unique_check_dataset.set_concept(check_concept)
|
18 |
+
|
19 |
+
|
20 |
+
def list_models():
|
21 |
+
"""
|
22 |
+
List the models in the model directory.
|
23 |
+
"""
|
24 |
+
models_info = utils.get_models_info(leela=False)
|
25 |
+
return sorted([[model_info[0]] for model_info in models_info])
|
26 |
+
|
27 |
+
|
28 |
+
def on_select_model_df(
|
29 |
+
evt: gr.SelectData,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
When a model is selected, update the statement.
|
33 |
+
"""
|
34 |
+
return evt.value
|
35 |
+
|
36 |
+
|
37 |
+
def compute_policy_statistics(
|
38 |
+
model_name,
|
39 |
+
):
|
40 |
+
global current_policy_statistics
|
41 |
+
global dataset
|
42 |
+
|
43 |
+
if model_name == "":
|
44 |
+
gr.Warning(
|
45 |
+
"Please select a model.",
|
46 |
+
)
|
47 |
+
return None
|
48 |
+
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "policy")
|
49 |
+
current_policy_statistics = lens.analyse_dataset(dataset, wrapper, 10)
|
50 |
+
return make_policy_plot()
|
51 |
+
|
52 |
+
|
53 |
+
def make_policy_plot():
|
54 |
+
global current_policy_statistics
|
55 |
+
|
56 |
+
if current_policy_statistics is None:
|
57 |
+
gr.Warning(
|
58 |
+
"Please compute policy statistics first.",
|
59 |
+
)
|
60 |
+
return None
|
61 |
+
else:
|
62 |
+
return visualisation.render_policy_statistics(current_policy_statistics)
|
63 |
+
|
64 |
+
|
65 |
+
def compute_lrp_statistics(
|
66 |
+
model_name,
|
67 |
+
):
|
68 |
+
global current_lrp_statistics
|
69 |
+
global dataset
|
70 |
+
|
71 |
+
if model_name == "":
|
72 |
+
gr.Warning(
|
73 |
+
"Please select a model.",
|
74 |
+
)
|
75 |
+
return None, None, None
|
76 |
+
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "lrp")
|
77 |
+
current_lrp_statistics = lens.compute_statistics(dataset, wrapper, 10)
|
78 |
+
return make_lrp_plot()
|
79 |
+
|
80 |
+
|
81 |
+
def make_lrp_plot():
|
82 |
+
global current_lrp_statistics
|
83 |
+
|
84 |
+
if current_lrp_statistics is None:
|
85 |
+
gr.Warning(
|
86 |
+
"Please compute LRP statistics first.",
|
87 |
+
)
|
88 |
+
return None, None, None
|
89 |
+
else:
|
90 |
+
return visualisation.render_relevance_proportion(current_lrp_statistics)
|
91 |
+
|
92 |
+
|
93 |
+
def compute_probing_statistics(
|
94 |
+
model_name,
|
95 |
+
):
|
96 |
+
global current_probing_statistics
|
97 |
+
global check_concept
|
98 |
+
global unique_check_dataset
|
99 |
+
|
100 |
+
if model_name == "":
|
101 |
+
gr.Warning(
|
102 |
+
"Please select a model.",
|
103 |
+
)
|
104 |
+
return None
|
105 |
+
wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "probing", concept=check_concept)
|
106 |
+
current_probing_statistics = lens.compute_statistics(unique_check_dataset, wrapper, 10)
|
107 |
+
return make_probing_plot()
|
108 |
+
|
109 |
+
|
110 |
+
def make_probing_plot():
|
111 |
+
global current_probing_statistics
|
112 |
+
|
113 |
+
if current_probing_statistics is None:
|
114 |
+
gr.Warning(
|
115 |
+
"Please compute probing statistics first.",
|
116 |
+
)
|
117 |
+
return None
|
118 |
+
else:
|
119 |
+
return visualisation.render_probing_statistics(current_probing_statistics)
|
120 |
+
|
121 |
+
|
122 |
+
with gr.Blocks() as interface:
|
123 |
+
with gr.Row():
|
124 |
+
with gr.Column(scale=2):
|
125 |
+
model_df = gr.Dataframe(
|
126 |
+
headers=["Available models"],
|
127 |
+
datatype=["str"],
|
128 |
+
interactive=False,
|
129 |
+
type="array",
|
130 |
+
value=list_models,
|
131 |
+
)
|
132 |
+
with gr.Column(scale=1):
|
133 |
+
with gr.Row():
|
134 |
+
model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
|
135 |
+
model_df.select(
|
136 |
+
on_select_model_df,
|
137 |
+
None,
|
138 |
+
model_name,
|
139 |
+
)
|
140 |
+
|
141 |
+
with gr.Row():
|
142 |
+
with gr.Column():
|
143 |
+
policy_plot = gr.Plot(label="Policy statistics")
|
144 |
+
policy_compute_button = gr.Button(value="Compute policy statistics")
|
145 |
+
policy_plot_button = gr.Button(value="Plot policy statistics")
|
146 |
+
|
147 |
+
policy_compute_button.click(
|
148 |
+
compute_policy_statistics,
|
149 |
+
inputs=[model_name],
|
150 |
+
outputs=[policy_plot],
|
151 |
+
)
|
152 |
+
policy_plot_button.click(make_policy_plot, outputs=[policy_plot])
|
153 |
+
|
154 |
+
with gr.Column():
|
155 |
+
lrp_plot_hist = gr.Plot(label="LRP history statistics")
|
156 |
+
|
157 |
+
with gr.Row():
|
158 |
+
with gr.Column():
|
159 |
+
lrp_plot_planes = gr.Plot(label="LRP planes statistics")
|
160 |
+
|
161 |
+
with gr.Column():
|
162 |
+
lrp_plot_pieces = gr.Plot(label="LRP pieces statistics")
|
163 |
+
|
164 |
+
with gr.Row():
|
165 |
+
lrp_compute_button = gr.Button(value="Compute LRP statistics")
|
166 |
+
with gr.Row():
|
167 |
+
lrp_plot_button = gr.Button(value="Plot LRP statistics")
|
168 |
+
|
169 |
+
lrp_compute_button.click(
|
170 |
+
compute_lrp_statistics,
|
171 |
+
inputs=[model_name],
|
172 |
+
outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
|
173 |
+
)
|
174 |
+
lrp_plot_button.click(
|
175 |
+
make_lrp_plot,
|
176 |
+
outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
|
177 |
+
)
|
178 |
+
|
179 |
+
with gr.Column():
|
180 |
+
probing_plot = gr.Plot(label="Probing statistics")
|
181 |
+
probing_compute_button = gr.Button(value="Compute probing statistics")
|
182 |
+
probing_plot_button = gr.Button(value="Plot probing statistics")
|
183 |
+
|
184 |
+
probing_compute_button.click(
|
185 |
+
compute_probing_statistics,
|
186 |
+
inputs=[model_name],
|
187 |
+
outputs=[probing_plot],
|
188 |
+
)
|
189 |
+
probing_plot_button.click(make_probing_plot, outputs=[probing_plot])
|
app/utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utils for the demo app.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import subprocess
|
8 |
+
|
9 |
+
from demo import constants, state
|
10 |
+
from lczerolens import LensFactory, LczeroModel
|
11 |
+
from lczerolens.model import lczero as lczero_utils
|
12 |
+
|
13 |
+
|
14 |
+
def get_models_info(onnx=True, leela=True):
|
15 |
+
"""
|
16 |
+
Get the names of the models in the model directory.
|
17 |
+
"""
|
18 |
+
model_df = []
|
19 |
+
exp = r"(?P<n_filters>\d+)x(?P<n_blocks>\d+)"
|
20 |
+
if onnx:
|
21 |
+
for filename in os.listdir(constants.MODEL_DIRECTORY):
|
22 |
+
if filename.endswith(".onnx"):
|
23 |
+
match = re.search(exp, filename)
|
24 |
+
if match is None:
|
25 |
+
n_filters = -1
|
26 |
+
n_blocks = -1
|
27 |
+
else:
|
28 |
+
n_filters = int(match.group("n_filters"))
|
29 |
+
n_blocks = int(match.group("n_blocks"))
|
30 |
+
model_df.append(
|
31 |
+
[
|
32 |
+
filename,
|
33 |
+
"ONNX",
|
34 |
+
n_blocks,
|
35 |
+
n_filters,
|
36 |
+
]
|
37 |
+
)
|
38 |
+
if leela:
|
39 |
+
for filename in os.listdir(constants.LEELA_MODEL_DIRECTORY):
|
40 |
+
if filename.endswith(".pb.gz"):
|
41 |
+
match = re.search(exp, filename)
|
42 |
+
if match is None:
|
43 |
+
n_filters = -1
|
44 |
+
n_blocks = -1
|
45 |
+
else:
|
46 |
+
n_filters = int(match.group("n_filters"))
|
47 |
+
n_blocks = int(match.group("n_blocks"))
|
48 |
+
model_df.append(
|
49 |
+
[
|
50 |
+
filename,
|
51 |
+
"LEELA",
|
52 |
+
n_blocks,
|
53 |
+
n_filters,
|
54 |
+
]
|
55 |
+
)
|
56 |
+
return model_df
|
57 |
+
|
58 |
+
|
59 |
+
def save_model(tmp_file_path):
|
60 |
+
"""
|
61 |
+
Save the model to the model directory.
|
62 |
+
"""
|
63 |
+
popen = subprocess.Popen(
|
64 |
+
["file", tmp_file_path],
|
65 |
+
stdout=subprocess.PIPE,
|
66 |
+
stderr=subprocess.PIPE,
|
67 |
+
)
|
68 |
+
popen.wait()
|
69 |
+
if popen.returncode != 0:
|
70 |
+
raise RuntimeError
|
71 |
+
file_desc = popen.stdout.read().decode("utf-8").split(tmp_file_path)[1].strip()
|
72 |
+
rename_match = re.search(r"was\s\"(?P<name>.+)\"", file_desc)
|
73 |
+
type_match = re.search(r"\:\s(?P<type>[a-zA-Z]+)", file_desc)
|
74 |
+
if rename_match is None or type_match is None:
|
75 |
+
raise RuntimeError
|
76 |
+
model_name = rename_match.group("name")
|
77 |
+
model_type = type_match.group("type")
|
78 |
+
if model_type != "gzip":
|
79 |
+
raise RuntimeError
|
80 |
+
os.rename(
|
81 |
+
tmp_file_path,
|
82 |
+
f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz",
|
83 |
+
)
|
84 |
+
try:
|
85 |
+
lczero_utils.describenet(
|
86 |
+
f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz",
|
87 |
+
)
|
88 |
+
except RuntimeError:
|
89 |
+
os.remove(f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz")
|
90 |
+
raise RuntimeError
|
91 |
+
|
92 |
+
|
93 |
+
def get_wrapper_from_state(model_name):
|
94 |
+
"""
|
95 |
+
Get the model wrapper from the state.
|
96 |
+
"""
|
97 |
+
if model_name in state.wrappers:
|
98 |
+
return state.wrappers[model_name]
|
99 |
+
else:
|
100 |
+
wrapper = LczeroModel.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}")
|
101 |
+
state.wrappers[model_name] = wrapper
|
102 |
+
return wrapper
|
103 |
+
|
104 |
+
|
105 |
+
def get_wrapper_lens_from_state(model_name, lens_type, lens_name="lens", **kwargs):
|
106 |
+
"""
|
107 |
+
Get the model wrapper and lens from the state.
|
108 |
+
"""
|
109 |
+
if model_name in state.wrappers:
|
110 |
+
wrapper = state.wrappers[model_name]
|
111 |
+
else:
|
112 |
+
wrapper = LczeroModel.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}")
|
113 |
+
state.wrappers[model_name] = wrapper
|
114 |
+
if lens_name in state.lenses[lens_type]:
|
115 |
+
lens = state.lenses[lens_type][lens_name]
|
116 |
+
else:
|
117 |
+
lens = LensFactory.from_name(lens_type, **kwargs)
|
118 |
+
if not lens.is_compatible(wrapper):
|
119 |
+
raise ValueError(f"Lens of type {lens_type} not compatible with model.")
|
120 |
+
state.lenses[lens_type][lens_name] = lens
|
121 |
+
return wrapper, lens
|
app/visualisation.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Visualisation utils.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import chess
|
6 |
+
import chess.svg
|
7 |
+
import matplotlib
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torchviz
|
12 |
+
|
13 |
+
from . import constants
|
14 |
+
|
15 |
+
COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000)
|
16 |
+
ALPHA = 1.0
|
17 |
+
|
18 |
+
|
19 |
+
def render_heatmap(
|
20 |
+
board,
|
21 |
+
heatmap,
|
22 |
+
square=None,
|
23 |
+
vmin=None,
|
24 |
+
vmax=None,
|
25 |
+
arrows=None,
|
26 |
+
normalise="none",
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
Render a heatmap on the board.
|
30 |
+
"""
|
31 |
+
if normalise == "abs":
|
32 |
+
a_max = heatmap.abs().max()
|
33 |
+
if a_max != 0:
|
34 |
+
heatmap = heatmap / a_max
|
35 |
+
vmin = -1
|
36 |
+
vmax = 1
|
37 |
+
if vmin is None:
|
38 |
+
vmin = heatmap.min()
|
39 |
+
if vmax is None:
|
40 |
+
vmax = heatmap.max()
|
41 |
+
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
|
42 |
+
|
43 |
+
color_dict = {}
|
44 |
+
for square_index in range(64):
|
45 |
+
color = COLOR_MAP(norm(heatmap[square_index]))
|
46 |
+
color = (*color[:3], ALPHA)
|
47 |
+
color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True)
|
48 |
+
fig = plt.figure(figsize=(6, 0.6))
|
49 |
+
ax = plt.gca()
|
50 |
+
ax.axis("off")
|
51 |
+
fig.colorbar(
|
52 |
+
matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP),
|
53 |
+
ax=ax,
|
54 |
+
orientation="horizontal",
|
55 |
+
fraction=1.0,
|
56 |
+
)
|
57 |
+
if square is not None:
|
58 |
+
try:
|
59 |
+
check = chess.parse_square(square)
|
60 |
+
except ValueError:
|
61 |
+
check = None
|
62 |
+
else:
|
63 |
+
check = None
|
64 |
+
if arrows is None:
|
65 |
+
arrows = []
|
66 |
+
plt.close()
|
67 |
+
return (
|
68 |
+
chess.svg.board(
|
69 |
+
board,
|
70 |
+
check=check,
|
71 |
+
fill=color_dict,
|
72 |
+
size=350,
|
73 |
+
arrows=arrows,
|
74 |
+
),
|
75 |
+
fig,
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
def render_architecture(model, name: str = "model", directory: str = ""):
|
80 |
+
"""
|
81 |
+
Render the architecture of the model.
|
82 |
+
"""
|
83 |
+
out = model(torch.zeros(1, 112, 8, 8))
|
84 |
+
if len(out) == 2:
|
85 |
+
policy, outcome_probs = out
|
86 |
+
value = torch.zeros(outcome_probs.shape[0], 1)
|
87 |
+
else:
|
88 |
+
policy, outcome_probs, value = out
|
89 |
+
torchviz.make_dot(policy, params=dict(list(model.named_parameters()))).render(
|
90 |
+
f"{directory}/{name}_policy", format="svg"
|
91 |
+
)
|
92 |
+
torchviz.make_dot(outcome_probs, params=dict(list(model.named_parameters()))).render(
|
93 |
+
f"{directory}/{name}_outcome_probs", format="svg"
|
94 |
+
)
|
95 |
+
torchviz.make_dot(value, params=dict(list(model.named_parameters()))).render(
|
96 |
+
f"{directory}/{name}_value", format="svg"
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
def render_policy_distribution(
|
101 |
+
policy,
|
102 |
+
legal_moves,
|
103 |
+
n_bins=20,
|
104 |
+
):
|
105 |
+
"""
|
106 |
+
Render the policy distribution histogram.
|
107 |
+
"""
|
108 |
+
legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool()
|
109 |
+
fig = plt.figure(figsize=(6, 6))
|
110 |
+
ax = plt.gca()
|
111 |
+
_, bins = np.histogram(policy, bins=n_bins)
|
112 |
+
ax.hist(
|
113 |
+
policy[~legal_mask],
|
114 |
+
bins=bins,
|
115 |
+
alpha=0.5,
|
116 |
+
density=True,
|
117 |
+
label="Illegal moves",
|
118 |
+
)
|
119 |
+
ax.hist(
|
120 |
+
policy[legal_mask],
|
121 |
+
bins=bins,
|
122 |
+
alpha=0.5,
|
123 |
+
density=True,
|
124 |
+
label="Legal moves",
|
125 |
+
)
|
126 |
+
plt.xlabel("Policy")
|
127 |
+
plt.ylabel("Density")
|
128 |
+
plt.legend()
|
129 |
+
plt.yscale("log")
|
130 |
+
return fig
|
131 |
+
|
132 |
+
|
133 |
+
def render_policy_statistics(
|
134 |
+
statistics,
|
135 |
+
):
|
136 |
+
"""
|
137 |
+
Render the policy statistics.
|
138 |
+
"""
|
139 |
+
fig = plt.figure(figsize=(6, 6))
|
140 |
+
ax = plt.gca()
|
141 |
+
move_indices = list(statistics["mean_legal_logits"].keys())
|
142 |
+
legal_means_avg = [np.mean(statistics["mean_legal_logits"][move_idx]) for move_idx in move_indices]
|
143 |
+
illegal_means_avg = [np.mean(statistics["mean_illegal_logits"][move_idx]) for move_idx in move_indices]
|
144 |
+
legal_means_std = [np.std(statistics["mean_legal_logits"][move_idx]) for move_idx in move_indices]
|
145 |
+
illegal_means_std = [np.std(statistics["mean_illegal_logits"][move_idx]) for move_idx in move_indices]
|
146 |
+
ax.errorbar(
|
147 |
+
move_indices,
|
148 |
+
legal_means_avg,
|
149 |
+
yerr=legal_means_std,
|
150 |
+
label="Legal moves",
|
151 |
+
)
|
152 |
+
ax.errorbar(
|
153 |
+
move_indices,
|
154 |
+
illegal_means_avg,
|
155 |
+
yerr=illegal_means_std,
|
156 |
+
label="Illegal moves",
|
157 |
+
)
|
158 |
+
plt.xlabel("Move index")
|
159 |
+
plt.ylabel("Mean policy logits")
|
160 |
+
plt.legend()
|
161 |
+
return fig
|
162 |
+
|
163 |
+
|
164 |
+
def render_relevance_proportion(statistics, scaled=True):
|
165 |
+
"""
|
166 |
+
Render the relevance proportion statistics.
|
167 |
+
"""
|
168 |
+
norm = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False)
|
169 |
+
fig_hist = plt.figure(figsize=(6, 6))
|
170 |
+
ax = plt.gca()
|
171 |
+
move_indices = list(statistics["planes_relevance_proportion"].keys())
|
172 |
+
for h in range(8):
|
173 |
+
relevance_proportion_avg = [
|
174 |
+
np.mean([rel[13 * h : 13 * (h + 1)].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
|
175 |
+
for move_idx in move_indices
|
176 |
+
]
|
177 |
+
relevance_proportion_std = [
|
178 |
+
np.std([rel[13 * h : 13 * (h + 1)].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
|
179 |
+
for move_idx in move_indices
|
180 |
+
]
|
181 |
+
ax.errorbar(
|
182 |
+
move_indices[h + 1 :],
|
183 |
+
relevance_proportion_avg[h + 1 :],
|
184 |
+
yerr=relevance_proportion_std[h + 1 :],
|
185 |
+
label=f"History {h}",
|
186 |
+
c=COLOR_MAP(norm(h / 9)),
|
187 |
+
)
|
188 |
+
|
189 |
+
relevance_proportion_avg = [
|
190 |
+
np.mean([rel[104:108].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
|
191 |
+
for move_idx in move_indices
|
192 |
+
]
|
193 |
+
relevance_proportion_std = [
|
194 |
+
np.std([rel[104:108].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
|
195 |
+
for move_idx in move_indices
|
196 |
+
]
|
197 |
+
ax.errorbar(
|
198 |
+
move_indices,
|
199 |
+
relevance_proportion_avg,
|
200 |
+
yerr=relevance_proportion_std,
|
201 |
+
label="Castling rights",
|
202 |
+
c=COLOR_MAP(norm(8 / 9)),
|
203 |
+
)
|
204 |
+
relevance_proportion_avg = [
|
205 |
+
np.mean([rel[108:].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
|
206 |
+
for move_idx in move_indices
|
207 |
+
]
|
208 |
+
relevance_proportion_std = [
|
209 |
+
np.std([rel[108:].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
|
210 |
+
for move_idx in move_indices
|
211 |
+
]
|
212 |
+
ax.errorbar(
|
213 |
+
move_indices,
|
214 |
+
relevance_proportion_avg,
|
215 |
+
yerr=relevance_proportion_std,
|
216 |
+
label="Remaining planes",
|
217 |
+
c=COLOR_MAP(norm(9 / 9)),
|
218 |
+
)
|
219 |
+
plt.xlabel("Move index")
|
220 |
+
plt.ylabel("Absolute relevance proportion")
|
221 |
+
plt.yscale("log")
|
222 |
+
plt.legend()
|
223 |
+
|
224 |
+
if scaled:
|
225 |
+
stat_key = "planes_relevance_proportion_scaled"
|
226 |
+
else:
|
227 |
+
stat_key = "planes_relevance_proportion"
|
228 |
+
fig_planes = plt.figure(figsize=(6, 6))
|
229 |
+
ax = plt.gca()
|
230 |
+
move_indices = list(statistics[stat_key].keys())
|
231 |
+
for p in range(13):
|
232 |
+
relevance_proportion_avg = [
|
233 |
+
np.mean([rel[p].item() for rel in statistics[stat_key][move_idx]]) for move_idx in move_indices
|
234 |
+
]
|
235 |
+
relevance_proportion_std = [
|
236 |
+
np.std([rel[p].item() for rel in statistics[stat_key][move_idx]]) for move_idx in move_indices
|
237 |
+
]
|
238 |
+
ax.errorbar(
|
239 |
+
move_indices,
|
240 |
+
relevance_proportion_avg,
|
241 |
+
yerr=relevance_proportion_std,
|
242 |
+
label=constants.PLANE_NAMES[p],
|
243 |
+
c=COLOR_MAP(norm(p / 12)),
|
244 |
+
)
|
245 |
+
|
246 |
+
plt.xlabel("Move index")
|
247 |
+
plt.ylabel("Absolute relevance proportion")
|
248 |
+
plt.yscale("log")
|
249 |
+
plt.legend()
|
250 |
+
|
251 |
+
fig_pieces = plt.figure(figsize=(6, 6))
|
252 |
+
ax = plt.gca()
|
253 |
+
for p in range(1, 13):
|
254 |
+
stat_key = f"configuration_relevance_proportion_threatened_piece{p}"
|
255 |
+
n_attackers = list(statistics[stat_key].keys())
|
256 |
+
relevance_proportion_avg = [
|
257 |
+
np.mean(statistics[f"configuration_relevance_proportion_threatened_piece{p}"][n]) for n in n_attackers
|
258 |
+
]
|
259 |
+
relevance_proportion_std = [np.std(statistics[stat_key][n]) for n in n_attackers]
|
260 |
+
ax.errorbar(
|
261 |
+
n_attackers,
|
262 |
+
relevance_proportion_avg,
|
263 |
+
yerr=relevance_proportion_std,
|
264 |
+
label="PNBRQKpnbrqk"[p - 1],
|
265 |
+
c=COLOR_MAP(norm(p / 12)),
|
266 |
+
)
|
267 |
+
|
268 |
+
plt.xlabel("Number of attackers")
|
269 |
+
plt.ylabel("Absolute configuration relevance proportion")
|
270 |
+
plt.yscale("log")
|
271 |
+
plt.legend()
|
272 |
+
|
273 |
+
return fig_hist, fig_planes, fig_pieces
|
274 |
+
|
275 |
+
|
276 |
+
def render_probing_statistics(
|
277 |
+
statistics,
|
278 |
+
):
|
279 |
+
"""
|
280 |
+
Render the probing statistics.
|
281 |
+
"""
|
282 |
+
fig = plt.figure(figsize=(6, 6))
|
283 |
+
ax = plt.gca()
|
284 |
+
n_blocks = len(statistics["metrics"])
|
285 |
+
for metric in statistics["metrics"]["block0"]:
|
286 |
+
avg = []
|
287 |
+
std = []
|
288 |
+
for block_idx in range(n_blocks):
|
289 |
+
metrics = statistics["metrics"]
|
290 |
+
block_data = metrics[f"block{block_idx}"]
|
291 |
+
avg.append(np.mean(block_data[metric]))
|
292 |
+
std.append(np.std(block_data[metric]))
|
293 |
+
ax.errorbar(
|
294 |
+
range(n_blocks),
|
295 |
+
avg,
|
296 |
+
yerr=std,
|
297 |
+
label=metric,
|
298 |
+
)
|
299 |
+
plt.xlabel("Block index")
|
300 |
+
plt.ylabel("Metric")
|
301 |
+
plt.yscale("log")
|
302 |
+
plt.legend()
|
303 |
+
return fig
|