## Unpacking SDXL Turbo
### Interpreting Text-to-Image Models with Sparse Autoencoders

This colab prepares and launches the demo Gradio app for SAE's features discovery

### Cloning code and installing dependencies

In [None]:
# if we are on colab then we need to clone the repo
import sys
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    !git clone https://github.com/surkovv/sdxl-unbox.git
    %cd __test

!pip install -r requirements.txt

In [None]:
%load_ext gradio

### Loading SDXL Turbo and SAEs

In [None]:
import os
import gradio as gr
import torch
from SDLens import HookedStableDiffusionXLPipeline
from SAE import SparseAutoencoder
from app import create_demo

assert torch.cuda.is_available(), "Your machine has no access to GPU. If you are using Colab, consider changing environment"

In [None]:
# The SAEs were trained to work with torch.float32, but they can also work with torch.float16
# Change this value to torch.float32 if you have access to a GPU with >30GB of memory
dtype=torch.float16

In [None]:
pipe = HookedStableDiffusionXLPipeline.from_pretrained(
    'stabilityai/sdxl-turbo',
    torch_dtype=dtype,
    device_map="balanced",
    variant=("fp16" if dtype==torch.float16 else None)
)
pipe.set_progress_bar_config(disable=True)

In [None]:
path_to_checkpoints = './checkpoints/'

code_to_block = {
    "down.2.1": "unet.down_blocks.2.attentions.1",
    "mid.0": "unet.mid_block.attentions.0",
    "up.0.1": "unet.up_blocks.0.attentions.1",
    "up.0.0": "unet.up_blocks.0.attentions.0"
}

saes_dict = {}
means_dict = {}

for code, block in code_to_block.items():
    sae = SparseAutoencoder.load_from_disk(
        os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final"),
    )
    means = torch.load(
        os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final", "mean.pt"),
        weights_only=True
    )
    saes_dict[code] = sae.to('cuda', dtype=dtype)
    means_dict[code] = means.to('cuda', dtype=dtype)

### Launching Demo
Once demo is running, you can also use a pullic link in the output of the cell below.

Interesting features to look at:
- `down.2.1 #4998`: cartoon feature
- `down.2.1 #230:`: "fury" feature
- `down.2.1 #89`: "muscleman" feature
- `down.2.1 #4074`: anime feature
- `up.0.1   #4977`: tiger stripes
- `up.0.1   #90`: fur
- `up.0.1   #2165`: twilight blur

In [None]:
demo = create_demo(pipe, saes_dict, means_dict)
demo.launch(share=True, height=1200)

In [None]:
demo.close()

## Citation

If you find this notebook useful in your research, please cite our paper:

```bibtex
[Citation placeholder]
```