Spaces:
Sleeping
Sleeping
Commit
·
7277ab2
unverified
·
0
Parent(s):
reset space
Browse files- .gitattributes +35 -0
- .gitignore +140 -0
- README.md +14 -0
- app.py +36 -0
- demo/__init__.py +0 -0
- demo/constants.py +21 -0
- demo/figures/.gitignore +2 -0
- demo/interfaces/__init__.py +0 -0
- demo/interfaces/activations.py +168 -0
- demo/interfaces/board.py +57 -0
- demo/interfaces/encodings.py +99 -0
- demo/interfaces/gradients.py +186 -0
- demo/interfaces/play.py +139 -0
- demo/leela-models/.gitignore +2 -0
- demo/onnx-models/lc0-10-4238.onnx +3 -0
- demo/onnx-models/lc0-19-1876.onnx +3 -0
- demo/onnx-models/lc0-19-4508.onnx +3 -0
- demo/onnx-models/look-ahead-lc0.onnx +3 -0
- demo/onnx-models/maia-1100.onnx +3 -0
- demo/onnx-models/maia-1200.onnx +3 -0
- demo/onnx-models/maia-1300.onnx +3 -0
- demo/onnx-models/maia-1400.onnx +3 -0
- demo/onnx-models/maia-1500.onnx +3 -0
- demo/onnx-models/maia-1600.onnx +3 -0
- demo/onnx-models/maia-1700.onnx +3 -0
- demo/onnx-models/maia-1800.onnx +3 -0
- demo/onnx-models/maia-1900.onnx +3 -0
- demo/onnx-models/t1-256x10-distilled-swa-2432500.onnx +3 -0
- demo/utils.py +68 -0
- pyproject.toml +11 -0
- requirements.txt +2 -0
- resolve-assets.sh +5 -0
- uv.lock +0 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pipenv
|
85 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
86 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
87 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
88 |
+
# install all needed dependencies.
|
89 |
+
#Pipfile.lock
|
90 |
+
|
91 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
92 |
+
__pypackages__/
|
93 |
+
|
94 |
+
# Celery stuff
|
95 |
+
celerybeat-schedule
|
96 |
+
celerybeat.pid
|
97 |
+
|
98 |
+
# SageMath parsed files
|
99 |
+
*.sage.py
|
100 |
+
|
101 |
+
# Environments
|
102 |
+
.env
|
103 |
+
.venv
|
104 |
+
env/
|
105 |
+
venv/
|
106 |
+
ENV/
|
107 |
+
env.bak/
|
108 |
+
venv.bak/
|
109 |
+
|
110 |
+
# Spyder project settings
|
111 |
+
.spyderproject
|
112 |
+
.spyproject
|
113 |
+
|
114 |
+
# Rope project settings
|
115 |
+
.ropeproject
|
116 |
+
|
117 |
+
# mkdocs documentation
|
118 |
+
/site
|
119 |
+
|
120 |
+
# mypy
|
121 |
+
.mypy_cache/
|
122 |
+
.dmypy.json
|
123 |
+
dmypy.json
|
124 |
+
|
125 |
+
# Pyre type checker
|
126 |
+
.pyre/
|
127 |
+
|
128 |
+
# Pickle files
|
129 |
+
*.pkl
|
130 |
+
|
131 |
+
# Various files
|
132 |
+
ignored
|
133 |
+
debug
|
134 |
+
*.zip
|
135 |
+
lc0
|
136 |
+
!bin/lc0
|
137 |
+
wandb
|
138 |
+
**/.DS_Store
|
139 |
+
|
140 |
+
*secret*
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Lczerolens Demo
|
3 |
+
emoji: 🔬
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.28.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
license: mit
|
11 |
+
short_description: Demo lczerolens features
|
12 |
+
---
|
13 |
+
|
14 |
+
See the documentation [here](https://lczerolens.readthedocs.io/).
|
app.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio demo for lczerolens.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
from demo.interfaces import (
|
8 |
+
board,
|
9 |
+
encodings,
|
10 |
+
gradients,
|
11 |
+
play,
|
12 |
+
activations,
|
13 |
+
)
|
14 |
+
|
15 |
+
demo = gr.TabbedInterface(
|
16 |
+
[
|
17 |
+
board.interface,
|
18 |
+
play.interface,
|
19 |
+
encodings.interface,
|
20 |
+
activations.interface,
|
21 |
+
gradients.interface,
|
22 |
+
],
|
23 |
+
[
|
24 |
+
"Board",
|
25 |
+
"Play",
|
26 |
+
"Encodings",
|
27 |
+
"Activations",
|
28 |
+
"Gradients",
|
29 |
+
],
|
30 |
+
title="lczerolens Demo",
|
31 |
+
analytics_enabled=False,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
demo.launch()
|
demo/__init__.py
ADDED
File without changes
|
demo/constants.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Constants for the demo.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
DEMO_DIRECTORY = Path(__file__).parent
|
9 |
+
|
10 |
+
ONNX_MODEL_DIRECTORY = DEMO_DIRECTORY / "onnx-models"
|
11 |
+
LEELA_MODEL_DIRECTORY = DEMO_DIRECTORY / "leela-models"
|
12 |
+
FIGURE_DIRECTORY = DEMO_DIRECTORY / "figures"
|
13 |
+
|
14 |
+
ONNX_MODEL_NAMES = [
|
15 |
+
f for f in os.listdir(ONNX_MODEL_DIRECTORY)
|
16 |
+
if f.endswith(".onnx")
|
17 |
+
]
|
18 |
+
LEELA_MODEL_NAMES = [
|
19 |
+
f for f in os.listdir(LEELA_MODEL_DIRECTORY)
|
20 |
+
if f.endswith(".pb.gz")
|
21 |
+
]
|
demo/figures/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
demo/interfaces/__init__.py
ADDED
File without changes
|
demo/interfaces/activations.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for plotting attention.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import chess
|
6 |
+
import chess.pgn
|
7 |
+
import io
|
8 |
+
import gradio as gr
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from lczerolens import LczeroBoard, LczeroModel, Lens
|
13 |
+
|
14 |
+
from demo import constants
|
15 |
+
from demo.utils import get_info
|
16 |
+
|
17 |
+
def get_model(model_name: str):
|
18 |
+
return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name))
|
19 |
+
|
20 |
+
def get_activations(model: LczeroModel, board: LczeroBoard):
|
21 |
+
lens = Lens.from_name("activation", "block\d/conv2/relu")
|
22 |
+
with torch.no_grad():
|
23 |
+
results = lens.analyse(model, board)
|
24 |
+
return [results[f"block{i}/conv2/relu_output"][0] for i in range(len(results))]
|
25 |
+
|
26 |
+
def get_board(game_pgn:str, board_fen:str):
|
27 |
+
if game_pgn:
|
28 |
+
try:
|
29 |
+
board = LczeroBoard()
|
30 |
+
pgn = io.StringIO(game_pgn)
|
31 |
+
game = chess.pgn.read_game(pgn)
|
32 |
+
for move in game.mainline_moves():
|
33 |
+
board.push(move)
|
34 |
+
except Exception as e:
|
35 |
+
print(e)
|
36 |
+
gr.Warning("Error parsing PGN, using starting position.")
|
37 |
+
board = LczeroBoard()
|
38 |
+
else:
|
39 |
+
try:
|
40 |
+
board = LczeroBoard(board_fen)
|
41 |
+
except Exception as e:
|
42 |
+
print(e)
|
43 |
+
gr.Warning("Invalid FEN, using starting position.")
|
44 |
+
board = LczeroBoard()
|
45 |
+
return board
|
46 |
+
|
47 |
+
def render_activations(board: LczeroBoard, activations, layer_index:int, channel_index:int):
|
48 |
+
if layer_index >= len(activations):
|
49 |
+
safe_layer_index = len(activations) - 1
|
50 |
+
gr.Warning(f"Layer index {layer_index} out of range, using last layer ({safe_layer_index}).")
|
51 |
+
else:
|
52 |
+
safe_layer_index = layer_index
|
53 |
+
if channel_index >= activations[safe_layer_index].shape[0]:
|
54 |
+
safe_channel_index = activations[safe_layer_index].shape[0] - 1
|
55 |
+
gr.Warning(f"Channel index {channel_index} out of range, using last channel ({safe_channel_index}).")
|
56 |
+
else:
|
57 |
+
safe_channel_index = channel_index
|
58 |
+
heatmap = activations[safe_layer_index][safe_channel_index].view(64)
|
59 |
+
board.render_heatmap(
|
60 |
+
heatmap,
|
61 |
+
save_to=f"{constants.FIGURE_DIRECTORY}/activations.svg",
|
62 |
+
)
|
63 |
+
return f"{constants.FIGURE_DIRECTORY}/activations_board.svg", f"{constants.FIGURE_DIRECTORY}/activations_colorbar.svg"
|
64 |
+
|
65 |
+
def initial_load(model_name: str, board_fen: str, game_pgn: str, layer_index: int, channel_index: int):
|
66 |
+
model = get_model(model_name)
|
67 |
+
board = get_board(game_pgn, board_fen)
|
68 |
+
activations = get_activations(model, board)
|
69 |
+
info = get_info(model, board)
|
70 |
+
plots = render_activations(board, activations, layer_index, channel_index)
|
71 |
+
return model, board, activations, info, *plots
|
72 |
+
|
73 |
+
def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, layer_index: int, channel_index: int):
|
74 |
+
board = get_board(game_pgn, board_fen)
|
75 |
+
activations = get_activations(model, board)
|
76 |
+
info = get_info(model, board)
|
77 |
+
plots = render_activations(board, activations, layer_index, channel_index)
|
78 |
+
return board, activations, info, *plots
|
79 |
+
|
80 |
+
def on_model_change(model_name: str, board: LczeroBoard, layer_index: int, channel_index: int):
|
81 |
+
model = get_model(model_name)
|
82 |
+
activations = get_activations(model, board)
|
83 |
+
info = get_info(model, board)
|
84 |
+
plots = render_activations(board, activations, layer_index, channel_index)
|
85 |
+
return model, activations, info, *plots
|
86 |
+
|
87 |
+
with gr.Blocks() as interface:
|
88 |
+
with gr.Row():
|
89 |
+
with gr.Column():
|
90 |
+
with gr.Group():
|
91 |
+
gr.Markdown(
|
92 |
+
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
|
93 |
+
)
|
94 |
+
game_pgn = gr.Textbox(
|
95 |
+
label="Game PGN",
|
96 |
+
lines=1,
|
97 |
+
value="",
|
98 |
+
)
|
99 |
+
board_fen = gr.Textbox(
|
100 |
+
label="Board FEN",
|
101 |
+
lines=1,
|
102 |
+
max_lines=1,
|
103 |
+
value=chess.STARTING_FEN,
|
104 |
+
)
|
105 |
+
model_name = gr.Dropdown(
|
106 |
+
label="Model",
|
107 |
+
choices=constants.ONNX_MODEL_NAMES,
|
108 |
+
)
|
109 |
+
with gr.Group():
|
110 |
+
info = gr.Textbox(label="Info", lines=1, value="")
|
111 |
+
with gr.Group():
|
112 |
+
layer_index = gr.Slider(
|
113 |
+
label="Layer index",
|
114 |
+
minimum=0,
|
115 |
+
maximum=19,
|
116 |
+
step=1,
|
117 |
+
value=0,
|
118 |
+
)
|
119 |
+
channel_index = gr.Slider(
|
120 |
+
label="Channel index",
|
121 |
+
minimum=0,
|
122 |
+
maximum=200,
|
123 |
+
step=1,
|
124 |
+
value=0,
|
125 |
+
)
|
126 |
+
with gr.Column():
|
127 |
+
image_board = gr.Image(label="Board", interactive=False)
|
128 |
+
colorbar = gr.Image(label="Colorbar", interactive=False)
|
129 |
+
|
130 |
+
model = gr.State(value=None)
|
131 |
+
board = gr.State(value=None)
|
132 |
+
activations = gr.State(value=None)
|
133 |
+
|
134 |
+
interface.load(
|
135 |
+
initial_load,
|
136 |
+
inputs=[model_name, game_pgn, board_fen, layer_index, channel_index],
|
137 |
+
outputs=[model, board, activations, info, image_board, colorbar],
|
138 |
+
concurrency_limit=1,
|
139 |
+
concurrency_id="trace_queue"
|
140 |
+
)
|
141 |
+
game_pgn.submit(
|
142 |
+
on_board_change,
|
143 |
+
inputs=[model, game_pgn, board_fen, layer_index, channel_index],
|
144 |
+
outputs=[board, activations, info, image_board, colorbar],
|
145 |
+
concurrency_id="trace_queue"
|
146 |
+
)
|
147 |
+
board_fen.submit(
|
148 |
+
on_board_change,
|
149 |
+
inputs=[model, game_pgn, board_fen, layer_index, channel_index],
|
150 |
+
outputs=[board, activations, info, image_board, colorbar],
|
151 |
+
concurrency_id="trace_queue"
|
152 |
+
)
|
153 |
+
model_name.change(
|
154 |
+
on_model_change,
|
155 |
+
inputs=[model_name, board, layer_index, channel_index],
|
156 |
+
outputs=[model, activations, info, image_board, colorbar],
|
157 |
+
concurrency_id="trace_queue"
|
158 |
+
)
|
159 |
+
layer_index.change(
|
160 |
+
render_activations,
|
161 |
+
inputs=[board, activations, layer_index, channel_index],
|
162 |
+
outputs=[image_board, colorbar],
|
163 |
+
)
|
164 |
+
channel_index.change(
|
165 |
+
render_activations,
|
166 |
+
inputs=[board, activations, layer_index, channel_index],
|
167 |
+
outputs=[image_board, colorbar],
|
168 |
+
)
|
demo/interfaces/board.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for plotting a board.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import chess
|
6 |
+
import chess.svg
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from lczerolens.board import LczeroBoard
|
10 |
+
|
11 |
+
from ..utils import create_board_figure
|
12 |
+
|
13 |
+
def make_board_plot(board_fen, arrows, square):
|
14 |
+
try:
|
15 |
+
board = LczeroBoard(board_fen)
|
16 |
+
except ValueError:
|
17 |
+
board = LczeroBoard()
|
18 |
+
gr.Warning("Invalid FEN, using starting position.")
|
19 |
+
filepath = create_board_figure(board, arrows=arrows, square=square, name="board")
|
20 |
+
return filepath
|
21 |
+
|
22 |
+
|
23 |
+
with gr.Blocks() as interface:
|
24 |
+
with gr.Row():
|
25 |
+
with gr.Column():
|
26 |
+
board_fen = gr.Textbox(
|
27 |
+
label="Board starting FEN",
|
28 |
+
lines=1,
|
29 |
+
max_lines=1,
|
30 |
+
value=chess.STARTING_FEN,
|
31 |
+
)
|
32 |
+
arrows = gr.Textbox(
|
33 |
+
label="Arrows",
|
34 |
+
lines=1,
|
35 |
+
max_lines=1,
|
36 |
+
value="",
|
37 |
+
placeholder="e2e4 e7e5",
|
38 |
+
)
|
39 |
+
square = gr.Textbox(
|
40 |
+
label="Square",
|
41 |
+
lines=1,
|
42 |
+
max_lines=1,
|
43 |
+
value="",
|
44 |
+
placeholder="e4",
|
45 |
+
)
|
46 |
+
with gr.Column():
|
47 |
+
image = gr.Image(label="Board", interactive=False)
|
48 |
+
|
49 |
+
inputs = [
|
50 |
+
board_fen,
|
51 |
+
arrows,
|
52 |
+
square,
|
53 |
+
]
|
54 |
+
interface.load(make_board_plot, inputs=inputs, outputs=image)
|
55 |
+
board_fen.submit(make_board_plot, inputs=inputs, outputs=image)
|
56 |
+
arrows.submit(make_board_plot, inputs=inputs, outputs=image)
|
57 |
+
square.submit(make_board_plot, inputs=inputs, outputs=image)
|
demo/interfaces/encodings.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for plotting attention.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import chess
|
6 |
+
import chess.pgn
|
7 |
+
import io
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from lczerolens.board import LczeroBoard
|
11 |
+
|
12 |
+
from ..constants import FIGURE_DIRECTORY
|
13 |
+
|
14 |
+
def make_render(game_pgn:str, board_fen:str, plane_index:int):
|
15 |
+
if game_pgn:
|
16 |
+
try:
|
17 |
+
board = LczeroBoard()
|
18 |
+
pgn = io.StringIO(game_pgn)
|
19 |
+
game = chess.pgn.read_game(pgn)
|
20 |
+
for move in game.mainline_moves():
|
21 |
+
board.push(move)
|
22 |
+
except Exception as e:
|
23 |
+
print(e)
|
24 |
+
gr.Warning("Error parsing PGN, using starting position.")
|
25 |
+
board = LczeroBoard()
|
26 |
+
else:
|
27 |
+
try:
|
28 |
+
board = LczeroBoard(board_fen)
|
29 |
+
except Exception as e:
|
30 |
+
print(e)
|
31 |
+
gr.Warning("Invalid FEN, using starting position.")
|
32 |
+
board = LczeroBoard()
|
33 |
+
return board, *make_board_plot(board, plane_index)
|
34 |
+
|
35 |
+
def make_board_plot(board:LczeroBoard, plane_index:int):
|
36 |
+
input_tensor = board.to_input_tensor()
|
37 |
+
board.render_heatmap(
|
38 |
+
input_tensor[plane_index].view(64),
|
39 |
+
save_to=f"{FIGURE_DIRECTORY}/encodings.svg",
|
40 |
+
vmin=0,
|
41 |
+
vmax=1,
|
42 |
+
)
|
43 |
+
return f"{FIGURE_DIRECTORY}/encodings_board.svg", f"{FIGURE_DIRECTORY}/encodings_colorbar.svg"
|
44 |
+
|
45 |
+
with gr.Blocks() as interface:
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column():
|
48 |
+
with gr.Group():
|
49 |
+
gr.Markdown(
|
50 |
+
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
|
51 |
+
)
|
52 |
+
game_pgn = gr.Textbox(
|
53 |
+
label="Game PGN",
|
54 |
+
lines=1,
|
55 |
+
value="",
|
56 |
+
)
|
57 |
+
board_fen = gr.Textbox(
|
58 |
+
label="Board FEN",
|
59 |
+
lines=1,
|
60 |
+
max_lines=1,
|
61 |
+
value=chess.STARTING_FEN,
|
62 |
+
)
|
63 |
+
with gr.Group():
|
64 |
+
with gr.Row():
|
65 |
+
plane_index = gr.Slider(
|
66 |
+
label="Plane index",
|
67 |
+
minimum=0,
|
68 |
+
maximum=111,
|
69 |
+
step=1,
|
70 |
+
value=0,
|
71 |
+
)
|
72 |
+
with gr.Column():
|
73 |
+
image_board = gr.Image(label="Board", interactive=False)
|
74 |
+
colorbar = gr.Image(label="Colorbar", interactive=False)
|
75 |
+
|
76 |
+
state_board = gr.State(value=LczeroBoard())
|
77 |
+
|
78 |
+
render_inputs = [game_pgn, board_fen, plane_index]
|
79 |
+
render_outputs = [state_board, image_board, colorbar]
|
80 |
+
interface.load(
|
81 |
+
make_render,
|
82 |
+
inputs=render_inputs,
|
83 |
+
outputs=render_outputs,
|
84 |
+
)
|
85 |
+
game_pgn.submit(
|
86 |
+
make_render,
|
87 |
+
inputs=render_inputs,
|
88 |
+
outputs=render_outputs,
|
89 |
+
)
|
90 |
+
board_fen.submit(
|
91 |
+
make_render,
|
92 |
+
inputs=render_inputs,
|
93 |
+
outputs=render_outputs,
|
94 |
+
)
|
95 |
+
plane_index.change(
|
96 |
+
make_board_plot,
|
97 |
+
inputs=[state_board, plane_index],
|
98 |
+
outputs=[image_board, colorbar],
|
99 |
+
)
|
demo/interfaces/gradients.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for plotting attention.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import chess
|
6 |
+
import chess.pgn
|
7 |
+
import io
|
8 |
+
import gradio as gr
|
9 |
+
import os
|
10 |
+
|
11 |
+
from lczerolens import LczeroBoard, LczeroModel, Lens
|
12 |
+
|
13 |
+
from demo import constants
|
14 |
+
from demo.utils import get_info
|
15 |
+
|
16 |
+
def get_model(model_name: str):
|
17 |
+
return LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name))
|
18 |
+
|
19 |
+
def get_gradients(model: LczeroModel, board: LczeroBoard, target: str):
|
20 |
+
lens = Lens.from_name("gradient")
|
21 |
+
|
22 |
+
def init_target(model):
|
23 |
+
if target == "best_move":
|
24 |
+
return getattr(model, "output/policy").output.max(dim=1).values
|
25 |
+
else:
|
26 |
+
wdl_index = {"win": 0, "draw": 1, "loss": 2}[target]
|
27 |
+
return getattr(model, "output/wdl").output[:, wdl_index]
|
28 |
+
results = lens.analyse(model, board, init_target=init_target)
|
29 |
+
|
30 |
+
return results["input_grad"]
|
31 |
+
|
32 |
+
def get_board(game_pgn:str, board_fen:str):
|
33 |
+
if game_pgn:
|
34 |
+
try:
|
35 |
+
board = LczeroBoard()
|
36 |
+
pgn = io.StringIO(game_pgn)
|
37 |
+
game = chess.pgn.read_game(pgn)
|
38 |
+
for move in game.mainline_moves():
|
39 |
+
board.push(move)
|
40 |
+
except Exception as e:
|
41 |
+
print(e)
|
42 |
+
gr.Warning("Error parsing PGN, using starting position.")
|
43 |
+
board = LczeroBoard()
|
44 |
+
else:
|
45 |
+
try:
|
46 |
+
board = LczeroBoard(board_fen)
|
47 |
+
except Exception as e:
|
48 |
+
print(e)
|
49 |
+
gr.Warning("Invalid FEN, using starting position.")
|
50 |
+
board = LczeroBoard()
|
51 |
+
return board
|
52 |
+
|
53 |
+
def render_gradients(board: LczeroBoard, gradients, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index:int):
|
54 |
+
if average_over_planes:
|
55 |
+
heatmap = gradients[0, begin_average_index:end_average_index].mean(dim=0).view(64)
|
56 |
+
else:
|
57 |
+
heatmap = gradients[0, plane_index].view(64)
|
58 |
+
board.render_heatmap(
|
59 |
+
heatmap,
|
60 |
+
save_to=f"{constants.FIGURE_DIRECTORY}/gradients.svg",
|
61 |
+
)
|
62 |
+
return f"{constants.FIGURE_DIRECTORY}/gradients_board.svg", f"{constants.FIGURE_DIRECTORY}/gradients_colorbar.svg"
|
63 |
+
|
64 |
+
def initial_load(model_name: str, board_fen: str, game_pgn: str, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
|
65 |
+
model = get_model(model_name)
|
66 |
+
board = get_board(game_pgn, board_fen)
|
67 |
+
gradients = get_gradients(model, board, target)
|
68 |
+
info = get_info(model, board)
|
69 |
+
plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
|
70 |
+
return model, board, gradients, info, *plots
|
71 |
+
|
72 |
+
def on_board_change(model: LczeroModel, game_pgn: str, board_fen: str, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
|
73 |
+
board = get_board(game_pgn, board_fen)
|
74 |
+
gradients = get_gradients(model, board, target)
|
75 |
+
info = get_info(model, board)
|
76 |
+
plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
|
77 |
+
return board, gradients, info, *plots
|
78 |
+
|
79 |
+
def on_model_change(model_name: str, board: LczeroBoard, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
|
80 |
+
model = get_model(model_name)
|
81 |
+
gradients = get_gradients(model, board, target)
|
82 |
+
info = get_info(model, board)
|
83 |
+
plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
|
84 |
+
return model, gradients, info, *plots
|
85 |
+
|
86 |
+
def on_target_change(model: LczeroModel, board: LczeroBoard, target: str, average_over_planes:bool, begin_average_index:int, end_average_index:int, plane_index: int):
|
87 |
+
gradients = get_gradients(model, board, target)
|
88 |
+
plots = render_gradients(board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index)
|
89 |
+
return gradients, *plots
|
90 |
+
|
91 |
+
with gr.Blocks() as interface:
|
92 |
+
with gr.Row():
|
93 |
+
with gr.Column():
|
94 |
+
with gr.Group():
|
95 |
+
gr.Markdown(
|
96 |
+
"Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
|
97 |
+
)
|
98 |
+
game_pgn = gr.Textbox(
|
99 |
+
label="Game PGN",
|
100 |
+
lines=1,
|
101 |
+
value="",
|
102 |
+
)
|
103 |
+
board_fen = gr.Textbox(
|
104 |
+
label="Board FEN",
|
105 |
+
lines=1,
|
106 |
+
max_lines=1,
|
107 |
+
value=chess.STARTING_FEN,
|
108 |
+
)
|
109 |
+
model_name = gr.Dropdown(
|
110 |
+
label="Model",
|
111 |
+
choices=constants.ONNX_MODEL_NAMES,
|
112 |
+
)
|
113 |
+
with gr.Group():
|
114 |
+
info = gr.Textbox(label="Info", lines=1, value="")
|
115 |
+
with gr.Group():
|
116 |
+
target = gr.Radio(
|
117 |
+
["win", "draw", "loss", "best_move"], label="Target",
|
118 |
+
value="win",
|
119 |
+
)
|
120 |
+
average_over_planes = gr.Checkbox(label="Average over Planes", value=False)
|
121 |
+
with gr.Accordion("Average over planes", open=False):
|
122 |
+
begin_average_index = gr.Slider(
|
123 |
+
label="Begin average index",
|
124 |
+
minimum=0,
|
125 |
+
maximum=111,
|
126 |
+
step=1,
|
127 |
+
value=0,
|
128 |
+
)
|
129 |
+
end_average_index = gr.Slider(
|
130 |
+
label="End average index",
|
131 |
+
minimum=0,
|
132 |
+
maximum=111,
|
133 |
+
step=1,
|
134 |
+
value=111,
|
135 |
+
)
|
136 |
+
plane_index = gr.Slider(
|
137 |
+
label="Plane index",
|
138 |
+
minimum=0,
|
139 |
+
maximum=111,
|
140 |
+
step=1,
|
141 |
+
value=0,
|
142 |
+
)
|
143 |
+
with gr.Column():
|
144 |
+
image_board = gr.Image(label="Board", interactive=False)
|
145 |
+
colorbar = gr.Image(label="Colorbar", interactive=False)
|
146 |
+
|
147 |
+
model = gr.State(value=None)
|
148 |
+
board = gr.State(value=None)
|
149 |
+
gradients = gr.State(value=None)
|
150 |
+
|
151 |
+
interface.load(
|
152 |
+
initial_load,
|
153 |
+
inputs=[model_name, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index],
|
154 |
+
outputs=[model, board, gradients, info, image_board, colorbar],
|
155 |
+
concurrency_id="trace_queue"
|
156 |
+
)
|
157 |
+
game_pgn.submit(
|
158 |
+
on_board_change,
|
159 |
+
inputs=[model, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index],
|
160 |
+
outputs=[board, gradients, info, image_board, colorbar],
|
161 |
+
concurrency_id="trace_queue"
|
162 |
+
)
|
163 |
+
board_fen.submit(
|
164 |
+
on_board_change,
|
165 |
+
inputs=[model, game_pgn, board_fen, target, average_over_planes, begin_average_index, end_average_index, plane_index],
|
166 |
+
outputs=[board, gradients, info, image_board, colorbar],
|
167 |
+
concurrency_id="trace_queue"
|
168 |
+
)
|
169 |
+
model_name.change(
|
170 |
+
on_model_change,
|
171 |
+
inputs=[model_name, board, target, average_over_planes, begin_average_index, end_average_index, plane_index],
|
172 |
+
outputs=[model, gradients, info, image_board, colorbar],
|
173 |
+
concurrency_id="trace_queue"
|
174 |
+
)
|
175 |
+
target.change(
|
176 |
+
on_target_change,
|
177 |
+
inputs=[model, board, target, average_over_planes, begin_average_index, end_average_index, plane_index],
|
178 |
+
outputs=[gradients, image_board, colorbar],
|
179 |
+
concurrency_id="trace_queue"
|
180 |
+
)
|
181 |
+
for render_arg in [average_over_planes, begin_average_index, end_average_index, plane_index]:
|
182 |
+
render_arg.change(
|
183 |
+
render_gradients,
|
184 |
+
inputs=[board, gradients, average_over_planes, begin_average_index, end_average_index, plane_index],
|
185 |
+
outputs=[image_board, colorbar],
|
186 |
+
)
|
demo/interfaces/play.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Interface to play against the model.
|
2 |
+
"""
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
import chess
|
7 |
+
import chess.pgn
|
8 |
+
import random
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from lczerolens import LczeroBoard, LczeroModel
|
12 |
+
from lczerolens.play import PolicySampler
|
13 |
+
|
14 |
+
from demo import constants
|
15 |
+
from demo.utils import create_board_figure
|
16 |
+
|
17 |
+
|
18 |
+
def get_sampler(model_name: str):
|
19 |
+
model = LczeroModel.from_onnx_path(os.path.join(constants.ONNX_MODEL_DIRECTORY, model_name))
|
20 |
+
return PolicySampler(model)
|
21 |
+
|
22 |
+
def get_pgn(board: LczeroBoard):
|
23 |
+
game = chess.pgn.Game()
|
24 |
+
for move in board.move_stack:
|
25 |
+
game.add_variation(move)
|
26 |
+
return str(game).split("\n")[-1]
|
27 |
+
|
28 |
+
def render_board(
|
29 |
+
board: LczeroBoard,
|
30 |
+
):
|
31 |
+
player = board.turn
|
32 |
+
if len(board.move_stack) > 0:
|
33 |
+
last_move_uci = board.peek().uci()
|
34 |
+
else:
|
35 |
+
last_move_uci = None
|
36 |
+
|
37 |
+
if board.is_check():
|
38 |
+
check = board.king(board.turn)
|
39 |
+
else:
|
40 |
+
check = None
|
41 |
+
filepath = create_board_figure(
|
42 |
+
board,
|
43 |
+
orientation=player,
|
44 |
+
arrows=last_move_uci,
|
45 |
+
square=check,
|
46 |
+
name="play_board",
|
47 |
+
)
|
48 |
+
return filepath
|
49 |
+
|
50 |
+
def gather_outputs(board: LczeroBoard, sampler: PolicySampler):
|
51 |
+
return sampler, board, board.fen(), get_pgn(board), render_board(board), ""
|
52 |
+
|
53 |
+
def get_init(model_name: str):
|
54 |
+
sampler = get_sampler(model_name)
|
55 |
+
is_ai_white = random.choice([True, False])
|
56 |
+
init_board = LczeroBoard()
|
57 |
+
if is_ai_white:
|
58 |
+
play_ai_move(init_board, sampler)
|
59 |
+
return gather_outputs(init_board, sampler)
|
60 |
+
|
61 |
+
def play_user_move_then_ai_move(
|
62 |
+
uci_move: str,
|
63 |
+
board: LczeroBoard,
|
64 |
+
sampler: PolicySampler,
|
65 |
+
):
|
66 |
+
board.push_uci(uci_move)
|
67 |
+
play_ai_move(board, sampler)
|
68 |
+
return gather_outputs(board, sampler)
|
69 |
+
|
70 |
+
|
71 |
+
def play_ai_move(
|
72 |
+
board: LczeroBoard,
|
73 |
+
sampler: PolicySampler,
|
74 |
+
):
|
75 |
+
move, _ = next(iter(sampler.get_next_moves([board])))
|
76 |
+
board.push(move)
|
77 |
+
|
78 |
+
with gr.Blocks() as interface:
|
79 |
+
with gr.Row():
|
80 |
+
with gr.Column():
|
81 |
+
current_fen = gr.Textbox(
|
82 |
+
label="Board FEN",
|
83 |
+
lines=1,
|
84 |
+
max_lines=1,
|
85 |
+
value=chess.STARTING_FEN,
|
86 |
+
)
|
87 |
+
current_pgn = gr.Textbox(
|
88 |
+
label="Action sequence",
|
89 |
+
lines=1,
|
90 |
+
value="",
|
91 |
+
)
|
92 |
+
model_name = gr.Dropdown(
|
93 |
+
label="Model",
|
94 |
+
choices=constants.ONNX_MODEL_NAMES,
|
95 |
+
)
|
96 |
+
with gr.Column():
|
97 |
+
move_to_play = gr.Textbox(
|
98 |
+
label="Move to play (UCI)",
|
99 |
+
lines=1,
|
100 |
+
max_lines=1,
|
101 |
+
value="",
|
102 |
+
)
|
103 |
+
play_button = gr.Button("Play")
|
104 |
+
reset_button = gr.Button("Reset")
|
105 |
+
with gr.Column():
|
106 |
+
image_board = gr.Image(label="Board", interactive=False)
|
107 |
+
|
108 |
+
sampler = gr.State(value=None)
|
109 |
+
board = gr.State(value=None)
|
110 |
+
|
111 |
+
outputs = [sampler, board, current_fen, current_pgn, image_board, move_to_play]
|
112 |
+
|
113 |
+
play_button.click(
|
114 |
+
play_user_move_then_ai_move,
|
115 |
+
inputs=[move_to_play, board, sampler],
|
116 |
+
outputs=outputs,
|
117 |
+
)
|
118 |
+
move_to_play.submit(
|
119 |
+
play_user_move_then_ai_move,
|
120 |
+
inputs=[move_to_play, board, sampler],
|
121 |
+
outputs=outputs,
|
122 |
+
)
|
123 |
+
|
124 |
+
model_name.change(
|
125 |
+
get_sampler,
|
126 |
+
inputs=[model_name],
|
127 |
+
outputs=[sampler],
|
128 |
+
)
|
129 |
+
|
130 |
+
reset_button.click(
|
131 |
+
get_init,
|
132 |
+
inputs=[model_name],
|
133 |
+
outputs=outputs,
|
134 |
+
)
|
135 |
+
interface.load(
|
136 |
+
get_init,
|
137 |
+
inputs=[model_name],
|
138 |
+
outputs=outputs,
|
139 |
+
)
|
demo/leela-models/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
demo/onnx-models/lc0-10-4238.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:096bc5b833572de5357961f8cb2d059fe4f0ff303598131f20278fb23bc58919
|
3 |
+
size 15152179
|
demo/onnx-models/lc0-19-1876.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df7e2b87f0f4525d7de7201ee2e946fc10172200ef97c6786b3fdcfb8610b88a
|
3 |
+
size 97139290
|
demo/onnx-models/lc0-19-4508.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bcf7c13451f00fc9d1859f98a52a37056aa3d10ac7f874408ca46d0cde50c241
|
3 |
+
size 97139290
|
demo/onnx-models/look-ahead-lc0.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2a3dad19ed0546106e3b97d7aa7dbf0e3cc7041a6c63979f923360ee71cb4a24
|
3 |
+
size 378669573
|
demo/onnx-models/maia-1100.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b1a4ac9b99aee3c127285b7cdc2024d054a6de72abdfa2fe3db5e4f7078c96ad
|
3 |
+
size 3484716
|
demo/onnx-models/maia-1200.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:82c21e9cf65f0de9a72469eaae8845c3e92a32fe93e1e626a064cd3d24cca4cf
|
3 |
+
size 3483901
|
demo/onnx-models/maia-1300.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d7c3b8d9da7a2cc586f409b0264abcb9d00390276ffe5f39a09d591787555e50
|
3 |
+
size 3483901
|
demo/onnx-models/maia-1400.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3651ae981d8bc05cf735c79d1a51298d91d3450592ac874d9c2c66e6c90b896b
|
3 |
+
size 3483901
|
demo/onnx-models/maia-1500.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94333b25bb8d25685334645afbb7494b358ae5724ad651d735ae1683fb0a9e3f
|
3 |
+
size 3483901
|
demo/onnx-models/maia-1600.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d46d32c6c1ff2a9f04df5c0652e50ae502b8870c6e0895e6abeeb02bba01d423
|
3 |
+
size 3483901
|
demo/onnx-models/maia-1700.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b01ef1c6bf0136480526f2933f88623bc5b1e2bbf319fc6045be68524f942d8
|
3 |
+
size 3483901
|
demo/onnx-models/maia-1800.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b9b999de7d0fe4b2efc9d6591b30f4010ec7490c36676dd920ee0420f202a67
|
3 |
+
size 3483901
|
demo/onnx-models/maia-1900.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65ee89dcee614d2b7f5bf8fc5950e83050bf855ecb4d34f6e6214b09acc64572
|
3 |
+
size 3484716
|
demo/onnx-models/t1-256x10-distilled-swa-2432500.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5cf50f8d5e08c1f42c9a30bdcbd3a2deeef209f24ca4104d4170eda1b193bad
|
3 |
+
size 80896651
|
demo/utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import chess.svg
|
4 |
+
|
5 |
+
from lczerolens import LczeroBoard, LczeroModel, Lens
|
6 |
+
|
7 |
+
from . import constants
|
8 |
+
|
9 |
+
|
10 |
+
def create_board_figure(
|
11 |
+
board: LczeroBoard,
|
12 |
+
*,
|
13 |
+
orientation: bool = chess.WHITE,
|
14 |
+
arrows: str = "",
|
15 |
+
square: str = "",
|
16 |
+
name: str = "board",
|
17 |
+
):
|
18 |
+
try:
|
19 |
+
if arrows:
|
20 |
+
arrows_list = arrows.split(" ")
|
21 |
+
chess_arrows = []
|
22 |
+
for arrow in arrows_list:
|
23 |
+
from_square, to_square = arrow[:2], arrow[2:]
|
24 |
+
chess_arrows.append(
|
25 |
+
(
|
26 |
+
chess.parse_square(from_square),
|
27 |
+
chess.parse_square(to_square),
|
28 |
+
)
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
chess_arrows = []
|
32 |
+
except ValueError:
|
33 |
+
chess_arrows = []
|
34 |
+
gr.Warning("Invalid arrows, using none.")
|
35 |
+
|
36 |
+
try:
|
37 |
+
color_dict = {chess.parse_square(square): "#FF0000"} if square else {}
|
38 |
+
except ValueError:
|
39 |
+
color_dict = {}
|
40 |
+
gr.Warning("Invalid square, using none.")
|
41 |
+
|
42 |
+
svg_board = chess.svg.board(
|
43 |
+
board,
|
44 |
+
size=350,
|
45 |
+
orientation=orientation,
|
46 |
+
arrows=chess_arrows,
|
47 |
+
fill=color_dict,
|
48 |
+
)
|
49 |
+
with open(f"{constants.FIGURE_DIRECTORY}/{name}.svg", "w") as f:
|
50 |
+
f.write(svg_board)
|
51 |
+
return f"{constants.FIGURE_DIRECTORY}/{name}.svg"
|
52 |
+
|
53 |
+
|
54 |
+
class OutputLens(Lens):
|
55 |
+
def _intervene(self, model: LczeroModel, **kwargs) -> dict:
|
56 |
+
return model.output.save()
|
57 |
+
|
58 |
+
def get_info(model: LczeroModel, board: LczeroBoard):
|
59 |
+
lens = OutputLens()
|
60 |
+
output = lens.analyse(model, board)
|
61 |
+
w = output["wdl"][0,0]
|
62 |
+
d = output["wdl"][0,1]
|
63 |
+
l = output["wdl"][0,2]
|
64 |
+
legal_indices = board.get_legal_indices()
|
65 |
+
best_move_idx = output["policy"].gather(dim=1, index=legal_indices.unsqueeze(0)).argmax(dim=1).item()
|
66 |
+
best_move = board.decode_move(legal_indices[best_move_idx])
|
67 |
+
info = f"w: {w:.2f}, d: {d:.2f}, l: {l:.2f}, best: {best_move}"
|
68 |
+
return info
|
pyproject.toml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "lczerolens-demo"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Demo lczerolens features."
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.11"
|
7 |
+
dependencies = [
|
8 |
+
"gdown>=5.2.0",
|
9 |
+
"gradio>=5.20.1",
|
10 |
+
"lczerolens[viz]>=0.3.3",
|
11 |
+
]
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
lczerolens[viz]>=0.3.3
|
2 |
+
gdown
|
resolve-assets.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gdown 1cxC8_8vw7akfPyc9cZxwaAbLG2Zl4XiT -O demo/onnx-models/lc0-10-4238.onnx
|
2 |
+
gdown 15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X -O demo/onnx-models/lc0-19-1876.onnx
|
3 |
+
gdown 1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd -O demo/onnx-models/lc0-19-4508.onnx
|
4 |
+
gdown 1TI429e9mr2de7LjHp2IIl7ouMoUaDjjZ -O demo/onnx-models/maia-1100.onnx
|
5 |
+
gdown 1-8IJ5WYMPpcxOsHfIKY8xKskwk2z_yrY -O demo/onnx-models/maia-1900.onnx
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|