Xmaster6y commited on
Commit
343fa36
·
1 Parent(s): fa9d807
app/__init__.py ADDED
File without changes
app/attention_interface.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting attention.
3
+ """
4
+
5
+ import copy
6
+
7
+ import chess
8
+ import gradio as gr
9
+
10
+ from demo import constants, utils, visualisation
11
+
12
+
13
+ def list_models():
14
+ """
15
+ List the models in the model directory.
16
+ """
17
+ models_info = utils.get_models_info(leela=False)
18
+ return sorted([[model_info[0]] for model_info in models_info])
19
+
20
+
21
+ def on_select_model_df(
22
+ evt: gr.SelectData,
23
+ ):
24
+ """
25
+ When a model is selected, update the statement.
26
+ """
27
+ return evt.value
28
+
29
+
30
+ def compute_cache(
31
+ board_fen,
32
+ action_seq,
33
+ model_name,
34
+ attention_layer,
35
+ attention_head,
36
+ square,
37
+ state_board_index,
38
+ state_boards,
39
+ state_cache,
40
+ ):
41
+ if model_name == "":
42
+ gr.Warning("No model selected.")
43
+ return None, None, None, state_boards, state_cache
44
+
45
+ try:
46
+ board = chess.Board(board_fen)
47
+ except ValueError:
48
+ board = chess.Board()
49
+ gr.Warning("Invalid FEN, using starting position.")
50
+ state_boards = [board.copy()]
51
+ if action_seq:
52
+ try:
53
+ if action_seq.startswith("1."):
54
+ for action in action_seq.split():
55
+ if action.endswith("."):
56
+ continue
57
+ board.push_san(action)
58
+ state_boards.append(board.copy())
59
+ else:
60
+ for action in action_seq.split():
61
+ board.push_uci(action)
62
+ state_boards.append(board.copy())
63
+ except ValueError:
64
+ gr.Warning(f"Invalid action {action} stopping before it.")
65
+ try:
66
+ wrapper, lens = utils.get_wrapper_lens_from_state(
67
+ model_name,
68
+ "activation",
69
+ lens_name="attention",
70
+ module_exp=r"encoder\d+/mha/QK/softmax",
71
+ )
72
+ except ValueError:
73
+ gr.Warning("Could not load model.")
74
+ return None, None, None, state_boards, state_cache
75
+ state_cache = []
76
+ for board in state_boards:
77
+ attention_cache = copy.deepcopy(lens.analyse_board(board, wrapper))
78
+ state_cache.append(attention_cache)
79
+ return (
80
+ *make_plot(
81
+ attention_layer,
82
+ attention_head,
83
+ square,
84
+ state_board_index,
85
+ state_boards,
86
+ state_cache,
87
+ ),
88
+ state_boards,
89
+ state_cache,
90
+ )
91
+
92
+
93
+ def make_plot(
94
+ attention_layer,
95
+ attention_head,
96
+ square,
97
+ state_board_index,
98
+ state_boards,
99
+ state_cache,
100
+ ):
101
+ if state_cache == []:
102
+ gr.Warning("No cache available.")
103
+ return None, None, None
104
+
105
+ board = state_boards[state_board_index]
106
+ num_attention_layers = len(state_cache[state_board_index])
107
+ if attention_layer > num_attention_layers:
108
+ gr.Warning(
109
+ f"Attention layer {attention_layer} does not exist, " f"using layer {num_attention_layers} instead."
110
+ )
111
+ attention_layer = num_attention_layers
112
+
113
+ key = f"encoder{attention_layer-1}/mha/QK/softmax"
114
+ try:
115
+ attention_tensor = state_cache[state_board_index][key]
116
+ except KeyError:
117
+ gr.Warning(f"Combination {key} does not exist.")
118
+ return None, None, None
119
+ if attention_head > attention_tensor.shape[1]:
120
+ gr.Warning(
121
+ f"Attention head {attention_head} does not exist, " f"using head {attention_tensor.shape[1]+1} instead."
122
+ )
123
+ attention_head = attention_tensor.shape[1]
124
+ try:
125
+ square_index = chess.SQUARE_NAMES.index(square)
126
+ except ValueError:
127
+ gr.Warning(f"Invalid square {square}, using a1 instead.")
128
+ square_index = 0
129
+ square = "a1"
130
+ if board.turn == chess.BLACK:
131
+ square_index = chess.square_mirror(square_index)
132
+
133
+ heatmap = attention_tensor[0, attention_head - 1, square_index]
134
+ if board.turn == chess.BLACK:
135
+ heatmap = heatmap.view(8, 8).flip(0).view(64)
136
+ svg_board, fig = visualisation.render_heatmap(board, heatmap, square=square)
137
+ with open(f"{constants.FIGURE_DIRECTORY}/attention.svg", "w") as f:
138
+ f.write(svg_board)
139
+ return f"{constants.FIGURE_DIRECTORY}/attention.svg", board.fen(), fig
140
+
141
+
142
+ def previous_board(
143
+ attention_layer,
144
+ attention_head,
145
+ square,
146
+ state_board_index,
147
+ state_boards,
148
+ state_cache,
149
+ ):
150
+ state_board_index -= 1
151
+ if state_board_index < 0:
152
+ gr.Warning("Already at first board.")
153
+ state_board_index = 0
154
+ return (
155
+ *make_plot(
156
+ attention_layer,
157
+ attention_head,
158
+ square,
159
+ state_board_index,
160
+ state_boards,
161
+ state_cache,
162
+ ),
163
+ state_board_index,
164
+ )
165
+
166
+
167
+ def next_board(
168
+ attention_layer,
169
+ attention_head,
170
+ square,
171
+ state_board_index,
172
+ state_boards,
173
+ state_cache,
174
+ ):
175
+ state_board_index += 1
176
+ if state_board_index >= len(state_boards):
177
+ gr.Warning("Already at last board.")
178
+ state_board_index = len(state_boards) - 1
179
+ return (
180
+ *make_plot(
181
+ attention_layer,
182
+ attention_head,
183
+ square,
184
+ state_board_index,
185
+ state_boards,
186
+ state_cache,
187
+ ),
188
+ state_board_index,
189
+ )
190
+
191
+
192
+ with gr.Blocks() as interface:
193
+ with gr.Row():
194
+ with gr.Column(scale=2):
195
+ model_df = gr.Dataframe(
196
+ headers=["Available models"],
197
+ datatype=["str"],
198
+ interactive=False,
199
+ type="array",
200
+ value=list_models,
201
+ )
202
+ with gr.Column(scale=1):
203
+ with gr.Row():
204
+ model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
205
+
206
+ model_df.select(
207
+ on_select_model_df,
208
+ None,
209
+ model_name,
210
+ )
211
+
212
+ with gr.Row():
213
+ with gr.Column():
214
+ board_fen = gr.Textbox(
215
+ label="Board starting FEN",
216
+ lines=1,
217
+ max_lines=1,
218
+ value=chess.STARTING_FEN,
219
+ )
220
+ action_seq = gr.Textbox(
221
+ label="Action sequence",
222
+ lines=1,
223
+ max_lines=1,
224
+ value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
225
+ )
226
+ compute_cache_button = gr.Button("Compute cache")
227
+
228
+ with gr.Group():
229
+ with gr.Row():
230
+ attention_layer = gr.Slider(
231
+ label="Attention layer",
232
+ minimum=1,
233
+ maximum=24,
234
+ step=1,
235
+ value=1,
236
+ )
237
+ attention_head = gr.Slider(
238
+ label="Attention head",
239
+ minimum=1,
240
+ maximum=24,
241
+ step=1,
242
+ value=1,
243
+ )
244
+ with gr.Row():
245
+ square = gr.Textbox(
246
+ label="Square",
247
+ lines=1,
248
+ max_lines=1,
249
+ value="a1",
250
+ scale=1,
251
+ )
252
+ with gr.Row():
253
+ previous_board_button = gr.Button("Previous board")
254
+ next_board_button = gr.Button("Next board")
255
+ current_board_fen = gr.Textbox(
256
+ label="Board FEN",
257
+ lines=1,
258
+ max_lines=1,
259
+ )
260
+ colorbar = gr.Plot(label="Colorbar")
261
+ with gr.Column():
262
+ image = gr.Image(label="Board")
263
+
264
+ state_board_index = gr.State(0)
265
+ state_boards = gr.State([])
266
+ state_cache = gr.State([])
267
+ base_inputs = [
268
+ attention_layer,
269
+ attention_head,
270
+ square,
271
+ state_board_index,
272
+ state_boards,
273
+ state_cache,
274
+ ]
275
+ outputs = [image, current_board_fen, colorbar]
276
+
277
+ compute_cache_button.click(
278
+ compute_cache,
279
+ inputs=[board_fen, action_seq, model_name] + base_inputs,
280
+ outputs=outputs + [state_boards, state_cache],
281
+ )
282
+
283
+ previous_board_button.click(
284
+ previous_board,
285
+ inputs=base_inputs,
286
+ outputs=outputs + [state_board_index],
287
+ )
288
+ next_board_button.click(next_board, inputs=base_inputs, outputs=outputs + [state_board_index])
289
+
290
+ attention_layer.change(make_plot, inputs=base_inputs, outputs=outputs)
291
+ attention_head.change(make_plot, inputs=base_inputs, outputs=outputs)
292
+ square.submit(make_plot, inputs=base_inputs, outputs=outputs)
app/backend_interface.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for visualizing the policy of a model.
3
+ """
4
+
5
+ import chess
6
+ import chess.svg
7
+ import gradio as gr
8
+ import torch
9
+ from lczero.backends import Backend, GameState, Weights
10
+
11
+ from demo import constants, utils, visualisation
12
+ from lczerolens import move_encodings
13
+ from lczerolens.model import lczero as lczero_utils
14
+ from lczerolens.xai import PolicyLens
15
+
16
+
17
+ def list_models():
18
+ """
19
+ List the models in the model directory.
20
+ """
21
+ models_info = utils.get_models_info(onnx=False)
22
+ return sorted([[model_info[0]] for model_info in models_info])
23
+
24
+
25
+ def on_select_model_df(
26
+ evt: gr.SelectData,
27
+ ):
28
+ """
29
+ When a model is selected, update the statement.
30
+ """
31
+ return evt.value
32
+
33
+
34
+ def make_policy_plot(
35
+ board_fen,
36
+ action_seq,
37
+ view,
38
+ model_name,
39
+ depth,
40
+ use_softmax,
41
+ aggregate_topk,
42
+ render_bestk,
43
+ only_legal,
44
+ ):
45
+ if model_name == "":
46
+ gr.Warning(
47
+ "Please select a model.",
48
+ )
49
+ return (
50
+ None,
51
+ None,
52
+ "",
53
+ )
54
+ try:
55
+ board = chess.Board(board_fen)
56
+ except ValueError:
57
+ board = chess.Board()
58
+ gr.Warning("Invalid FEN, using starting position.")
59
+ if action_seq:
60
+ try:
61
+ for action in action_seq.split():
62
+ board.push_uci(action)
63
+ except ValueError:
64
+ gr.Warning("Invalid action sequence, using starting position.")
65
+ board = chess.Board()
66
+ lczero_weights = Weights(f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}")
67
+ lczero_backend = Backend(lczero_weights)
68
+ uci_moves = [move.uci() for move in board.move_stack]
69
+ lczero_game = GameState(moves=uci_moves)
70
+ policy, value = lczero_utils.prediction_from_backend(
71
+ lczero_backend,
72
+ lczero_game,
73
+ softmax=use_softmax,
74
+ only_legal=only_legal,
75
+ illegal_value=0,
76
+ )
77
+ pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(policy, int(aggregate_topk))
78
+
79
+ if view == "from":
80
+ if board.turn == chess.WHITE:
81
+ heatmap = pickup_agg
82
+ else:
83
+ heatmap = pickup_agg.view(8, 8).flip(0).view(64)
84
+ else:
85
+ if board.turn == chess.WHITE:
86
+ heatmap = dropoff_agg
87
+ else:
88
+ heatmap = dropoff_agg.view(8, 8).flip(0).view(64)
89
+ us_them = (board.turn, not board.turn)
90
+ if only_legal:
91
+ legal_moves = [move_encodings.encode_move(move, us_them) for move in board.legal_moves]
92
+ filtered_policy = torch.zeros(1858)
93
+ filtered_policy[legal_moves] = policy[legal_moves]
94
+ if (filtered_policy < 0).any():
95
+ gr.Warning("Some legal moves have negative policy.")
96
+ topk_moves = torch.topk(filtered_policy, render_bestk)
97
+ else:
98
+ topk_moves = torch.topk(policy, render_bestk)
99
+ arrows = []
100
+ for move_index in topk_moves.indices:
101
+ move = move_encodings.decode_move(move_index, us_them)
102
+ arrows.append((move.from_square, move.to_square))
103
+ svg_board, fig = visualisation.render_heatmap(board, heatmap, arrows=arrows)
104
+ with open(f"{constants.FIGURE_DIRECTORY}/policy.svg", "w") as f:
105
+ f.write(svg_board)
106
+ raw_policy, _ = lczero_utils.prediction_from_backend(
107
+ lczero_backend,
108
+ lczero_game,
109
+ softmax=False,
110
+ only_legal=False,
111
+ illegal_value=0,
112
+ )
113
+ fig_dist = visualisation.render_policy_distribution(
114
+ raw_policy,
115
+ [move_encodings.encode_move(move, us_them) for move in board.legal_moves],
116
+ )
117
+ return (
118
+ f"{constants.FIGURE_DIRECTORY}/policy.svg",
119
+ fig,
120
+ (f"Value: {value:.2f}"),
121
+ fig_dist,
122
+ )
123
+
124
+
125
+ with gr.Blocks() as interface:
126
+ with gr.Row():
127
+ with gr.Column(scale=2):
128
+ model_df = gr.Dataframe(
129
+ headers=["Available models"],
130
+ datatype=["str"],
131
+ interactive=False,
132
+ type="array",
133
+ value=list_models,
134
+ )
135
+ with gr.Column(scale=1):
136
+ with gr.Row():
137
+ model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
138
+
139
+ model_df.select(
140
+ on_select_model_df,
141
+ None,
142
+ model_name,
143
+ )
144
+ with gr.Row():
145
+ with gr.Column():
146
+ board_fen = gr.Textbox(
147
+ label="Board FEN",
148
+ lines=1,
149
+ max_lines=1,
150
+ value=chess.STARTING_FEN,
151
+ )
152
+ action_seq = gr.Textbox(
153
+ label="Action sequence",
154
+ lines=1,
155
+ max_lines=1,
156
+ value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
157
+ )
158
+ with gr.Group():
159
+ with gr.Row():
160
+ depth = gr.Radio(label="Depth", choices=[0], value=0)
161
+ use_softmax = gr.Checkbox(label="Use softmax", value=True)
162
+ with gr.Row():
163
+ aggregate_topk = gr.Slider(
164
+ label="Aggregate top k",
165
+ minimum=1,
166
+ maximum=1858,
167
+ step=1,
168
+ value=1858,
169
+ scale=3,
170
+ )
171
+ view = gr.Radio(
172
+ label="View",
173
+ choices=["from", "to"],
174
+ value="from",
175
+ scale=1,
176
+ )
177
+ with gr.Row():
178
+ render_bestk = gr.Slider(
179
+ label="Render best k",
180
+ minimum=1,
181
+ maximum=5,
182
+ step=1,
183
+ value=5,
184
+ scale=3,
185
+ )
186
+ only_legal = gr.Checkbox(label="Only legal", value=True, scale=1)
187
+
188
+ policy_button = gr.Button("Plot policy")
189
+ colorbar = gr.Plot(label="Colorbar")
190
+ game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="")
191
+ with gr.Column():
192
+ image = gr.Image(label="Board")
193
+ density_plot = gr.Plot(label="Density")
194
+
195
+ policy_inputs = [
196
+ board_fen,
197
+ action_seq,
198
+ view,
199
+ model_name,
200
+ depth,
201
+ use_softmax,
202
+ aggregate_topk,
203
+ render_bestk,
204
+ only_legal,
205
+ ]
206
+ policy_outputs = [image, colorbar, game_info, density_plot]
207
+ policy_button.click(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
app/board_interface.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting a board.
3
+ """
4
+
5
+ import chess
6
+ import gradio as gr
7
+
8
+ from demo import constants
9
+
10
+
11
+ def make_board_plot(board_fen, arrows, square):
12
+ try:
13
+ board = chess.Board(board_fen)
14
+ except ValueError:
15
+ board = chess.Board()
16
+ gr.Warning("Invalid FEN, using starting position.")
17
+ try:
18
+ if arrows:
19
+ arrows_list = arrows.split(" ")
20
+ chess_arrows = []
21
+ for arrow in arrows_list:
22
+ from_square, to_square = arrow[:2], arrow[2:]
23
+ chess_arrows.append(
24
+ (
25
+ chess.parse_square(from_square),
26
+ chess.parse_square(to_square),
27
+ )
28
+ )
29
+ else:
30
+ chess_arrows = []
31
+ except ValueError:
32
+ chess_arrows = []
33
+ gr.Warning("Invalid arrows, using none.")
34
+
35
+ color_dict = {chess.parse_square(square): "#FF0000"} if square else {}
36
+ svg_board = chess.svg.board(
37
+ board,
38
+ size=350,
39
+ arrows=chess_arrows,
40
+ fill=color_dict,
41
+ )
42
+ with open(f"{constants.FIGURE_DIRECTORY}/board.svg", "w") as f:
43
+ f.write(svg_board)
44
+ return f"{constants.FIGURE_DIRECTORY}/board.svg"
45
+
46
+
47
+ with gr.Blocks() as interface:
48
+ with gr.Row():
49
+ with gr.Column():
50
+ board_fen = gr.Textbox(
51
+ label="Board starting FEN",
52
+ lines=1,
53
+ max_lines=1,
54
+ value=chess.STARTING_FEN,
55
+ )
56
+ arrows = gr.Textbox(
57
+ label="Arrows",
58
+ lines=1,
59
+ max_lines=1,
60
+ value="",
61
+ placeholder="e2e4 e7e5",
62
+ )
63
+ square = gr.Textbox(
64
+ label="Square",
65
+ lines=1,
66
+ max_lines=1,
67
+ value="",
68
+ placeholder="e4",
69
+ )
70
+ with gr.Column():
71
+ image = gr.Image(label="Board", interactive=False)
72
+
73
+ inputs = [
74
+ board_fen,
75
+ arrows,
76
+ square,
77
+ ]
78
+ board_fen.submit(make_board_plot, inputs=inputs, outputs=image)
79
+ arrows.submit(make_board_plot, inputs=inputs, outputs=image)
80
+ interface.load(make_board_plot, inputs=inputs, outputs=image)
app/constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants for the demo.
3
+ """
4
+
5
+ MODEL_DIRECTORY = "demo/onnx_models"
6
+ LEELA_MODEL_DIRECTORY = "demo/leela_models"
7
+ FIGURE_DIRECTORY = "demo/figures"
app/convert_interface.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for converting models.
3
+ """
4
+
5
+ import os
6
+ import uuid
7
+
8
+ import gradio as gr
9
+
10
+ from demo import constants, utils
11
+ from lczerolens.model import lczero as lczero_utils
12
+
13
+
14
+ def list_models():
15
+ """
16
+ List the models in the model directory.
17
+ """
18
+ models_info = utils.get_models_info()
19
+ return sorted([[model_info[0]] for model_info in models_info])
20
+
21
+
22
+ def on_select_model_df(
23
+ evt: gr.SelectData,
24
+ ):
25
+ """
26
+ When a model is selected, update the statement.
27
+ """
28
+ return evt.value
29
+
30
+
31
+ def convert_model(
32
+ model_name: str,
33
+ ):
34
+ """
35
+ Convert the model.
36
+ """
37
+ if model_name == "":
38
+ gr.Warning(
39
+ "Please select a model.",
40
+ )
41
+ return list_models(), ""
42
+ if model_name.endswith(".onnx"):
43
+ gr.Warning(
44
+ "ONNX conversion not implemented.",
45
+ )
46
+ return list_models(), ""
47
+ try:
48
+ lczero_utils.convert_to_onnx(
49
+ f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}",
50
+ f"{constants.MODEL_DIRECTORY}/{model_name[:-6]}.onnx",
51
+ )
52
+ except RuntimeError:
53
+ gr.Warning(
54
+ f"Could not convert net at `{model_name}`.",
55
+ )
56
+ return list_models(), "Conversion failed"
57
+ return list_models(), "Conversion successful"
58
+
59
+
60
+ def upload_model(
61
+ model_file: gr.File,
62
+ ):
63
+ """
64
+ Convert the model.
65
+ """
66
+ if model_file is None:
67
+ gr.Warning(
68
+ "File not uploaded.",
69
+ )
70
+ return list_models()
71
+ try:
72
+ id = uuid.uuid4()
73
+ tmp_file_path = f"{constants.LEELA_MODEL_DIRECTORY}/{id}"
74
+ with open(
75
+ tmp_file_path,
76
+ "wb",
77
+ ) as f:
78
+ f.write(model_file)
79
+ utils.save_model(tmp_file_path)
80
+ except RuntimeError:
81
+ gr.Warning(
82
+ "Invalid file type.",
83
+ )
84
+ finally:
85
+ if os.path.exists(tmp_file_path):
86
+ os.remove(tmp_file_path)
87
+ return list_models()
88
+
89
+
90
+ def get_model_description(
91
+ model_name: str,
92
+ ):
93
+ """
94
+ Get the model description.
95
+ """
96
+ if model_name == "":
97
+ gr.Warning(
98
+ "Please select a model.",
99
+ )
100
+ return ""
101
+ if model_name.endswith(".onnx"):
102
+ gr.Warning(
103
+ "ONNX description not implemented.",
104
+ )
105
+ return ""
106
+ try:
107
+ description = lczero_utils.describenet(
108
+ f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}",
109
+ )
110
+ except RuntimeError:
111
+ raise gr.Error(
112
+ f"Could not describe net at `{model_name}`.",
113
+ )
114
+ return description
115
+
116
+
117
+ def get_model_path(
118
+ model_name: str,
119
+ ):
120
+ """
121
+ Get the model path.
122
+ """
123
+ if model_name == "":
124
+ gr.Warning(
125
+ "Please select a model.",
126
+ )
127
+ return None
128
+ if model_name.endswith(".onnx"):
129
+ return f"{constants.MODEL_DIRECTORY}/{model_name}"
130
+ else:
131
+ return f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}"
132
+
133
+
134
+ with gr.Blocks() as interface:
135
+ model_file = gr.File(type="binary")
136
+ upload_button = gr.Button(
137
+ value="Upload",
138
+ )
139
+ with gr.Row():
140
+ with gr.Column(scale=2):
141
+ model_df = gr.Dataframe(
142
+ headers=["Available models"],
143
+ datatype=["str"],
144
+ interactive=False,
145
+ type="array",
146
+ value=list_models,
147
+ )
148
+ with gr.Column(scale=1):
149
+ with gr.Row():
150
+ model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
151
+ conversion_status = gr.Textbox(
152
+ label="Conversion status",
153
+ lines=1,
154
+ interactive=False,
155
+ )
156
+
157
+ convert_button = gr.Button(
158
+ value="Convert",
159
+ )
160
+ describe_button = gr.Button(
161
+ value="Describe model",
162
+ )
163
+ model_description = gr.Textbox(
164
+ label="Model description",
165
+ lines=1,
166
+ interactive=False,
167
+ )
168
+ download_button = gr.Button(
169
+ value="Get download link",
170
+ )
171
+ download_file = gr.File(
172
+ type="filepath",
173
+ label="Download link",
174
+ interactive=False,
175
+ )
176
+
177
+ model_df.select(
178
+ on_select_model_df,
179
+ None,
180
+ model_name,
181
+ )
182
+ upload_button.click(
183
+ upload_model,
184
+ model_file,
185
+ model_df,
186
+ )
187
+ convert_button.click(
188
+ convert_model,
189
+ model_name,
190
+ [model_df, conversion_status],
191
+ )
192
+ describe_button.click(
193
+ get_model_description,
194
+ model_name,
195
+ model_description,
196
+ )
197
+ download_button.click(
198
+ get_model_path,
199
+ model_name,
200
+ download_file,
201
+ )
app/crp_interface.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting policy.
3
+ """
4
+
5
+ import copy
6
+
7
+ import chess
8
+ import gradio as gr
9
+
10
+ from demo import constants, utils, visualisation
11
+
12
+ cache = None
13
+ boards = None
14
+ board_index = 0
15
+
16
+
17
+ def list_models():
18
+ """
19
+ List the models in the model directory.
20
+ """
21
+ models_info = utils.get_models_info(leela=False)
22
+ return sorted([[model_info[0]] for model_info in models_info])
23
+
24
+
25
+ def on_select_model_df(
26
+ evt: gr.SelectData,
27
+ ):
28
+ """
29
+ When a model is selected, update the statement.
30
+ """
31
+ return evt.value
32
+
33
+
34
+ def compute_cache(
35
+ board_fen,
36
+ action_seq,
37
+ model_name,
38
+ plane_index,
39
+ history_index,
40
+ ):
41
+ global cache
42
+ global boards
43
+ if model_name == "":
44
+ gr.Warning("No model selected.")
45
+ return None, None, None, None, None
46
+ try:
47
+ board = chess.Board(board_fen)
48
+ except ValueError:
49
+ board = chess.Board()
50
+ gr.Warning("Invalid FEN, using starting position.")
51
+ boards = [board.copy()]
52
+ if action_seq:
53
+ try:
54
+ if action_seq.startswith("1."):
55
+ for action in action_seq.split():
56
+ if action.endswith("."):
57
+ continue
58
+ board.push_san(action)
59
+ boards.append(board.copy())
60
+ else:
61
+ for action in action_seq.split():
62
+ board.push_uci(action)
63
+ boards.append(board.copy())
64
+ except ValueError:
65
+ gr.Warning(f"Invalid action {action} stopping before it.")
66
+ wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "crp")
67
+ cache = []
68
+ for board in boards:
69
+ relevance = lens.compute_heatmap(board, wrapper)
70
+ cache.append(copy.deepcopy(relevance))
71
+ return (
72
+ *make_plot(
73
+ plane_index,
74
+ ),
75
+ *make_history_plot(
76
+ history_index,
77
+ ),
78
+ )
79
+
80
+
81
+ def make_plot(
82
+ plane_index,
83
+ ):
84
+ global cache
85
+ global boards
86
+ global board_index
87
+
88
+ if cache is None:
89
+ gr.Warning("Cache not computed!")
90
+ return None, None, None
91
+
92
+ board = boards[board_index]
93
+ relevance_tensor = cache[board_index]
94
+ a_max = relevance_tensor.abs().max()
95
+ if a_max != 0:
96
+ relevance_tensor = relevance_tensor / a_max
97
+ vmin = -1
98
+ vmax = 1
99
+ heatmap = relevance_tensor[plane_index - 1].view(64)
100
+ if board.turn == chess.BLACK:
101
+ heatmap = heatmap.view(8, 8).flip(0).view(64)
102
+ svg_board, fig = visualisation.render_heatmap(board, heatmap, vmin=vmin, vmax=vmax)
103
+ with open(f"{constants.FIGURE_DIRECTORY}/lrp.svg", "w") as f:
104
+ f.write(svg_board)
105
+ return f"{constants.FIGURE_DIRECTORY}/lrp.svg", board.fen(), fig
106
+
107
+
108
+ def make_history_plot(
109
+ history_index,
110
+ ):
111
+ global cache
112
+ global boards
113
+ global board_index
114
+
115
+ if cache is None:
116
+ gr.Warning("Cache not computed!")
117
+ return None, None
118
+
119
+ board = boards[board_index]
120
+ relevance_tensor = cache[board_index]
121
+ a_max = relevance_tensor.abs().max()
122
+ if a_max != 0:
123
+ relevance_tensor = relevance_tensor / a_max
124
+ vmin = -1
125
+ vmax = 1
126
+ heatmap = relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1].sum(dim=0).view(64)
127
+ if board.turn == chess.BLACK:
128
+ heatmap = heatmap.view(8, 8).flip(0).view(64)
129
+ if board_index - history_index + 1 < 0:
130
+ history_board = chess.Board(fen=None)
131
+ else:
132
+ history_board = boards[board_index - history_index + 1]
133
+ svg_board, fig = visualisation.render_heatmap(history_board, heatmap, vmin=vmin, vmax=vmax)
134
+ with open(f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", "w") as f:
135
+ f.write(svg_board)
136
+ return f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", fig
137
+
138
+
139
+ def previous_board(
140
+ plane_index,
141
+ history_index,
142
+ ):
143
+ global board_index
144
+ board_index -= 1
145
+ if board_index < 0:
146
+ gr.Warning("Already at first board.")
147
+ board_index = 0
148
+ return (
149
+ *make_plot(
150
+ plane_index,
151
+ ),
152
+ *make_history_plot(
153
+ history_index,
154
+ ),
155
+ )
156
+
157
+
158
+ def next_board(
159
+ plane_index,
160
+ history_index,
161
+ ):
162
+ global board_index
163
+ board_index += 1
164
+ if board_index >= len(boards):
165
+ gr.Warning("Already at last board.")
166
+ board_index = len(boards) - 1
167
+ return (
168
+ *make_plot(
169
+ plane_index,
170
+ ),
171
+ *make_history_plot(
172
+ history_index,
173
+ ),
174
+ )
175
+
176
+
177
+ with gr.Blocks() as interface:
178
+ with gr.Row():
179
+ with gr.Column(scale=2):
180
+ model_df = gr.Dataframe(
181
+ headers=["Available models"],
182
+ datatype=["str"],
183
+ interactive=False,
184
+ type="array",
185
+ value=list_models,
186
+ )
187
+ with gr.Column(scale=1):
188
+ with gr.Row():
189
+ model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
190
+
191
+ model_df.select(
192
+ on_select_model_df,
193
+ None,
194
+ model_name,
195
+ )
196
+
197
+ with gr.Row():
198
+ with gr.Column():
199
+ board_fen = gr.Textbox(
200
+ label="Board starting FEN",
201
+ lines=1,
202
+ max_lines=1,
203
+ value=chess.STARTING_FEN,
204
+ )
205
+ action_seq = gr.Textbox(
206
+ label="Action sequence",
207
+ lines=1,
208
+ max_lines=1,
209
+ value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
210
+ )
211
+ compute_cache_button = gr.Button("Compute heatmaps")
212
+
213
+ with gr.Group():
214
+ with gr.Row():
215
+ plane_index = gr.Slider(
216
+ label="Plane index",
217
+ minimum=1,
218
+ maximum=112,
219
+ step=1,
220
+ value=1,
221
+ )
222
+ with gr.Row():
223
+ previous_board_button = gr.Button("Previous board")
224
+ next_board_button = gr.Button("Next board")
225
+ current_board_fen = gr.Textbox(
226
+ label="Board FEN",
227
+ lines=1,
228
+ max_lines=1,
229
+ )
230
+ colorbar = gr.Plot(label="Colorbar")
231
+ with gr.Column():
232
+ image = gr.Image(label="Board")
233
+
234
+ with gr.Row():
235
+ with gr.Column():
236
+ with gr.Group():
237
+ with gr.Row():
238
+ histroy_index = gr.Slider(
239
+ label="History index",
240
+ minimum=1,
241
+ maximum=8,
242
+ step=1,
243
+ value=1,
244
+ )
245
+ history_colorbar = gr.Plot(label="Colorbar")
246
+ with gr.Column():
247
+ history_image = gr.Image(label="Board")
248
+
249
+ base_inputs = [
250
+ plane_index,
251
+ histroy_index,
252
+ ]
253
+ outputs = [
254
+ image,
255
+ current_board_fen,
256
+ colorbar,
257
+ history_image,
258
+ history_colorbar,
259
+ ]
260
+
261
+ compute_cache_button.click(
262
+ compute_cache,
263
+ inputs=[board_fen, action_seq, model_name] + base_inputs,
264
+ outputs=outputs,
265
+ )
266
+
267
+ previous_board_button.click(previous_board, inputs=base_inputs, outputs=outputs)
268
+ next_board_button.click(next_board, inputs=base_inputs, outputs=outputs)
269
+
270
+ plane_index.change(
271
+ make_plot,
272
+ inputs=plane_index,
273
+ outputs=[image, current_board_fen, colorbar],
274
+ )
275
+ histroy_index.change(
276
+ make_history_plot,
277
+ inputs=histroy_index,
278
+ outputs=[history_image, history_colorbar],
279
+ )
app/encoding_interface.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting encodings.
3
+ """
4
+
5
+ import chess
6
+ import gradio as gr
7
+
8
+ from demo import constants, visualisation
9
+ from lczerolens import board_encodings
10
+
11
+
12
+ def make_encoding_plot(
13
+ board_fen,
14
+ action_seq,
15
+ plane_index,
16
+ color_flip,
17
+ ):
18
+ try:
19
+ board = chess.Board(board_fen)
20
+ except ValueError:
21
+ board = chess.Board()
22
+ gr.Warning("Invalid FEN, using starting position.")
23
+ if action_seq:
24
+ try:
25
+ for action in action_seq.split():
26
+ board.push_uci(action)
27
+ except ValueError:
28
+ gr.Warning("Invalid action sequence, using starting position.")
29
+ board = chess.Board()
30
+ board_tensor = board_encodings.board_to_input_tensor(board)
31
+ heatmap = board_tensor[plane_index]
32
+ if color_flip and board.turn == chess.BLACK:
33
+ heatmap = heatmap.flip(0)
34
+ svg_board, fig = visualisation.render_heatmap(board, heatmap.view(64), vmin=0.0, vmax=1.0)
35
+ with open(f"{constants.FIGURE_DIRECTORY}/encoding.svg", "w") as f:
36
+ f.write(svg_board)
37
+ return f"{constants.FIGURE_DIRECTORY}/encoding.svg", fig
38
+
39
+
40
+ with gr.Blocks() as interface:
41
+ with gr.Row():
42
+ with gr.Column():
43
+ board_fen = gr.Textbox(
44
+ label="Board starting FEN",
45
+ lines=1,
46
+ max_lines=1,
47
+ value=chess.STARTING_FEN,
48
+ )
49
+ action_seq = gr.Textbox(
50
+ label="Action sequence",
51
+ lines=1,
52
+ max_lines=1,
53
+ value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
54
+ )
55
+ with gr.Group():
56
+ with gr.Row():
57
+ plane_index = gr.Slider(
58
+ label="Plane index",
59
+ minimum=0,
60
+ maximum=111,
61
+ step=1,
62
+ value=0,
63
+ scale=3,
64
+ )
65
+ color_flip = gr.Checkbox(label="Color flip", value=True, scale=1)
66
+
67
+ colorbar = gr.Plot(label="Colorbar")
68
+ with gr.Column():
69
+ image = gr.Image(label="Board")
70
+
71
+ policy_inputs = [
72
+ board_fen,
73
+ action_seq,
74
+ plane_index,
75
+ color_flip,
76
+ ]
77
+ policy_outputs = [image, colorbar]
78
+ board_fen.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
79
+ action_seq.submit(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
80
+ plane_index.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
81
+ color_flip.change(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
82
+ interface.load(make_encoding_plot, inputs=policy_inputs, outputs=policy_outputs)
app/figures/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
app/leela_models/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
app/lrp_interface.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting policy.
3
+ """
4
+
5
+ import copy
6
+
7
+ import chess
8
+ import gradio as gr
9
+
10
+ from demo import constants, utils, visualisation
11
+
12
+ cache = None
13
+ boards = None
14
+ board_index = 0
15
+
16
+
17
+ def list_models():
18
+ """
19
+ List the models in the model directory.
20
+ """
21
+ models_info = utils.get_models_info(leela=False)
22
+ return sorted([[model_info[0]] for model_info in models_info])
23
+
24
+
25
+ def on_select_model_df(
26
+ evt: gr.SelectData,
27
+ ):
28
+ """
29
+ When a model is selected, update the statement.
30
+ """
31
+ return evt.value
32
+
33
+
34
+ def compute_cache(
35
+ board_fen,
36
+ action_seq,
37
+ model_name,
38
+ plane_index,
39
+ history_index,
40
+ ):
41
+ global cache
42
+ global boards
43
+ if model_name == "":
44
+ gr.Warning("No model selected.")
45
+ return None, None, None, None, None
46
+ try:
47
+ board = chess.Board(board_fen)
48
+ except ValueError:
49
+ board = chess.Board()
50
+ gr.Warning("Invalid FEN, using starting position.")
51
+ boards = [board.copy()]
52
+ if action_seq:
53
+ try:
54
+ if action_seq.startswith("1."):
55
+ for action in action_seq.split():
56
+ if action.endswith("."):
57
+ continue
58
+ board.push_san(action)
59
+ boards.append(board.copy())
60
+ else:
61
+ for action in action_seq.split():
62
+ board.push_uci(action)
63
+ boards.append(board.copy())
64
+ except ValueError:
65
+ gr.Warning(f"Invalid action {action} stopping before it.")
66
+ wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "lrp")
67
+ cache = []
68
+ for board in boards:
69
+ relevance = lens.compute_heatmap(board, wrapper)
70
+ cache.append(copy.deepcopy(relevance))
71
+ return (
72
+ *make_plot(
73
+ plane_index,
74
+ ),
75
+ *make_history_plot(
76
+ history_index,
77
+ ),
78
+ )
79
+
80
+
81
+ def make_plot(
82
+ plane_index,
83
+ ):
84
+ global cache
85
+ global boards
86
+ global board_index
87
+
88
+ if cache is None:
89
+ gr.Warning("Cache not computed!")
90
+ return None, None, None
91
+
92
+ board = boards[board_index]
93
+ relevance_tensor = cache[board_index]
94
+ a_max = relevance_tensor.abs().max()
95
+ if a_max != 0:
96
+ relevance_tensor = relevance_tensor / a_max
97
+ vmin = -1
98
+ vmax = 1
99
+ heatmap = relevance_tensor[plane_index - 1].view(64)
100
+ if board.turn == chess.BLACK:
101
+ heatmap = heatmap.view(8, 8).flip(0).view(64)
102
+ svg_board, fig = visualisation.render_heatmap(board, heatmap, vmin=vmin, vmax=vmax)
103
+ with open(f"{constants.FIGURE_DIRECTORY}/lrp.svg", "w") as f:
104
+ f.write(svg_board)
105
+ return f"{constants.FIGURE_DIRECTORY}/lrp.svg", board.fen(), fig
106
+
107
+
108
+ def make_history_plot(
109
+ history_index,
110
+ ):
111
+ global cache
112
+ global boards
113
+ global board_index
114
+
115
+ if cache is None:
116
+ gr.Warning("Cache not computed!")
117
+ return None, None
118
+
119
+ board = boards[board_index]
120
+ relevance_tensor = cache[board_index]
121
+ a_max = relevance_tensor.abs().max()
122
+ if a_max != 0:
123
+ relevance_tensor = relevance_tensor / a_max
124
+ vmin = -1
125
+ vmax = 1
126
+ heatmap = relevance_tensor[13 * (history_index - 1) : 13 * history_index - 1].sum(dim=0).view(64)
127
+ if board.turn == chess.BLACK:
128
+ heatmap = heatmap.view(8, 8).flip(0).view(64)
129
+ if board_index - history_index + 1 < 0:
130
+ history_board = chess.Board(fen=None)
131
+ else:
132
+ history_board = boards[board_index - history_index + 1]
133
+ svg_board, fig = visualisation.render_heatmap(history_board, heatmap, vmin=vmin, vmax=vmax)
134
+ with open(f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", "w") as f:
135
+ f.write(svg_board)
136
+ return f"{constants.FIGURE_DIRECTORY}/lrp_history.svg", fig
137
+
138
+
139
+ def previous_board(
140
+ plane_index,
141
+ history_index,
142
+ ):
143
+ global board_index
144
+ board_index -= 1
145
+ if board_index < 0:
146
+ gr.Warning("Already at first board.")
147
+ board_index = 0
148
+ return (
149
+ *make_plot(
150
+ plane_index,
151
+ ),
152
+ *make_history_plot(
153
+ history_index,
154
+ ),
155
+ )
156
+
157
+
158
+ def next_board(
159
+ plane_index,
160
+ history_index,
161
+ ):
162
+ global board_index
163
+ board_index += 1
164
+ if board_index >= len(boards):
165
+ gr.Warning("Already at last board.")
166
+ board_index = len(boards) - 1
167
+ return (
168
+ *make_plot(
169
+ plane_index,
170
+ ),
171
+ *make_history_plot(
172
+ history_index,
173
+ ),
174
+ )
175
+
176
+
177
+ with gr.Blocks() as interface:
178
+ with gr.Row():
179
+ with gr.Column(scale=2):
180
+ model_df = gr.Dataframe(
181
+ headers=["Available models"],
182
+ datatype=["str"],
183
+ interactive=False,
184
+ type="array",
185
+ value=list_models,
186
+ )
187
+ with gr.Column(scale=1):
188
+ with gr.Row():
189
+ model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
190
+
191
+ model_df.select(
192
+ on_select_model_df,
193
+ None,
194
+ model_name,
195
+ )
196
+
197
+ with gr.Row():
198
+ with gr.Column():
199
+ board_fen = gr.Textbox(
200
+ label="Board starting FEN",
201
+ lines=1,
202
+ max_lines=1,
203
+ value=chess.STARTING_FEN,
204
+ )
205
+ action_seq = gr.Textbox(
206
+ label="Action sequence",
207
+ lines=1,
208
+ max_lines=1,
209
+ value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
210
+ )
211
+ compute_cache_button = gr.Button("Compute heatmaps")
212
+
213
+ with gr.Group():
214
+ with gr.Row():
215
+ plane_index = gr.Slider(
216
+ label="Plane index",
217
+ minimum=1,
218
+ maximum=112,
219
+ step=1,
220
+ value=1,
221
+ )
222
+ with gr.Row():
223
+ previous_board_button = gr.Button("Previous board")
224
+ next_board_button = gr.Button("Next board")
225
+ current_board_fen = gr.Textbox(
226
+ label="Board FEN",
227
+ lines=1,
228
+ max_lines=1,
229
+ )
230
+ colorbar = gr.Plot(label="Colorbar")
231
+ with gr.Column():
232
+ image = gr.Image(label="Board")
233
+
234
+ with gr.Row():
235
+ with gr.Column():
236
+ with gr.Group():
237
+ with gr.Row():
238
+ histroy_index = gr.Slider(
239
+ label="History index",
240
+ minimum=1,
241
+ maximum=8,
242
+ step=1,
243
+ value=1,
244
+ )
245
+ history_colorbar = gr.Plot(label="Colorbar")
246
+ with gr.Column():
247
+ history_image = gr.Image(label="Board")
248
+
249
+ base_inputs = [
250
+ plane_index,
251
+ histroy_index,
252
+ ]
253
+ outputs = [
254
+ image,
255
+ current_board_fen,
256
+ colorbar,
257
+ history_image,
258
+ history_colorbar,
259
+ ]
260
+
261
+ compute_cache_button.click(
262
+ compute_cache,
263
+ inputs=[board_fen, action_seq, model_name] + base_inputs,
264
+ outputs=outputs,
265
+ )
266
+
267
+ previous_board_button.click(previous_board, inputs=base_inputs, outputs=outputs)
268
+ next_board_button.click(next_board, inputs=base_inputs, outputs=outputs)
269
+
270
+ plane_index.change(
271
+ make_plot,
272
+ inputs=plane_index,
273
+ outputs=[image, current_board_fen, colorbar],
274
+ )
275
+ histroy_index.change(
276
+ make_history_plot,
277
+ inputs=histroy_index,
278
+ outputs=[history_image, history_colorbar],
279
+ )
app/main.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio demo for lczero-easy.
3
+ """
4
+
5
+ import gradio as gr
6
+
7
+ from . import (
8
+ attention_interface,
9
+ backend_interface,
10
+ board_interface,
11
+ convert_interface,
12
+ crp_interface,
13
+ encoding_interface,
14
+ lrp_interface,
15
+ policy_interface,
16
+ statistics_interface,
17
+ )
18
+
19
+ demo = gr.TabbedInterface(
20
+ [
21
+ crp_interface.interface,
22
+ statistics_interface.interface,
23
+ lrp_interface.interface,
24
+ attention_interface.interface,
25
+ policy_interface.interface,
26
+ backend_interface.interface,
27
+ encoding_interface.interface,
28
+ board_interface.interface,
29
+ convert_interface.interface,
30
+ ],
31
+ [
32
+ "CRP",
33
+ "Statistics",
34
+ "LRP",
35
+ "Attention",
36
+ "Policy",
37
+ "Backend",
38
+ "Encoding",
39
+ "Board",
40
+ "Convert",
41
+ ],
42
+ title="LczeroLens Demo",
43
+ analytics_enabled=False,
44
+ )
45
+
46
+ if __name__ == "__main__":
47
+ demo.launch(
48
+ server_port=8000,
49
+ server_name="0.0.0.0",
50
+ )
app/onnx_models/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
app/policy_interface.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for visualizing the policy of a model.
3
+ """
4
+
5
+ import chess
6
+ import chess.svg
7
+ import gradio as gr
8
+ import torch
9
+
10
+ from demo import constants, utils, visualisation
11
+ from lczerolens import move_encodings
12
+ from lczerolens.xai import PolicyLens
13
+
14
+ current_board = None
15
+ current_raw_policy = None
16
+ current_policy = None
17
+ current_value = None
18
+ current_outcome = None
19
+
20
+
21
+ def list_models():
22
+ """
23
+ List the models in the model directory.
24
+ """
25
+ models_info = utils.get_models_info(leela=False)
26
+ return sorted([[model_info[0]] for model_info in models_info])
27
+
28
+
29
+ def on_select_model_df(
30
+ evt: gr.SelectData,
31
+ ):
32
+ """
33
+ When a model is selected, update the statement.
34
+ """
35
+ return evt.value
36
+
37
+
38
+ def compute_policy(
39
+ board_fen,
40
+ action_seq,
41
+ model_name,
42
+ ):
43
+ global current_board
44
+ global current_policy
45
+ global current_raw_policy
46
+ global current_value
47
+ global current_outcome
48
+ if model_name == "":
49
+ gr.Warning(
50
+ "Please select a model.",
51
+ )
52
+ return (
53
+ None,
54
+ None,
55
+ "",
56
+ )
57
+ try:
58
+ board = chess.Board(board_fen)
59
+ except ValueError:
60
+ gr.Warning("Invalid FEN.")
61
+ return (None, None, "", None)
62
+ if action_seq:
63
+ try:
64
+ for action in action_seq.split():
65
+ board.push_uci(action)
66
+ except ValueError:
67
+ gr.Warning("Invalid action sequence.")
68
+ return (None, None, "", None)
69
+ wrapper = utils.get_wrapper_from_state(model_name)
70
+ (output,) = wrapper.predict(board)
71
+ current_raw_policy = output["policy"][0]
72
+ policy = torch.softmax(output["policy"][0], dim=-1)
73
+
74
+ filtered_policy = torch.full((1858,), 0.0)
75
+ legal_moves = [move_encodings.encode_move(move, (board.turn, not board.turn)) for move in board.legal_moves]
76
+ filtered_policy[legal_moves] = policy[legal_moves]
77
+ policy = filtered_policy
78
+
79
+ current_board = board
80
+ current_policy = policy
81
+ current_value = output.get("value", None)
82
+ current_outcome = output.get("wdl", None)
83
+
84
+
85
+ def make_plot(
86
+ view,
87
+ aggregate_topk,
88
+ move_to_play,
89
+ ):
90
+ global current_board
91
+ global current_policy
92
+ global current_raw_policy
93
+ global current_value
94
+ global current_outcome
95
+
96
+ if current_board is None or current_policy is None:
97
+ gr.Warning("Please compute a policy first.")
98
+ return (None, None, "", None)
99
+
100
+ pickup_agg, dropoff_agg = PolicyLens.aggregate_policy(current_policy, int(aggregate_topk))
101
+
102
+ if view == "from":
103
+ if current_board.turn == chess.WHITE:
104
+ heatmap = pickup_agg
105
+ else:
106
+ heatmap = pickup_agg.view(8, 8).flip(0).view(64)
107
+ else:
108
+ if current_board.turn == chess.WHITE:
109
+ heatmap = dropoff_agg
110
+ else:
111
+ heatmap = dropoff_agg.view(8, 8).flip(0).view(64)
112
+ us_them = (current_board.turn, not current_board.turn)
113
+ topk_moves = torch.topk(current_policy, 50)
114
+ move = move_encodings.decode_move(topk_moves.indices[move_to_play - 1], us_them)
115
+ arrows = [(move.from_square, move.to_square)]
116
+ svg_board, fig = visualisation.render_heatmap(current_board, heatmap, arrows=arrows)
117
+ with open(f"{constants.FIGURE_DIRECTORY}/policy.svg", "w") as f:
118
+ f.write(svg_board)
119
+ fig_dist = visualisation.render_policy_distribution(
120
+ current_raw_policy,
121
+ [move_encodings.encode_move(move, us_them) for move in current_board.legal_moves],
122
+ )
123
+ return (
124
+ f"{constants.FIGURE_DIRECTORY}/policy.svg",
125
+ fig,
126
+ (f"Value: {current_value} - WDL: {current_outcome}"),
127
+ fig_dist,
128
+ )
129
+
130
+
131
+ def make_policy_plot(
132
+ board_fen,
133
+ action_seq,
134
+ view,
135
+ model_name,
136
+ aggregate_topk,
137
+ move_to_play,
138
+ ):
139
+ compute_policy(
140
+ board_fen,
141
+ action_seq,
142
+ model_name,
143
+ )
144
+ return make_plot(
145
+ view,
146
+ aggregate_topk,
147
+ move_to_play,
148
+ )
149
+
150
+
151
+ def play_move(
152
+ board_fen,
153
+ action_seq,
154
+ view,
155
+ model_name,
156
+ aggregate_topk,
157
+ move_to_play,
158
+ ):
159
+ global current_board
160
+ global current_policy
161
+
162
+ move = move_encodings.decode_move(
163
+ current_policy.topk(50).indices[move_to_play - 1],
164
+ (current_board.turn, not current_board.turn),
165
+ )
166
+ current_board.push(move)
167
+ action_seq = f"{action_seq} {move.uci()}"
168
+ compute_policy(
169
+ board_fen,
170
+ action_seq,
171
+ model_name,
172
+ )
173
+ return [
174
+ *make_plot(
175
+ view,
176
+ aggregate_topk,
177
+ 1,
178
+ ),
179
+ action_seq,
180
+ 1,
181
+ ]
182
+
183
+
184
+ with gr.Blocks() as interface:
185
+ with gr.Row():
186
+ with gr.Column(scale=2):
187
+ model_df = gr.Dataframe(
188
+ headers=["Available models"],
189
+ datatype=["str"],
190
+ interactive=False,
191
+ type="array",
192
+ value=list_models,
193
+ )
194
+ with gr.Column(scale=1):
195
+ with gr.Row():
196
+ model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
197
+ model_df.select(
198
+ on_select_model_df,
199
+ None,
200
+ model_name,
201
+ )
202
+
203
+ with gr.Row():
204
+ with gr.Column():
205
+ board_fen = gr.Textbox(
206
+ label="Board FEN",
207
+ lines=1,
208
+ max_lines=1,
209
+ value=chess.STARTING_FEN,
210
+ )
211
+ action_seq = gr.Textbox(
212
+ label="Action sequence",
213
+ lines=1,
214
+ value=("e2e3 b8c6 d2d4 e7e5 g1f3 d8e7 " "d4d5 e5e4 f3d4 c6e5 f2f4 e5g6"),
215
+ )
216
+ with gr.Group():
217
+ with gr.Row():
218
+ aggregate_topk = gr.Slider(
219
+ label="Aggregate top k",
220
+ minimum=1,
221
+ maximum=1858,
222
+ step=1,
223
+ value=1858,
224
+ scale=3,
225
+ )
226
+ view = gr.Radio(
227
+ label="View",
228
+ choices=["from", "to"],
229
+ value="from",
230
+ scale=1,
231
+ )
232
+ with gr.Row():
233
+ move_to_play = gr.Slider(
234
+ label="Move to play",
235
+ minimum=1,
236
+ maximum=50,
237
+ step=1,
238
+ value=1,
239
+ scale=3,
240
+ )
241
+ play_button = gr.Button("Play")
242
+
243
+ policy_button = gr.Button("Compute policy")
244
+ colorbar = gr.Plot(label="Colorbar")
245
+ game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="")
246
+ with gr.Column():
247
+ image = gr.Image(label="Board")
248
+ density_plot = gr.Plot(label="Density")
249
+
250
+ policy_inputs = [
251
+ board_fen,
252
+ action_seq,
253
+ view,
254
+ model_name,
255
+ aggregate_topk,
256
+ move_to_play,
257
+ ]
258
+ policy_outputs = [image, colorbar, game_info, density_plot]
259
+ policy_button.click(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
260
+ board_fen.submit(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
261
+ action_seq.submit(make_policy_plot, inputs=policy_inputs, outputs=policy_outputs)
262
+
263
+ fast_inputs = [
264
+ view,
265
+ aggregate_topk,
266
+ move_to_play,
267
+ ]
268
+ aggregate_topk.change(make_plot, inputs=fast_inputs, outputs=policy_outputs)
269
+ view.change(make_plot, inputs=fast_inputs, outputs=policy_outputs)
270
+ move_to_play.change(make_plot, inputs=fast_inputs, outputs=policy_outputs)
271
+
272
+ play_button.click(
273
+ play_move,
274
+ inputs=policy_inputs,
275
+ outputs=policy_outputs + [action_seq, move_to_play],
276
+ )
app/state.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Global state for the demo application.
3
+ """
4
+
5
+ from typing import Dict
6
+
7
+ from lczerolens import Lens, ModelWrapper
8
+
9
+ wrappers: Dict[str, ModelWrapper] = {}
10
+
11
+ lenses: Dict[str, Dict[str, Lens]] = {
12
+ "activation": {},
13
+ "lrp": {},
14
+ "crp": {},
15
+ "policy": {},
16
+ "probing": {},
17
+ "patching": {},
18
+ }
app/statistics_interface.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for visualizing the policy of a model.
3
+ """
4
+
5
+ import gradio as gr
6
+
7
+ from demo import utils, visualisation
8
+ from lczerolens import GameDataset
9
+ from lczerolens.xai import ConceptDataset, HasThreatConcept
10
+
11
+ current_policy_statistics = None
12
+ current_lrp_statistics = None
13
+ current_probing_statistics = None
14
+ dataset = GameDataset("assets/test_stockfish_10.jsonl")
15
+ check_concept = HasThreatConcept("K", relative=True)
16
+ unique_check_dataset = ConceptDataset.from_game_dataset(dataset)
17
+ unique_check_dataset.set_concept(check_concept)
18
+
19
+
20
+ def list_models():
21
+ """
22
+ List the models in the model directory.
23
+ """
24
+ models_info = utils.get_models_info(leela=False)
25
+ return sorted([[model_info[0]] for model_info in models_info])
26
+
27
+
28
+ def on_select_model_df(
29
+ evt: gr.SelectData,
30
+ ):
31
+ """
32
+ When a model is selected, update the statement.
33
+ """
34
+ return evt.value
35
+
36
+
37
+ def compute_policy_statistics(
38
+ model_name,
39
+ ):
40
+ global current_policy_statistics
41
+ global dataset
42
+
43
+ if model_name == "":
44
+ gr.Warning(
45
+ "Please select a model.",
46
+ )
47
+ return None
48
+ wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "policy")
49
+ current_policy_statistics = lens.analyse_dataset(dataset, wrapper, 10)
50
+ return make_policy_plot()
51
+
52
+
53
+ def make_policy_plot():
54
+ global current_policy_statistics
55
+
56
+ if current_policy_statistics is None:
57
+ gr.Warning(
58
+ "Please compute policy statistics first.",
59
+ )
60
+ return None
61
+ else:
62
+ return visualisation.render_policy_statistics(current_policy_statistics)
63
+
64
+
65
+ def compute_lrp_statistics(
66
+ model_name,
67
+ ):
68
+ global current_lrp_statistics
69
+ global dataset
70
+
71
+ if model_name == "":
72
+ gr.Warning(
73
+ "Please select a model.",
74
+ )
75
+ return None, None, None
76
+ wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "lrp")
77
+ current_lrp_statistics = lens.compute_statistics(dataset, wrapper, 10)
78
+ return make_lrp_plot()
79
+
80
+
81
+ def make_lrp_plot():
82
+ global current_lrp_statistics
83
+
84
+ if current_lrp_statistics is None:
85
+ gr.Warning(
86
+ "Please compute LRP statistics first.",
87
+ )
88
+ return None, None, None
89
+ else:
90
+ return visualisation.render_relevance_proportion(current_lrp_statistics)
91
+
92
+
93
+ def compute_probing_statistics(
94
+ model_name,
95
+ ):
96
+ global current_probing_statistics
97
+ global check_concept
98
+ global unique_check_dataset
99
+
100
+ if model_name == "":
101
+ gr.Warning(
102
+ "Please select a model.",
103
+ )
104
+ return None
105
+ wrapper, lens = utils.get_wrapper_lens_from_state(model_name, "probing", concept=check_concept)
106
+ current_probing_statistics = lens.compute_statistics(unique_check_dataset, wrapper, 10)
107
+ return make_probing_plot()
108
+
109
+
110
+ def make_probing_plot():
111
+ global current_probing_statistics
112
+
113
+ if current_probing_statistics is None:
114
+ gr.Warning(
115
+ "Please compute probing statistics first.",
116
+ )
117
+ return None
118
+ else:
119
+ return visualisation.render_probing_statistics(current_probing_statistics)
120
+
121
+
122
+ with gr.Blocks() as interface:
123
+ with gr.Row():
124
+ with gr.Column(scale=2):
125
+ model_df = gr.Dataframe(
126
+ headers=["Available models"],
127
+ datatype=["str"],
128
+ interactive=False,
129
+ type="array",
130
+ value=list_models,
131
+ )
132
+ with gr.Column(scale=1):
133
+ with gr.Row():
134
+ model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7)
135
+ model_df.select(
136
+ on_select_model_df,
137
+ None,
138
+ model_name,
139
+ )
140
+
141
+ with gr.Row():
142
+ with gr.Column():
143
+ policy_plot = gr.Plot(label="Policy statistics")
144
+ policy_compute_button = gr.Button(value="Compute policy statistics")
145
+ policy_plot_button = gr.Button(value="Plot policy statistics")
146
+
147
+ policy_compute_button.click(
148
+ compute_policy_statistics,
149
+ inputs=[model_name],
150
+ outputs=[policy_plot],
151
+ )
152
+ policy_plot_button.click(make_policy_plot, outputs=[policy_plot])
153
+
154
+ with gr.Column():
155
+ lrp_plot_hist = gr.Plot(label="LRP history statistics")
156
+
157
+ with gr.Row():
158
+ with gr.Column():
159
+ lrp_plot_planes = gr.Plot(label="LRP planes statistics")
160
+
161
+ with gr.Column():
162
+ lrp_plot_pieces = gr.Plot(label="LRP pieces statistics")
163
+
164
+ with gr.Row():
165
+ lrp_compute_button = gr.Button(value="Compute LRP statistics")
166
+ with gr.Row():
167
+ lrp_plot_button = gr.Button(value="Plot LRP statistics")
168
+
169
+ lrp_compute_button.click(
170
+ compute_lrp_statistics,
171
+ inputs=[model_name],
172
+ outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
173
+ )
174
+ lrp_plot_button.click(
175
+ make_lrp_plot,
176
+ outputs=[lrp_plot_hist, lrp_plot_planes, lrp_plot_pieces],
177
+ )
178
+
179
+ with gr.Column():
180
+ probing_plot = gr.Plot(label="Probing statistics")
181
+ probing_compute_button = gr.Button(value="Compute probing statistics")
182
+ probing_plot_button = gr.Button(value="Plot probing statistics")
183
+
184
+ probing_compute_button.click(
185
+ compute_probing_statistics,
186
+ inputs=[model_name],
187
+ outputs=[probing_plot],
188
+ )
189
+ probing_plot_button.click(make_probing_plot, outputs=[probing_plot])
app/utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils for the demo app.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import subprocess
8
+
9
+ from demo import constants, state
10
+ from lczerolens import LensFactory, LczeroModel
11
+ from lczerolens.model import lczero as lczero_utils
12
+
13
+
14
+ def get_models_info(onnx=True, leela=True):
15
+ """
16
+ Get the names of the models in the model directory.
17
+ """
18
+ model_df = []
19
+ exp = r"(?P<n_filters>\d+)x(?P<n_blocks>\d+)"
20
+ if onnx:
21
+ for filename in os.listdir(constants.MODEL_DIRECTORY):
22
+ if filename.endswith(".onnx"):
23
+ match = re.search(exp, filename)
24
+ if match is None:
25
+ n_filters = -1
26
+ n_blocks = -1
27
+ else:
28
+ n_filters = int(match.group("n_filters"))
29
+ n_blocks = int(match.group("n_blocks"))
30
+ model_df.append(
31
+ [
32
+ filename,
33
+ "ONNX",
34
+ n_blocks,
35
+ n_filters,
36
+ ]
37
+ )
38
+ if leela:
39
+ for filename in os.listdir(constants.LEELA_MODEL_DIRECTORY):
40
+ if filename.endswith(".pb.gz"):
41
+ match = re.search(exp, filename)
42
+ if match is None:
43
+ n_filters = -1
44
+ n_blocks = -1
45
+ else:
46
+ n_filters = int(match.group("n_filters"))
47
+ n_blocks = int(match.group("n_blocks"))
48
+ model_df.append(
49
+ [
50
+ filename,
51
+ "LEELA",
52
+ n_blocks,
53
+ n_filters,
54
+ ]
55
+ )
56
+ return model_df
57
+
58
+
59
+ def save_model(tmp_file_path):
60
+ """
61
+ Save the model to the model directory.
62
+ """
63
+ popen = subprocess.Popen(
64
+ ["file", tmp_file_path],
65
+ stdout=subprocess.PIPE,
66
+ stderr=subprocess.PIPE,
67
+ )
68
+ popen.wait()
69
+ if popen.returncode != 0:
70
+ raise RuntimeError
71
+ file_desc = popen.stdout.read().decode("utf-8").split(tmp_file_path)[1].strip()
72
+ rename_match = re.search(r"was\s\"(?P<name>.+)\"", file_desc)
73
+ type_match = re.search(r"\:\s(?P<type>[a-zA-Z]+)", file_desc)
74
+ if rename_match is None or type_match is None:
75
+ raise RuntimeError
76
+ model_name = rename_match.group("name")
77
+ model_type = type_match.group("type")
78
+ if model_type != "gzip":
79
+ raise RuntimeError
80
+ os.rename(
81
+ tmp_file_path,
82
+ f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz",
83
+ )
84
+ try:
85
+ lczero_utils.describenet(
86
+ f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz",
87
+ )
88
+ except RuntimeError:
89
+ os.remove(f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz")
90
+ raise RuntimeError
91
+
92
+
93
+ def get_wrapper_from_state(model_name):
94
+ """
95
+ Get the model wrapper from the state.
96
+ """
97
+ if model_name in state.wrappers:
98
+ return state.wrappers[model_name]
99
+ else:
100
+ wrapper = LczeroModel.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}")
101
+ state.wrappers[model_name] = wrapper
102
+ return wrapper
103
+
104
+
105
+ def get_wrapper_lens_from_state(model_name, lens_type, lens_name="lens", **kwargs):
106
+ """
107
+ Get the model wrapper and lens from the state.
108
+ """
109
+ if model_name in state.wrappers:
110
+ wrapper = state.wrappers[model_name]
111
+ else:
112
+ wrapper = LczeroModel.from_path(f"{constants.MODEL_DIRECTORY}/{model_name}")
113
+ state.wrappers[model_name] = wrapper
114
+ if lens_name in state.lenses[lens_type]:
115
+ lens = state.lenses[lens_type][lens_name]
116
+ else:
117
+ lens = LensFactory.from_name(lens_type, **kwargs)
118
+ if not lens.is_compatible(wrapper):
119
+ raise ValueError(f"Lens of type {lens_type} not compatible with model.")
120
+ state.lenses[lens_type][lens_name] = lens
121
+ return wrapper, lens
app/visualisation.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualisation utils.
3
+ """
4
+
5
+ import chess
6
+ import chess.svg
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import torch
11
+ import torchviz
12
+
13
+ from . import constants
14
+
15
+ COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000)
16
+ ALPHA = 1.0
17
+
18
+
19
+ def render_heatmap(
20
+ board,
21
+ heatmap,
22
+ square=None,
23
+ vmin=None,
24
+ vmax=None,
25
+ arrows=None,
26
+ normalise="none",
27
+ ):
28
+ """
29
+ Render a heatmap on the board.
30
+ """
31
+ if normalise == "abs":
32
+ a_max = heatmap.abs().max()
33
+ if a_max != 0:
34
+ heatmap = heatmap / a_max
35
+ vmin = -1
36
+ vmax = 1
37
+ if vmin is None:
38
+ vmin = heatmap.min()
39
+ if vmax is None:
40
+ vmax = heatmap.max()
41
+ norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
42
+
43
+ color_dict = {}
44
+ for square_index in range(64):
45
+ color = COLOR_MAP(norm(heatmap[square_index]))
46
+ color = (*color[:3], ALPHA)
47
+ color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True)
48
+ fig = plt.figure(figsize=(6, 0.6))
49
+ ax = plt.gca()
50
+ ax.axis("off")
51
+ fig.colorbar(
52
+ matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP),
53
+ ax=ax,
54
+ orientation="horizontal",
55
+ fraction=1.0,
56
+ )
57
+ if square is not None:
58
+ try:
59
+ check = chess.parse_square(square)
60
+ except ValueError:
61
+ check = None
62
+ else:
63
+ check = None
64
+ if arrows is None:
65
+ arrows = []
66
+ plt.close()
67
+ return (
68
+ chess.svg.board(
69
+ board,
70
+ check=check,
71
+ fill=color_dict,
72
+ size=350,
73
+ arrows=arrows,
74
+ ),
75
+ fig,
76
+ )
77
+
78
+
79
+ def render_architecture(model, name: str = "model", directory: str = ""):
80
+ """
81
+ Render the architecture of the model.
82
+ """
83
+ out = model(torch.zeros(1, 112, 8, 8))
84
+ if len(out) == 2:
85
+ policy, outcome_probs = out
86
+ value = torch.zeros(outcome_probs.shape[0], 1)
87
+ else:
88
+ policy, outcome_probs, value = out
89
+ torchviz.make_dot(policy, params=dict(list(model.named_parameters()))).render(
90
+ f"{directory}/{name}_policy", format="svg"
91
+ )
92
+ torchviz.make_dot(outcome_probs, params=dict(list(model.named_parameters()))).render(
93
+ f"{directory}/{name}_outcome_probs", format="svg"
94
+ )
95
+ torchviz.make_dot(value, params=dict(list(model.named_parameters()))).render(
96
+ f"{directory}/{name}_value", format="svg"
97
+ )
98
+
99
+
100
+ def render_policy_distribution(
101
+ policy,
102
+ legal_moves,
103
+ n_bins=20,
104
+ ):
105
+ """
106
+ Render the policy distribution histogram.
107
+ """
108
+ legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool()
109
+ fig = plt.figure(figsize=(6, 6))
110
+ ax = plt.gca()
111
+ _, bins = np.histogram(policy, bins=n_bins)
112
+ ax.hist(
113
+ policy[~legal_mask],
114
+ bins=bins,
115
+ alpha=0.5,
116
+ density=True,
117
+ label="Illegal moves",
118
+ )
119
+ ax.hist(
120
+ policy[legal_mask],
121
+ bins=bins,
122
+ alpha=0.5,
123
+ density=True,
124
+ label="Legal moves",
125
+ )
126
+ plt.xlabel("Policy")
127
+ plt.ylabel("Density")
128
+ plt.legend()
129
+ plt.yscale("log")
130
+ return fig
131
+
132
+
133
+ def render_policy_statistics(
134
+ statistics,
135
+ ):
136
+ """
137
+ Render the policy statistics.
138
+ """
139
+ fig = plt.figure(figsize=(6, 6))
140
+ ax = plt.gca()
141
+ move_indices = list(statistics["mean_legal_logits"].keys())
142
+ legal_means_avg = [np.mean(statistics["mean_legal_logits"][move_idx]) for move_idx in move_indices]
143
+ illegal_means_avg = [np.mean(statistics["mean_illegal_logits"][move_idx]) for move_idx in move_indices]
144
+ legal_means_std = [np.std(statistics["mean_legal_logits"][move_idx]) for move_idx in move_indices]
145
+ illegal_means_std = [np.std(statistics["mean_illegal_logits"][move_idx]) for move_idx in move_indices]
146
+ ax.errorbar(
147
+ move_indices,
148
+ legal_means_avg,
149
+ yerr=legal_means_std,
150
+ label="Legal moves",
151
+ )
152
+ ax.errorbar(
153
+ move_indices,
154
+ illegal_means_avg,
155
+ yerr=illegal_means_std,
156
+ label="Illegal moves",
157
+ )
158
+ plt.xlabel("Move index")
159
+ plt.ylabel("Mean policy logits")
160
+ plt.legend()
161
+ return fig
162
+
163
+
164
+ def render_relevance_proportion(statistics, scaled=True):
165
+ """
166
+ Render the relevance proportion statistics.
167
+ """
168
+ norm = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False)
169
+ fig_hist = plt.figure(figsize=(6, 6))
170
+ ax = plt.gca()
171
+ move_indices = list(statistics["planes_relevance_proportion"].keys())
172
+ for h in range(8):
173
+ relevance_proportion_avg = [
174
+ np.mean([rel[13 * h : 13 * (h + 1)].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
175
+ for move_idx in move_indices
176
+ ]
177
+ relevance_proportion_std = [
178
+ np.std([rel[13 * h : 13 * (h + 1)].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
179
+ for move_idx in move_indices
180
+ ]
181
+ ax.errorbar(
182
+ move_indices[h + 1 :],
183
+ relevance_proportion_avg[h + 1 :],
184
+ yerr=relevance_proportion_std[h + 1 :],
185
+ label=f"History {h}",
186
+ c=COLOR_MAP(norm(h / 9)),
187
+ )
188
+
189
+ relevance_proportion_avg = [
190
+ np.mean([rel[104:108].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
191
+ for move_idx in move_indices
192
+ ]
193
+ relevance_proportion_std = [
194
+ np.std([rel[104:108].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
195
+ for move_idx in move_indices
196
+ ]
197
+ ax.errorbar(
198
+ move_indices,
199
+ relevance_proportion_avg,
200
+ yerr=relevance_proportion_std,
201
+ label="Castling rights",
202
+ c=COLOR_MAP(norm(8 / 9)),
203
+ )
204
+ relevance_proportion_avg = [
205
+ np.mean([rel[108:].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
206
+ for move_idx in move_indices
207
+ ]
208
+ relevance_proportion_std = [
209
+ np.std([rel[108:].sum() for rel in statistics["planes_relevance_proportion"][move_idx]])
210
+ for move_idx in move_indices
211
+ ]
212
+ ax.errorbar(
213
+ move_indices,
214
+ relevance_proportion_avg,
215
+ yerr=relevance_proportion_std,
216
+ label="Remaining planes",
217
+ c=COLOR_MAP(norm(9 / 9)),
218
+ )
219
+ plt.xlabel("Move index")
220
+ plt.ylabel("Absolute relevance proportion")
221
+ plt.yscale("log")
222
+ plt.legend()
223
+
224
+ if scaled:
225
+ stat_key = "planes_relevance_proportion_scaled"
226
+ else:
227
+ stat_key = "planes_relevance_proportion"
228
+ fig_planes = plt.figure(figsize=(6, 6))
229
+ ax = plt.gca()
230
+ move_indices = list(statistics[stat_key].keys())
231
+ for p in range(13):
232
+ relevance_proportion_avg = [
233
+ np.mean([rel[p].item() for rel in statistics[stat_key][move_idx]]) for move_idx in move_indices
234
+ ]
235
+ relevance_proportion_std = [
236
+ np.std([rel[p].item() for rel in statistics[stat_key][move_idx]]) for move_idx in move_indices
237
+ ]
238
+ ax.errorbar(
239
+ move_indices,
240
+ relevance_proportion_avg,
241
+ yerr=relevance_proportion_std,
242
+ label=constants.PLANE_NAMES[p],
243
+ c=COLOR_MAP(norm(p / 12)),
244
+ )
245
+
246
+ plt.xlabel("Move index")
247
+ plt.ylabel("Absolute relevance proportion")
248
+ plt.yscale("log")
249
+ plt.legend()
250
+
251
+ fig_pieces = plt.figure(figsize=(6, 6))
252
+ ax = plt.gca()
253
+ for p in range(1, 13):
254
+ stat_key = f"configuration_relevance_proportion_threatened_piece{p}"
255
+ n_attackers = list(statistics[stat_key].keys())
256
+ relevance_proportion_avg = [
257
+ np.mean(statistics[f"configuration_relevance_proportion_threatened_piece{p}"][n]) for n in n_attackers
258
+ ]
259
+ relevance_proportion_std = [np.std(statistics[stat_key][n]) for n in n_attackers]
260
+ ax.errorbar(
261
+ n_attackers,
262
+ relevance_proportion_avg,
263
+ yerr=relevance_proportion_std,
264
+ label="PNBRQKpnbrqk"[p - 1],
265
+ c=COLOR_MAP(norm(p / 12)),
266
+ )
267
+
268
+ plt.xlabel("Number of attackers")
269
+ plt.ylabel("Absolute configuration relevance proportion")
270
+ plt.yscale("log")
271
+ plt.legend()
272
+
273
+ return fig_hist, fig_planes, fig_pieces
274
+
275
+
276
+ def render_probing_statistics(
277
+ statistics,
278
+ ):
279
+ """
280
+ Render the probing statistics.
281
+ """
282
+ fig = plt.figure(figsize=(6, 6))
283
+ ax = plt.gca()
284
+ n_blocks = len(statistics["metrics"])
285
+ for metric in statistics["metrics"]["block0"]:
286
+ avg = []
287
+ std = []
288
+ for block_idx in range(n_blocks):
289
+ metrics = statistics["metrics"]
290
+ block_data = metrics[f"block{block_idx}"]
291
+ avg.append(np.mean(block_data[metric]))
292
+ std.append(np.std(block_data[metric]))
293
+ ax.errorbar(
294
+ range(n_blocks),
295
+ avg,
296
+ yerr=std,
297
+ label=metric,
298
+ )
299
+ plt.xlabel("Block index")
300
+ plt.ylabel("Metric")
301
+ plt.yscale("log")
302
+ plt.legend()
303
+ return fig