surokpro2 commited on
Commit
8cd00a9
1 Parent(s): c65c6b4

Upload folder using huggingface_hub

Browse files
Files changed (38) hide show
  1. .gitattributes +1 -0
  2. .gitignore +164 -0
  3. LICENSE +21 -0
  4. README.MD +68 -0
  5. README.md +3 -9
  6. SAE/__init__.py +1 -0
  7. SAE/config.json +23 -0
  8. SAE/dataset_iterator.py +53 -0
  9. SAE/sae.py +216 -0
  10. SAE/sae_utils.py +47 -0
  11. SDLens/__init__.py +1 -0
  12. SDLens/hooked_scheduler.py +40 -0
  13. SDLens/hooked_sd_pipeline.py +319 -0
  14. app.ipynb +0 -0
  15. app.py +399 -0
  16. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  17. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  18. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  19. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  20. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  21. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  22. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  23. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  24. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  25. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  26. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  27. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  28. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  29. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  30. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  31. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  32. example.ipynb +0 -0
  33. requirements.txt +7 -0
  34. resourses/image.png +3 -0
  35. scripts/collect_latents_dataset.py +96 -0
  36. scripts/train_sae.py +308 -0
  37. utils/__init__.py +1 -0
  38. utils/hooks.py +45 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ resourses/image.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
