Spaces:
Runtime error
Runtime error
new repo structure
Browse files- .gitignore +139 -0
- app.py +4 -0
- assets/.gitignore +2 -0
- figures/.gitignore +2 -0
- requirements.txt +7 -0
- src/constants.py +18 -0
- src/global_variables.py +54 -0
- src/helpers/__init__.py +4 -0
- src/helpers/generator.py +59 -0
- src/helpers/sae.py +93 -0
- src/interfaces/__init__.py +2 -0
- src/interfaces/feature_interface.py +121 -0
- src/interfaces/stats_interface.py +0 -0
- src/visualisation.py +108 -0
.gitignore
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pipenv
|
85 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
86 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
87 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
88 |
+
# install all needed dependencies.
|
89 |
+
#Pipfile.lock
|
90 |
+
|
91 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
92 |
+
__pypackages__/
|
93 |
+
|
94 |
+
# Celery stuff
|
95 |
+
celerybeat-schedule
|
96 |
+
celerybeat.pid
|
97 |
+
|
98 |
+
# SageMath parsed files
|
99 |
+
*.sage.py
|
100 |
+
|
101 |
+
# Environments
|
102 |
+
.env
|
103 |
+
.venv
|
104 |
+
env/
|
105 |
+
venv/
|
106 |
+
ENV/
|
107 |
+
env.bak/
|
108 |
+
venv.bak/
|
109 |
+
|
110 |
+
# Spyder project settings
|
111 |
+
.spyderproject
|
112 |
+
.spyproject
|
113 |
+
|
114 |
+
# Rope project settings
|
115 |
+
.ropeproject
|
116 |
+
|
117 |
+
# mkdocs documentation
|
118 |
+
/site
|
119 |
+
|
120 |
+
# mypy
|
121 |
+
.mypy_cache/
|
122 |
+
.dmypy.json
|
123 |
+
dmypy.json
|
124 |
+
|
125 |
+
# Pyre type checker
|
126 |
+
.pyre/
|
127 |
+
|
128 |
+
# Pickle files
|
129 |
+
*.pkl
|
130 |
+
|
131 |
+
# Various files
|
132 |
+
ignored
|
133 |
+
debug
|
134 |
+
*.zip
|
135 |
+
lc0
|
136 |
+
!bin/lc0
|
137 |
+
wandb
|
138 |
+
|
139 |
+
*secret*
|
app.py
CHANGED
@@ -4,11 +4,15 @@ Main Gradio module.
|
|
4 |
|
5 |
import gradio as gr
|
6 |
|
|
|
|
|
7 |
|
8 |
demo = gr.TabbedInterface(
|
9 |
[
|
|
|
10 |
],
|
11 |
[
|
|
|
12 |
],
|
13 |
title="Lczero Planning Demo",
|
14 |
analytics_enabled=False,
|
|
|
4 |
|
5 |
import gradio as gr
|
6 |
|
7 |
+
from src.interfaces import feature_interface
|
8 |
+
|
9 |
|
10 |
demo = gr.TabbedInterface(
|
11 |
[
|
12 |
+
feature_interface,
|
13 |
],
|
14 |
[
|
15 |
+
"Feature Activation",
|
16 |
],
|
17 |
title="Lczero Planning Demo",
|
18 |
analytics_enabled=False,
|
assets/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
figures/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/Xmaster6y/lczerolens
|
2 |
+
chess
|
3 |
+
matplotlib
|
4 |
+
numpy
|
5 |
+
torch
|
6 |
+
tensordict
|
7 |
+
einops
|
src/constants.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Manage constants for the app.
|
2 |
+
"""
|
3 |
+
|
4 |
+
import os
|
5 |
+
import pathlib
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
ASSETS_FOLDER = pathlib.Path(__file__).parent.parent / "assets"
|
10 |
+
FIGURES_FOLER = pathlib.Path(__file__).parent.parent / "figures"
|
11 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
12 |
+
|
13 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
|
15 |
+
MODEL_NAME = "lc0-10-4238.onnx"
|
16 |
+
SAE_CONFIG = "debug"
|
17 |
+
LAYER = 9
|
18 |
+
N_FEATURES = 7680
|
src/global_variables.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Manage global variables for the app.
|
2 |
+
"""
|
3 |
+
|
4 |
+
from huggingface_hub import HfApi
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
from lczerolens import ModelWrapper
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from src.constants import HF_TOKEN, ASSETS_FOLDER, DEVICE, MODEL_NAME, SAE_CONFIG, LAYER
|
11 |
+
from src.helpers import SparseAutoEncoder, OutputGenerator
|
12 |
+
|
13 |
+
hf_api: HfApi
|
14 |
+
wrapper: ModelWrapper
|
15 |
+
sae: SparseAutoEncoder
|
16 |
+
generator: OutputGenerator
|
17 |
+
|
18 |
+
|
19 |
+
def setup():
|
20 |
+
global hf_api
|
21 |
+
global wrapper
|
22 |
+
global sae
|
23 |
+
global generator
|
24 |
+
|
25 |
+
hf_api = HfApi(token=HF_TOKEN)
|
26 |
+
hf_api.snapshot_download(
|
27 |
+
local_dir=f"{ASSETS_FOLDER}/models",
|
28 |
+
repo_id="Xmaster6y/lczero-planning-models",
|
29 |
+
repo_type="model",
|
30 |
+
)
|
31 |
+
hf_api.snapshot_download(
|
32 |
+
local_dir=f"{ASSETS_FOLDER}/saes",
|
33 |
+
repo_id="Xmaster6y/lczero-planning-saes",
|
34 |
+
repo_type="model",
|
35 |
+
)
|
36 |
+
|
37 |
+
wrapper = ModelWrapper.from_onnx_path(f"{ASSETS_FOLDER}/models/{MODEL_NAME}").to(DEVICE)
|
38 |
+
sae_dict = torch.load(
|
39 |
+
f"{ASSETS_FOLDER}/saes/{SAE_CONFIG}/model.pt",
|
40 |
+
map_location=DEVICE,
|
41 |
+
weights_only=True
|
42 |
+
)
|
43 |
+
sae = SparseAutoEncoder()
|
44 |
+
sae.load_state_dict(
|
45 |
+
sae_dict
|
46 |
+
)
|
47 |
+
generator = OutputGenerator(
|
48 |
+
sae=sae,
|
49 |
+
wrapper=wrapper,
|
50 |
+
module_exp=rf".*block{LAYER}/conv2/relu"
|
51 |
+
)
|
52 |
+
|
53 |
+
if gr.NO_RELOAD:
|
54 |
+
setup()
|
src/helpers/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from .generator import OutputGenerator
|
4 |
+
from .sae import SparseAutoEncoder
|
src/helpers/generator.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Script to generate features for a given board state.
|
2 |
+
"""
|
3 |
+
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from lczerolens import ModelWrapper
|
7 |
+
from lczerolens.xai import ActivationLens
|
8 |
+
from lczerolens.encodings import InputEncoding
|
9 |
+
import chess
|
10 |
+
import einops
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from .sae import SparseAutoEncoder
|
14 |
+
|
15 |
+
|
16 |
+
class OutputGenerator:
|
17 |
+
|
18 |
+
def __init__(self, sae: SparseAutoEncoder, wrapper: ModelWrapper, module_exp: Optional[str] = None):
|
19 |
+
self.sae = sae
|
20 |
+
self.wrapper = wrapper
|
21 |
+
self.lens = ActivationLens(module_exp=module_exp)
|
22 |
+
|
23 |
+
def generate(
|
24 |
+
self,
|
25 |
+
root_fen: Optional[str] = None,
|
26 |
+
traj_fen: Optional[str] = None,
|
27 |
+
root_board: Optional[chess.Board] = None,
|
28 |
+
traj_board: Optional[chess.Board] = None,
|
29 |
+
):
|
30 |
+
if root_board is not None and traj_board is not None:
|
31 |
+
input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE
|
32 |
+
elif root_fen is not None and traj_fen is not None:
|
33 |
+
root_board = chess.Board(root_fen)
|
34 |
+
traj_board = chess.Board(traj_fen)
|
35 |
+
input_encoding = InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED
|
36 |
+
else:
|
37 |
+
raise ValueError
|
38 |
+
iter_boards = iter([[root_board, traj_board]])
|
39 |
+
act_dict, (model_output,) = self.lens.analyse_batched_boards(
|
40 |
+
iter_boards,
|
41 |
+
self.wrapper,
|
42 |
+
{
|
43 |
+
"return_output": True,
|
44 |
+
"wrapper_kwargs": {
|
45 |
+
"input_encoding": input_encoding,
|
46 |
+
}
|
47 |
+
}
|
48 |
+
)
|
49 |
+
if len(act_dict) == 0:
|
50 |
+
raise ValueError("No module matced the given expression.")
|
51 |
+
elif len(act_dict) > 1:
|
52 |
+
raise ValueError("Multiple modules matched the given expression.")
|
53 |
+
acts = next(iter(act_dict.values()))
|
54 |
+
root_acts = einops.rearrange(acts[0], "c h w -> (h w) c")
|
55 |
+
traj_acts = einops.rearrange(acts[1], "c h w -> (h w) c")
|
56 |
+
pixel_acts = torch.cat([root_acts, traj_acts], dim=1)
|
57 |
+
sae_output = self.sae(pixel_acts, output_features=True)
|
58 |
+
return model_output, pixel_acts, sae_output
|
59 |
+
|
src/helpers/sae.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Defines the dictionary classes
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from tensordict import TensorDict
|
8 |
+
|
9 |
+
|
10 |
+
class SparseAutoEncoder(nn.Module):
|
11 |
+
"""
|
12 |
+
A 2-layer sparse autoencoder.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
activation_dim,
|
18 |
+
dict_size,
|
19 |
+
pre_bias=False,
|
20 |
+
init_normalise_dict=None,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
self.activation_dim = activation_dim
|
24 |
+
self.dict_size = dict_size
|
25 |
+
self.pre_bias = pre_bias
|
26 |
+
self.init_normalise_dict = init_normalise_dict
|
27 |
+
|
28 |
+
self.b_enc = nn.Parameter(torch.zeros(self.dict_size))
|
29 |
+
self.relu = nn.ReLU()
|
30 |
+
|
31 |
+
self.W_dec = nn.Parameter(
|
32 |
+
torch.nn.init.kaiming_uniform_(
|
33 |
+
torch.empty(
|
34 |
+
self.dict_size,
|
35 |
+
self.activation_dim,
|
36 |
+
)
|
37 |
+
)
|
38 |
+
)
|
39 |
+
if init_normalise_dict == "l2":
|
40 |
+
self.normalize_dict_(less_than_1=False)
|
41 |
+
self.W_dec *= 0.1
|
42 |
+
elif init_normalise_dict == "less_than_1":
|
43 |
+
self.normalize_dict_(less_than_1=True)
|
44 |
+
|
45 |
+
self.W_enc = nn.Parameter(self.W_dec.t())
|
46 |
+
self.b_dec = nn.Parameter(
|
47 |
+
torch.zeros(
|
48 |
+
self.activation_dim,
|
49 |
+
)
|
50 |
+
)
|
51 |
+
|
52 |
+
@torch.no_grad()
|
53 |
+
def normalize_dict_(
|
54 |
+
self,
|
55 |
+
less_than_1=False,
|
56 |
+
):
|
57 |
+
norm = self.W_dec.norm(dim=1)
|
58 |
+
positive_mask = norm != 0
|
59 |
+
if less_than_1:
|
60 |
+
greater_than_1_mask = (norm > 1) & (positive_mask)
|
61 |
+
self.W_dec[greater_than_1_mask] /= norm[greater_than_1_mask].unsqueeze(1)
|
62 |
+
else:
|
63 |
+
self.W_dec[positive_mask] /= norm[positive_mask].unsqueeze(1)
|
64 |
+
|
65 |
+
def encode(self, x):
|
66 |
+
return x @ self.W_enc + self.b_enc
|
67 |
+
|
68 |
+
def decode(self, f):
|
69 |
+
return f @ self.W_dec + self.b_dec
|
70 |
+
|
71 |
+
def forward(self, x, output_features=False, ghost_mask=None):
|
72 |
+
"""
|
73 |
+
Forward pass of an autoencoder.
|
74 |
+
x : activations to be autoencoded
|
75 |
+
output_features : if True, return the encoded features as well
|
76 |
+
as the decoded x
|
77 |
+
ghost_mask : if not None, run this autoencoder in "ghost mode"
|
78 |
+
where features are masked
|
79 |
+
"""
|
80 |
+
if self.pre_bias:
|
81 |
+
x = x - self.b_dec
|
82 |
+
f_pre = self.encode(x)
|
83 |
+
out = TensorDict({}, batch_size=x.shape[0])
|
84 |
+
if ghost_mask is not None:
|
85 |
+
f_ghost = torch.exp(f_pre) * ghost_mask.to(f_pre)
|
86 |
+
x_ghost = f_ghost @ self.W_dec
|
87 |
+
out["x_ghost"] = x_ghost
|
88 |
+
f = self.relu(f_pre)
|
89 |
+
if output_features:
|
90 |
+
out["features"] = f
|
91 |
+
x_hat = self.decode(f)
|
92 |
+
out["x_hat"] = x_hat
|
93 |
+
return out
|
src/interfaces/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .feature_interface import interface as feature_interface
|
src/interfaces/feature_interface.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio interface for plotting policy.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import chess
|
6 |
+
import gradio as gr
|
7 |
+
import uuid
|
8 |
+
|
9 |
+
from lczerolens.encodings import encode_move
|
10 |
+
|
11 |
+
from src import constants, global_variables, visualisation
|
12 |
+
|
13 |
+
|
14 |
+
def compute_features_fn(
|
15 |
+
features,
|
16 |
+
model_output,
|
17 |
+
file_id,
|
18 |
+
root_fen,
|
19 |
+
traj_fen,
|
20 |
+
feature_index
|
21 |
+
):
|
22 |
+
model_output, _, sae_output = global_variables.generator.generate(
|
23 |
+
root_fen=root_fen,
|
24 |
+
traj_fen=traj_fen
|
25 |
+
)
|
26 |
+
features = sae_output["f"]
|
27 |
+
first_output = render_feature_index(
|
28 |
+
features,
|
29 |
+
model_output,
|
30 |
+
file_id,
|
31 |
+
feature_index,
|
32 |
+
traj_fen,
|
33 |
+
)
|
34 |
+
game_info = f"WDL: {model_output.get('wdl')}"
|
35 |
+
return *first_output, game_info
|
36 |
+
|
37 |
+
|
38 |
+
def render_feature_index(
|
39 |
+
features,
|
40 |
+
model_output,
|
41 |
+
file_id,
|
42 |
+
feature_index,
|
43 |
+
traj_fen,
|
44 |
+
):
|
45 |
+
if file_id is None:
|
46 |
+
file_id = str(uuid.uuid4())
|
47 |
+
board = chess.Board(traj_fen)
|
48 |
+
pixel_features = features[:,feature_index]
|
49 |
+
if board.turn:
|
50 |
+
heatmap = pixel_features.view(64)
|
51 |
+
else:
|
52 |
+
heatmap = pixel_features.view(8,8).flip(0).view(64)
|
53 |
+
|
54 |
+
best_legal_logit = None
|
55 |
+
best_legal_move = None
|
56 |
+
for move in board.legal_moves:
|
57 |
+
move_index = encode_move(move, (board.turn, not board.turn))
|
58 |
+
logit = model_output["policy"][1,move_index].item()
|
59 |
+
if best_legal_logit is None:
|
60 |
+
best_legal_logit = logit
|
61 |
+
else:
|
62 |
+
best_legal_move = move
|
63 |
+
|
64 |
+
svg_board, fig = visualisation.render_heatmap(
|
65 |
+
board,
|
66 |
+
heatmap,
|
67 |
+
arrows=[(best_legal_move.from_square, best_legal_move.to_square)],
|
68 |
+
)
|
69 |
+
with open(f"{constants.FIGURES_FOLER}/{file_id}.svg", "w") as f:
|
70 |
+
f.write(svg_board)
|
71 |
+
return (
|
72 |
+
features,
|
73 |
+
model_output,
|
74 |
+
file_id,
|
75 |
+
f"{constants.FIGURES_FOLER}/{file_id}.svg",
|
76 |
+
fig
|
77 |
+
)
|
78 |
+
|
79 |
+
with gr.Blocks() as interface:
|
80 |
+
with gr.Row():
|
81 |
+
with gr.Column():
|
82 |
+
root_fen = gr.Textbox(
|
83 |
+
label="Root FEN",
|
84 |
+
lines=1,
|
85 |
+
max_lines=1,
|
86 |
+
value=chess.STARTING_FEN,
|
87 |
+
)
|
88 |
+
traj_fen = gr.Textbox(
|
89 |
+
label="Trajectory FEN",
|
90 |
+
lines=1,
|
91 |
+
max_lines=1,
|
92 |
+
value="rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1",
|
93 |
+
)
|
94 |
+
compute_features = gr.Button("Compute features")
|
95 |
+
|
96 |
+
with gr.Group():
|
97 |
+
with gr.Row():
|
98 |
+
feature_index = gr.Slider(
|
99 |
+
label="Feature index",
|
100 |
+
minimum=0,
|
101 |
+
maximum=constants.N_FEATURES,
|
102 |
+
step=1,
|
103 |
+
value=0,
|
104 |
+
)
|
105 |
+
|
106 |
+
with gr.Group():
|
107 |
+
with gr.Row():
|
108 |
+
game_info = gr.Textbox(label="Game info", lines=1, max_lines=1, value="")
|
109 |
+
with gr.Row():
|
110 |
+
colorbar = gr.Plot(label="Colorbar")
|
111 |
+
with gr.Column():
|
112 |
+
board_image = gr.Image(label="Board")
|
113 |
+
|
114 |
+
features = gr.State(None)
|
115 |
+
model_output = gr.State(None)
|
116 |
+
file_id = gr.State(None)
|
117 |
+
compute_features.click(
|
118 |
+
compute_features_fn,
|
119 |
+
inputs=[features, model_output, file_id, root_fen, traj_fen, feature_index],
|
120 |
+
outputs=[features, model_output, file_id, board_image, colorbar, game_info],
|
121 |
+
)
|
src/interfaces/stats_interface.py
ADDED
File without changes
|
src/visualisation.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Visualisation utils.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import chess
|
6 |
+
import chess.svg
|
7 |
+
import matplotlib
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000)
|
14 |
+
ALPHA = 1.0
|
15 |
+
NORM = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False)
|
16 |
+
|
17 |
+
|
18 |
+
def render_heatmap(
|
19 |
+
board,
|
20 |
+
heatmap,
|
21 |
+
square=None,
|
22 |
+
vmin=None,
|
23 |
+
vmax=None,
|
24 |
+
arrows=None,
|
25 |
+
normalise="none",
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
Render a heatmap on the board.
|
29 |
+
"""
|
30 |
+
if normalise == "abs":
|
31 |
+
a_max = heatmap.abs().max()
|
32 |
+
if a_max != 0:
|
33 |
+
heatmap = heatmap / a_max
|
34 |
+
vmin = -1
|
35 |
+
vmax = 1
|
36 |
+
if vmin is None:
|
37 |
+
vmin = heatmap.min()
|
38 |
+
if vmax is None:
|
39 |
+
vmax = heatmap.max()
|
40 |
+
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
|
41 |
+
|
42 |
+
color_dict = {}
|
43 |
+
for square_index in range(64):
|
44 |
+
color = COLOR_MAP(norm(heatmap[square_index]))
|
45 |
+
color = (*color[:3], ALPHA)
|
46 |
+
color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True)
|
47 |
+
fig = plt.figure(figsize=(6, 0.6))
|
48 |
+
ax = plt.gca()
|
49 |
+
ax.axis("off")
|
50 |
+
fig.colorbar(
|
51 |
+
matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP),
|
52 |
+
ax=ax,
|
53 |
+
orientation="horizontal",
|
54 |
+
fraction=1.0,
|
55 |
+
)
|
56 |
+
if square is not None:
|
57 |
+
try:
|
58 |
+
check = chess.parse_square(square)
|
59 |
+
except ValueError:
|
60 |
+
check = None
|
61 |
+
else:
|
62 |
+
check = None
|
63 |
+
if arrows is None:
|
64 |
+
arrows = []
|
65 |
+
plt.close()
|
66 |
+
return (
|
67 |
+
chess.svg.board(
|
68 |
+
board,
|
69 |
+
check=check,
|
70 |
+
fill=color_dict,
|
71 |
+
size=350,
|
72 |
+
arrows=arrows,
|
73 |
+
),
|
74 |
+
fig,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def render_policy_distribution(
|
79 |
+
policy,
|
80 |
+
legal_moves,
|
81 |
+
n_bins=20,
|
82 |
+
):
|
83 |
+
"""
|
84 |
+
Render the policy distribution histogram.
|
85 |
+
"""
|
86 |
+
legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool()
|
87 |
+
fig = plt.figure(figsize=(6, 6))
|
88 |
+
ax = plt.gca()
|
89 |
+
_, bins = np.histogram(policy, bins=n_bins)
|
90 |
+
ax.hist(
|
91 |
+
policy[~legal_mask],
|
92 |
+
bins=bins,
|
93 |
+
alpha=0.5,
|
94 |
+
density=True,
|
95 |
+
label="Illegal moves",
|
96 |
+
)
|
97 |
+
ax.hist(
|
98 |
+
policy[legal_mask],
|
99 |
+
bins=bins,
|
100 |
+
alpha=0.5,
|
101 |
+
density=True,
|
102 |
+
label="Legal moves",
|
103 |
+
)
|
104 |
+
plt.xlabel("Policy")
|
105 |
+
plt.ylabel("Density")
|
106 |
+
plt.legend()
|
107 |
+
plt.yscale("log")
|
108 |
+
return fig
|