finish writeup
Browse files- .gitignore +2 -1
- 11_resid/README.md +9 -0
- 14_resid/README.md +9 -0
- 17_resid/README.md +9 -0
- 20_resid/README.md +9 -0
- 22_resid/README.md +2 -2
- 2_resid/README.md +8 -0
- 5_resid/README.md +9 -0
- 8_resid/README.md +9 -0
- README.md +54 -18
- media/active-feature-proportion.png +3 -0
- media/layer_22_individually_scaled.png +3 -0
- media/layer_22_training_outputs.png +3 -0
- media/sort-eval.png +3 -0
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
.DS_Store
|
2 |
-
pad.ipynb
|
|
|
|
1 |
.DS_Store
|
2 |
+
pad.ipynb
|
3 |
+
cruft
|
11_resid/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Layer 11
|
2 |
+
|
3 |
+
## Examples
|
4 |
+
|
5 |
+
3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
|
6 |
+
|
7 |
+
![](./examples/1259_grid.png)
|
8 |
+
![](./examples/1462_grid.png)
|
9 |
+
![](./examples/628_grid.png)
|
14_resid/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Layer 14
|
2 |
+
|
3 |
+
## Examples
|
4 |
+
|
5 |
+
3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
|
6 |
+
|
7 |
+
![](./examples/476_grid.png)
|
8 |
+
![](./examples/70_grid.jpg)
|
9 |
+
![](./examples/843_grid.png)
|
17_resid/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Layer 17
|
2 |
+
|
3 |
+
## Examples
|
4 |
+
|
5 |
+
3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
|
6 |
+
|
7 |
+
![](./examples/157_grid.png)
|
8 |
+
![](./examples/568_grid.png)
|
9 |
+
![](./examples/606_grid.png)
|
20_resid/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Layer 20
|
2 |
+
|
3 |
+
## Examples
|
4 |
+
|
5 |
+
3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
|
6 |
+
|
7 |
+
![](./examples/195_grid.png)
|
8 |
+
![](./examples/363_grid.jpg)
|
9 |
+
![](./examples/392_grid.png)
|
22_resid/README.md
CHANGED
@@ -2,8 +2,8 @@
|
|
2 |
|
3 |
## Examples
|
4 |
|
5 |
-
3x3 grids cherry picked from the first 100 features which activate on more than 9/500,000 laion images.
|
6 |
|
7 |
![](./examples/308_grid.png)
|
8 |
![](./examples/464_grid.png)
|
9 |
-
![](./examples/575_grid.png)
|
|
|
2 |
|
3 |
## Examples
|
4 |
|
5 |
+
3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
|
6 |
|
7 |
![](./examples/308_grid.png)
|
8 |
![](./examples/464_grid.png)
|
9 |
+
![](./examples/575_grid.png)
|
2_resid/README.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Layer 2
|
2 |
+
|
3 |
+
## Examples
|
4 |
+
|
5 |
+
3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images. Only 2 features from among the first 2048 actually activated in this case.
|
6 |
+
|
7 |
+
![](./examples/1162_grid.png)
|
8 |
+
![](./examples/1173_grid.png)
|
5_resid/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Layer 5
|
2 |
+
|
3 |
+
## Examples
|
4 |
+
|
5 |
+
3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
|
6 |
+
|
7 |
+
![](./examples/1952_grid.png)
|
8 |
+
![](./examples/2023_grid.png)
|
9 |
+
![](./examples/648_grid.png)
|
8_resid/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Layer 8
|
2 |
+
|
3 |
+
## Examples
|
4 |
+
|
5 |
+
3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
|
6 |
+
|
7 |
+
![](./examples/1597_grid.png)
|
8 |
+
![](./examples/1735_grid.png)
|
9 |
+
![](./examples/187_grid.png)
|
README.md
CHANGED
@@ -10,36 +10,40 @@ Heavily inspired by [google/gemma-scope](https://huggingface.co/google/gemma-sco
|
|
10 |
|
11 |
![](./media/mse.png)
|
12 |
|
13 |
-
| Layer | MSE | Explained Variance | Dead Feature Proportion |
|
14 |
-
|-------|----------|--------------------|-------------------------|
|
15 |
-
| 2 | 267.95 | 0.763 | 0.000912 |
|
16 |
-
| 5 | 354.46 | 0.665 | 0 |
|
17 |
-
| 8 | 357.58 | 0.642 | 0 |
|
18 |
-
| 11 | 321.23 | 0.674 | 0 |
|
19 |
-
| 14 | 319.64 | 0.689 | 0 |
|
20 |
-
| 17 | 261.20 | 0.731 | 0 |
|
21 |
-
| 20 | 278.06 | 0.706 | 0.0000763 |
|
22 |
-
| 22 | 299.96 | 0.684 | 0 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
|
27 |
### Vital Statistics:
|
28 |
|
29 |
- Number of tokens trained per autoencoder: 1.2 Billion
|
30 |
-
- Token type: all 257 image tokens (as opposed to just the
|
31 |
- Number of unique images trained per autoencoder: 4.5 Million
|
32 |
- Training Dataset: [Laion-2b](https://huggingface.co/datasets/laion/laion2B-multi-joined-translated-to-en)
|
33 |
- SAE Architecture: topk with k=32
|
34 |
- Layer Location: always the residual stream
|
35 |
- Training Checkpoints: every ~25 million tokens
|
36 |
-
- Number of features: 65536
|
37 |
|
38 |
## Usage
|
39 |
|
40 |
-
First install our pypi package and PIL
|
41 |
|
42 |
-
`pip install clipscope
|
43 |
|
44 |
Then
|
45 |
|
@@ -73,9 +77,9 @@ print('latent shape', output['latent'].shape) # (1, 65536)
|
|
73 |
print('reconstruction shape', output['reconstruction'].shape) # (1, 1024)
|
74 |
```
|
75 |
|
76 |
-
##
|
77 |
|
78 |
-
We calculate MSE as `(batch - reconstruction).pow(2).sum(dim=-1).mean()` i.e. The MSE between the batch and the
|
79 |
|
80 |
We calculate Explained Variance as
|
81 |
|
@@ -89,9 +93,41 @@ We calculate dead feature proportion as the proportion of features which have no
|
|
89 |
|
90 |
## Subjective Interpretability
|
91 |
|
92 |
-
To give an intuitive feel for the interpretability of these models we run 500,000 images from laion2b selected at random through the final trained SAE for each layer and record the latent activations for each. We then winnow down to the first 100 features which activate for at least 9 images. We cherry pick 3 of these and display them in a 3x3 grid for each layer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
|
|
94 |
|
95 |
## Automated Sort EVALs
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
## Token-wise MSE
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
![](./media/mse.png)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
| Layer | MSE | Explained Variance | Dead Feature Proportion | Active Feature Proportion | Sort Eval Accuracy (CLS token) |
|
15 |
+
|-------|---------|--------------------|-------------------------|---------------------------|-------------------------------|
|
16 |
+
| 2 | 267.95 | 0.763 | 0.000912 | 0.001 | - |
|
17 |
+
| 5 | 354.46 | 0.665 | 0 | 0.0034 | - |
|
18 |
+
| 8 | 357.58 | 0.642 | 0 | 0.01074 | - |
|
19 |
+
| 11 | 321.23 | 0.674 | 0 | 0.0415 | 0.7334 |
|
20 |
+
| 14 | 319.64 | 0.689 | 0 | 0.07866 | 0.7427 |
|
21 |
+
| 17 | 261.20 | 0.731 | 0 | 0.1477 | 0.8689 |
|
22 |
+
| 20 | 278.06 | 0.706 | 0.0000763 | 0.2036 | 0.9149 |
|
23 |
+
| 22 | 299.96 | 0.684 | 0 | 0.1588 | 0.8641 |
|
24 |
|
25 |
+
|
26 |
+
|
27 |
+
![](./media/sort-eval.png)
|
28 |
+
|
29 |
+
Training logs are available [via wandb](https://wandb.ai/lewington/ViT-L-14-laion2B-s32B-b82K/workspace) and training code is available on [github](https://github.com/Lewington-pitsos/vitsae). The training process is heavily reliant on [AWS ECS](https://aws.amazon.com/ecs/) so the weights and biases logs may contain some strange artifacts when a spot instance is killed and the training is resumed by another instance. Some of the code is ripped directly from [Hugo Fry](https://github.com/HugoFry/mats_sae_training_for_ViTs).
|
30 |
|
31 |
### Vital Statistics:
|
32 |
|
33 |
- Number of tokens trained per autoencoder: 1.2 Billion
|
34 |
+
- Token type: all 257 image tokens (as opposed to just the CLS token)
|
35 |
- Number of unique images trained per autoencoder: 4.5 Million
|
36 |
- Training Dataset: [Laion-2b](https://huggingface.co/datasets/laion/laion2B-multi-joined-translated-to-en)
|
37 |
- SAE Architecture: topk with k=32
|
38 |
- Layer Location: always the residual stream
|
39 |
- Training Checkpoints: every ~25 million tokens
|
40 |
+
- Number of features per autoencoder: 65536 (expansion factor 16)
|
41 |
|
42 |
## Usage
|
43 |
|
44 |
+
First install our pypi package and PIL (pillow)
|
45 |
|
46 |
+
`pip install clipscope pillow`
|
47 |
|
48 |
Then
|
49 |
|
|
|
77 |
print('reconstruction shape', output['reconstruction'].shape) # (1, 1024)
|
78 |
```
|
79 |
|
80 |
+
## Formulae
|
81 |
|
82 |
+
We calculate MSE as `(batch - reconstruction).pow(2).sum(dim=-1).mean()` i.e. The MSE between the batch and the reconstruction, summed across features. We use batch norm to bring all activations into a similar range. We normalize the input before training using batch norm.
|
83 |
|
84 |
We calculate Explained Variance as
|
85 |
|
|
|
93 |
|
94 |
## Subjective Interpretability
|
95 |
|
96 |
+
To give an intuitive feel for the interpretability of these models we run 500,000 images from laion2b selected at random through the final trained SAE for each layer and record the latent activations on the CLS token for each. We then winnow down to the first 100 features which activate for at least 9 images. We cherry pick 3 of these and display them in a 3x3 grid for each layer. Below are the 3 grids for feature 22. Other grids are available for each layer in the `README.md` for that layer.
|
97 |
+
|
98 |
+
**feature 308**
|
99 |
+
|
100 |
+
![](./22_resid/examples/308_grid.png)
|
101 |
+
|
102 |
+
**feature 464**
|
103 |
+
|
104 |
+
![](./22_resid/examples/464_grid.png)
|
105 |
+
|
106 |
+
**feature 575**
|
107 |
|
108 |
+
![](./22_resid/examples/575_grid.png)
|
109 |
|
110 |
## Automated Sort EVALs
|
111 |
|
112 |
+
We performed automated [Sort Evals](https://transformer-circuits.pub/2024/august-update/index.html) following Anthropic except with the *formatted dataset examples* for both features being replaced by a 3x3 grid of the top 9 activating images and the *formatted query example* being replaced by an image activating for only one of those features but not included in the 3x3s. Our methodology was as follows:
|
113 |
+
|
114 |
+
1. pass 500,000 laion images from the training dataset through the SAE, record the first 2048 latent activations for each image (only 2048 to conserve space)
|
115 |
+
2. ignore all features which do not activate for at least 10 of those images
|
116 |
+
3. for the remaining features, if there are more than 100, select the first 100
|
117 |
+
4. if there are fewer than 100, the comparison will not be fair so we cannot proceed (this is why we do not have sort eval scores for layers 8, 5 and 2)
|
118 |
+
5. randomly select 500 pairs from among the 100 features, and perform a sort eval for each pair using gpt-4o
|
119 |
+
4. select 400 samples randomly from among these 500 sort evals 5 times, each time recording the accuracy (n correct / 400) for that subsample of 400
|
120 |
+
5. calculate the mean and standard deviation of these 5 accuracies.
|
121 |
+
|
122 |
+
The outcomes are plotted below. Active Feature Proportion is the proportion of features which activate for at least 10 images across the 500,000 image dataset for that layer. For the CLS token at layer 2 only 2/2048 features were "active" in this sense.
|
123 |
+
|
124 |
+
![](./media/sort-eval.png)
|
125 |
+
|
126 |
+
![](./media/active-feature-proportion.png)
|
127 |
+
|
128 |
## Token-wise MSE
|
129 |
+
|
130 |
+
All layers were trained across all 257 image patches. Below we provide plots demonstrating the reconstruction MSE for each token (other than the CLS token) as training progressed. It seems that throughout training the outer tokens are easier to reconstruct than those in the middle, presumably because these tokens capture more important information (i.e. foreground objects) and are therefore more information rich.
|
131 |
+
|
132 |
+
![](./media/layer_22_training_outputs.png)
|
133 |
+
![](./media/layer_22_individually_scaled.png)
|
media/active-feature-proportion.png
ADDED
Git LFS Details
|
media/layer_22_individually_scaled.png
ADDED
Git LFS Details
|
media/layer_22_training_outputs.png
ADDED
Git LFS Details
|
media/sort-eval.png
ADDED
Git LFS Details
|