File size: 2,205 Bytes
8cd00a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Unpacking SDXL Turbo: Interpreting Text-to-Image Models with Sparse Autoencoders

![modification demostration](resourses/image.png)

This repository contains code to reproduce results from our paper on using sparse autoencoders (SAEs) to analyze and interpret the internal representations of text-to-image diffusion models, specifically SDXL Turbo.

## Repository Structure

```
|-- SAE/                    # Core sparse autoencoder implementation
|-- SDLens/                 # Tools for analyzing diffusion models
|   `-- hooked_sd_pipeline.py   # Modified stable diffusion pipeline
|-- scripts/
|   |-- collect_latents_dataset.py  # Generate training data
|   `-- train_sae.py                    # Train SAE models
|-- utils/
|   `-- hooks.py           # Hook utility functions
|-- checkpoints/           # Pretrained SAE model checkpoints
|-- app.py                # Demo application
|-- app.ipynb             # Interactive notebook demo
|-- example.ipynb         # Usage examples
`-- requirements.txt      # Python dependencies
```

## Installation

```bash
pip install -r requirements.txt
```

## Demo Application

You can try our gradio demo application (`app.ipynb`) to browse and experiment with 20K+ features of our trained SAEs out-of-the-box. You can find the same notebook on [Google Colab](https://colab.research.google.com/drive/1Sd-g3w2Fwv7pc_fxgeQOR3S_RKr18qMP?usp=sharing).

## Usage

1. Collect latent data from SDXL Turbo:
```bash
python scripts/collect_latents_dataset.py --save_path={your_save_path}
```

2. Train sparse autoencoders:

    2.1. Insert the path of stored latents and directory to store checkpoints in `SAE/config.json`

    2.2. Run the training script:

```bash
python scripts/train_sae.py
```

## Pretrained Models

We provide pretrained SAE checkpoints for 4 key transformer blocks in SDXL Turbo's U-Net. See `example.ipynb` for analysis examples and visualization of learned features.


## Citation

If you find this code useful in your research, please cite our paper:

```bibtex
[Citation placeholder]
```

## Acknowledgements

The SAE component was implemented based on [`openai/sparse_autoencoder`](https://github.com/openai/sparse_autoencoder) repository.