+ wandb/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Viacheslav Surkov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.MD ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Unpacking SDXL Turbo: Interpreting Text-to-Image Models with Sparse Autoencoders
2
+
3
+ ![modification demostration](resourses/image.png)
4
+
5
+ This repository contains code to reproduce results from our paper on using sparse autoencoders (SAEs) to analyze and interpret the internal representations of text-to-image diffusion models, specifically SDXL Turbo.
6
+
7
+ ## Repository Structure
8
+
9
+ ```
10
+ |-- SAE/ # Core sparse autoencoder implementation
11
+ |-- SDLens/ # Tools for analyzing diffusion models
12
+ | `-- hooked_sd_pipeline.py # Modified stable diffusion pipeline
13
+ |-- scripts/
14
+ | |-- collect_latents_dataset.py # Generate training data
15
+ | `-- train_sae.py # Train SAE models
16
+ |-- utils/
17
+ | `-- hooks.py # Hook utility functions
18
+ |-- checkpoints/ # Pretrained SAE model checkpoints
19
+ |-- app.py # Demo application
20
+ |-- app.ipynb # Interactive notebook demo
21
+ |-- example.ipynb # Usage examples
22
+ `-- requirements.txt # Python dependencies
23
+ ```
24
+
25
+ ## Installation
26
+
27
+ ```bash
28
+ pip install -r requirements.txt
29
+ ```
30
+
31
+ ## Demo Application
32
+
33
+ You can try our gradio demo application (`app.ipynb`) to browse and experiment with 20K+ features of our trained SAEs out-of-the-box. You can find the same notebook on [Google Colab](https://colab.research.google.com/drive/1Sd-g3w2Fwv7pc_fxgeQOR3S_RKr18qMP?usp=sharing).
34
+
35
+ ## Usage
36
+
37
+ 1. Collect latent data from SDXL Turbo:
38
+ ```bash
39
+ python scripts/collect_latents_dataset.py --save_path={your_save_path}
40
+ ```
41
+
42
+ 2. Train sparse autoencoders:
43
+
44
+ 2.1. Insert the path of stored latents and directory to store checkpoints in `SAE/config.json`
45
+
46
+ 2.2. Run the training script:
47
+
48
+ ```bash
49
+ python scripts/train_sae.py
50
+ ```
51
+
52
+ ## Pretrained Models
53
+
54
+ We provide pretrained SAE checkpoints for 4 key transformer blocks in SDXL Turbo's U-Net. See `example.ipynb` for analysis examples and visualization of learned features.
55
+
56
+
57
+ ## Citation
58
+
59
+ If you find this code useful in your research, please cite our paper:
60
+
61
+ ```bibtex
62
+ [Citation placeholder]
63
+ ```
64
+
65
+ ## Acknowledgements
66
+
67
+ The SAE component was implemented based on [`openai/sparse_autoencoder`](https://github.com/openai/sparse_autoencoder) repository.
68
+
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Unboxing SDXL With SAEs
3
- emoji: 🦀
4
- colorFrom: red
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.4.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Unboxing_SDXL_with_SAEs
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 4.44.1
6
  ---
 
 
SAE/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sae import SparseAutoencoder
SAE/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sae_configs": [
3
+ {
4
+ "d_model": 1280,
5
+ "n_dirs": 5120,
6
+ "k": 20
7
+ },
8
+ {
9
+ "d_model": 1280,
10
+ "n_dirs": 640,
11
+ "k": 20
12
+ }
13
+ ],
14
+ "bs": 4096,
15
+ "log_interval": 500,
16
+ "save_interval": 5000,
17
+
18
+ "paths_to_latents": [
19
+ "PASS YOUR PATHS HERE. Example /home/username/latents/<timestamp>. It should contain tar archives with latents."
20
+ ],
21
+ "save_path_base": "<Your SAE save path>",
22
+ "block_name": "unet.down_blocks.2.attentions.1"
23
+ }
SAE/dataset_iterator.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import webdataset as wds
2
+ import os
3
+ import torch
4
+
5
+ class ActivationsDataloader:
6
+ def __init__(self, paths_to_datasets, block_name, batch_size, output_or_diff='diff', num_in_buffer=50):
7
+ assert output_or_diff in ['diff', 'output'], "Provide 'output' or 'diff'"
8
+
9
+ self.dataset = wds.WebDataset(
10
+ [os.path.join(path_to_dataset, f"{block_name}.tar")
11
+ for path_to_dataset in paths_to_datasets]
12
+ ).decode("torch")
13
+ self.iter = iter(self.dataset)
14
+ self.buffer = None
15
+ self.pointer = 0
16
+ self.num_in_buffer = num_in_buffer
17
+ self.output_or_diff = output_or_diff
18
+ self.batch_size = batch_size
19
+ self.one_size = None
20
+
21
+ def renew_buffer(self, to_retrieve):
22
+ to_merge = []
23
+ if self.buffer is not None and self.buffer.shape[0] > self.pointer:
24
+ to_merge = [self.buffer[self.pointer:].clone()]
25
+ del self.buffer
26
+ for _ in range(to_retrieve):
27
+ sample = next(self.iter)
28
+ latents = sample['output.pth'] if self.output_or_diff == 'output' else sample['diff.pth']
29
+ latents = latents.permute((0, 1, 3, 4, 2))
30
+ latents = latents.reshape((-1, latents.shape[-1]))
31
+ to_merge.append(latents.to('cuda'))
32
+ self.one_size = latents.shape[0]
33
+ self.buffer = torch.cat(to_merge, dim=0)
34
+ shuffled_indices = torch.randperm(self.buffer.shape[0])
35
+ self.buffer = self.buffer[shuffled_indices]
36
+ self.pointer = 0
37
+
38
+ def iterate(self):
39
+ while True:
40
+ if self.buffer == None or self.buffer.shape[0] - self.pointer < self.num_in_buffer * self.one_size * 4 // 5:
41
+ try:
42
+ to_retrieve = self.num_in_buffer if self.buffer is None else self.num_in_buffer // 5
43
+ self.renew_buffer(to_retrieve)
44
+ except StopIteration:
45
+ break
46
+
47
+ batch = self.buffer[self.pointer: self.pointer + self.batch_size]
48
+ self.pointer += self.batch_size
49
+
50
+ assert batch.shape[0] == self.batch_size
51
+ yield batch
52
+
53
+
SAE/sae.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from
3
+ https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/model.py
4
+ '''
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import os
9
+ import json
10
+
11
+ class SparseAutoencoder(nn.Module):
12
+ """
13
+ Top-K Autoencoder with sparse kernels. Implements:
14
+
15
+ latents = relu(topk(encoder(x - pre_bias) + latent_bias))
16
+ recons = decoder(latents) + pre_bias
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ n_dirs_local: int,
22
+ d_model: int,
23
+ k: int,
24
+ auxk: int | None,
25
+ dead_steps_threshold: int,
26
+ ):
27
+ super().__init__()
28
+ self.n_dirs_local = n_dirs_local
29
+ self.d_model = d_model
30
+ self.k = k
31
+ self.auxk = auxk
32
+ self.dead_steps_threshold = dead_steps_threshold
33
+
34
+ self.encoder = nn.Linear(d_model, n_dirs_local, bias=False)
35
+ self.decoder = nn.Linear(n_dirs_local, d_model, bias=False)
36
+
37
+ self.pre_bias = nn.Parameter(torch.zeros(d_model))
38
+ self.latent_bias = nn.Parameter(torch.zeros(n_dirs_local))
39
+
40
+ self.stats_last_nonzero: torch.Tensor
41
+ self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
42
+
43
+ def auxk_mask_fn(x):
44
+ dead_mask = self.stats_last_nonzero > dead_steps_threshold
45
+ x.data *= dead_mask # inplace to save memory
46
+ return x
47
+
48
+ self.auxk_mask_fn = auxk_mask_fn
49
+
50
+ ## initialization
51
+
52
+ # "tied" init
53
+ self.decoder.weight.data = self.encoder.weight.data.T.clone()
54
+
55
+ # store decoder in column major layout for kernel
56
+ self.decoder.weight.data = self.decoder.weight.data.T.contiguous().T
57
+
58
+ unit_norm_decoder_(self)
59
+
60
+ def save_to_disk(self, path: str):
61
+ PATH_TO_CFG = 'config.json'
62
+ PATH_TO_WEIGHTS = 'state_dict.pth'
63
+
64
+ cfg = {
65
+ "n_dirs_local": self.n_dirs_local,
66
+ "d_model": self.d_model,
67
+ "k": self.k,
68
+ "auxk": self.auxk,
69
+ "dead_steps_threshold": self.dead_steps_threshold,
70
+ }
71
+
72
+ os.makedirs(path, exist_ok=True)
73
+
74
+ with open(os.path.join(path, PATH_TO_CFG), 'w') as f:
75
+ json.dump(cfg, f)
76
+
77
+
78
+ torch.save({
79
+ "state_dict": self.state_dict(),
80
+ }, os.path.join(path, PATH_TO_WEIGHTS))
81
+
82
+
83
+ @classmethod
84
+ def load_from_disk(cls, path: str):
85
+ PATH_TO_CFG = 'config.json'
86
+ PATH_TO_WEIGHTS = 'state_dict.pth'
87
+
88
+ with open(os.path.join(path, PATH_TO_CFG), 'r') as f:
89
+ cfg = json.load(f)
90
+
91
+ ae = cls(
92
+ n_dirs_local=cfg["n_dirs_local"],
93
+ d_model=cfg["d_model"],
94
+ k=cfg["k"],
95
+ auxk=cfg["auxk"],
96
+ dead_steps_threshold=cfg["dead_steps_threshold"],
97
+ )
98
+
99
+ state_dict = torch.load(os.path.join(path, PATH_TO_WEIGHTS))["state_dict"]
100
+ ae.load_state_dict(state_dict)
101
+
102
+ return ae
103
+
104
+ @property
105
+ def n_dirs(self):
106
+ return self.n_dirs_local
107
+
108
+ def encode(self, x):
109
+ x = x - self.pre_bias
110
+ latents_pre_act = self.encoder(x) + self.latent_bias
111
+
112
+ vals, inds = torch.topk(
113
+ latents_pre_act,
114
+ k=self.k,
115
+ dim=-1
116
+ )
117
+
118
+ latents = torch.zeros_like(latents_pre_act)
119
+ latents.scatter_(-1, inds, torch.relu(vals))
120
+
121
+ return latents
122
+
123
+ def forward(self, x):
124
+ x = x - self.pre_bias
125
+ latents_pre_act = self.encoder(x) + self.latent_bias
126
+ vals, inds = torch.topk(
127
+ latents_pre_act,
128
+ k=self.k,
129
+ dim=-1
130
+ )
131
+
132
+ ## set num nonzero stat ##
133
+ tmp = torch.zeros_like(self.stats_last_nonzero)
134
+ tmp.scatter_add_(
135
+ 0,
136
+ inds.reshape(-1),
137
+ (vals > 1e-3).to(tmp.dtype).reshape(-1),
138
+ )
139
+ self.stats_last_nonzero *= 1 - tmp.clamp(max=1)
140
+ self.stats_last_nonzero += 1
141
+ ## end stats ##
142
+
143
+ ## auxk
144
+ if self.auxk is not None: # for auxk
145
+ # IMPORTANT: has to go after stats update!
146
+ # WARN: auxk_mask_fn can mutate latents_pre_act!
147
+ auxk_vals, auxk_inds = torch.topk(
148
+ self.auxk_mask_fn(latents_pre_act),
149
+ k=self.auxk,
150
+ dim=-1
151
+ )
152
+ else:
153
+ auxk_inds = None
154
+ auxk_vals = None
155
+
156
+ ## end auxk
157
+
158
+ vals = torch.relu(vals)
159
+ if auxk_vals is not None:
160
+ auxk_vals = torch.relu(auxk_vals)
161
+
162
+
163
+ rows, cols = latents_pre_act.size()
164
+ row_indices = torch.arange(rows).unsqueeze(1).expand(-1, self.k).reshape(-1)
165
+ vals = vals.reshape(-1)
166
+ inds = inds.reshape(-1)
167
+
168
+ indices = torch.stack([row_indices.to(inds.device), inds])
169
+
170
+ sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
171
+
172
+ recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
173
+
174
+
175
+ return recons, {
176
+ "inds": inds,
177
+ "vals": vals,
178
+ "auxk_inds": auxk_inds,
179
+ "auxk_vals": auxk_vals,
180
+ }
181
+
182
+
183
+ def decode_sparse(self, inds, vals):
184
+ rows, cols = inds.shape[0], self.n_dirs
185
+
186
+ row_indices = torch.arange(rows).unsqueeze(1).expand(-1, inds.shape[1]).reshape(-1)
187
+ vals = vals.reshape(-1)
188
+ inds = inds.reshape(-1)
189
+
190
+ indices = torch.stack([row_indices.to(inds.device), inds])
191
+
192
+ sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
193
+
194
+ recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
195
+ return recons
196
+
197
+ @property
198
+ def device(self):
199
+ return next(self.parameters()).device
200
+
201
+
202
+ def unit_norm_decoder_(autoencoder: SparseAutoencoder) -> None:
203
+ """
204
+ Unit normalize the decoder weights of an autoencoder.
205
+ """
206
+ autoencoder.decoder.weight.data /= autoencoder.decoder.weight.data.norm(dim=0)
207
+
208
+
209
+ def unit_norm_decoder_grad_adjustment_(autoencoder) -> None:
210
+ """project out gradient information parallel to the dictionary vectors - assumes that the decoder is already unit normed"""
211
+
212
+ assert autoencoder.decoder.weight.grad is not None
213
+
214
+ autoencoder.decoder.weight.grad +=\
215
+ torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad) *\
216
+ autoencoder.decoder.weight.data * -1
SAE/sae_utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass, field
3
+
4
+ @dataclass
5
+ class SAETrainingConfig:
6
+ d_model: int
7
+ n_dirs: int
8
+ k: int
9
+ block_name: str
10
+ bs: int
11
+ save_path_base: str
12
+ auxk: int = 256
13
+ lr: float = 1e-4
14
+ eps: float = 6.25e-10
15
+ dead_toks_threshold: int = 10_000_000
16
+ auxk_coef: float = 1/32
17
+
18
+ @property
19
+ def sae_name(self):
20
+ return f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}'
21
+
22
+ @property
23
+ def save_path(self):
24
+ return f'/dlabscratch1/surkov/sae_models/{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}'
25
+
26
+
27
+ @dataclass
28
+ class Config:
29
+ saes: list[SAETrainingConfig]
30
+ paths_to_latents: list[str]
31
+ log_interval: int
32
+ save_interval: int
33
+ bs: int
34
+ block_name: str
35
+ wandb_project: str = 'sdxl_sae_train'
36
+ wandb_name: str = 'multiple_sae'
37
+
38
+ def __init__(self, cfg_json):
39
+ self.saes = [SAETrainingConfig(**sae_cfg, block_name=cfg_json['block_name'], bs=cfg_json['bs'], save_path_base=cfg_json['save_path_base'])
40
+ for sae_cfg in cfg_json['sae_configs']]
41
+
42
+ self.save_path_base = cfg_json['save_path_base']
43
+ self.paths_to_latents = cfg_json['paths_to_latents']
44
+ self.log_interval = cfg_json['log_interval']
45
+ self.save_interval = cfg_json['save_interval']
46
+ self.bs = cfg_json['bs']
47
+ self.block_name = cfg_json['block_name']
SDLens/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .hooked_sd_pipeline import HookedIFPipeline, HookedStableDiffusionXLPipeline
SDLens/hooked_scheduler.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DDPMScheduler
2
+ import torch
3
+
4
+ class HookedNoiseScheduler:
5
+ scheduler: DDPMScheduler
6
+ pre_hooks: list
7
+ post_hooks: list
8
+
9
+ def __init__(self, scheduler):
10
+ object.__setattr__(self, 'scheduler', scheduler)
11
+ object.__setattr__(self, 'pre_hooks', [])
12
+ object.__setattr__(self, 'post_hooks', [])
13
+
14
+ def step(
15
+ self,
16
+ model_output, timestep, sample, generator, return_dict
17
+ ):
18
+ assert return_dict == False, "return_dict == True is not implemented"
19
+ for hook in self.pre_hooks:
20
+ hook_output = hook(model_output, timestep, sample, generator)
21
+ if hook_output is not None:
22
+ model_output, timestep, sample, generator = hook_output
23
+
24
+ (pred_prev_sample, ) = self.scheduler.step(model_output, timestep, sample, generator, return_dict)
25
+
26
+ for hook in self.post_hooks:
27
+ hook_output = hook(pred_prev_sample)
28
+ if hook_output is not None:
29
+ pred_prev_sample = hook_output
30
+
31
+ return (pred_prev_sample, )
32
+
33
+ def __getattr__(self, name):
34
+ return getattr(self.scheduler, name)
35
+
36
+ def __setattr__(self, name, value):
37
+ if name in {'scheduler', 'pre_hooks', 'post_hooks'}:
38
+ object.__setattr__(self, name, value)
39
+ else:
40
+ setattr(self.scheduler, name, value)
SDLens/hooked_sd_pipeline.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ from diffusers import StableDiffusionXLPipeline, IFPipeline
3
+ from typing import List, Dict, Callable, Union
4
+ import torch
5
+ from .hooked_scheduler import HookedNoiseScheduler
6
+
7
+ def retrieve(io):
8
+ if isinstance(io, tuple):
9
+ if len(io) == 1:
10
+ return io[0]
11
+ else:
12
+ raise ValueError("A tuple should have length of 1")
13
+ elif isinstance(io, torch.Tensor):
14
+ return io
15
+ else:
16
+ raise ValueError("Input/Output must be a tensor, or 1-element tuple")
17
+
18
+
19
+ class HookedDiffusionAbstractPipeline:
20
+ parent_cls = None
21
+ pipe = None
22
+
23
+ def __init__(self, pipe: parent_cls, use_hooked_scheduler: bool = False):
24
+ if use_hooked_scheduler:
25
+ pipe.scheduler = HookedNoiseScheduler(pipe.scheduler)
26
+ self.__dict__['pipe'] = pipe
27
+ self.use_hooked_scheduler = use_hooked_scheduler
28
+
29
+ @classmethod
30
+ def from_pretrained(cls, *args, **kwargs):
31
+ return cls(cls.parent_cls.from_pretrained(*args, **kwargs))
32
+
33
+
34
+ def run_with_hooks(self,
35
+ *args,
36
+ position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
37
+ **kwargs
38
+ ):
39
+ '''
40
+ Run the pipeline with hooks at specified positions.
41
+ Returns the final output.
42
+
43
+ Args:
44
+ *args: Arguments to pass to the pipeline.
45
+ position_hook_dict: A dictionary mapping positions to hooks.
46
+ The keys are positions in the pipeline where the hooks should be registered.
47
+ The values are either a single hook or a list of hooks to be registered at the specified position.
48
+ Each hook should be a callable that takes three arguments: (module, input, output).
49
+ **kwargs: Keyword arguments to pass to the pipeline.
50
+ '''
51
+ hooks = []
52
+ for position, hook in position_hook_dict.items():
53
+ if isinstance(hook, list):
54
+ for h in hook:
55
+ hooks.append(self._register_general_hook(position, h))
56
+ else:
57
+ hooks.append(self._register_general_hook(position, hook))
58
+
59
+ hooks = [hook for hook in hooks if hook is not None]
60
+
61
+ try:
62
+ output = self.pipe(*args, **kwargs)
63
+ finally:
64
+ for hook in hooks:
65
+ hook.remove()
66
+ if self.use_hooked_scheduler:
67
+ self.pipe.scheduler.pre_hooks = []
68
+ self.pipe.scheduler.post_hooks = []
69
+
70
+ return output
71
+
72
+ def run_with_cache(self,
73
+ *args,
74
+ positions_to_cache: List[str],
75
+ save_input: bool = False,
76
+ save_output: bool = True,
77
+ **kwargs
78
+ ):
79
+ '''
80
+ Run the pipeline with caching at specified positions.
81
+
82
+ This method allows you to cache the intermediate inputs and/or outputs of the pipeline
83
+ at certain positions. The final output of the pipeline and a dictionary of cached values
84
+ are returned.
85
+
86
+ Args:
87
+ *args: Arguments to pass to the pipeline.
88
+ positions_to_cache (List[str]): A list of positions in the pipeline where intermediate
89
+ inputs/outputs should be cached.
90
+ save_input (bool, optional): If True, caches the input at each specified position.
91
+ Defaults to False.
92
+ save_output (bool, optional): If True, caches the output at each specified position.
93
+ Defaults to True.
94
+ **kwargs: Keyword arguments to pass to the pipeline.
95
+
96
+ Returns:
97
+ final_output: The final output of the pipeline after execution.
98
+ cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
99
+ and values are dictionaries containing the cached 'input' and/or 'output' at each position,
100
+ depending on the flags `save_input` and `save_output`.
101
+ '''
102
+ cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
103
+ hooks = [
104
+ self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
105
+ ]
106
+ hooks = [hook for hook in hooks if hook is not None]
107
+ output = self.pipe(*args, **kwargs)
108
+ for hook in hooks:
109
+ hook.remove()
110
+ if self.use_hooked_scheduler:
111
+ self.pipe.scheduler.pre_hooks = []
112
+ self.pipe.scheduler.post_hooks = []
113
+
114
+ cache_dict = {}
115
+ if save_input:
116
+ for position, block in cache_input.items():
117
+ cache_input[position] = torch.stack(block, dim=1)
118
+ cache_dict['input'] = cache_input
119
+
120
+ if save_output:
121
+ for position, block in cache_output.items():
122
+ cache_output[position] = torch.stack(block, dim=1)
123
+ cache_dict['output'] = cache_output
124
+ return output, cache_dict
125
+
126
+ def run_with_hooks_and_cache(self,
127
+ *args,
128
+ position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
129
+ positions_to_cache: List[str] = [],
130
+ save_input: bool = False,
131
+ save_output: bool = True,
132
+ **kwargs
133
+ ):
134
+ '''
135
+ Run the pipeline with hooks and caching at specified positions.
136
+
137
+ This method allows you to register hooks at certain positions in the pipeline and
138
+ cache intermediate inputs and/or outputs at specified positions. Hooks can be used
139
+ for inspecting or modifying the pipeline's execution, and caching stores intermediate
140
+ values for later inspection or use.
141
+
142
+ Args:
143
+ *args: Arguments to pass to the pipeline.
144
+ position_hook_dict Dict[str, Union[Callable, List[Callable]]]:
145
+ A dictionary where the keys are the positions in the pipeline, and the values
146
+ are hooks (either a single hook or a list of hooks) to be registered at those positions.
147
+ Each hook should be a callable that accepts three arguments: (module, input, output).
148
+ positions_to_cache (List[str], optional): A list of positions in the pipeline where
149
+ intermediate inputs/outputs should be cached. Defaults to an empty list.
150
+ save_input (bool, optional): If True, caches the input at each specified position.
151
+ Defaults to False.
152
+ save_output (bool, optional): If True, caches the output at each specified position.
153
+ Defaults to True.
154
+ **kwargs: Additional keyword arguments to pass to the pipeline.
155
+
156
+ Returns:
157
+ final_output: The final output of the pipeline after execution.
158
+ cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
159
+ and values are dictionaries containing the cached 'input' and/or 'output' at each position,
160
+ depending on the flags `save_input` and `save_output`.
161
+ '''
162
+ cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
163
+ hooks = [
164
+ self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
165
+ ]
166
+
167
+ for position, hook in position_hook_dict.items():
168
+ if isinstance(hook, list):
169
+ for h in hook:
170
+ hooks.append(self._register_general_hook(position, h))
171
+ else:
172
+ hooks.append(self._register_general_hook(position, hook))
173
+
174
+ hooks = [hook for hook in hooks if hook is not None]
175
+ output = self.pipe(*args, **kwargs)
176
+ for hook in hooks:
177
+ hook.remove()
178
+ if self.use_hooked_scheduler:
179
+ self.pipe.scheduler.pre_hooks = []
180
+ self.pipe.scheduler.post_hooks = []
181
+
182
+ cache_dict = {}
183
+ if save_input:
184
+ for position, block in cache_input.items():
185
+ cache_input[position] = torch.stack(block, dim=1)
186
+ cache_dict['input'] = cache_input
187
+
188
+ if save_output:
189
+ for position, block in cache_output.items():
190
+ cache_output[position] = torch.stack(block, dim=1)
191
+ cache_dict['output'] = cache_output
192
+
193
+ return output, cache_dict
194
+
195
+
196
+ def _locate_block(self, position: str):
197
+ '''
198
+ Locate the block at the specified position in the pipeline.
199
+ '''
200
+ block = self.pipe
201
+ for step in position.split('.'):
202
+ if step.isdigit():
203
+ step = int(step)
204
+ block = block[step]
205
+ else:
206
+ block = getattr(block, step)
207
+ return block
208
+
209
+
210
+ def _register_cache_hook(self, position: str, cache_input: Dict, cache_output: Dict):
211
+
212
+ if position.endswith('$self_attention') or position.endswith('$cross_attention'):
213
+ return self._register_cache_attention_hook(position, cache_output)
214
+
215
+ if position == 'noise':
216
+ def hook(model_output, timestep, sample, generator):
217
+ if position not in cache_output:
218
+ cache_output[position] = []
219
+ cache_output[position].append(sample)
220
+
221
+ if self.use_hooked_scheduler:
222
+ self.pipe.scheduler.post_hooks.append(hook)
223
+ else:
224
+ raise ValueError('Cannot cache noise without using hooked scheduler')
225
+ return
226
+
227
+ block = self._locate_block(position)
228
+
229
+ def hook(module, input, kwargs, output):
230
+ if cache_input is not None:
231
+ if position not in cache_input:
232
+ cache_input[position] = []
233
+ cache_input[position].append(retrieve(input))
234
+
235
+ if cache_output is not None:
236
+ if position not in cache_output:
237
+ cache_output[position] = []
238
+ cache_output[position].append(retrieve(output))
239
+
240
+ return block.register_forward_hook(hook, with_kwargs=True)
241
+
242
+ def _register_cache_attention_hook(self, position, cache):
243
+ attn_block = self._locate_block(position.split('$')[0])
244
+ if position.endswith('$self_attention'):
245
+ attn_block = attn_block.attn1
246
+ elif position.endswith('$cross_attention'):
247
+ attn_block = attn_block.attn2
248
+ else:
249
+ raise ValueError('Wrong attention type')
250
+
251
+ def hook(module, args, kwargs, output):
252
+ hidden_states = args[0]
253
+ encoder_hidden_states = kwargs['encoder_hidden_states']
254
+ attention_mask = kwargs['attention_mask']
255
+ batch_size, sequence_length, _ = hidden_states.shape
256
+ attention_mask = attn_block.prepare_attention_mask(attention_mask, sequence_length, batch_size)
257
+ query = attn_block.to_q(hidden_states)
258
+
259
+
260
+ if encoder_hidden_states is None:
261
+ encoder_hidden_states = hidden_states
262
+ elif attn_block.norm_cross is not None:
263
+ encoder_hidden_states = attn_block.norm_cross(encoder_hidden_states)
264
+
265
+ key = attn_block.to_k(encoder_hidden_states)
266
+ value = attn_block.to_v(encoder_hidden_states)
267
+
268
+ query = attn_block.head_to_batch_dim(query)
269
+ key = attn_block.head_to_batch_dim(key)
270
+ value = attn_block.head_to_batch_dim(value)
271
+
272
+ attention_probs = attn_block.get_attention_scores(query, key, attention_mask)
273
+ attention_probs = attention_probs.view(
274
+ batch_size,
275
+ attention_probs.shape[0] // batch_size,
276
+ attention_probs.shape[1],
277
+ attention_probs.shape[2]
278
+ )
279
+ if position not in cache:
280
+ cache[position] = []
281
+ cache[position].append(attention_probs)
282
+
283
+ return attn_block.register_forward_hook(hook, with_kwargs=True)
284
+
285
+ def _register_general_hook(self, position, hook):
286
+ if position == 'scheduler_pre':
287
+ if not self.use_hooked_scheduler:
288
+ raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
289
+ self.pipe.scheduler.pre_hooks.append(hook)
290
+ return
291
+ elif position == 'scheduler_post':
292
+ if not self.use_hooked_scheduler:
293
+ raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
294
+ self.pipe.scheduler.post_hooks.append(hook)
295
+ return
296
+
297
+ block = self._locate_block(position)
298
+ return block.register_forward_hook(hook)
299
+
300
+ def to(self, *args, **kwargs):
301
+ self.pipe = self.pipe.to(*args, **kwargs)
302
+ return self
303
+
304
+ def __getattr__(self, name):
305
+ return getattr(self.pipe, name)
306
+
307
+ def __setattr__(self, name, value):
308
+ return setattr(self.pipe, name, value)
309
+
310
+ def __call__(self, *args, **kwargs):
311
+ return self.pipe(*args, **kwargs)
312
+
313
+
314
+ class HookedStableDiffusionXLPipeline(HookedDiffusionAbstractPipeline):
315
+ parent_cls = StableDiffusionXLPipeline
316
+
317
+
318
+ class HookedIFPipeline(HookedDiffusionAbstractPipeline):
319
+ parent_cls = IFPipeline
app.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ from PIL import Image
5
+ from SDLens import HookedStableDiffusionXLPipeline
6
+ from SAE import SparseAutoencoder
7
+ from utils import add_feature_on_area
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.colors import ListedColormap
11
+ from utils import add_feature_on_area, replace_with_feature
12
+ import threading
13
+
14
+ code_to_block = {
15
+ "down.2.1": "unet.down_blocks.2.attentions.1",
16
+ "mid.0": "unet.mid_block.attentions.0",
17
+ "up.0.1": "unet.up_blocks.0.attentions.1",
18
+ "up.0.0": "unet.up_blocks.0.attentions.0"
19
+ }
20
+ lock = threading.Lock()
21
+
22
+ def process_cache(cache, saes_dict):
23
+
24
+ top_features_dict = {}
25
+ sparse_maps_dict = {}
26
+
27
+ for code in code_to_block.keys():
28
+ block = code_to_block[code]
29
+ sae = saes_dict[code]
30
+
31
+ diff = cache["output"][block] - cache["input"][block]
32
+ diff = diff.permute(0, 1, 3, 4, 2).squeeze(0).squeeze(0)
33
+ with torch.no_grad():
34
+ sparse_maps = sae.encode(diff)
35
+ averages = torch.mean(sparse_maps, dim=(0, 1))
36
+
37
+ top_features = torch.topk(averages, 10).indices
38
+
39
+ top_features_dict[code] = top_features.cpu().tolist()
40
+ sparse_maps_dict[code] = sparse_maps.cpu().numpy()
41
+
42
+ return top_features_dict, sparse_maps_dict
43
+
44
+
45
+ def plot_image_heatmap(cache, block_select, radio):
46
+ code = block_select.split()[0]
47
+ feature = int(radio)
48
+ block = code_to_block[code]
49
+
50
+ heatmap = cache["heatmaps"][code][:, :, feature]
51
+ heatmap = np.kron(heatmap, np.ones((32, 32)))
52
+ image = cache["image"].convert("RGBA")
53
+
54
+ jet = plt.cm.jet
55
+ cmap = jet(np.arange(jet.N))
56
+ cmap[:1, -1] = 0
57
+ cmap[1:, -1] = 0.6
58
+ cmap = ListedColormap(cmap)
59
+ heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
60
+ heatmap_rgba = cmap(heatmap)
61
+ heatmap_image = Image.fromarray((heatmap_rgba * 255).astype(np.uint8))
62
+ heatmap_with_transparency = Image.alpha_composite(image, heatmap_image)
63
+
64
+ return heatmap_with_transparency
65
+
66
+
67
+ def create_prompt_part(pipe, saes_dict, demo):
68
+ def image_gen(prompt):
69
+ lock.acquire()
70
+ try:
71
+ images, cache = pipe.run_with_cache(
72
+ prompt,
73
+ positions_to_cache=list(code_to_block.values()),
74
+ num_inference_steps=1,
75
+ generator=torch.Generator(device="cpu").manual_seed(42),
76
+ guidance_scale=0.0,
77
+ save_input=True,
78
+ save_output=True
79
+ )
80
+ finally:
81
+ lock.release()
82
+
83
+ top_features_dict, top_sparse_maps_dict = process_cache(cache, saes_dict)
84
+ return images.images[0], {
85
+ "image": images.images[0],
86
+ "heatmaps": top_sparse_maps_dict,
87
+ "features": top_features_dict
88
+ }
89
+
90
+ def update_radio(cache, block_select):
91
+ code = block_select.split()[0]
92
+ return gr.update(choices=cache["features"][code])
93
+
94
+ def update_img(cache, block_select, radio):
95
+ new_img = plot_image_heatmap(cache, block_select, radio)
96
+ return new_img
97
+
98
+ with gr.Tab("Explore", elem_classes="tabs") as explore_tab:
99
+ cache = gr.State(value={
100
+ "image": None,
101
+ "heatmaps": None,
102
+ "features": []
103
+ })
104
+ with gr.Row():
105
+ with gr.Column(scale=7):
106
+ with gr.Row(equal_height=True):
107
+ prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party and eathing a dish with peas.")
108
+ button = gr.Button("Generate", elem_classes="generate_button1")
109
+
110
+ with gr.Row():
111
+ image = gr.Image(width=512, height=512, image_mode="RGB", label="Generated image")
112
+
113
+ with gr.Column(scale=4):
114
+ block_select = gr.Dropdown(
115
+ choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
116
+ value="down.2.1 (composition)",
117
+ label="Select block",
118
+ elem_id="block_select",
119
+ interactive=True
120
+ )
121
+ radio = gr.Radio(choices=[], label="Select a feature", interactive=True)
122
+
123
+ button.click(image_gen, [prompt_field], outputs=[image, cache])
124
+ cache.change(update_radio, [cache, block_select], outputs=[radio])
125
+ block_select.select(update_radio, [cache, block_select], outputs=[radio])
126
+ radio.select(update_img, [cache, block_select, radio], outputs=[image])
127
+ demo.load(image_gen, [prompt_field], outputs=[image, cache])
128
+
129
+ return explore_tab
130
+
131
+ def downsample_mask(image, factor):
132
+ downsampled = image.reshape(
133
+ (image.shape[0] // factor, factor,
134
+ image.shape[1] // factor, factor)
135
+ )
136
+ downsampled = downsampled.mean(axis=(1, 3))
137
+ return downsampled
138
+
139
+ def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, means_dict, demo):
140
+ def image_gen(prompt, num_steps):
141
+ lock.acquire()
142
+ try:
143
+ images = pipe.run_with_hooks(
144
+ prompt,
145
+ position_hook_dict={},
146
+ num_inference_steps=num_steps,
147
+ generator=torch.Generator(device="cpu").manual_seed(42),
148
+ guidance_scale=0.0
149
+ )
150
+ finally:
151
+ lock.release()
152
+ return images.images[0]
153
+
154
+ def image_mod(prompt, block_str, brush_index, strength, num_steps, input_image):
155
+ block = block_str.split(" ")[0]
156
+
157
+ mask = (input_image["layers"][0] > 0)[:, :, -1].astype(float)
158
+ mask = downsample_mask(mask, 32)
159
+ mask = torch.tensor(mask, dtype=torch.float32, device="cuda")
160
+
161
+ if mask.sum() == 0:
162
+ gr.Info("No mask selected, please draw on the input image")
163
+
164
+ def hook(module, input, output):
165
+ return add_feature_on_area(
166
+ saes_dict[block],
167
+ brush_index,
168
+ mask * means_dict[block][brush_index] * strength,
169
+ module,
170
+ input,
171
+ output
172
+ )
173
+
174
+ lock.acquire()
175
+ try:
176
+ image = pipe.run_with_hooks(
177
+ prompt,
178
+ position_hook_dict={code_to_block[block]: hook},
179
+ num_inference_steps=num_steps,
180
+ generator=torch.Generator(device="cpu").manual_seed(42),
181
+ guidance_scale=0.0
182
+ ).images[0]
183
+ finally:
184
+ lock.release()
185
+ return image
186
+
187
+ def feature_icon(block_str, brush_index):
188
+ block = block_str.split(" ")[0]
189
+ if block in ["mid.0", "up.0.0"]:
190
+ gr.Info("Note that Feature Icon works best with down.2.1 and up.0.1 blocks but feel free to explore", duration=3)
191
+
192
+ def hook(module, input, output):
193
+ return replace_with_feature(
194
+ saes_dict[block],
195
+ brush_index,
196
+ means_dict[block][brush_index] * saes_dict[block].k,
197
+ module,
198
+ input,
199
+ output
200
+ )
201
+
202
+ lock.acquire()
203
+ try:
204
+ image = pipe.run_with_hooks(
205
+ "",
206
+ position_hook_dict={code_to_block[block]: hook},
207
+ num_inference_steps=1,
208
+ generator=torch.Generator(device="cpu").manual_seed(42),
209
+ guidance_scale=0.0
210
+ ).images[0]
211
+ finally:
212
+ lock.release()
213
+ return image
214
+
215
+ with gr.Tab("Paint!", elem_classes="tabs") as intervene_tab:
216
+ image_state = gr.State(value=None)
217
+ with gr.Row():
218
+ with gr.Column(scale=3):
219
+ # Generation column
220
+ with gr.Row():
221
+ # prompt and num_steps
222
+ prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A dog plays with a ball, cartoon", elem_id="prompt_input")
223
+ num_steps = gr.Number(value=1, label="Number of steps", minimum=1, maximum=4, elem_id="num_steps", precision=0)
224
+ with gr.Row():
225
+ # Generate button
226
+ button_generate = gr.Button("Generate", elem_id="generate_button")
227
+ with gr.Column(scale=3):
228
+ # Intervention column
229
+ with gr.Row():
230
+ # dropdowns and number inputs
231
+ with gr.Column(scale=7):
232
+ with gr.Row():
233
+ block_select = gr.Dropdown(
234
+ choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
235
+ value="down.2.1 (composition)",
236
+ label="Select block",
237
+ elem_id="block_select"
238
+ )
239
+ brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=5119, elem_id="brush_index", precision=0)
240
+ with gr.Row():
241
+ button_icon = gr.Button('Feature Icon', elem_id="feature_icon_button")
242
+ with gr.Column(scale=3):
243
+ with gr.Row():
244
+ strength = gr.Number(value=10, label="Strength", minimum=-40, maximum=40, elem_id="strength", precision=2)
245
+ with gr.Row():
246
+ button = gr.Button('Apply', elem_id="apply_button")
247
+
248
+ with gr.Row():
249
+ with gr.Column():
250
+ # Input image
251
+ i_image = gr.Sketchpad(
252
+ height=610,
253
+ layers=False, transforms=[], placeholder="Generate and paint!",
254
+ brush=gr.Brush(default_size=64, color_mode="fixed", colors=['black']),
255
+ container=False,
256
+ canvas_size=(512, 512),
257
+ label="Input Image")
258
+ clear_button = gr.Button("Clear")
259
+ clear_button.click(lambda x: x, [image_state], [i_image])
260
+ # Output image
261
+ o_image = gr.Image(width=512, height=512, label="Output Image")
262
+
263
+ # Set up the click events
264
+ button_generate.click(image_gen, inputs=[prompt_field, num_steps], outputs=[image_state])
265
+ image_state.change(lambda x: x, [image_state], [i_image])
266
+ button.click(image_mod,
267
+ inputs=[prompt_field, block_select, brush_index, strength, num_steps, i_image],
268
+ outputs=o_image)
269
+ button_icon.click(feature_icon, inputs=[block_select, brush_index], outputs=o_image)
270
+ demo.load(image_gen, [prompt_field, num_steps], outputs=[image_state])
271
+
272
+
273
+ return intervene_tab
274
+
275
+
276
+ def create_top_images_part(demo):
277
+ def update_top_images(block_select, brush_index):
278
+ block = block_select.split(" ")[0]
279
+ url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{brush_index}.jpg"
280
+ return url
281
+
282
+ with gr.Tab("Top Images", elem_classes="tabs") as top_images_tab:
283
+ with gr.Row():
284
+ block_select = gr.Dropdown(
285
+ choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
286
+ value="down.2.1 (composition)",
287
+ label="Select block"
288
+ )
289
+ brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=5119, precision=0)
290
+ with gr.Row():
291
+ image = gr.Image(width=600, height=600, label="Top Images")
292
+
293
+ block_select.select(update_top_images, [block_select, brush_index], outputs=[image])
294
+ brush_index.change(update_top_images, [block_select, brush_index], outputs=[image])
295
+ demo.load(update_top_images, [block_select, brush_index], outputs=[image])
296
+ return top_images_tab
297
+
298
+
299
+ def create_intro_part():
300
+ with gr.Tab("Instructions", elem_classes="tabs") as intro_tab:
301
+ gr.Markdown(
302
+ '''# Unpacking SDXL Turbo with Sparse Autoencoders
303
+ ## Demo Overview
304
+ This demo showcases the use of Sparse Autoencoders (SAEs) to understand the features learned by the Stable Diffusion XL Turbo model.
305
+
306
+ ## How to Use
307
+ ### Explore
308
+ * Enter a prompt in the text box and click on the "Generate" button to generate an image.
309
+ * You can observe the active features in different blocks plot on top of the generated image.
310
+ ### Top Images
311
+ * For each feature, you can view the top images that activate the feature the most.
312
+ ### Paint!
313
+ * Generate an image using the prompt.
314
+ * Paint on the generated image to apply interventions.
315
+ * Use the "Feature Icon" button to understand how the selected brush functions.
316
+
317
+ ### Remarks
318
+ * Not all brushes mix well with all images. Experiment with different brushes and strengths.
319
+ * Feature Icon works best with `down.2.1 (composition)` and `up.0.1 (style)` blocks.
320
+ * This demo is provided for research purposes only. We do not take responsibility for the content generated by the demo.
321
+
322
+ ### Interesting features to try
323
+ To get started, try the following features:
324
+ - down.2.1 (composition): 2301 (evil) 3747 (image frame) 4998 (cartoon)
325
+ - up.0.1 (style): 4977 (tiger stripes) 90 (fur) 2615 (twilight blur)
326
+ '''
327
+ )
328
+
329
+ return intro_tab
330
+
331
+
332
+ def create_demo(pipe, saes_dict, means_dict):
333
+ custom_css = """
334
+ .tabs button {
335
+ font-size: 20px !important; /* Adjust font size for tab text */
336
+ padding: 10px !important; /* Adjust padding to make the tabs bigger */
337
+ font-weight: bold !important; /* Adjust font weight to make the text bold */
338
+ }
339
+ .generate_button1 {
340
+ max-width: 160px !important;
341
+ margin-top: 20px !important;
342
+ margin-bottom: 20px !important;
343
+ }
344
+ """
345
+
346
+ with gr.Blocks(css=custom_css) as demo:
347
+ with create_intro_part():
348
+ pass
349
+ with create_prompt_part(pipe, saes_dict, demo):
350
+ pass
351
+ with create_top_images_part(demo):
352
+ pass
353
+ with create_intervene_part(pipe, saes_dict, means_dict, demo):
354
+ pass
355
+
356
+ return demo
357
+
358
+
359
+ if __name__ == "__main__":
360
+ import os
361
+ import gradio as gr
362
+ import torch
363
+ from SDLens import HookedStableDiffusionXLPipeline
364
+ from SAE import SparseAutoencoder
365
+
366
+ dtype=torch.float32
367
+ pipe = HookedStableDiffusionXLPipeline.from_pretrained(
368
+ 'stabilityai/sdxl-turbo',
369
+ torch_dtype=dtype,
370
+ device_map="balanced",
371
+ variant=("fp16" if dtype==torch.float16 else None)
372
+ )
373
+ pipe.set_progress_bar_config(disable=True)
374
+
375
+ path_to_checkpoints = './checkpoints/'
376
+
377
+ code_to_block = {
378
+ "down.2.1": "unet.down_blocks.2.attentions.1",
379
+ "mid.0": "unet.mid_block.attentions.0",
380
+ "up.0.1": "unet.up_blocks.0.attentions.1",
381
+ "up.0.0": "unet.up_blocks.0.attentions.0"
382
+ }
383
+
384
+ saes_dict = {}
385
+ means_dict = {}
386
+
387
+ for code, block in code_to_block.items():
388
+ sae = SparseAutoencoder.load_from_disk(
389
+ os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final"),
390
+ )
391
+ means = torch.load(
392
+ os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final", "mean.pt"),
393
+ weights_only=True
394
+ )
395
+ saes_dict[code] = sae.to('cuda', dtype=dtype)
396
+ means_dict[code] = means.to('cuda', dtype=dtype)
397
+
398
+ demo = create_demo(pipe, saes_dict, means_dict)
399
+ demo.launch()
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:387f2b6f8c4e4a6f1227921f28f00dfa4beb2bd4e422b7eb592cd8627af0e58f
3
+ size 21581
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39e3c6d17aa572a53368ca8ba8f82757947a3caf14fe654e84b175d0dc0a4650
3
+ size 52497831
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6ca694c9504a7a8aa827004d3fdec5c1cb8fcf3904acc3562d1861fc6e65c19
3
+ size 21576
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80790481d0e56ac3fa36599703cee7a05cfb4cc078db57c8f9180e860c330e1d
3
+ size 21581
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49d38d9178c2a2780e04a5482a2feb9548c6e9a636ed1bf85291acf42e0ffa34
3
+ size 52497831
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb6bfc7ce5e596f8aa048ab262ca56841868c222bf07eb2ed35b6e4f7094fea6
3
+ size 21576
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de036d0fb9ee663f7bdf60e4a5d89d038516dae637531676b53ff75d05eab46b
3
+ size 21581
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14c45efd9cce0258f014c49babdcd0e9ce8b266fe31eed72db1a45b990a1a0f8
3
+ size 52497831
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb9c04499ccae041987cc262894e254c2f04288857a8a0470cfb1b86a8ecfa09
3
+ size 21576
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96dbf6fffe9d62c3b3352f8e4fe48c54dfd69906cf8ad6828d5ce93db9a5f0dc
3
+ size 21581
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8eed82f4bcb2f010ae9075f10a1ece801ee3dec46dba7fadccc35f6c0a7836b
3
+ size 52497831
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe5c5be0c4c2d2b57e7888319053cb64929559f947c8ce445ddd6a397302afab
3
+ size 21576
example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ diffusers==0.29.2
2
+ gradio==4.44.1
3
+ torch>=2.4.0
4
+ numpy
5
+ matplotlib
6
+ pillow
7
+ wandb
resourses/image.png ADDED

Git LFS Details

  • SHA256: 86594c5876d61a3eac5238b739eeec41418995c7696b6453d70b4e683ebd82df
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
scripts/collect_latents_dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import io
4
+ import tarfile
5
+ import torch
6
+ import webdataset as wds
7
+ import numpy as np
8
+
9
+ from tqdm import tqdm
10
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
11
+ from SDLens.hooked_sd_pipeline import HookedStableDiffusionXLPipeline
12
+
13
+ import datetime
14
+ from datasets import load_dataset
15
+ from torch.utils.data import DataLoader
16
+ import diffusers
17
+ import fire
18
+
19
+ def main(save_path, start_at=0, finish_at=30000, dataset_batch_size=50):
20
+ blocks_to_save = [
21
+ 'unet.down_blocks.2.attentions.1',
22
+ 'unet.mid_block.attentions.0',
23
+ 'unet.up_blocks.0.attentions.0',
24
+ 'unet.up_blocks.0.attentions.1',
25
+ ]
26
+
27
+ # Initialization
28
+ dataset = load_dataset("guangyil/laion-coco-aesthetic", split="train", columns=["caption"], streaming=True).shuffle(seed=42)
29
+ pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo')
30
+ pipe.to('cuda')
31
+ pipe.set_progress_bar_config(disable=True)
32
+ dataloader = DataLoader(dataset, batch_size=dataset_batch_size)
33
+
34
+ ct = datetime.datetime.now()
35
+ save_path = os.path.join(save_path, str(ct))
36
+ # Collecting dataset
37
+ os.makedirs(save_path, exist_ok=True)
38
+
39
+ writers = {
40
+ block: wds.TarWriter(f'{save_path}/{block}.tar') for block in blocks_to_save
41
+ }
42
+
43
+ writers.update({'images': wds.TarWriter(f'{save_path}/images.tar')})
44
+
45
+ def to_kwargs(kwargs_to_save):
46
+ kwargs = kwargs_to_save.copy()
47
+ seed = kwargs['seed']
48
+ del kwargs['seed']
49
+ kwargs['generator'] = torch.Generator(device="cpu").manual_seed(num_document)
50
+ return kwargs
51
+
52
+ dataloader_iter = iter(dataloader)
53
+ for num_document, batch in tqdm(enumerate(dataloader)):
54
+ if num_document < start_at:
55
+ continue
56
+
57
+ if num_document >= finish_at:
58
+ break
59
+
60
+ kwargs_to_save = {
61
+ 'prompt': batch['caption'],
62
+ 'positions_to_cache': blocks_to_save,
63
+ 'save_input': True,
64
+ 'save_output': True,
65
+ 'num_inference_steps': 1,
66
+ 'guidance_scale': 0.0,
67
+ 'seed': num_document,
68
+ 'output_type': 'pil'
69
+ }
70
+
71
+ kwargs = to_kwargs(kwargs_to_save)
72
+
73
+ output, cache = pipe.run_with_cache(
74
+ **kwargs
75
+ )
76
+
77
+ blocks = cache['input'].keys()
78
+ for block in blocks:
79
+ sample = {
80
+ "__key__": f"sample_{num_document}",
81
+ "output.pth": cache['output'][block],
82
+ "diff.pth": cache['output'][block] - cache['input'][block],
83
+ "gen_args.json": kwargs_to_save
84
+ }
85
+
86
+ writers[block].write(sample)
87
+ writers['images'].write({
88
+ "__key__": f"sample_{num_document}",
89
+ "images.npy": np.stack(output.images)
90
+ })
91
+
92
+ for block, writer in writers.items():
93
+ writer.close()
94
+
95
+ if __name__ == '__main__':
96
+ fire.Fire(main)
scripts/train_sae.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from
3
+ https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/train.py
4
+ '''
5
+
6
+
7
+ import os
8
+ import sys
9
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
10
+ from typing import Callable, Iterable, Iterator
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.distributed import ReduceOp
17
+ from SAE.dataset_iterator import ActivationsDataloader
18
+ from SAE.sae import SparseAutoencoder, unit_norm_decoder_, unit_norm_decoder_grad_adjustment_
19
+ from SAE.sae_utils import SAETrainingConfig, Config
20
+
21
+ from types import SimpleNamespace
22
+ from typing import Optional, List
23
+ import json
24
+
25
+ import tqdm
26
+
27
+ def weighted_average(points: torch.Tensor, weights: torch.Tensor):
28
+ weights = weights / weights.sum()
29
+ return (points * weights.view(-1, 1)).sum(dim=0)
30
+
31
+
32
+ @torch.no_grad()
33
+ def geometric_median_objective(
34
+ median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
35
+ ) -> torch.Tensor:
36
+
37
+ norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
38
+
39
+ return (norms * weights).sum()
40
+
41
+
42
+ def compute_geometric_median(
43
+ points: torch.Tensor,
44
+ weights: Optional[torch.Tensor] = None,
45
+ eps: float = 1e-6,
46
+ maxiter: int = 100,
47
+ ftol: float = 1e-20,
48
+ do_log: bool = False,
49
+ ):
50
+ """
51
+ :param points: ``torch.Tensor`` of shape ``(n, d)``
52
+ :param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
53
+ :param eps: Smallest allowed value of denominator, to avoid divide by zero.
54
+ Equivalently, this is a smoothing parameter. Default 1e-6.
55
+ :param maxiter: Maximum number of Weiszfeld iterations. Default 100
56
+ :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
57
+ :param do_log: If true will return a log of function values encountered through the course of the algorithm
58
+ :return: SimpleNamespace object with fields
59
+ - `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
60
+ - `termination`: string explaining how the algorithm terminated.
61
+ - `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
62
+ """
63
+ with torch.no_grad():
64
+
65
+ if weights is None:
66
+ weights = torch.ones((points.shape[0],), device=points.device)
67
+ # initialize median estimate at mean
68
+ new_weights = weights
69
+ median = weighted_average(points, weights)
70
+ objective_value = geometric_median_objective(median, points, weights)
71
+ if do_log:
72
+ logs = [objective_value]
73
+ else:
74
+ logs = None
75
+
76
+ # Weiszfeld iterations
77
+ early_termination = False
78
+ pbar = tqdm.tqdm(range(maxiter))
79
+ for _ in pbar:
80
+ prev_obj_value = objective_value
81
+
82
+ norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
83
+ new_weights = weights / torch.clamp(norms, min=eps)
84
+ median = weighted_average(points, new_weights)
85
+ objective_value = geometric_median_objective(median, points, weights)
86
+
87
+ if logs is not None:
88
+ logs.append(objective_value)
89
+ if abs(prev_obj_value - objective_value) <= ftol * objective_value:
90
+ early_termination = True
91
+ break
92
+
93
+ pbar.set_description(f"Objective value: {objective_value:.4f}")
94
+
95
+ median = weighted_average(points, new_weights) # allow autodiff to track it
96
+ return SimpleNamespace(
97
+ median=median,
98
+ new_weights=new_weights,
99
+ termination=(
100
+ "function value converged within tolerance"
101
+ if early_termination
102
+ else "maximum iterations reached"
103
+ ),
104
+ logs=logs,
105
+ )
106
+
107
+ def maybe_transpose(x):
108
+ return x.T if not x.is_contiguous() and x.T.is_contiguous() else x
109
+
110
+ import wandb
111
+
112
+ RANK = 0
113
+
114
+ class Logger:
115
+ def __init__(self, sae_name, **kws):
116
+ self.vals = {}
117
+ self.enabled = (RANK == 0) and not kws.pop("dummy", False)
118
+ self.sae_name = sae_name
119
+
120
+ def logkv(self, k, v):
121
+ if self.enabled:
122
+ self.vals[f'{self.sae_name}/{k}'] = v.detach() if isinstance(v, torch.Tensor) else v
123
+ return v
124
+
125
+ def dumpkvs(self, step):
126
+ if self.enabled:
127
+ wandb.log(self.vals, step=step)
128
+ self.vals = {}
129
+
130
+
131
+ class FeaturesStats:
132
+ def __init__(self, dim, logger):
133
+ self.dim = dim
134
+ self.logger = logger
135
+ self.reinit()
136
+
137
+ def reinit(self):
138
+ self.n_activated = torch.zeros(self.dim, dtype=torch.long, device="cuda")
139
+ self.n = 0
140
+
141
+ def update(self, inds):
142
+ self.n += inds.shape[0]
143
+ inds = inds.flatten().detach()
144
+ self.n_activated.scatter_add_(0, inds, torch.ones_like(inds))
145
+
146
+ def log(self):
147
+ self.logger.logkv('activated', (self.n_activated / self.n + 1e-9).log10().cpu().numpy())
148
+
149
+ def training_loop_(
150
+ aes,
151
+ train_acts_iter,
152
+ loss_fn,
153
+ log_interval,
154
+ save_interval,
155
+ loggers,
156
+ sae_cfgs,
157
+ ):
158
+ sae_packs = []
159
+ for ae, cfg, logger in zip(aes, sae_cfgs, loggers):
160
+ pbar = tqdm.tqdm(unit=" steps", desc="Training Loss: ")
161
+ fstats = FeaturesStats(ae.n_dirs, logger)
162
+ opt = torch.optim.Adam(ae.parameters(), lr=cfg.lr, eps=cfg.eps, fused=True)
163
+ sae_packs.append((ae, cfg, logger, pbar, fstats, opt))
164
+
165
+ for i, flat_acts_train_batch in enumerate(train_acts_iter):
166
+ flat_acts_train_batch = flat_acts_train_batch.cuda()
167
+
168
+ for ae, cfg, logger, pbar, fstats, opt in sae_packs:
169
+ recons, info = ae(flat_acts_train_batch)
170
+ loss = loss_fn(ae, cfg, flat_acts_train_batch, recons, info, logger)
171
+
172
+ fstats.update(info['inds'])
173
+
174
+ bs = flat_acts_train_batch.shape[0]
175
+ logger.logkv('not-activated 1e4', (ae.stats_last_nonzero > 1e4 / bs).mean(dtype=float).item())
176
+ logger.logkv('not-activated 1e6', (ae.stats_last_nonzero > 1e6 / bs).mean(dtype=float).item())
177
+ logger.logkv('not-activated 1e7', (ae.stats_last_nonzero > 1e7 / bs).mean(dtype=float).item())
178
+
179
+ logger.logkv('explained variance', explained_variance(recons, flat_acts_train_batch))
180
+ logger.logkv('l2_div', (torch.linalg.norm(recons, dim=1) / torch.linalg.norm(flat_acts_train_batch, dim=1)).mean())
181
+
182
+ if (i + 1) % log_interval == 0:
183
+ fstats.log()
184
+ fstats.reinit()
185
+
186
+ if (i + 1) % save_interval == 0:
187
+ ae.save_to_disk(f"{cfg.save_path}/{i + 1}")
188
+
189
+ loss.backward()
190
+
191
+ unit_norm_decoder_(ae)
192
+ unit_norm_decoder_grad_adjustment_(ae)
193
+
194
+ opt.step()
195
+ opt.zero_grad()
196
+ logger.dumpkvs(i)
197
+
198
+ pbar.set_description(f"Training Loss {loss.item():.4f}")
199
+ pbar.update(1)
200
+
201
+
202
+ for ae, cfg, logger, pbar, fstats, opt in sae_packs:
203
+ pbar.close()
204
+ ae.save_to_disk(f"{cfg.save_path}/final")
205
+
206
+
207
+ def init_from_data_(ae, stats_acts_sample):
208
+ ae.pre_bias.data = (
209
+ compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float()
210
+ )
211
+
212
+
213
+ def mse(recons, x):
214
+ # return ((recons - x) ** 2).sum(dim=-1).mean()
215
+ return ((recons - x) ** 2).mean()
216
+
217
+ def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
218
+ # only used for auxk
219
+ xs_mu = xs.mean(dim=0)
220
+
221
+ loss = mse(recon, xs) / mse(
222
+ xs_mu[None, :].broadcast_to(xs.shape), xs
223
+ )
224
+
225
+ return loss
226
+
227
+ def explained_variance(recons, x):
228
+ # Compute the variance of the difference
229
+ diff = x - recons
230
+ diff_var = torch.var(diff, dim=0, unbiased=False)
231
+
232
+ # Compute the variance of the original tensor
233
+ x_var = torch.var(x, dim=0, unbiased=False)
234
+
235
+ # Avoid division by zero
236
+ explained_var = 1 - diff_var / (x_var + 1e-8)
237
+
238
+ return explained_var.mean()
239
+
240
+
241
+ def main():
242
+ cfg = Config(json.load(open('SAE/config.json')))
243
+
244
+ dataloader = ActivationsDataloader(cfg.paths_to_latents, cfg.block_name, cfg.bs)
245
+
246
+ acts_iter = dataloader.iterate()
247
+ stats_acts_sample = torch.cat([
248
+ next(acts_iter).cpu() for _ in range(10)
249
+ ], dim=0)
250
+
251
+ aes = [
252
+ SparseAutoencoder(
253
+ n_dirs_local=sae.n_dirs,
254
+ d_model=sae.d_model,
255
+ k=sae.k,
256
+ auxk=sae.auxk,
257
+ dead_steps_threshold=sae.dead_toks_threshold // cfg.bs,
258
+ ).cuda()
259
+ for sae in cfg.saes
260
+ ]
261
+
262
+ for ae in aes:
263
+ init_from_data_(ae, stats_acts_sample)
264
+
265
+ mse_scale = (
266
+ 1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean()
267
+ )
268
+ mse_scale = mse_scale.item()
269
+ del stats_acts_sample
270
+
271
+ wandb.init(
272
+ project=cfg.wandb_project,
273
+ name=cfg.wandb_name,
274
+ )
275
+
276
+ loggers = [Logger(
277
+ sae_name=cfg_sae.sae_name,
278
+ dummy=False,
279
+ ) for cfg_sae in cfg.saes]
280
+
281
+ training_loop_(
282
+ aes,
283
+ acts_iter,
284
+ lambda ae, cfg_sae, flat_acts_train_batch, recons, info, logger: (
285
+ # MSE
286
+ logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch))
287
+ # AuxK
288
+ + logger.logkv(
289
+ "train_maxk_recons",
290
+ cfg_sae.auxk_coef
291
+ * normalized_mse(
292
+ ae.decode_sparse(
293
+ info["auxk_inds"],
294
+ info["auxk_vals"],
295
+ ),
296
+ flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(),
297
+ ).nan_to_num(0),
298
+ )
299
+ ),
300
+ sae_cfgs = cfg.saes,
301
+ loggers=loggers,
302
+ log_interval=cfg.log_interval,
303
+ save_interval=cfg.save_interval,
304
+ )
305
+
306
+
307
+ if __name__ == "__main__":
308
+ main()
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .hooks import *
utils/hooks.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ @torch.no_grad()
4
+ def add_feature(sae, feature_idx, value, module, input, output):
5
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
6
+ activated = sae.encode(diff)
7
+ mask = torch.zeros_like(activated, device=diff.device)
8
+ mask[..., feature_idx] = value
9
+ to_add = mask @ sae.decoder.weight.T
10
+ return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
11
+
12
+
13
+ @torch.no_grad()
14
+ def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
15
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
16
+ activated = sae.encode(diff)
17
+ mask = torch.zeros_like(activated, device=diff.device)
18
+ if len(activation_map) == 2:
19
+ activation_map = activation_map.unsqueeze(0)
20
+ mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
21
+ to_add = mask @ sae.decoder.weight.T
22
+ return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
23
+
24
+
25
+ @torch.no_grad()
26
+ def replace_with_feature(sae, feature_idx, value, module, input, output):
27
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
28
+ activated = sae.encode(diff)
29
+ mask = torch.zeros_like(activated, device=diff.device)
30
+ mask[..., feature_idx] = value
31
+ to_add = mask @ sae.decoder.weight.T
32
+ return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
33
+
34
+
35
+ @torch.no_grad()
36
+ def reconstruct_sae_hook(sae, module, input, output):
37
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
38
+ activated = sae.encode(diff)
39
+ reconstructed = sae.decoder(activated) + sae.pre_bias
40
+ return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
41
+
42
+
43
+ @torch.no_grad()
44
+ def ablate_block(module, input, output):
45
+ return input