Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gitignore +164 -0
- LICENSE +21 -0
- README.MD +68 -0
- README.md +3 -9
- SAE/__init__.py +1 -0
- SAE/config.json +23 -0
- SAE/dataset_iterator.py +53 -0
- SAE/sae.py +216 -0
- SAE/sae_utils.py +47 -0
- SDLens/__init__.py +1 -0
- SDLens/hooked_scheduler.py +40 -0
- SDLens/hooked_sd_pipeline.py +319 -0
- app.ipynb +0 -0
- app.py +399 -0
- checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
- checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
- checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
- checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
- checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
- checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
- checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
- checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
- checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
- checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
- checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
- checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
- checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
- checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
- checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
- checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
- example.ipynb +0 -0
- requirements.txt +7 -0
- resourses/image.png +3 -0
- scripts/collect_latents_dataset.py +96 -0
- scripts/train_sae.py +308 -0
- utils/__init__.py +1 -0
- utils/hooks.py +45 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
resourses/image.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
|
164 |
+
wandb/
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Viacheslav Surkov
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.MD
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Unpacking SDXL Turbo: Interpreting Text-to-Image Models with Sparse Autoencoders
|
2 |
+
|
3 |
+
![modification demostration](resourses/image.png)
|
4 |
+
|
5 |
+
This repository contains code to reproduce results from our paper on using sparse autoencoders (SAEs) to analyze and interpret the internal representations of text-to-image diffusion models, specifically SDXL Turbo.
|
6 |
+
|
7 |
+
## Repository Structure
|
8 |
+
|
9 |
+
```
|
10 |
+
|-- SAE/ # Core sparse autoencoder implementation
|
11 |
+
|-- SDLens/ # Tools for analyzing diffusion models
|
12 |
+
| `-- hooked_sd_pipeline.py # Modified stable diffusion pipeline
|
13 |
+
|-- scripts/
|
14 |
+
| |-- collect_latents_dataset.py # Generate training data
|
15 |
+
| `-- train_sae.py # Train SAE models
|
16 |
+
|-- utils/
|
17 |
+
| `-- hooks.py # Hook utility functions
|
18 |
+
|-- checkpoints/ # Pretrained SAE model checkpoints
|
19 |
+
|-- app.py # Demo application
|
20 |
+
|-- app.ipynb # Interactive notebook demo
|
21 |
+
|-- example.ipynb # Usage examples
|
22 |
+
`-- requirements.txt # Python dependencies
|
23 |
+
```
|
24 |
+
|
25 |
+
## Installation
|
26 |
+
|
27 |
+
```bash
|
28 |
+
pip install -r requirements.txt
|
29 |
+
```
|
30 |
+
|
31 |
+
## Demo Application
|
32 |
+
|
33 |
+
You can try our gradio demo application (`app.ipynb`) to browse and experiment with 20K+ features of our trained SAEs out-of-the-box. You can find the same notebook on [Google Colab](https://colab.research.google.com/drive/1Sd-g3w2Fwv7pc_fxgeQOR3S_RKr18qMP?usp=sharing).
|
34 |
+
|
35 |
+
## Usage
|
36 |
+
|
37 |
+
1. Collect latent data from SDXL Turbo:
|
38 |
+
```bash
|
39 |
+
python scripts/collect_latents_dataset.py --save_path={your_save_path}
|
40 |
+
```
|
41 |
+
|
42 |
+
2. Train sparse autoencoders:
|
43 |
+
|
44 |
+
2.1. Insert the path of stored latents and directory to store checkpoints in `SAE/config.json`
|
45 |
+
|
46 |
+
2.2. Run the training script:
|
47 |
+
|
48 |
+
```bash
|
49 |
+
python scripts/train_sae.py
|
50 |
+
```
|
51 |
+
|
52 |
+
## Pretrained Models
|
53 |
+
|
54 |
+
We provide pretrained SAE checkpoints for 4 key transformer blocks in SDXL Turbo's U-Net. See `example.ipynb` for analysis examples and visualization of learned features.
|
55 |
+
|
56 |
+
|
57 |
+
## Citation
|
58 |
+
|
59 |
+
If you find this code useful in your research, please cite our paper:
|
60 |
+
|
61 |
+
```bibtex
|
62 |
+
[Citation placeholder]
|
63 |
+
```
|
64 |
+
|
65 |
+
## Acknowledgements
|
66 |
+
|
67 |
+
The SAE component was implemented based on [`openai/sparse_autoencoder`](https://github.com/openai/sparse_autoencoder) repository.
|
68 |
+
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji: 🦀
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: gray
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.4.0
|
8 |
app_file: app.py
|
9 |
-
|
|
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Unboxing_SDXL_with_SAEs
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.44.1
|
6 |
---
|
|
|
|
SAE/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sae import SparseAutoencoder
|
SAE/config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"sae_configs": [
|
3 |
+
{
|
4 |
+
"d_model": 1280,
|
5 |
+
"n_dirs": 5120,
|
6 |
+
"k": 20
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"d_model": 1280,
|
10 |
+
"n_dirs": 640,
|
11 |
+
"k": 20
|
12 |
+
}
|
13 |
+
],
|
14 |
+
"bs": 4096,
|
15 |
+
"log_interval": 500,
|
16 |
+
"save_interval": 5000,
|
17 |
+
|
18 |
+
"paths_to_latents": [
|
19 |
+
"PASS YOUR PATHS HERE. Example /home/username/latents/<timestamp>. It should contain tar archives with latents."
|
20 |
+
],
|
21 |
+
"save_path_base": "<Your SAE save path>",
|
22 |
+
"block_name": "unet.down_blocks.2.attentions.1"
|
23 |
+
}
|
SAE/dataset_iterator.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import webdataset as wds
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class ActivationsDataloader:
|
6 |
+
def __init__(self, paths_to_datasets, block_name, batch_size, output_or_diff='diff', num_in_buffer=50):
|
7 |
+
assert output_or_diff in ['diff', 'output'], "Provide 'output' or 'diff'"
|
8 |
+
|
9 |
+
self.dataset = wds.WebDataset(
|
10 |
+
[os.path.join(path_to_dataset, f"{block_name}.tar")
|
11 |
+
for path_to_dataset in paths_to_datasets]
|
12 |
+
).decode("torch")
|
13 |
+
self.iter = iter(self.dataset)
|
14 |
+
self.buffer = None
|
15 |
+
self.pointer = 0
|
16 |
+
self.num_in_buffer = num_in_buffer
|
17 |
+
self.output_or_diff = output_or_diff
|
18 |
+
self.batch_size = batch_size
|
19 |
+
self.one_size = None
|
20 |
+
|
21 |
+
def renew_buffer(self, to_retrieve):
|
22 |
+
to_merge = []
|
23 |
+
if self.buffer is not None and self.buffer.shape[0] > self.pointer:
|
24 |
+
to_merge = [self.buffer[self.pointer:].clone()]
|
25 |
+
del self.buffer
|
26 |
+
for _ in range(to_retrieve):
|
27 |
+
sample = next(self.iter)
|
28 |
+
latents = sample['output.pth'] if self.output_or_diff == 'output' else sample['diff.pth']
|
29 |
+
latents = latents.permute((0, 1, 3, 4, 2))
|
30 |
+
latents = latents.reshape((-1, latents.shape[-1]))
|
31 |
+
to_merge.append(latents.to('cuda'))
|
32 |
+
self.one_size = latents.shape[0]
|
33 |
+
self.buffer = torch.cat(to_merge, dim=0)
|
34 |
+
shuffled_indices = torch.randperm(self.buffer.shape[0])
|
35 |
+
self.buffer = self.buffer[shuffled_indices]
|
36 |
+
self.pointer = 0
|
37 |
+
|
38 |
+
def iterate(self):
|
39 |
+
while True:
|
40 |
+
if self.buffer == None or self.buffer.shape[0] - self.pointer < self.num_in_buffer * self.one_size * 4 // 5:
|
41 |
+
try:
|
42 |
+
to_retrieve = self.num_in_buffer if self.buffer is None else self.num_in_buffer // 5
|
43 |
+
self.renew_buffer(to_retrieve)
|
44 |
+
except StopIteration:
|
45 |
+
break
|
46 |
+
|
47 |
+
batch = self.buffer[self.pointer: self.pointer + self.batch_size]
|
48 |
+
self.pointer += self.batch_size
|
49 |
+
|
50 |
+
assert batch.shape[0] == self.batch_size
|
51 |
+
yield batch
|
52 |
+
|
53 |
+
|
SAE/sae.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Adapted from
|
3 |
+
https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/model.py
|
4 |
+
'''
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
|
11 |
+
class SparseAutoencoder(nn.Module):
|
12 |
+
"""
|
13 |
+
Top-K Autoencoder with sparse kernels. Implements:
|
14 |
+
|
15 |
+
latents = relu(topk(encoder(x - pre_bias) + latent_bias))
|
16 |
+
recons = decoder(latents) + pre_bias
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
n_dirs_local: int,
|
22 |
+
d_model: int,
|
23 |
+
k: int,
|
24 |
+
auxk: int | None,
|
25 |
+
dead_steps_threshold: int,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
self.n_dirs_local = n_dirs_local
|
29 |
+
self.d_model = d_model
|
30 |
+
self.k = k
|
31 |
+
self.auxk = auxk
|
32 |
+
self.dead_steps_threshold = dead_steps_threshold
|
33 |
+
|
34 |
+
self.encoder = nn.Linear(d_model, n_dirs_local, bias=False)
|
35 |
+
self.decoder = nn.Linear(n_dirs_local, d_model, bias=False)
|
36 |
+
|
37 |
+
self.pre_bias = nn.Parameter(torch.zeros(d_model))
|
38 |
+
self.latent_bias = nn.Parameter(torch.zeros(n_dirs_local))
|
39 |
+
|
40 |
+
self.stats_last_nonzero: torch.Tensor
|
41 |
+
self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
|
42 |
+
|
43 |
+
def auxk_mask_fn(x):
|
44 |
+
dead_mask = self.stats_last_nonzero > dead_steps_threshold
|
45 |
+
x.data *= dead_mask # inplace to save memory
|
46 |
+
return x
|
47 |
+
|
48 |
+
self.auxk_mask_fn = auxk_mask_fn
|
49 |
+
|
50 |
+
## initialization
|
51 |
+
|
52 |
+
# "tied" init
|
53 |
+
self.decoder.weight.data = self.encoder.weight.data.T.clone()
|
54 |
+
|
55 |
+
# store decoder in column major layout for kernel
|
56 |
+
self.decoder.weight.data = self.decoder.weight.data.T.contiguous().T
|
57 |
+
|
58 |
+
unit_norm_decoder_(self)
|
59 |
+
|
60 |
+
def save_to_disk(self, path: str):
|
61 |
+
PATH_TO_CFG = 'config.json'
|
62 |
+
PATH_TO_WEIGHTS = 'state_dict.pth'
|
63 |
+
|
64 |
+
cfg = {
|
65 |
+
"n_dirs_local": self.n_dirs_local,
|
66 |
+
"d_model": self.d_model,
|
67 |
+
"k": self.k,
|
68 |
+
"auxk": self.auxk,
|
69 |
+
"dead_steps_threshold": self.dead_steps_threshold,
|
70 |
+
}
|
71 |
+
|
72 |
+
os.makedirs(path, exist_ok=True)
|
73 |
+
|
74 |
+
with open(os.path.join(path, PATH_TO_CFG), 'w') as f:
|
75 |
+
json.dump(cfg, f)
|
76 |
+
|
77 |
+
|
78 |
+
torch.save({
|
79 |
+
"state_dict": self.state_dict(),
|
80 |
+
}, os.path.join(path, PATH_TO_WEIGHTS))
|
81 |
+
|
82 |
+
|
83 |
+
@classmethod
|
84 |
+
def load_from_disk(cls, path: str):
|
85 |
+
PATH_TO_CFG = 'config.json'
|
86 |
+
PATH_TO_WEIGHTS = 'state_dict.pth'
|
87 |
+
|
88 |
+
with open(os.path.join(path, PATH_TO_CFG), 'r') as f:
|
89 |
+
cfg = json.load(f)
|
90 |
+
|
91 |
+
ae = cls(
|
92 |
+
n_dirs_local=cfg["n_dirs_local"],
|
93 |
+
d_model=cfg["d_model"],
|
94 |
+
k=cfg["k"],
|
95 |
+
auxk=cfg["auxk"],
|
96 |
+
dead_steps_threshold=cfg["dead_steps_threshold"],
|
97 |
+
)
|
98 |
+
|
99 |
+
state_dict = torch.load(os.path.join(path, PATH_TO_WEIGHTS))["state_dict"]
|
100 |
+
ae.load_state_dict(state_dict)
|
101 |
+
|
102 |
+
return ae
|
103 |
+
|
104 |
+
@property
|
105 |
+
def n_dirs(self):
|
106 |
+
return self.n_dirs_local
|
107 |
+
|
108 |
+
def encode(self, x):
|
109 |
+
x = x - self.pre_bias
|
110 |
+
latents_pre_act = self.encoder(x) + self.latent_bias
|
111 |
+
|
112 |
+
vals, inds = torch.topk(
|
113 |
+
latents_pre_act,
|
114 |
+
k=self.k,
|
115 |
+
dim=-1
|
116 |
+
)
|
117 |
+
|
118 |
+
latents = torch.zeros_like(latents_pre_act)
|
119 |
+
latents.scatter_(-1, inds, torch.relu(vals))
|
120 |
+
|
121 |
+
return latents
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
x = x - self.pre_bias
|
125 |
+
latents_pre_act = self.encoder(x) + self.latent_bias
|
126 |
+
vals, inds = torch.topk(
|
127 |
+
latents_pre_act,
|
128 |
+
k=self.k,
|
129 |
+
dim=-1
|
130 |
+
)
|
131 |
+
|
132 |
+
## set num nonzero stat ##
|
133 |
+
tmp = torch.zeros_like(self.stats_last_nonzero)
|
134 |
+
tmp.scatter_add_(
|
135 |
+
0,
|
136 |
+
inds.reshape(-1),
|
137 |
+
(vals > 1e-3).to(tmp.dtype).reshape(-1),
|
138 |
+
)
|
139 |
+
self.stats_last_nonzero *= 1 - tmp.clamp(max=1)
|
140 |
+
self.stats_last_nonzero += 1
|
141 |
+
## end stats ##
|
142 |
+
|
143 |
+
## auxk
|
144 |
+
if self.auxk is not None: # for auxk
|
145 |
+
# IMPORTANT: has to go after stats update!
|
146 |
+
# WARN: auxk_mask_fn can mutate latents_pre_act!
|
147 |
+
auxk_vals, auxk_inds = torch.topk(
|
148 |
+
self.auxk_mask_fn(latents_pre_act),
|
149 |
+
k=self.auxk,
|
150 |
+
dim=-1
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
auxk_inds = None
|
154 |
+
auxk_vals = None
|
155 |
+
|
156 |
+
## end auxk
|
157 |
+
|
158 |
+
vals = torch.relu(vals)
|
159 |
+
if auxk_vals is not None:
|
160 |
+
auxk_vals = torch.relu(auxk_vals)
|
161 |
+
|
162 |
+
|
163 |
+
rows, cols = latents_pre_act.size()
|
164 |
+
row_indices = torch.arange(rows).unsqueeze(1).expand(-1, self.k).reshape(-1)
|
165 |
+
vals = vals.reshape(-1)
|
166 |
+
inds = inds.reshape(-1)
|
167 |
+
|
168 |
+
indices = torch.stack([row_indices.to(inds.device), inds])
|
169 |
+
|
170 |
+
sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
|
171 |
+
|
172 |
+
recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
|
173 |
+
|
174 |
+
|
175 |
+
return recons, {
|
176 |
+
"inds": inds,
|
177 |
+
"vals": vals,
|
178 |
+
"auxk_inds": auxk_inds,
|
179 |
+
"auxk_vals": auxk_vals,
|
180 |
+
}
|
181 |
+
|
182 |
+
|
183 |
+
def decode_sparse(self, inds, vals):
|
184 |
+
rows, cols = inds.shape[0], self.n_dirs
|
185 |
+
|
186 |
+
row_indices = torch.arange(rows).unsqueeze(1).expand(-1, inds.shape[1]).reshape(-1)
|
187 |
+
vals = vals.reshape(-1)
|
188 |
+
inds = inds.reshape(-1)
|
189 |
+
|
190 |
+
indices = torch.stack([row_indices.to(inds.device), inds])
|
191 |
+
|
192 |
+
sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
|
193 |
+
|
194 |
+
recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
|
195 |
+
return recons
|
196 |
+
|
197 |
+
@property
|
198 |
+
def device(self):
|
199 |
+
return next(self.parameters()).device
|
200 |
+
|
201 |
+
|
202 |
+
def unit_norm_decoder_(autoencoder: SparseAutoencoder) -> None:
|
203 |
+
"""
|
204 |
+
Unit normalize the decoder weights of an autoencoder.
|
205 |
+
"""
|
206 |
+
autoencoder.decoder.weight.data /= autoencoder.decoder.weight.data.norm(dim=0)
|
207 |
+
|
208 |
+
|
209 |
+
def unit_norm_decoder_grad_adjustment_(autoencoder) -> None:
|
210 |
+
"""project out gradient information parallel to the dictionary vectors - assumes that the decoder is already unit normed"""
|
211 |
+
|
212 |
+
assert autoencoder.decoder.weight.grad is not None
|
213 |
+
|
214 |
+
autoencoder.decoder.weight.grad +=\
|
215 |
+
torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad) *\
|
216 |
+
autoencoder.decoder.weight.data * -1
|
SAE/sae_utils.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class SAETrainingConfig:
|
6 |
+
d_model: int
|
7 |
+
n_dirs: int
|
8 |
+
k: int
|
9 |
+
block_name: str
|
10 |
+
bs: int
|
11 |
+
save_path_base: str
|
12 |
+
auxk: int = 256
|
13 |
+
lr: float = 1e-4
|
14 |
+
eps: float = 6.25e-10
|
15 |
+
dead_toks_threshold: int = 10_000_000
|
16 |
+
auxk_coef: float = 1/32
|
17 |
+
|
18 |
+
@property
|
19 |
+
def sae_name(self):
|
20 |
+
return f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}'
|
21 |
+
|
22 |
+
@property
|
23 |
+
def save_path(self):
|
24 |
+
return f'/dlabscratch1/surkov/sae_models/{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}'
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class Config:
|
29 |
+
saes: list[SAETrainingConfig]
|
30 |
+
paths_to_latents: list[str]
|
31 |
+
log_interval: int
|
32 |
+
save_interval: int
|
33 |
+
bs: int
|
34 |
+
block_name: str
|
35 |
+
wandb_project: str = 'sdxl_sae_train'
|
36 |
+
wandb_name: str = 'multiple_sae'
|
37 |
+
|
38 |
+
def __init__(self, cfg_json):
|
39 |
+
self.saes = [SAETrainingConfig(**sae_cfg, block_name=cfg_json['block_name'], bs=cfg_json['bs'], save_path_base=cfg_json['save_path_base'])
|
40 |
+
for sae_cfg in cfg_json['sae_configs']]
|
41 |
+
|
42 |
+
self.save_path_base = cfg_json['save_path_base']
|
43 |
+
self.paths_to_latents = cfg_json['paths_to_latents']
|
44 |
+
self.log_interval = cfg_json['log_interval']
|
45 |
+
self.save_interval = cfg_json['save_interval']
|
46 |
+
self.bs = cfg_json['bs']
|
47 |
+
self.block_name = cfg_json['block_name']
|
SDLens/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .hooked_sd_pipeline import HookedIFPipeline, HookedStableDiffusionXLPipeline
|
SDLens/hooked_scheduler.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DDPMScheduler
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class HookedNoiseScheduler:
|
5 |
+
scheduler: DDPMScheduler
|
6 |
+
pre_hooks: list
|
7 |
+
post_hooks: list
|
8 |
+
|
9 |
+
def __init__(self, scheduler):
|
10 |
+
object.__setattr__(self, 'scheduler', scheduler)
|
11 |
+
object.__setattr__(self, 'pre_hooks', [])
|
12 |
+
object.__setattr__(self, 'post_hooks', [])
|
13 |
+
|
14 |
+
def step(
|
15 |
+
self,
|
16 |
+
model_output, timestep, sample, generator, return_dict
|
17 |
+
):
|
18 |
+
assert return_dict == False, "return_dict == True is not implemented"
|
19 |
+
for hook in self.pre_hooks:
|
20 |
+
hook_output = hook(model_output, timestep, sample, generator)
|
21 |
+
if hook_output is not None:
|
22 |
+
model_output, timestep, sample, generator = hook_output
|
23 |
+
|
24 |
+
(pred_prev_sample, ) = self.scheduler.step(model_output, timestep, sample, generator, return_dict)
|
25 |
+
|
26 |
+
for hook in self.post_hooks:
|
27 |
+
hook_output = hook(pred_prev_sample)
|
28 |
+
if hook_output is not None:
|
29 |
+
pred_prev_sample = hook_output
|
30 |
+
|
31 |
+
return (pred_prev_sample, )
|
32 |
+
|
33 |
+
def __getattr__(self, name):
|
34 |
+
return getattr(self.scheduler, name)
|
35 |
+
|
36 |
+
def __setattr__(self, name, value):
|
37 |
+
if name in {'scheduler', 'pre_hooks', 'post_hooks'}:
|
38 |
+
object.__setattr__(self, name, value)
|
39 |
+
else:
|
40 |
+
setattr(self.scheduler, name, value)
|
SDLens/hooked_sd_pipeline.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import einops
|
2 |
+
from diffusers import StableDiffusionXLPipeline, IFPipeline
|
3 |
+
from typing import List, Dict, Callable, Union
|
4 |
+
import torch
|
5 |
+
from .hooked_scheduler import HookedNoiseScheduler
|
6 |
+
|
7 |
+
def retrieve(io):
|
8 |
+
if isinstance(io, tuple):
|
9 |
+
if len(io) == 1:
|
10 |
+
return io[0]
|
11 |
+
else:
|
12 |
+
raise ValueError("A tuple should have length of 1")
|
13 |
+
elif isinstance(io, torch.Tensor):
|
14 |
+
return io
|
15 |
+
else:
|
16 |
+
raise ValueError("Input/Output must be a tensor, or 1-element tuple")
|
17 |
+
|
18 |
+
|
19 |
+
class HookedDiffusionAbstractPipeline:
|
20 |
+
parent_cls = None
|
21 |
+
pipe = None
|
22 |
+
|
23 |
+
def __init__(self, pipe: parent_cls, use_hooked_scheduler: bool = False):
|
24 |
+
if use_hooked_scheduler:
|
25 |
+
pipe.scheduler = HookedNoiseScheduler(pipe.scheduler)
|
26 |
+
self.__dict__['pipe'] = pipe
|
27 |
+
self.use_hooked_scheduler = use_hooked_scheduler
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def from_pretrained(cls, *args, **kwargs):
|
31 |
+
return cls(cls.parent_cls.from_pretrained(*args, **kwargs))
|
32 |
+
|
33 |
+
|
34 |
+
def run_with_hooks(self,
|
35 |
+
*args,
|
36 |
+
position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
|
37 |
+
**kwargs
|
38 |
+
):
|
39 |
+
'''
|
40 |
+
Run the pipeline with hooks at specified positions.
|
41 |
+
Returns the final output.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
*args: Arguments to pass to the pipeline.
|
45 |
+
position_hook_dict: A dictionary mapping positions to hooks.
|
46 |
+
The keys are positions in the pipeline where the hooks should be registered.
|
47 |
+
The values are either a single hook or a list of hooks to be registered at the specified position.
|
48 |
+
Each hook should be a callable that takes three arguments: (module, input, output).
|
49 |
+
**kwargs: Keyword arguments to pass to the pipeline.
|
50 |
+
'''
|
51 |
+
hooks = []
|
52 |
+
for position, hook in position_hook_dict.items():
|
53 |
+
if isinstance(hook, list):
|
54 |
+
for h in hook:
|
55 |
+
hooks.append(self._register_general_hook(position, h))
|
56 |
+
else:
|
57 |
+
hooks.append(self._register_general_hook(position, hook))
|
58 |
+
|
59 |
+
hooks = [hook for hook in hooks if hook is not None]
|
60 |
+
|
61 |
+
try:
|
62 |
+
output = self.pipe(*args, **kwargs)
|
63 |
+
finally:
|
64 |
+
for hook in hooks:
|
65 |
+
hook.remove()
|
66 |
+
if self.use_hooked_scheduler:
|
67 |
+
self.pipe.scheduler.pre_hooks = []
|
68 |
+
self.pipe.scheduler.post_hooks = []
|
69 |
+
|
70 |
+
return output
|
71 |
+
|
72 |
+
def run_with_cache(self,
|
73 |
+
*args,
|
74 |
+
positions_to_cache: List[str],
|
75 |
+
save_input: bool = False,
|
76 |
+
save_output: bool = True,
|
77 |
+
**kwargs
|
78 |
+
):
|
79 |
+
'''
|
80 |
+
Run the pipeline with caching at specified positions.
|
81 |
+
|
82 |
+
This method allows you to cache the intermediate inputs and/or outputs of the pipeline
|
83 |
+
at certain positions. The final output of the pipeline and a dictionary of cached values
|
84 |
+
are returned.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
*args: Arguments to pass to the pipeline.
|
88 |
+
positions_to_cache (List[str]): A list of positions in the pipeline where intermediate
|
89 |
+
inputs/outputs should be cached.
|
90 |
+
save_input (bool, optional): If True, caches the input at each specified position.
|
91 |
+
Defaults to False.
|
92 |
+
save_output (bool, optional): If True, caches the output at each specified position.
|
93 |
+
Defaults to True.
|
94 |
+
**kwargs: Keyword arguments to pass to the pipeline.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
final_output: The final output of the pipeline after execution.
|
98 |
+
cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
|
99 |
+
and values are dictionaries containing the cached 'input' and/or 'output' at each position,
|
100 |
+
depending on the flags `save_input` and `save_output`.
|
101 |
+
'''
|
102 |
+
cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
|
103 |
+
hooks = [
|
104 |
+
self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
|
105 |
+
]
|
106 |
+
hooks = [hook for hook in hooks if hook is not None]
|
107 |
+
output = self.pipe(*args, **kwargs)
|
108 |
+
for hook in hooks:
|
109 |
+
hook.remove()
|
110 |
+
if self.use_hooked_scheduler:
|
111 |
+
self.pipe.scheduler.pre_hooks = []
|
112 |
+
self.pipe.scheduler.post_hooks = []
|
113 |
+
|
114 |
+
cache_dict = {}
|
115 |
+
if save_input:
|
116 |
+
for position, block in cache_input.items():
|
117 |
+
cache_input[position] = torch.stack(block, dim=1)
|
118 |
+
cache_dict['input'] = cache_input
|
119 |
+
|
120 |
+
if save_output:
|
121 |
+
for position, block in cache_output.items():
|
122 |
+
cache_output[position] = torch.stack(block, dim=1)
|
123 |
+
cache_dict['output'] = cache_output
|
124 |
+
return output, cache_dict
|
125 |
+
|
126 |
+
def run_with_hooks_and_cache(self,
|
127 |
+
*args,
|
128 |
+
position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
|
129 |
+
positions_to_cache: List[str] = [],
|
130 |
+
save_input: bool = False,
|
131 |
+
save_output: bool = True,
|
132 |
+
**kwargs
|
133 |
+
):
|
134 |
+
'''
|
135 |
+
Run the pipeline with hooks and caching at specified positions.
|
136 |
+
|
137 |
+
This method allows you to register hooks at certain positions in the pipeline and
|
138 |
+
cache intermediate inputs and/or outputs at specified positions. Hooks can be used
|
139 |
+
for inspecting or modifying the pipeline's execution, and caching stores intermediate
|
140 |
+
values for later inspection or use.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
*args: Arguments to pass to the pipeline.
|
144 |
+
position_hook_dict Dict[str, Union[Callable, List[Callable]]]:
|
145 |
+
A dictionary where the keys are the positions in the pipeline, and the values
|
146 |
+
are hooks (either a single hook or a list of hooks) to be registered at those positions.
|
147 |
+
Each hook should be a callable that accepts three arguments: (module, input, output).
|
148 |
+
positions_to_cache (List[str], optional): A list of positions in the pipeline where
|
149 |
+
intermediate inputs/outputs should be cached. Defaults to an empty list.
|
150 |
+
save_input (bool, optional): If True, caches the input at each specified position.
|
151 |
+
Defaults to False.
|
152 |
+
save_output (bool, optional): If True, caches the output at each specified position.
|
153 |
+
Defaults to True.
|
154 |
+
**kwargs: Additional keyword arguments to pass to the pipeline.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
final_output: The final output of the pipeline after execution.
|
158 |
+
cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
|
159 |
+
and values are dictionaries containing the cached 'input' and/or 'output' at each position,
|
160 |
+
depending on the flags `save_input` and `save_output`.
|
161 |
+
'''
|
162 |
+
cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
|
163 |
+
hooks = [
|
164 |
+
self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
|
165 |
+
]
|
166 |
+
|
167 |
+
for position, hook in position_hook_dict.items():
|
168 |
+
if isinstance(hook, list):
|
169 |
+
for h in hook:
|
170 |
+
hooks.append(self._register_general_hook(position, h))
|
171 |
+
else:
|
172 |
+
hooks.append(self._register_general_hook(position, hook))
|
173 |
+
|
174 |
+
hooks = [hook for hook in hooks if hook is not None]
|
175 |
+
output = self.pipe(*args, **kwargs)
|
176 |
+
for hook in hooks:
|
177 |
+
hook.remove()
|
178 |
+
if self.use_hooked_scheduler:
|
179 |
+
self.pipe.scheduler.pre_hooks = []
|
180 |
+
self.pipe.scheduler.post_hooks = []
|
181 |
+
|
182 |
+
cache_dict = {}
|
183 |
+
if save_input:
|
184 |
+
for position, block in cache_input.items():
|
185 |
+
cache_input[position] = torch.stack(block, dim=1)
|
186 |
+
cache_dict['input'] = cache_input
|
187 |
+
|
188 |
+
if save_output:
|
189 |
+
for position, block in cache_output.items():
|
190 |
+
cache_output[position] = torch.stack(block, dim=1)
|
191 |
+
cache_dict['output'] = cache_output
|
192 |
+
|
193 |
+
return output, cache_dict
|
194 |
+
|
195 |
+
|
196 |
+
def _locate_block(self, position: str):
|
197 |
+
'''
|
198 |
+
Locate the block at the specified position in the pipeline.
|
199 |
+
'''
|
200 |
+
block = self.pipe
|
201 |
+
for step in position.split('.'):
|
202 |
+
if step.isdigit():
|
203 |
+
step = int(step)
|
204 |
+
block = block[step]
|
205 |
+
else:
|
206 |
+
block = getattr(block, step)
|
207 |
+
return block
|
208 |
+
|
209 |
+
|
210 |
+
def _register_cache_hook(self, position: str, cache_input: Dict, cache_output: Dict):
|
211 |
+
|
212 |
+
if position.endswith('$self_attention') or position.endswith('$cross_attention'):
|
213 |
+
return self._register_cache_attention_hook(position, cache_output)
|
214 |
+
|
215 |
+
if position == 'noise':
|
216 |
+
def hook(model_output, timestep, sample, generator):
|
217 |
+
if position not in cache_output:
|
218 |
+
cache_output[position] = []
|
219 |
+
cache_output[position].append(sample)
|
220 |
+
|
221 |
+
if self.use_hooked_scheduler:
|
222 |
+
self.pipe.scheduler.post_hooks.append(hook)
|
223 |
+
else:
|
224 |
+
raise ValueError('Cannot cache noise without using hooked scheduler')
|
225 |
+
return
|
226 |
+
|
227 |
+
block = self._locate_block(position)
|
228 |
+
|
229 |
+
def hook(module, input, kwargs, output):
|
230 |
+
if cache_input is not None:
|
231 |
+
if position not in cache_input:
|
232 |
+
cache_input[position] = []
|
233 |
+
cache_input[position].append(retrieve(input))
|
234 |
+
|
235 |
+
if cache_output is not None:
|
236 |
+
if position not in cache_output:
|
237 |
+
cache_output[position] = []
|
238 |
+
cache_output[position].append(retrieve(output))
|
239 |
+
|
240 |
+
return block.register_forward_hook(hook, with_kwargs=True)
|
241 |
+
|
242 |
+
def _register_cache_attention_hook(self, position, cache):
|
243 |
+
attn_block = self._locate_block(position.split('$')[0])
|
244 |
+
if position.endswith('$self_attention'):
|
245 |
+
attn_block = attn_block.attn1
|
246 |
+
elif position.endswith('$cross_attention'):
|
247 |
+
attn_block = attn_block.attn2
|
248 |
+
else:
|
249 |
+
raise ValueError('Wrong attention type')
|
250 |
+
|
251 |
+
def hook(module, args, kwargs, output):
|
252 |
+
hidden_states = args[0]
|
253 |
+
encoder_hidden_states = kwargs['encoder_hidden_states']
|
254 |
+
attention_mask = kwargs['attention_mask']
|
255 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
256 |
+
attention_mask = attn_block.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
257 |
+
query = attn_block.to_q(hidden_states)
|
258 |
+
|
259 |
+
|
260 |
+
if encoder_hidden_states is None:
|
261 |
+
encoder_hidden_states = hidden_states
|
262 |
+
elif attn_block.norm_cross is not None:
|
263 |
+
encoder_hidden_states = attn_block.norm_cross(encoder_hidden_states)
|
264 |
+
|
265 |
+
key = attn_block.to_k(encoder_hidden_states)
|
266 |
+
value = attn_block.to_v(encoder_hidden_states)
|
267 |
+
|
268 |
+
query = attn_block.head_to_batch_dim(query)
|
269 |
+
key = attn_block.head_to_batch_dim(key)
|
270 |
+
value = attn_block.head_to_batch_dim(value)
|
271 |
+
|
272 |
+
attention_probs = attn_block.get_attention_scores(query, key, attention_mask)
|
273 |
+
attention_probs = attention_probs.view(
|
274 |
+
batch_size,
|
275 |
+
attention_probs.shape[0] // batch_size,
|
276 |
+
attention_probs.shape[1],
|
277 |
+
attention_probs.shape[2]
|
278 |
+
)
|
279 |
+
if position not in cache:
|
280 |
+
cache[position] = []
|
281 |
+
cache[position].append(attention_probs)
|
282 |
+
|
283 |
+
return attn_block.register_forward_hook(hook, with_kwargs=True)
|
284 |
+
|
285 |
+
def _register_general_hook(self, position, hook):
|
286 |
+
if position == 'scheduler_pre':
|
287 |
+
if not self.use_hooked_scheduler:
|
288 |
+
raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
|
289 |
+
self.pipe.scheduler.pre_hooks.append(hook)
|
290 |
+
return
|
291 |
+
elif position == 'scheduler_post':
|
292 |
+
if not self.use_hooked_scheduler:
|
293 |
+
raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
|
294 |
+
self.pipe.scheduler.post_hooks.append(hook)
|
295 |
+
return
|
296 |
+
|
297 |
+
block = self._locate_block(position)
|
298 |
+
return block.register_forward_hook(hook)
|
299 |
+
|
300 |
+
def to(self, *args, **kwargs):
|
301 |
+
self.pipe = self.pipe.to(*args, **kwargs)
|
302 |
+
return self
|
303 |
+
|
304 |
+
def __getattr__(self, name):
|
305 |
+
return getattr(self.pipe, name)
|
306 |
+
|
307 |
+
def __setattr__(self, name, value):
|
308 |
+
return setattr(self.pipe, name, value)
|
309 |
+
|
310 |
+
def __call__(self, *args, **kwargs):
|
311 |
+
return self.pipe(*args, **kwargs)
|
312 |
+
|
313 |
+
|
314 |
+
class HookedStableDiffusionXLPipeline(HookedDiffusionAbstractPipeline):
|
315 |
+
parent_cls = StableDiffusionXLPipeline
|
316 |
+
|
317 |
+
|
318 |
+
class HookedIFPipeline(HookedDiffusionAbstractPipeline):
|
319 |
+
parent_cls = IFPipeline
|
app.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from SDLens import HookedStableDiffusionXLPipeline
|
6 |
+
from SAE import SparseAutoencoder
|
7 |
+
from utils import add_feature_on_area
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from matplotlib.colors import ListedColormap
|
11 |
+
from utils import add_feature_on_area, replace_with_feature
|
12 |
+
import threading
|
13 |
+
|
14 |
+
code_to_block = {
|
15 |
+
"down.2.1": "unet.down_blocks.2.attentions.1",
|
16 |
+
"mid.0": "unet.mid_block.attentions.0",
|
17 |
+
"up.0.1": "unet.up_blocks.0.attentions.1",
|
18 |
+
"up.0.0": "unet.up_blocks.0.attentions.0"
|
19 |
+
}
|
20 |
+
lock = threading.Lock()
|
21 |
+
|
22 |
+
def process_cache(cache, saes_dict):
|
23 |
+
|
24 |
+
top_features_dict = {}
|
25 |
+
sparse_maps_dict = {}
|
26 |
+
|
27 |
+
for code in code_to_block.keys():
|
28 |
+
block = code_to_block[code]
|
29 |
+
sae = saes_dict[code]
|
30 |
+
|
31 |
+
diff = cache["output"][block] - cache["input"][block]
|
32 |
+
diff = diff.permute(0, 1, 3, 4, 2).squeeze(0).squeeze(0)
|
33 |
+
with torch.no_grad():
|
34 |
+
sparse_maps = sae.encode(diff)
|
35 |
+
averages = torch.mean(sparse_maps, dim=(0, 1))
|
36 |
+
|
37 |
+
top_features = torch.topk(averages, 10).indices
|
38 |
+
|
39 |
+
top_features_dict[code] = top_features.cpu().tolist()
|
40 |
+
sparse_maps_dict[code] = sparse_maps.cpu().numpy()
|
41 |
+
|
42 |
+
return top_features_dict, sparse_maps_dict
|
43 |
+
|
44 |
+
|
45 |
+
def plot_image_heatmap(cache, block_select, radio):
|
46 |
+
code = block_select.split()[0]
|
47 |
+
feature = int(radio)
|
48 |
+
block = code_to_block[code]
|
49 |
+
|
50 |
+
heatmap = cache["heatmaps"][code][:, :, feature]
|
51 |
+
heatmap = np.kron(heatmap, np.ones((32, 32)))
|
52 |
+
image = cache["image"].convert("RGBA")
|
53 |
+
|
54 |
+
jet = plt.cm.jet
|
55 |
+
cmap = jet(np.arange(jet.N))
|
56 |
+
cmap[:1, -1] = 0
|
57 |
+
cmap[1:, -1] = 0.6
|
58 |
+
cmap = ListedColormap(cmap)
|
59 |
+
heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
|
60 |
+
heatmap_rgba = cmap(heatmap)
|
61 |
+
heatmap_image = Image.fromarray((heatmap_rgba * 255).astype(np.uint8))
|
62 |
+
heatmap_with_transparency = Image.alpha_composite(image, heatmap_image)
|
63 |
+
|
64 |
+
return heatmap_with_transparency
|
65 |
+
|
66 |
+
|
67 |
+
def create_prompt_part(pipe, saes_dict, demo):
|
68 |
+
def image_gen(prompt):
|
69 |
+
lock.acquire()
|
70 |
+
try:
|
71 |
+
images, cache = pipe.run_with_cache(
|
72 |
+
prompt,
|
73 |
+
positions_to_cache=list(code_to_block.values()),
|
74 |
+
num_inference_steps=1,
|
75 |
+
generator=torch.Generator(device="cpu").manual_seed(42),
|
76 |
+
guidance_scale=0.0,
|
77 |
+
save_input=True,
|
78 |
+
save_output=True
|
79 |
+
)
|
80 |
+
finally:
|
81 |
+
lock.release()
|
82 |
+
|
83 |
+
top_features_dict, top_sparse_maps_dict = process_cache(cache, saes_dict)
|
84 |
+
return images.images[0], {
|
85 |
+
"image": images.images[0],
|
86 |
+
"heatmaps": top_sparse_maps_dict,
|
87 |
+
"features": top_features_dict
|
88 |
+
}
|
89 |
+
|
90 |
+
def update_radio(cache, block_select):
|
91 |
+
code = block_select.split()[0]
|
92 |
+
return gr.update(choices=cache["features"][code])
|
93 |
+
|
94 |
+
def update_img(cache, block_select, radio):
|
95 |
+
new_img = plot_image_heatmap(cache, block_select, radio)
|
96 |
+
return new_img
|
97 |
+
|
98 |
+
with gr.Tab("Explore", elem_classes="tabs") as explore_tab:
|
99 |
+
cache = gr.State(value={
|
100 |
+
"image": None,
|
101 |
+
"heatmaps": None,
|
102 |
+
"features": []
|
103 |
+
})
|
104 |
+
with gr.Row():
|
105 |
+
with gr.Column(scale=7):
|
106 |
+
with gr.Row(equal_height=True):
|
107 |
+
prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party and eathing a dish with peas.")
|
108 |
+
button = gr.Button("Generate", elem_classes="generate_button1")
|
109 |
+
|
110 |
+
with gr.Row():
|
111 |
+
image = gr.Image(width=512, height=512, image_mode="RGB", label="Generated image")
|
112 |
+
|
113 |
+
with gr.Column(scale=4):
|
114 |
+
block_select = gr.Dropdown(
|
115 |
+
choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
|
116 |
+
value="down.2.1 (composition)",
|
117 |
+
label="Select block",
|
118 |
+
elem_id="block_select",
|
119 |
+
interactive=True
|
120 |
+
)
|
121 |
+
radio = gr.Radio(choices=[], label="Select a feature", interactive=True)
|
122 |
+
|
123 |
+
button.click(image_gen, [prompt_field], outputs=[image, cache])
|
124 |
+
cache.change(update_radio, [cache, block_select], outputs=[radio])
|
125 |
+
block_select.select(update_radio, [cache, block_select], outputs=[radio])
|
126 |
+
radio.select(update_img, [cache, block_select, radio], outputs=[image])
|
127 |
+
demo.load(image_gen, [prompt_field], outputs=[image, cache])
|
128 |
+
|
129 |
+
return explore_tab
|
130 |
+
|
131 |
+
def downsample_mask(image, factor):
|
132 |
+
downsampled = image.reshape(
|
133 |
+
(image.shape[0] // factor, factor,
|
134 |
+
image.shape[1] // factor, factor)
|
135 |
+
)
|
136 |
+
downsampled = downsampled.mean(axis=(1, 3))
|
137 |
+
return downsampled
|
138 |
+
|
139 |
+
def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, means_dict, demo):
|
140 |
+
def image_gen(prompt, num_steps):
|
141 |
+
lock.acquire()
|
142 |
+
try:
|
143 |
+
images = pipe.run_with_hooks(
|
144 |
+
prompt,
|
145 |
+
position_hook_dict={},
|
146 |
+
num_inference_steps=num_steps,
|
147 |
+
generator=torch.Generator(device="cpu").manual_seed(42),
|
148 |
+
guidance_scale=0.0
|
149 |
+
)
|
150 |
+
finally:
|
151 |
+
lock.release()
|
152 |
+
return images.images[0]
|
153 |
+
|
154 |
+
def image_mod(prompt, block_str, brush_index, strength, num_steps, input_image):
|
155 |
+
block = block_str.split(" ")[0]
|
156 |
+
|
157 |
+
mask = (input_image["layers"][0] > 0)[:, :, -1].astype(float)
|
158 |
+
mask = downsample_mask(mask, 32)
|
159 |
+
mask = torch.tensor(mask, dtype=torch.float32, device="cuda")
|
160 |
+
|
161 |
+
if mask.sum() == 0:
|
162 |
+
gr.Info("No mask selected, please draw on the input image")
|
163 |
+
|
164 |
+
def hook(module, input, output):
|
165 |
+
return add_feature_on_area(
|
166 |
+
saes_dict[block],
|
167 |
+
brush_index,
|
168 |
+
mask * means_dict[block][brush_index] * strength,
|
169 |
+
module,
|
170 |
+
input,
|
171 |
+
output
|
172 |
+
)
|
173 |
+
|
174 |
+
lock.acquire()
|
175 |
+
try:
|
176 |
+
image = pipe.run_with_hooks(
|
177 |
+
prompt,
|
178 |
+
position_hook_dict={code_to_block[block]: hook},
|
179 |
+
num_inference_steps=num_steps,
|
180 |
+
generator=torch.Generator(device="cpu").manual_seed(42),
|
181 |
+
guidance_scale=0.0
|
182 |
+
).images[0]
|
183 |
+
finally:
|
184 |
+
lock.release()
|
185 |
+
return image
|
186 |
+
|
187 |
+
def feature_icon(block_str, brush_index):
|
188 |
+
block = block_str.split(" ")[0]
|
189 |
+
if block in ["mid.0", "up.0.0"]:
|
190 |
+
gr.Info("Note that Feature Icon works best with down.2.1 and up.0.1 blocks but feel free to explore", duration=3)
|
191 |
+
|
192 |
+
def hook(module, input, output):
|
193 |
+
return replace_with_feature(
|
194 |
+
saes_dict[block],
|
195 |
+
brush_index,
|
196 |
+
means_dict[block][brush_index] * saes_dict[block].k,
|
197 |
+
module,
|
198 |
+
input,
|
199 |
+
output
|
200 |
+
)
|
201 |
+
|
202 |
+
lock.acquire()
|
203 |
+
try:
|
204 |
+
image = pipe.run_with_hooks(
|
205 |
+
"",
|
206 |
+
position_hook_dict={code_to_block[block]: hook},
|
207 |
+
num_inference_steps=1,
|
208 |
+
generator=torch.Generator(device="cpu").manual_seed(42),
|
209 |
+
guidance_scale=0.0
|
210 |
+
).images[0]
|
211 |
+
finally:
|
212 |
+
lock.release()
|
213 |
+
return image
|
214 |
+
|
215 |
+
with gr.Tab("Paint!", elem_classes="tabs") as intervene_tab:
|
216 |
+
image_state = gr.State(value=None)
|
217 |
+
with gr.Row():
|
218 |
+
with gr.Column(scale=3):
|
219 |
+
# Generation column
|
220 |
+
with gr.Row():
|
221 |
+
# prompt and num_steps
|
222 |
+
prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A dog plays with a ball, cartoon", elem_id="prompt_input")
|
223 |
+
num_steps = gr.Number(value=1, label="Number of steps", minimum=1, maximum=4, elem_id="num_steps", precision=0)
|
224 |
+
with gr.Row():
|
225 |
+
# Generate button
|
226 |
+
button_generate = gr.Button("Generate", elem_id="generate_button")
|
227 |
+
with gr.Column(scale=3):
|
228 |
+
# Intervention column
|
229 |
+
with gr.Row():
|
230 |
+
# dropdowns and number inputs
|
231 |
+
with gr.Column(scale=7):
|
232 |
+
with gr.Row():
|
233 |
+
block_select = gr.Dropdown(
|
234 |
+
choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
|
235 |
+
value="down.2.1 (composition)",
|
236 |
+
label="Select block",
|
237 |
+
elem_id="block_select"
|
238 |
+
)
|
239 |
+
brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=5119, elem_id="brush_index", precision=0)
|
240 |
+
with gr.Row():
|
241 |
+
button_icon = gr.Button('Feature Icon', elem_id="feature_icon_button")
|
242 |
+
with gr.Column(scale=3):
|
243 |
+
with gr.Row():
|
244 |
+
strength = gr.Number(value=10, label="Strength", minimum=-40, maximum=40, elem_id="strength", precision=2)
|
245 |
+
with gr.Row():
|
246 |
+
button = gr.Button('Apply', elem_id="apply_button")
|
247 |
+
|
248 |
+
with gr.Row():
|
249 |
+
with gr.Column():
|
250 |
+
# Input image
|
251 |
+
i_image = gr.Sketchpad(
|
252 |
+
height=610,
|
253 |
+
layers=False, transforms=[], placeholder="Generate and paint!",
|
254 |
+
brush=gr.Brush(default_size=64, color_mode="fixed", colors=['black']),
|
255 |
+
container=False,
|
256 |
+
canvas_size=(512, 512),
|
257 |
+
label="Input Image")
|
258 |
+
clear_button = gr.Button("Clear")
|
259 |
+
clear_button.click(lambda x: x, [image_state], [i_image])
|
260 |
+
# Output image
|
261 |
+
o_image = gr.Image(width=512, height=512, label="Output Image")
|
262 |
+
|
263 |
+
# Set up the click events
|
264 |
+
button_generate.click(image_gen, inputs=[prompt_field, num_steps], outputs=[image_state])
|
265 |
+
image_state.change(lambda x: x, [image_state], [i_image])
|
266 |
+
button.click(image_mod,
|
267 |
+
inputs=[prompt_field, block_select, brush_index, strength, num_steps, i_image],
|
268 |
+
outputs=o_image)
|
269 |
+
button_icon.click(feature_icon, inputs=[block_select, brush_index], outputs=o_image)
|
270 |
+
demo.load(image_gen, [prompt_field, num_steps], outputs=[image_state])
|
271 |
+
|
272 |
+
|
273 |
+
return intervene_tab
|
274 |
+
|
275 |
+
|
276 |
+
def create_top_images_part(demo):
|
277 |
+
def update_top_images(block_select, brush_index):
|
278 |
+
block = block_select.split(" ")[0]
|
279 |
+
url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{brush_index}.jpg"
|
280 |
+
return url
|
281 |
+
|
282 |
+
with gr.Tab("Top Images", elem_classes="tabs") as top_images_tab:
|
283 |
+
with gr.Row():
|
284 |
+
block_select = gr.Dropdown(
|
285 |
+
choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
|
286 |
+
value="down.2.1 (composition)",
|
287 |
+
label="Select block"
|
288 |
+
)
|
289 |
+
brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=5119, precision=0)
|
290 |
+
with gr.Row():
|
291 |
+
image = gr.Image(width=600, height=600, label="Top Images")
|
292 |
+
|
293 |
+
block_select.select(update_top_images, [block_select, brush_index], outputs=[image])
|
294 |
+
brush_index.change(update_top_images, [block_select, brush_index], outputs=[image])
|
295 |
+
demo.load(update_top_images, [block_select, brush_index], outputs=[image])
|
296 |
+
return top_images_tab
|
297 |
+
|
298 |
+
|
299 |
+
def create_intro_part():
|
300 |
+
with gr.Tab("Instructions", elem_classes="tabs") as intro_tab:
|
301 |
+
gr.Markdown(
|
302 |
+
'''# Unpacking SDXL Turbo with Sparse Autoencoders
|
303 |
+
## Demo Overview
|
304 |
+
This demo showcases the use of Sparse Autoencoders (SAEs) to understand the features learned by the Stable Diffusion XL Turbo model.
|
305 |
+
|
306 |
+
## How to Use
|
307 |
+
### Explore
|
308 |
+
* Enter a prompt in the text box and click on the "Generate" button to generate an image.
|
309 |
+
* You can observe the active features in different blocks plot on top of the generated image.
|
310 |
+
### Top Images
|
311 |
+
* For each feature, you can view the top images that activate the feature the most.
|
312 |
+
### Paint!
|
313 |
+
* Generate an image using the prompt.
|
314 |
+
* Paint on the generated image to apply interventions.
|
315 |
+
* Use the "Feature Icon" button to understand how the selected brush functions.
|
316 |
+
|
317 |
+
### Remarks
|
318 |
+
* Not all brushes mix well with all images. Experiment with different brushes and strengths.
|
319 |
+
* Feature Icon works best with `down.2.1 (composition)` and `up.0.1 (style)` blocks.
|
320 |
+
* This demo is provided for research purposes only. We do not take responsibility for the content generated by the demo.
|
321 |
+
|
322 |
+
### Interesting features to try
|
323 |
+
To get started, try the following features:
|
324 |
+
- down.2.1 (composition): 2301 (evil) 3747 (image frame) 4998 (cartoon)
|
325 |
+
- up.0.1 (style): 4977 (tiger stripes) 90 (fur) 2615 (twilight blur)
|
326 |
+
'''
|
327 |
+
)
|
328 |
+
|
329 |
+
return intro_tab
|
330 |
+
|
331 |
+
|
332 |
+
def create_demo(pipe, saes_dict, means_dict):
|
333 |
+
custom_css = """
|
334 |
+
.tabs button {
|
335 |
+
font-size: 20px !important; /* Adjust font size for tab text */
|
336 |
+
padding: 10px !important; /* Adjust padding to make the tabs bigger */
|
337 |
+
font-weight: bold !important; /* Adjust font weight to make the text bold */
|
338 |
+
}
|
339 |
+
.generate_button1 {
|
340 |
+
max-width: 160px !important;
|
341 |
+
margin-top: 20px !important;
|
342 |
+
margin-bottom: 20px !important;
|
343 |
+
}
|
344 |
+
"""
|
345 |
+
|
346 |
+
with gr.Blocks(css=custom_css) as demo:
|
347 |
+
with create_intro_part():
|
348 |
+
pass
|
349 |
+
with create_prompt_part(pipe, saes_dict, demo):
|
350 |
+
pass
|
351 |
+
with create_top_images_part(demo):
|
352 |
+
pass
|
353 |
+
with create_intervene_part(pipe, saes_dict, means_dict, demo):
|
354 |
+
pass
|
355 |
+
|
356 |
+
return demo
|
357 |
+
|
358 |
+
|
359 |
+
if __name__ == "__main__":
|
360 |
+
import os
|
361 |
+
import gradio as gr
|
362 |
+
import torch
|
363 |
+
from SDLens import HookedStableDiffusionXLPipeline
|
364 |
+
from SAE import SparseAutoencoder
|
365 |
+
|
366 |
+
dtype=torch.float32
|
367 |
+
pipe = HookedStableDiffusionXLPipeline.from_pretrained(
|
368 |
+
'stabilityai/sdxl-turbo',
|
369 |
+
torch_dtype=dtype,
|
370 |
+
device_map="balanced",
|
371 |
+
variant=("fp16" if dtype==torch.float16 else None)
|
372 |
+
)
|
373 |
+
pipe.set_progress_bar_config(disable=True)
|
374 |
+
|
375 |
+
path_to_checkpoints = './checkpoints/'
|
376 |
+
|
377 |
+
code_to_block = {
|
378 |
+
"down.2.1": "unet.down_blocks.2.attentions.1",
|
379 |
+
"mid.0": "unet.mid_block.attentions.0",
|
380 |
+
"up.0.1": "unet.up_blocks.0.attentions.1",
|
381 |
+
"up.0.0": "unet.up_blocks.0.attentions.0"
|
382 |
+
}
|
383 |
+
|
384 |
+
saes_dict = {}
|
385 |
+
means_dict = {}
|
386 |
+
|
387 |
+
for code, block in code_to_block.items():
|
388 |
+
sae = SparseAutoencoder.load_from_disk(
|
389 |
+
os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final"),
|
390 |
+
)
|
391 |
+
means = torch.load(
|
392 |
+
os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final", "mean.pt"),
|
393 |
+
weights_only=True
|
394 |
+
)
|
395 |
+
saes_dict[code] = sae.to('cuda', dtype=dtype)
|
396 |
+
means_dict[code] = means.to('cuda', dtype=dtype)
|
397 |
+
|
398 |
+
demo = create_demo(pipe, saes_dict, means_dict)
|
399 |
+
demo.launch()
|
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
|
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:387f2b6f8c4e4a6f1227921f28f00dfa4beb2bd4e422b7eb592cd8627af0e58f
|
3 |
+
size 21581
|
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:39e3c6d17aa572a53368ca8ba8f82757947a3caf14fe654e84b175d0dc0a4650
|
3 |
+
size 52497831
|
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6ca694c9504a7a8aa827004d3fdec5c1cb8fcf3904acc3562d1861fc6e65c19
|
3 |
+
size 21576
|
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
|
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:80790481d0e56ac3fa36599703cee7a05cfb4cc078db57c8f9180e860c330e1d
|
3 |
+
size 21581
|
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:49d38d9178c2a2780e04a5482a2feb9548c6e9a636ed1bf85291acf42e0ffa34
|
3 |
+
size 52497831
|
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb6bfc7ce5e596f8aa048ab262ca56841868c222bf07eb2ed35b6e4f7094fea6
|
3 |
+
size 21576
|
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
|
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de036d0fb9ee663f7bdf60e4a5d89d038516dae637531676b53ff75d05eab46b
|
3 |
+
size 21581
|
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:14c45efd9cce0258f014c49babdcd0e9ce8b266fe31eed72db1a45b990a1a0f8
|
3 |
+
size 52497831
|
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb9c04499ccae041987cc262894e254c2f04288857a8a0470cfb1b86a8ecfa09
|
3 |
+
size 21576
|
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
|
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:96dbf6fffe9d62c3b3352f8e4fe48c54dfd69906cf8ad6828d5ce93db9a5f0dc
|
3 |
+
size 21581
|
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8eed82f4bcb2f010ae9075f10a1ece801ee3dec46dba7fadccc35f6c0a7836b
|
3 |
+
size 52497831
|
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe5c5be0c4c2d2b57e7888319053cb64929559f947c8ce445ddd6a397302afab
|
3 |
+
size 21576
|
example.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.29.2
|
2 |
+
gradio==4.44.1
|
3 |
+
torch>=2.4.0
|
4 |
+
numpy
|
5 |
+
matplotlib
|
6 |
+
pillow
|
7 |
+
wandb
|
resourses/image.png
ADDED
Git LFS Details
|
scripts/collect_latents_dataset.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import io
|
4 |
+
import tarfile
|
5 |
+
import torch
|
6 |
+
import webdataset as wds
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
11 |
+
from SDLens.hooked_sd_pipeline import HookedStableDiffusionXLPipeline
|
12 |
+
|
13 |
+
import datetime
|
14 |
+
from datasets import load_dataset
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
import diffusers
|
17 |
+
import fire
|
18 |
+
|
19 |
+
def main(save_path, start_at=0, finish_at=30000, dataset_batch_size=50):
|
20 |
+
blocks_to_save = [
|
21 |
+
'unet.down_blocks.2.attentions.1',
|
22 |
+
'unet.mid_block.attentions.0',
|
23 |
+
'unet.up_blocks.0.attentions.0',
|
24 |
+
'unet.up_blocks.0.attentions.1',
|
25 |
+
]
|
26 |
+
|
27 |
+
# Initialization
|
28 |
+
dataset = load_dataset("guangyil/laion-coco-aesthetic", split="train", columns=["caption"], streaming=True).shuffle(seed=42)
|
29 |
+
pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo')
|
30 |
+
pipe.to('cuda')
|
31 |
+
pipe.set_progress_bar_config(disable=True)
|
32 |
+
dataloader = DataLoader(dataset, batch_size=dataset_batch_size)
|
33 |
+
|
34 |
+
ct = datetime.datetime.now()
|
35 |
+
save_path = os.path.join(save_path, str(ct))
|
36 |
+
# Collecting dataset
|
37 |
+
os.makedirs(save_path, exist_ok=True)
|
38 |
+
|
39 |
+
writers = {
|
40 |
+
block: wds.TarWriter(f'{save_path}/{block}.tar') for block in blocks_to_save
|
41 |
+
}
|
42 |
+
|
43 |
+
writers.update({'images': wds.TarWriter(f'{save_path}/images.tar')})
|
44 |
+
|
45 |
+
def to_kwargs(kwargs_to_save):
|
46 |
+
kwargs = kwargs_to_save.copy()
|
47 |
+
seed = kwargs['seed']
|
48 |
+
del kwargs['seed']
|
49 |
+
kwargs['generator'] = torch.Generator(device="cpu").manual_seed(num_document)
|
50 |
+
return kwargs
|
51 |
+
|
52 |
+
dataloader_iter = iter(dataloader)
|
53 |
+
for num_document, batch in tqdm(enumerate(dataloader)):
|
54 |
+
if num_document < start_at:
|
55 |
+
continue
|
56 |
+
|
57 |
+
if num_document >= finish_at:
|
58 |
+
break
|
59 |
+
|
60 |
+
kwargs_to_save = {
|
61 |
+
'prompt': batch['caption'],
|
62 |
+
'positions_to_cache': blocks_to_save,
|
63 |
+
'save_input': True,
|
64 |
+
'save_output': True,
|
65 |
+
'num_inference_steps': 1,
|
66 |
+
'guidance_scale': 0.0,
|
67 |
+
'seed': num_document,
|
68 |
+
'output_type': 'pil'
|
69 |
+
}
|
70 |
+
|
71 |
+
kwargs = to_kwargs(kwargs_to_save)
|
72 |
+
|
73 |
+
output, cache = pipe.run_with_cache(
|
74 |
+
**kwargs
|
75 |
+
)
|
76 |
+
|
77 |
+
blocks = cache['input'].keys()
|
78 |
+
for block in blocks:
|
79 |
+
sample = {
|
80 |
+
"__key__": f"sample_{num_document}",
|
81 |
+
"output.pth": cache['output'][block],
|
82 |
+
"diff.pth": cache['output'][block] - cache['input'][block],
|
83 |
+
"gen_args.json": kwargs_to_save
|
84 |
+
}
|
85 |
+
|
86 |
+
writers[block].write(sample)
|
87 |
+
writers['images'].write({
|
88 |
+
"__key__": f"sample_{num_document}",
|
89 |
+
"images.npy": np.stack(output.images)
|
90 |
+
})
|
91 |
+
|
92 |
+
for block, writer in writers.items():
|
93 |
+
writer.close()
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
fire.Fire(main)
|
scripts/train_sae.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Adapted from
|
3 |
+
https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/train.py
|
4 |
+
'''
|
5 |
+
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
10 |
+
from typing import Callable, Iterable, Iterator
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.distributed import ReduceOp
|
17 |
+
from SAE.dataset_iterator import ActivationsDataloader
|
18 |
+
from SAE.sae import SparseAutoencoder, unit_norm_decoder_, unit_norm_decoder_grad_adjustment_
|
19 |
+
from SAE.sae_utils import SAETrainingConfig, Config
|
20 |
+
|
21 |
+
from types import SimpleNamespace
|
22 |
+
from typing import Optional, List
|
23 |
+
import json
|
24 |
+
|
25 |
+
import tqdm
|
26 |
+
|
27 |
+
def weighted_average(points: torch.Tensor, weights: torch.Tensor):
|
28 |
+
weights = weights / weights.sum()
|
29 |
+
return (points * weights.view(-1, 1)).sum(dim=0)
|
30 |
+
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def geometric_median_objective(
|
34 |
+
median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
|
35 |
+
) -> torch.Tensor:
|
36 |
+
|
37 |
+
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
|
38 |
+
|
39 |
+
return (norms * weights).sum()
|
40 |
+
|
41 |
+
|
42 |
+
def compute_geometric_median(
|
43 |
+
points: torch.Tensor,
|
44 |
+
weights: Optional[torch.Tensor] = None,
|
45 |
+
eps: float = 1e-6,
|
46 |
+
maxiter: int = 100,
|
47 |
+
ftol: float = 1e-20,
|
48 |
+
do_log: bool = False,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
:param points: ``torch.Tensor`` of shape ``(n, d)``
|
52 |
+
:param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
|
53 |
+
:param eps: Smallest allowed value of denominator, to avoid divide by zero.
|
54 |
+
Equivalently, this is a smoothing parameter. Default 1e-6.
|
55 |
+
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
|
56 |
+
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
|
57 |
+
:param do_log: If true will return a log of function values encountered through the course of the algorithm
|
58 |
+
:return: SimpleNamespace object with fields
|
59 |
+
- `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
|
60 |
+
- `termination`: string explaining how the algorithm terminated.
|
61 |
+
- `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
|
62 |
+
"""
|
63 |
+
with torch.no_grad():
|
64 |
+
|
65 |
+
if weights is None:
|
66 |
+
weights = torch.ones((points.shape[0],), device=points.device)
|
67 |
+
# initialize median estimate at mean
|
68 |
+
new_weights = weights
|
69 |
+
median = weighted_average(points, weights)
|
70 |
+
objective_value = geometric_median_objective(median, points, weights)
|
71 |
+
if do_log:
|
72 |
+
logs = [objective_value]
|
73 |
+
else:
|
74 |
+
logs = None
|
75 |
+
|
76 |
+
# Weiszfeld iterations
|
77 |
+
early_termination = False
|
78 |
+
pbar = tqdm.tqdm(range(maxiter))
|
79 |
+
for _ in pbar:
|
80 |
+
prev_obj_value = objective_value
|
81 |
+
|
82 |
+
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
|
83 |
+
new_weights = weights / torch.clamp(norms, min=eps)
|
84 |
+
median = weighted_average(points, new_weights)
|
85 |
+
objective_value = geometric_median_objective(median, points, weights)
|
86 |
+
|
87 |
+
if logs is not None:
|
88 |
+
logs.append(objective_value)
|
89 |
+
if abs(prev_obj_value - objective_value) <= ftol * objective_value:
|
90 |
+
early_termination = True
|
91 |
+
break
|
92 |
+
|
93 |
+
pbar.set_description(f"Objective value: {objective_value:.4f}")
|
94 |
+
|
95 |
+
median = weighted_average(points, new_weights) # allow autodiff to track it
|
96 |
+
return SimpleNamespace(
|
97 |
+
median=median,
|
98 |
+
new_weights=new_weights,
|
99 |
+
termination=(
|
100 |
+
"function value converged within tolerance"
|
101 |
+
if early_termination
|
102 |
+
else "maximum iterations reached"
|
103 |
+
),
|
104 |
+
logs=logs,
|
105 |
+
)
|
106 |
+
|
107 |
+
def maybe_transpose(x):
|
108 |
+
return x.T if not x.is_contiguous() and x.T.is_contiguous() else x
|
109 |
+
|
110 |
+
import wandb
|
111 |
+
|
112 |
+
RANK = 0
|
113 |
+
|
114 |
+
class Logger:
|
115 |
+
def __init__(self, sae_name, **kws):
|
116 |
+
self.vals = {}
|
117 |
+
self.enabled = (RANK == 0) and not kws.pop("dummy", False)
|
118 |
+
self.sae_name = sae_name
|
119 |
+
|
120 |
+
def logkv(self, k, v):
|
121 |
+
if self.enabled:
|
122 |
+
self.vals[f'{self.sae_name}/{k}'] = v.detach() if isinstance(v, torch.Tensor) else v
|
123 |
+
return v
|
124 |
+
|
125 |
+
def dumpkvs(self, step):
|
126 |
+
if self.enabled:
|
127 |
+
wandb.log(self.vals, step=step)
|
128 |
+
self.vals = {}
|
129 |
+
|
130 |
+
|
131 |
+
class FeaturesStats:
|
132 |
+
def __init__(self, dim, logger):
|
133 |
+
self.dim = dim
|
134 |
+
self.logger = logger
|
135 |
+
self.reinit()
|
136 |
+
|
137 |
+
def reinit(self):
|
138 |
+
self.n_activated = torch.zeros(self.dim, dtype=torch.long, device="cuda")
|
139 |
+
self.n = 0
|
140 |
+
|
141 |
+
def update(self, inds):
|
142 |
+
self.n += inds.shape[0]
|
143 |
+
inds = inds.flatten().detach()
|
144 |
+
self.n_activated.scatter_add_(0, inds, torch.ones_like(inds))
|
145 |
+
|
146 |
+
def log(self):
|
147 |
+
self.logger.logkv('activated', (self.n_activated / self.n + 1e-9).log10().cpu().numpy())
|
148 |
+
|
149 |
+
def training_loop_(
|
150 |
+
aes,
|
151 |
+
train_acts_iter,
|
152 |
+
loss_fn,
|
153 |
+
log_interval,
|
154 |
+
save_interval,
|
155 |
+
loggers,
|
156 |
+
sae_cfgs,
|
157 |
+
):
|
158 |
+
sae_packs = []
|
159 |
+
for ae, cfg, logger in zip(aes, sae_cfgs, loggers):
|
160 |
+
pbar = tqdm.tqdm(unit=" steps", desc="Training Loss: ")
|
161 |
+
fstats = FeaturesStats(ae.n_dirs, logger)
|
162 |
+
opt = torch.optim.Adam(ae.parameters(), lr=cfg.lr, eps=cfg.eps, fused=True)
|
163 |
+
sae_packs.append((ae, cfg, logger, pbar, fstats, opt))
|
164 |
+
|
165 |
+
for i, flat_acts_train_batch in enumerate(train_acts_iter):
|
166 |
+
flat_acts_train_batch = flat_acts_train_batch.cuda()
|
167 |
+
|
168 |
+
for ae, cfg, logger, pbar, fstats, opt in sae_packs:
|
169 |
+
recons, info = ae(flat_acts_train_batch)
|
170 |
+
loss = loss_fn(ae, cfg, flat_acts_train_batch, recons, info, logger)
|
171 |
+
|
172 |
+
fstats.update(info['inds'])
|
173 |
+
|
174 |
+
bs = flat_acts_train_batch.shape[0]
|
175 |
+
logger.logkv('not-activated 1e4', (ae.stats_last_nonzero > 1e4 / bs).mean(dtype=float).item())
|
176 |
+
logger.logkv('not-activated 1e6', (ae.stats_last_nonzero > 1e6 / bs).mean(dtype=float).item())
|
177 |
+
logger.logkv('not-activated 1e7', (ae.stats_last_nonzero > 1e7 / bs).mean(dtype=float).item())
|
178 |
+
|
179 |
+
logger.logkv('explained variance', explained_variance(recons, flat_acts_train_batch))
|
180 |
+
logger.logkv('l2_div', (torch.linalg.norm(recons, dim=1) / torch.linalg.norm(flat_acts_train_batch, dim=1)).mean())
|
181 |
+
|
182 |
+
if (i + 1) % log_interval == 0:
|
183 |
+
fstats.log()
|
184 |
+
fstats.reinit()
|
185 |
+
|
186 |
+
if (i + 1) % save_interval == 0:
|
187 |
+
ae.save_to_disk(f"{cfg.save_path}/{i + 1}")
|
188 |
+
|
189 |
+
loss.backward()
|
190 |
+
|
191 |
+
unit_norm_decoder_(ae)
|
192 |
+
unit_norm_decoder_grad_adjustment_(ae)
|
193 |
+
|
194 |
+
opt.step()
|
195 |
+
opt.zero_grad()
|
196 |
+
logger.dumpkvs(i)
|
197 |
+
|
198 |
+
pbar.set_description(f"Training Loss {loss.item():.4f}")
|
199 |
+
pbar.update(1)
|
200 |
+
|
201 |
+
|
202 |
+
for ae, cfg, logger, pbar, fstats, opt in sae_packs:
|
203 |
+
pbar.close()
|
204 |
+
ae.save_to_disk(f"{cfg.save_path}/final")
|
205 |
+
|
206 |
+
|
207 |
+
def init_from_data_(ae, stats_acts_sample):
|
208 |
+
ae.pre_bias.data = (
|
209 |
+
compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float()
|
210 |
+
)
|
211 |
+
|
212 |
+
|
213 |
+
def mse(recons, x):
|
214 |
+
# return ((recons - x) ** 2).sum(dim=-1).mean()
|
215 |
+
return ((recons - x) ** 2).mean()
|
216 |
+
|
217 |
+
def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
|
218 |
+
# only used for auxk
|
219 |
+
xs_mu = xs.mean(dim=0)
|
220 |
+
|
221 |
+
loss = mse(recon, xs) / mse(
|
222 |
+
xs_mu[None, :].broadcast_to(xs.shape), xs
|
223 |
+
)
|
224 |
+
|
225 |
+
return loss
|
226 |
+
|
227 |
+
def explained_variance(recons, x):
|
228 |
+
# Compute the variance of the difference
|
229 |
+
diff = x - recons
|
230 |
+
diff_var = torch.var(diff, dim=0, unbiased=False)
|
231 |
+
|
232 |
+
# Compute the variance of the original tensor
|
233 |
+
x_var = torch.var(x, dim=0, unbiased=False)
|
234 |
+
|
235 |
+
# Avoid division by zero
|
236 |
+
explained_var = 1 - diff_var / (x_var + 1e-8)
|
237 |
+
|
238 |
+
return explained_var.mean()
|
239 |
+
|
240 |
+
|
241 |
+
def main():
|
242 |
+
cfg = Config(json.load(open('SAE/config.json')))
|
243 |
+
|
244 |
+
dataloader = ActivationsDataloader(cfg.paths_to_latents, cfg.block_name, cfg.bs)
|
245 |
+
|
246 |
+
acts_iter = dataloader.iterate()
|
247 |
+
stats_acts_sample = torch.cat([
|
248 |
+
next(acts_iter).cpu() for _ in range(10)
|
249 |
+
], dim=0)
|
250 |
+
|
251 |
+
aes = [
|
252 |
+
SparseAutoencoder(
|
253 |
+
n_dirs_local=sae.n_dirs,
|
254 |
+
d_model=sae.d_model,
|
255 |
+
k=sae.k,
|
256 |
+
auxk=sae.auxk,
|
257 |
+
dead_steps_threshold=sae.dead_toks_threshold // cfg.bs,
|
258 |
+
).cuda()
|
259 |
+
for sae in cfg.saes
|
260 |
+
]
|
261 |
+
|
262 |
+
for ae in aes:
|
263 |
+
init_from_data_(ae, stats_acts_sample)
|
264 |
+
|
265 |
+
mse_scale = (
|
266 |
+
1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean()
|
267 |
+
)
|
268 |
+
mse_scale = mse_scale.item()
|
269 |
+
del stats_acts_sample
|
270 |
+
|
271 |
+
wandb.init(
|
272 |
+
project=cfg.wandb_project,
|
273 |
+
name=cfg.wandb_name,
|
274 |
+
)
|
275 |
+
|
276 |
+
loggers = [Logger(
|
277 |
+
sae_name=cfg_sae.sae_name,
|
278 |
+
dummy=False,
|
279 |
+
) for cfg_sae in cfg.saes]
|
280 |
+
|
281 |
+
training_loop_(
|
282 |
+
aes,
|
283 |
+
acts_iter,
|
284 |
+
lambda ae, cfg_sae, flat_acts_train_batch, recons, info, logger: (
|
285 |
+
# MSE
|
286 |
+
logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch))
|
287 |
+
# AuxK
|
288 |
+
+ logger.logkv(
|
289 |
+
"train_maxk_recons",
|
290 |
+
cfg_sae.auxk_coef
|
291 |
+
* normalized_mse(
|
292 |
+
ae.decode_sparse(
|
293 |
+
info["auxk_inds"],
|
294 |
+
info["auxk_vals"],
|
295 |
+
),
|
296 |
+
flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(),
|
297 |
+
).nan_to_num(0),
|
298 |
+
)
|
299 |
+
),
|
300 |
+
sae_cfgs = cfg.saes,
|
301 |
+
loggers=loggers,
|
302 |
+
log_interval=cfg.log_interval,
|
303 |
+
save_interval=cfg.save_interval,
|
304 |
+
)
|
305 |
+
|
306 |
+
|
307 |
+
if __name__ == "__main__":
|
308 |
+
main()
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .hooks import *
|
utils/hooks.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
@torch.no_grad()
|
4 |
+
def add_feature(sae, feature_idx, value, module, input, output):
|
5 |
+
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
6 |
+
activated = sae.encode(diff)
|
7 |
+
mask = torch.zeros_like(activated, device=diff.device)
|
8 |
+
mask[..., feature_idx] = value
|
9 |
+
to_add = mask @ sae.decoder.weight.T
|
10 |
+
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
11 |
+
|
12 |
+
|
13 |
+
@torch.no_grad()
|
14 |
+
def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
|
15 |
+
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
16 |
+
activated = sae.encode(diff)
|
17 |
+
mask = torch.zeros_like(activated, device=diff.device)
|
18 |
+
if len(activation_map) == 2:
|
19 |
+
activation_map = activation_map.unsqueeze(0)
|
20 |
+
mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
|
21 |
+
to_add = mask @ sae.decoder.weight.T
|
22 |
+
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
23 |
+
|
24 |
+
|
25 |
+
@torch.no_grad()
|
26 |
+
def replace_with_feature(sae, feature_idx, value, module, input, output):
|
27 |
+
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
28 |
+
activated = sae.encode(diff)
|
29 |
+
mask = torch.zeros_like(activated, device=diff.device)
|
30 |
+
mask[..., feature_idx] = value
|
31 |
+
to_add = mask @ sae.decoder.weight.T
|
32 |
+
return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
33 |
+
|
34 |
+
|
35 |
+
@torch.no_grad()
|
36 |
+
def reconstruct_sae_hook(sae, module, input, output):
|
37 |
+
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
38 |
+
activated = sae.encode(diff)
|
39 |
+
reconstructed = sae.decoder(activated) + sae.pre_bias
|
40 |
+
return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
|
41 |
+
|
42 |
+
|
43 |
+
@torch.no_grad()
|
44 |
+
def ablate_block(module, input, output):
|
45 |
+
return input
|