Xmaster6y commited on
Commit
7277ab2
·
unverified ·
0 Parent(s):

reset space

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pipenv
85
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
86
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
87
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
88
+ # install all needed dependencies.
89
+ #Pipfile.lock
90
+
91
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
92
+ __pypackages__/
93
+
94
+ # Celery stuff
95
+ celerybeat-schedule
96
+ celerybeat.pid
97
+
98
+ # SageMath parsed files
99
+ *.sage.py
100
+
101
+ # Environments
102
+ .env
103
+ .venv
104
+ env/
105
+ venv/
106
+ ENV/
107
+ env.bak/
108
+ venv.bak/
109
+
110
+ # Spyder project settings
111
+ .spyderproject
112
+ .spyproject
113
+
114
+ # Rope project settings
115
+ .ropeproject
116
+
117
+ # mkdocs documentation
118
+ /site
119
+
120
+ # mypy
121
+ .mypy_cache/
122
+ .dmypy.json
123
+ dmypy.json
124
+
125
+ # Pyre type checker
126
+ .pyre/
127
+
128
+ # Pickle files
129
+ *.pkl
130
+
131
+ # Various files
132
+ ignored
133
+ debug
134
+ *.zip
135
+ lc0
136
+ !bin/lc0
137
+ wandb
138
+ **/.DS_Store
139
+
140
+ *secret*
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Lczerolens Demo
3
+ emoji: 🔬
4
+ colorFrom: gray
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.28.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: mit
11
+ short_description: Demo lczerolens features
12
+ ---
13
+
14
+ See the documentation [here](https://lczerolens.readthedocs.io/).
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio demo for lczerolens.
3
+ """
4
+
5
+ import gradio as gr
6
+
7
+ from demo.interfaces import (
8
+ board,
9
+ encodings,
10
+ gradients,
11
+ play,
12
+ activations,
13
+ )
14
+
15
+ demo = gr.TabbedInterface(
16
+ [
17
+ board.interface,
18
+ play.interface,
19
+ encodings.interface,
20
+ activations.interface,
21
+ gradients.interface,
22
+ ],
23
+ [
24
+ "Board",
25
+ "Play",
26
+ "Encodings",
27
+ "Activations",
28
+ "Gradients",
29
+ ],
30
+ title="lczerolens Demo",
31
+ analytics_enabled=False,
32
+ )
33
+
34
+
35
+ if __name__ == "__main__":
36
+ demo.launch()
demo/__init__.py ADDED
File without changes
demo/constants.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants for the demo.
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+
8
+ DEMO_DIRECTORY = Path(__file__).parent
9
+
10
+ ONNX_MODEL_DIRECTORY = DEMO_DIRECTORY / "onnx-models"
11
+ LEELA_MODEL_DIRECTORY = DEMO_DIRECTORY / "leela-models"
12
+ FIGURE_DIRECTORY = DEMO_DIRECTORY / "figures"
13
+
14
+ ONNX_MODEL_NAMES = [
15
+ f for f in os.listdir(ONNX_MODEL_DIRECTORY)
16
+ if f.endswith(".onnx")
17
+ ]
18
+ LEELA_MODEL_NAMES = [
19
+ f for f in os.listdir(LEELA_MODEL_DIRECTORY)
20
+ if f.endswith(".pb.gz")
21
+ ]
demo/figures/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
demo/interfaces/__init__.py ADDED
File without changes
demo/interfaces/activations.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting attention.
3
+ """
4
+
5
+ import chess
6
+ import chess.pgn
7
+ import io
8
+ import gradio as gr
9
+ import os
10
+ import torch
11
+
12
+ from lczerolens import LczeroBoard, LczeroModel, Lens
13
+
14
+ from demo import constants
15
+ from demo.utils import get_info
16
+
17
+ def get_model(model_name: str):
18
+ return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name))
19
+
20
+ def get_activations(model: LczeroModel, board: LczeroBoard):
21
+ lens = Lens.from_name("activation", "block\d/conv2/relu")
22
+ with torch.no_grad():
23
+ results = lens.analyse(model, board)
24
+ return [results[f"block{i}/conv2/relu_output"][0] for i in range(len(results))]
25
+
26
+ def get_board(game_pgn:str, board_fen:str):
27
+ if game_pgn:
28
+ try:
29
+ board = LczeroBoard()
30
+ pgn = io.StringIO(game_pgn)
31
+ game = chess.pgn.read_game(pgn)
32
+ for move in game.mainline_moves():
33
+ board.push(move)
34
+ except Exception as e:
35
+ print(e)
36
+ gr.Warning("Error parsing PGN, using starting position.")
37
+ board = LczeroBoard()
38
+ else:
39
+ try:
40
+ board = LczeroBoard(board_fen)
41
+ except Exception as e:
42
+ print(e)
43
+ gr.Warning("Invalid FEN, using starting position.")
44
+ board = LczeroBoard()
45
+ return board
46
+
47
+ def render_activations(board: LczeroBoard, activations, layer_index:int, channel_index:int):
48
+ if layer_index >= len(activations):
49
+ safe_layer_index = len(activations) - 1
50
+ gr.Warning(f"Layer index {layer_index} out of range, using last layer ({safe_layer_index}).")
51
+ else:
52
+ safe_layer_index = layer_index
53
+ if channel_index >= activations[safe_layer_index].shape[0]:
54
+ safe_channel_index = activations[safe_layer_index].shape[0] - 1
55
+ gr.Warning(f"Channel index {channel_index} out of range, using last channel ({safe_channel_index}).")
56
+ else:
57
+ safe_channel_index = channel_index
58
+ heatmap = activations[safe_layer_index][safe_channel_index].view(64)
59
+ board.render_heatmap(
60
+ heatmap,
61
+ save_to=f"{constants.FIGURE_DIRECTORY}/activations.svg",
62
+ )
63
+ return f"{constants.FIGURE_DIRECTORY}/activations_board.svg", f"{constants.FIGURE_DIRECTORY}/activations_colorbar.svg"
64
+
65
+ def initial_load(model_name: str, board_fen: str, game_pgn: str, layer_index: int, channel_index: int):
66
+ model = get_model(model_name)
67
+ board = get_board(game_pgn, board_fen)
68
+ activations = get_activations(model, board)
69
+ info = get_info(model, board)
70
+ plots = render_activations(board, activations, layer_index, channel_index)
71
+ return model, board, activations, info, *plots
72
+
73
+ def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, layer_index: int, channel_index: int):
74
+ board = get_board(game_pgn, board_fen)
75
+ activations = get_activations(model, board)
76
+ info = get_info(model, board)
77
+ plots = render_activations(board, activations, layer_index, channel_index)
78
+ return board, activations, info, *plots
79
+
80
+ def on_model_change(model_name: str, board: LczeroBoard, layer_index: int, channel_index: int):
81
+ model = get_model(model_name)
82
+ activations = get_activations(model, board)
83
+ info = get_info(model, board)
84
+ plots = render_activations(board, activations, layer_index, channel_index)
85
+ return model, activations, info, *plots
86
+
87
+ with gr.Blocks() as interface:
88
+ with gr.Row():
89
+ with gr.Column():
90
+ with gr.Group():
91
+ gr.Markdown(
92
+ "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
93
+ )
94
+ game_pgn = gr.Textbox(
95
+ label="Game PGN",
96
+ lines=1,
97
+ value="",
98
+ )
99
+ board_fen = gr.Textbox(
100
+ label="Board FEN",
101
+ lines=1,
102
+ max_lines=1,
103
+ value=chess.STARTING_FEN,
104
+ )
105
+ model_name = gr.Dropdown(
106
+ label="Model",
107
+ choices=constants.ONNX_MODEL_NAMES,
108
+ )
109
+ with gr.Group():
110
+ info = gr.Textbox(label="Info", lines=1, value="")
111
+ with gr.Group():
112
+ layer_index = gr.Slider(
113
+ label="Layer index",
114
+ minimum=0,
115
+ maximum=19,
116
+ step=1,
117
+ value=0,
118
+ )
119
+ channel_index = gr.Slider(
120
+ label="Channel index",
121
+ minimum=0,
122
+ maximum=200,
123
+ step=1,
124
+ value=0,
125
+ )
126
+ with gr.Column():
127
+ image_board = gr.Image(label="Board", interactive=False)
128
+ colorbar = gr.Image(label="Colorbar", interactive=False)
129
+
130
+ model = gr.State(value=None)
131
+ board = gr.State(value=None)
132
+ activations = gr.State(value=None)
133
+
134
+ interface.load(
135
+ initial_load,
136
+ inputs=[model_name, game_pgn, board_fen, layer_index, channel_index],
137
+ outputs=[model, board, activations, info, image_board, colorbar],
138
+ concurrency_limit=1,
139
+ concurrency_id="trace_queue"
140
+ )
141
+ game_pgn.submit(
142
+ on_board_change,
143
+ inputs=[model, game_pgn, board_fen, layer_index, channel_index],
144
+ outputs=[board, activations, info, image_board, colorbar],
145
+ concurrency_id="trace_queue"
146
+ )
147
+ board_fen.submit(
148
+ on_board_change,
149
+ inputs=[model, game_pgn, board_fen, layer_index, channel_index],
150
+ outputs=[board, activations, info, image_board, colorbar],
151
+ concurrency_id="trace_queue"
152
+ )
153
+ model_name.change(
154
+ on_model_change,
155
+ inputs=[model_name, board, layer_index, channel_index],
156
+ outputs=[model, activations, info, image_board, colorbar],
157
+ concurrency_id="trace_queue"
158
+ )
159
+ layer_index.change(
160
+ render_activations,
161
+ inputs=[board, activations, layer_index, channel_index],
162
+ outputs=[image_board, colorbar],
163
+ )
164
+ channel_index.change(
165
+ render_activations,
166
+ inputs=[board, activations, layer_index, channel_index],
167
+ outputs=[image_board, colorbar],
168
+ )
demo/interfaces/board.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting a board.
3
+ """
4
+
5
+ import chess
6
+ import chess.svg
7
+ import gradio as gr
8
+
9
+ from lczerolens.board import LczeroBoard
10
+
11
+ from ..utils import create_board_figure
12
+
13
+ def make_board_plot(board_fen, arrows, square):
14
+ try:
15
+ board = LczeroBoard(board_fen)
16
+ except ValueError:
17
+ board = LczeroBoard()
18
+ gr.Warning("Invalid FEN, using starting position.")
19
+ filepath = create_board_figure(board, arrows=arrows, square=square, name="board")
20
+ return filepath
21
+
22
+
23
+ with gr.Blocks() as interface:
24
+ with gr.Row():
25
+ with gr.Column():
26
+ board_fen = gr.Textbox(
27
+ label="Board starting FEN",
28
+ lines=1,
29
+ max_lines=1,
30
+ value=chess.STARTING_FEN,
31
+ )
32
+ arrows = gr.Textbox(
33
+ label="Arrows",
34
+ lines=1,
35
+ max_lines=1,
36
+ value="",
37
+ placeholder="e2e4 e7e5",
38
+ )
39
+ square = gr.Textbox(
40
+ label="Square",
41
+ lines=1,
42
+ max_lines=1,
43
+ value="",
44
+ placeholder="e4",
45
+ )
46
+ with gr.Column():
47
+ image = gr.Image(label="Board", interactive=False)
48
+
49
+ inputs = [
50
+ board_fen,
51
+ arrows,
52
+ square,
53
+ ]
54
+ interface.load(make_board_plot, inputs=inputs, outputs=image)
55
+ board_fen.submit(make_board_plot, inputs=inputs, outputs=image)
56
+ arrows.submit(make_board_plot, inputs=inputs, outputs=image)
57
+ square.submit(make_board_plot, inputs=inputs, outputs=image)
demo/interfaces/encodings.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting attention.
3
+ """
4
+
5
+ import chess
6
+ import chess.pgn
7
+ import io
8
+ import gradio as gr
9
+
10
+ from lczerolens.board import LczeroBoard
11
+
12
+ from ..constants import FIGURE_DIRECTORY
13
+
14
+ def make_render(game_pgn:str, board_fen:str, plane_index:int):
15
+ if game_pgn:
16
+ try:
17
+ board = LczeroBoard()
18
+ pgn = io.StringIO(game_pgn)
19
+ game = chess.pgn.read_game(pgn)
20
+ for move in game.mainline_moves():
21
+ board.push(move)
22
+ except Exception as e:
23
+ print(e)
24
+ gr.Warning("Error parsing PGN, using starting position.")
25
+ board = LczeroBoard()
26
+ else:
27
+ try:
28
+ board = LczeroBoard(board_fen)
29
+ except Exception as e:
30
+ print(e)
31
+ gr.Warning("Invalid FEN, using starting position.")
32
+ board = LczeroBoard()
33
+ return board, *make_board_plot(board, plane_index)
34
+
35
+ def make_board_plot(board:LczeroBoard, plane_index:int):
36
+ input_tensor = board.to_input_tensor()
37
+ board.render_heatmap(
38
+ input_tensor[plane_index].view(64),
39
+ save_to=f"{FIGURE_DIRECTORY}/encodings.svg",
40
+ vmin=0,
41
+ vmax=1,
42
+ )
43
+ return f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg"
44
+
45
+ with gr.Blocks() as interface:
46
+ with gr.Row():
47
+ with gr.Column():
48
+ with gr.Group():
49
+ gr.Markdown(
50
+ "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
51
+ )
52
+ game_pgn = gr.Textbox(
53
+ label="Game PGN",
54
+ lines=1,
55
+ value="",
56
+ )
57
+ board_fen = gr.Textbox(
58
+ label="Board FEN",
59
+ lines=1,
60
+ max_lines=1,
61
+ value=chess.STARTING_FEN,
62
+ )
63
+ with gr.Group():
64
+ with gr.Row():
65
+ plane_index = gr.Slider(
66
+ label="Plane index",
67
+ minimum=0,
68
+ maximum=111,
69
+ step=1,
70
+ value=0,
71
+ )
72
+ with gr.Column():
73
+ image_board = gr.Image(label="Board", interactive=False)
74
+ colorbar = gr.Image(label="Colorbar", interactive=False)
75
+
76
+ state_board = gr.State(value=LczeroBoard())
77
+
78
+ render_inputs = [game_pgn, board_fen, plane_index]
79
+ render_outputs = [state_board, image_board, colorbar]
80
+ interface.load(
81
+ make_render,
82
+ inputs=render_inputs,
83
+ outputs=render_outputs,
84
+ )
85
+ game_pgn.submit(
86
+ make_render,
87
+ inputs=render_inputs,
88
+ outputs=render_outputs,
89
+ )
90
+ board_fen.submit(
91
+ make_render,
92
+ inputs=render_inputs,
93
+ outputs=render_outputs,
94
+ )
95
+ plane_index.change(
96
+ make_board_plot,
97
+ inputs=[state_board, plane_index],
98
+ outputs=[image_board, colorbar],
99
+ )
demo/interfaces/gradients.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for plotting attention.
3
+ """
4
+
5
+ import chess
6
+ import chess.pgn
7
+ import io
8
+ import gradio as gr
9
+ import os
10
+
11
+ from lczerolens import LczeroBoard, LczeroModel, Lens
12
+
13
+ from demo import constants
14
+ from demo.utils import get_info
15
+
16
+ def get_model(model_name: str):
17
+ return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name))
18
+
19
+ def get_gradients(model: LczeroModel, board: LczeroBoard, target: str):
20
+ lens = Lens.from_name("gradient")
21
+
22
+ def init_target(model):
23
+ if target == "best_move":
24
+ return getattr(model, "output/policy").output.max(dim=1).values
25
+ else:
26
+ wdl_index = {"win": 0, "draw": 1, "loss": 2}[target]
27
+ return getattr(model, "output/wdl").output[:, wdl_index]
28
+ results = lens.analyse(model, board, init_target=init_target)
29
+
30
+ return results["input_grad"]
31
+
32
+ def get_board(game_pgn:str, board_fen:str):
33
+ if game_pgn:
34
+ try:
35
+ board = LczeroBoard()
36
+ pgn = io.StringIO(game_pgn)
37
+ game = chess.pgn.read_game(pgn)
38
+ for move in game.mainline_moves():
39
+ board.push(move)
40
+ except Exception as e:
41
+ print(e)
42
+ gr.Warning("Error parsing PGN, using starting position.")
43
+ board = LczeroBoard()
44
+ else:
45
+ try:
46
+ board = LczeroBoard(board_fen)
47
+ except Exception as e:
48
+ print(e)
49
+ gr.Warning("Invalid FEN, using starting position.")
50
+ board = LczeroBoard()
51
+ return board
52
+
53
+ def render_gradients(board: LczeroBoard, gradients, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index:int):
54
+ if average_over_planes:
55
+ heatmap = gradients[0, begin_average_index:end_average_index].mean(dim=0).view(64)
56
+ else:
57
+ heatmap = gradients[0, plane_index].view(64)
58
+ board.render_heatmap(
59
+ heatmap,
60
+ save_to=f"{constants.FIGURE_DIRECTORY}/gradients.svg",
61
+ )
62
+ return f"{constants.FIGURE_DIRECTORY}/gradients_board.svg", f"{constants.FIGURE_DIRECTORY}/gradients_colorbar.svg"
63
+
64
+ def initial_load(model_name: str, board_fen: str, game_pgn: str, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
65
+ model = get_model(model_name)
66
+ board = get_board(game_pgn, board_fen)
67
+ gradients = get_gradients(model, board, target)
68
+ info = get_info(model, board)
69
+ plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
70
+ return model, board, gradients, info, *plots
71
+
72
+ def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
73
+ board = get_board(game_pgn, board_fen)
74
+ gradients = get_gradients(model, board, target)
75
+ info = get_info(model, board)
76
+ plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
77
+ return board, gradients, info, *plots
78
+
79
+ def on_model_change(model_name: str, board: LczeroBoard, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
80
+ model = get_model(model_name)
81
+ gradients = get_gradients(model, board, target)
82
+ info = get_info(model, board)
83
+ plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
84
+ return model, gradients, info, *plots
85
+
86
+ def on_target_change(model: LczeroModel, board: LczeroBoard, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
87
+ gradients = get_gradients(model, board, target)
88
+ plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
89
+ return gradients, *plots
90
+
91
+ with gr.Blocks() as interface:
92
+ with gr.Row():
93
+ with gr.Column():
94
+ with gr.Group():
95
+ gr.Markdown(
96
+ "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
97
+ )
98
+ game_pgn = gr.Textbox(
99
+ label="Game PGN",
100
+ lines=1,
101
+ value="",
102
+ )
103
+ board_fen = gr.Textbox(
104
+ label="Board FEN",
105
+ lines=1,
106
+ max_lines=1,
107
+ value=chess.STARTING_FEN,
108
+ )
109
+ model_name = gr.Dropdown(
110
+ label="Model",
111
+ choices=constants.ONNX_MODEL_NAMES,
112
+ )
113
+ with gr.Group():
114
+ info = gr.Textbox(label="Info", lines=1, value="")
115
+ with gr.Group():
116
+ target = gr.Radio(
117
+ ["win", "draw", "loss", "best_move"], label="Target",
118
+ value="win",
119
+ )
120
+ average_over_planes = gr.Checkbox(label="Average over Planes", value=False)
121
+ with gr.Accordion("Average over planes", open=False):
122
+ begin_average_index = gr.Slider(
123
+ label="Begin average index",
124
+ minimum=0,
125
+ maximum=111,
126
+ step=1,
127
+ value=0,
128
+ )
129
+ end_average_index = gr.Slider(
130
+ label="End average index",
131
+ minimum=0,
132
+ maximum=111,
133
+ step=1,
134
+ value=111,
135
+ )
136
+ plane_index = gr.Slider(
137
+ label="Plane index",
138
+ minimum=0,
139
+ maximum=111,
140
+ step=1,
141
+ value=0,
142
+ )
143
+ with gr.Column():
144
+ image_board = gr.Image(label="Board", interactive=False)
145
+ colorbar = gr.Image(label="Colorbar", interactive=False)
146
+
147
+ model = gr.State(value=None)
148
+ board = gr.State(value=None)
149
+ gradients = gr.State(value=None)
150
+
151
+ interface.load(
152
+ initial_load,
153
+ inputs=[model_name, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index],
154
+ outputs=[model, board, gradients, info, image_board, colorbar],
155
+ concurrency_id="trace_queue"
156
+ )
157
+ game_pgn.submit(
158
+ on_board_change,
159
+ inputs=[model, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index],
160
+ outputs=[board, gradients, info, image_board, colorbar],
161
+ concurrency_id="trace_queue"
162
+ )
163
+ board_fen.submit(
164
+ on_board_change,
165
+ inputs=[model, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index],
166
+ outputs=[board, gradients, info, image_board, colorbar],
167
+ concurrency_id="trace_queue"
168
+ )
169
+ model_name.change(
170
+ on_model_change,
171
+ inputs=[model_name, board, target, average_over_planes, begin_average_index, end_average_index, plane_index],
172
+ outputs=[model, gradients, info, image_board, colorbar],
173
+ concurrency_id="trace_queue"
174
+ )
175
+ target.change(
176
+ on_target_change,
177
+ inputs=[model, board, target, average_over_planes, begin_average_index, end_average_index, plane_index],
178
+ outputs=[gradients, image_board, colorbar],
179
+ concurrency_id="trace_queue"
180
+ )
181
+ for render_arg in [average_over_planes, begin_average_index, end_average_index, plane_index]:
182
+ render_arg.change(
183
+ render_gradients,
184
+ inputs=[board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index],
185
+ outputs=[image_board, colorbar],
186
+ )
demo/interfaces/play.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interface to play against the model.
2
+ """
3
+
4
+ import os
5
+
6
+ import chess
7
+ import chess.pgn
8
+ import random
9
+ import gradio as gr
10
+
11
+ from lczerolens import LczeroBoard, LczeroModel
12
+ from lczerolens.play import PolicySampler
13
+
14
+ from demo import constants
15
+ from demo.utils import create_board_figure
16
+
17
+
18
+ def get_sampler(model_name: str):
19
+ model = LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name))
20
+ return PolicySampler(model)
21
+
22
+ def get_pgn(board: LczeroBoard):
23
+ game = chess.pgn.Game()
24
+ for move in board.move_stack:
25
+ game.add_variation(move)
26
+ return str(game).split("\n")[-1]
27
+
28
+ def render_board(
29
+ board: LczeroBoard,
30
+ ):
31
+ player = board.turn
32
+ if len(board.move_stack) > 0:
33
+ last_move_uci = board.peek().uci()
34
+ else:
35
+ last_move_uci = None
36
+
37
+ if board.is_check():
38
+ check = board.king(board.turn)
39
+ else:
40
+ check = None
41
+ filepath = create_board_figure(
42
+ board,
43
+ orientation=player,
44
+ arrows=last_move_uci,
45
+ square=check,
46
+ name="play_board",
47
+ )
48
+ return filepath
49
+
50
+ def gather_outputs(board: LczeroBoard, sampler: PolicySampler):
51
+ return sampler, board, board.fen(), get_pgn(board), render_board(board), ""
52
+
53
+ def get_init(model_name: str):
54
+ sampler = get_sampler(model_name)
55
+ is_ai_white = random.choice([True, False])
56
+ init_board = LczeroBoard()
57
+ if is_ai_white:
58
+ play_ai_move(init_board, sampler)
59
+ return gather_outputs(init_board, sampler)
60
+
61
+ def play_user_move_then_ai_move(
62
+ uci_move: str,
63
+ board: LczeroBoard,
64
+ sampler: PolicySampler,
65
+ ):
66
+ board.push_uci(uci_move)
67
+ play_ai_move(board, sampler)
68
+ return gather_outputs(board, sampler)
69
+
70
+
71
+ def play_ai_move(
72
+ board: LczeroBoard,
73
+ sampler: PolicySampler,
74
+ ):
75
+ move, _ = next(iter(sampler.get_next_moves([board])))
76
+ board.push(move)
77
+
78
+ with gr.Blocks() as interface:
79
+ with gr.Row():
80
+ with gr.Column():
81
+ current_fen = gr.Textbox(
82
+ label="Board FEN",
83
+ lines=1,
84
+ max_lines=1,
85
+ value=chess.STARTING_FEN,
86
+ )
87
+ current_pgn = gr.Textbox(
88
+ label="Action sequence",
89
+ lines=1,
90
+ value="",
91
+ )
92
+ model_name = gr.Dropdown(
93
+ label="Model",
94
+ choices=constants.ONNX_MODEL_NAMES,
95
+ )
96
+ with gr.Column():
97
+ move_to_play = gr.Textbox(
98
+ label="Move to play (UCI)",
99
+ lines=1,
100
+ max_lines=1,
101
+ value="",
102
+ )
103
+ play_button = gr.Button("Play")
104
+ reset_button = gr.Button("Reset")
105
+ with gr.Column():
106
+ image_board = gr.Image(label="Board", interactive=False)
107
+
108
+ sampler = gr.State(value=None)
109
+ board = gr.State(value=None)
110
+
111
+ outputs = [sampler, board, current_fen, current_pgn, image_board, move_to_play]
112
+
113
+ play_button.click(
114
+ play_user_move_then_ai_move,
115
+ inputs=[move_to_play, board, sampler],
116
+ outputs=outputs,
117
+ )
118
+ move_to_play.submit(
119
+ play_user_move_then_ai_move,
120
+ inputs=[move_to_play, board, sampler],
121
+ outputs=outputs,
122
+ )
123
+
124
+ model_name.change(
125
+ get_sampler,
126
+ inputs=[model_name],
127
+ outputs=[sampler],
128
+ )
129
+
130
+ reset_button.click(
131
+ get_init,
132
+ inputs=[model_name],
133
+ outputs=outputs,
134
+ )
135
+ interface.load(
136
+ get_init,
137
+ inputs=[model_name],
138
+ outputs=outputs,
139
+ )
demo/leela-models/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
demo/onnx-models/lc0-10-4238.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:096bc5b833572de5357961f8cb2d059fe4f0ff303598131f20278fb23bc58919
3
+ size 15152179
demo/onnx-models/lc0-19-1876.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df7e2b87f0f4525d7de7201ee2e946fc10172200ef97c6786b3fdcfb8610b88a
3
+ size 97139290
demo/onnx-models/lc0-19-4508.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcf7c13451f00fc9d1859f98a52a37056aa3d10ac7f874408ca46d0cde50c241
3
+ size 97139290
demo/onnx-models/look-ahead-lc0.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a3dad19ed0546106e3b97d7aa7dbf0e3cc7041a6c63979f923360ee71cb4a24
3
+ size 378669573
demo/onnx-models/maia-1100.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1a4ac9b99aee3c127285b7cdc2024d054a6de72abdfa2fe3db5e4f7078c96ad
3
+ size 3484716
demo/onnx-models/maia-1200.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82c21e9cf65f0de9a72469eaae8845c3e92a32fe93e1e626a064cd3d24cca4cf
3
+ size 3483901
demo/onnx-models/maia-1300.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7c3b8d9da7a2cc586f409b0264abcb9d00390276ffe5f39a09d591787555e50
3
+ size 3483901
demo/onnx-models/maia-1400.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3651ae981d8bc05cf735c79d1a51298d91d3450592ac874d9c2c66e6c90b896b
3
+ size 3483901
demo/onnx-models/maia-1500.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94333b25bb8d25685334645afbb7494b358ae5724ad651d735ae1683fb0a9e3f
3
+ size 3483901
demo/onnx-models/maia-1600.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d46d32c6c1ff2a9f04df5c0652e50ae502b8870c6e0895e6abeeb02bba01d423
3
+ size 3483901
demo/onnx-models/maia-1700.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b01ef1c6bf0136480526f2933f88623bc5b1e2bbf319fc6045be68524f942d8
3
+ size 3483901
demo/onnx-models/maia-1800.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b9b999de7d0fe4b2efc9d6591b30f4010ec7490c36676dd920ee0420f202a67
3
+ size 3483901
demo/onnx-models/maia-1900.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65ee89dcee614d2b7f5bf8fc5950e83050bf855ecb4d34f6e6214b09acc64572
3
+ size 3484716
demo/onnx-models/t1-256x10-distilled-swa-2432500.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5cf50f8d5e08c1f42c9a30bdcbd3a2deeef209f24ca4104d4170eda1b193bad
3
+ size 80896651
demo/utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import chess.svg
4
+
5
+ from lczerolens import LczeroBoard, LczeroModel, Lens
6
+
7
+ from . import constants
8
+
9
+
10
+ def create_board_figure(
11
+ board: LczeroBoard,
12
+ *,
13
+ orientation: bool = chess.WHITE,
14
+ arrows: str = "",
15
+ square: str = "",
16
+ name: str = "board",
17
+ ):
18
+ try:
19
+ if arrows:
20
+ arrows_list = arrows.split(" ")
21
+ chess_arrows = []
22
+ for arrow in arrows_list:
23
+ from_square, to_square = arrow[:2], arrow[2:]
24
+ chess_arrows.append(
25
+ (
26
+ chess.parse_square(from_square),
27
+ chess.parse_square(to_square),
28
+ )
29
+ )
30
+ else:
31
+ chess_arrows = []
32
+ except ValueError:
33
+ chess_arrows = []
34
+ gr.Warning("Invalid arrows, using none.")
35
+
36
+ try:
37
+ color_dict = {chess.parse_square(square): "#FF0000"} if square else {}
38
+ except ValueError:
39
+ color_dict = {}
40
+ gr.Warning("Invalid square, using none.")
41
+
42
+ svg_board = chess.svg.board(
43
+ board,
44
+ size=350,
45
+ orientation=orientation,
46
+ arrows=chess_arrows,
47
+ fill=color_dict,
48
+ )
49
+ with open(f"{constants.FIGURE_DIRECTORY}/{name}.svg", "w") as f:
50
+ f.write(svg_board)
51
+ return f"{constants.FIGURE_DIRECTORY}/{name}.svg"
52
+
53
+
54
+ class OutputLens(Lens):
55
+ def _intervene(self, model: LczeroModel, **kwargs) -> dict:
56
+ return model.output.save()
57
+
58
+ def get_info(model: LczeroModel, board: LczeroBoard):
59
+ lens = OutputLens()
60
+ output = lens.analyse(model, board)
61
+ w = output["wdl"][0,0]
62
+ d = output["wdl"][0,1]
63
+ l = output["wdl"][0,2]
64
+ legal_indices = board.get_legal_indices()
65
+ best_move_idx = output["policy"].gather(dim=1, index=legal_indices.unsqueeze(0)).argmax(dim=1).item()
66
+ best_move = board.decode_move(legal_indices[best_move_idx])
67
+ info = f"w: {w:.2f}, d: {d:.2f}, l: {l:.2f}, best: {best_move}"
68
+ return info
pyproject.toml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "lczerolens-demo"
3
+ version = "0.1.0"
4
+ description = "Demo lczerolens features."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "gdown>=5.2.0",
9
+ "gradio>=5.20.1",
10
+ "lczerolens[viz]>=0.3.3",
11
+ ]
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ lczerolens[viz]>=0.3.3
2
+ gdown
resolve-assets.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gdown 1cxC8_8vw7akfPyc9cZxwaAbLG2Zl4XiT -O demo/onnx-models/lc0-10-4238.onnx
2
+ gdown 15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X -O demo/onnx-models/lc0-19-1876.onnx
3
+ gdown 1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd -O demo/onnx-models/lc0-19-4508.onnx
4
+ gdown 1TI429e9mr2de7LjHp2IIl7ouMoUaDjjZ -O demo/onnx-models/maia-1100.onnx
5
+ gdown 1-8IJ5WYMPpcxOsHfIKY8xKskwk2z_yrY -O demo/onnx-models/maia-1900.onnx
uv.lock ADDED
The diff for this file is too large to render. See raw diff