lewington commited on
Commit
00727eb
1 Parent(s): 3b3b2cf

finish writeup

Browse files
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .DS_Store
2
- pad.ipynb
 
 
1
  .DS_Store
2
+ pad.ipynb
3
+ cruft
11_resid/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Layer 11
2
+
3
+ ## Examples
4
+
5
+ 3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
6
+
7
+ ![](./examples/1259_grid.png)
8
+ ![](./examples/1462_grid.png)
9
+ ![](./examples/628_grid.png)
14_resid/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Layer 14
2
+
3
+ ## Examples
4
+
5
+ 3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
6
+
7
+ ![](./examples/476_grid.png)
8
+ ![](./examples/70_grid.jpg)
9
+ ![](./examples/843_grid.png)
17_resid/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Layer 17
2
+
3
+ ## Examples
4
+
5
+ 3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
6
+
7
+ ![](./examples/157_grid.png)
8
+ ![](./examples/568_grid.png)
9
+ ![](./examples/606_grid.png)
20_resid/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Layer 20
2
+
3
+ ## Examples
4
+
5
+ 3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
6
+
7
+ ![](./examples/195_grid.png)
8
+ ![](./examples/363_grid.jpg)
9
+ ![](./examples/392_grid.png)
22_resid/README.md CHANGED
@@ -2,8 +2,8 @@
2
 
3
  ## Examples
4
 
5
- 3x3 grids cherry picked from the first 100 features which activate on more than 9/500,000 laion images.
6
 
7
  ![](./examples/308_grid.png)
8
  ![](./examples/464_grid.png)
9
- ![](./examples/575_grid.png)
 
2
 
3
  ## Examples
4
 
5
+ 3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
6
 
7
  ![](./examples/308_grid.png)
8
  ![](./examples/464_grid.png)
9
+ ![](./examples/575_grid.png)
2_resid/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Layer 2
2
+
3
+ ## Examples
4
+
5
+ 3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images. Only 2 features from among the first 2048 actually activated in this case.
6
+
7
+ ![](./examples/1162_grid.png)
8
+ ![](./examples/1173_grid.png)
5_resid/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Layer 5
2
+
3
+ ## Examples
4
+
5
+ 3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
6
+
7
+ ![](./examples/1952_grid.png)
8
+ ![](./examples/2023_grid.png)
9
+ ![](./examples/648_grid.png)
8_resid/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Layer 8
2
+
3
+ ## Examples
4
+
5
+ 3x3 grids cherry picked from the first 100 features which activate on the CLS token on more than 9/500,000 laion images.
6
+
7
+ ![](./examples/1597_grid.png)
8
+ ![](./examples/1735_grid.png)
9
+ ![](./examples/187_grid.png)
README.md CHANGED
@@ -10,36 +10,40 @@ Heavily inspired by [google/gemma-scope](https://huggingface.co/google/gemma-sco
10
 
11
  ![](./media/mse.png)
12
 
13
- | Layer | MSE | Explained Variance | Dead Feature Proportion |
14
- |-------|----------|--------------------|-------------------------|
15
- | 2 | 267.95 | 0.763 | 0.000912 |
16
- | 5 | 354.46 | 0.665 | 0 |
17
- | 8 | 357.58 | 0.642 | 0 |
18
- | 11 | 321.23 | 0.674 | 0 |
19
- | 14 | 319.64 | 0.689 | 0 |
20
- | 17 | 261.20 | 0.731 | 0 |
21
- | 20 | 278.06 | 0.706 | 0.0000763 |
22
- | 22 | 299.96 | 0.684 | 0 |
23
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- Training logs are available [via wandb](https://wandb.ai/lewington/ViT-L-14-laion2B-s32B-b82K/workspace) and training code is available on [github](https://github.com/Lewington-pitsos/vitsae). The training process is heavily reliant on [AWS ECS](https://aws.amazon.com/ecs/) so may contain some strange artifacts when a spot instance is killed and the training is resumed by another instance. Some of the code is ripped directly from [Hugo Fry](https://github.com/HugoFry/mats_sae_training_for_ViTs).
 
 
 
 
26
 
27
  ### Vital Statistics:
28
 
29
  - Number of tokens trained per autoencoder: 1.2 Billion
30
- - Token type: all 257 image tokens (as opposed to just the cls token)
31
  - Number of unique images trained per autoencoder: 4.5 Million
32
  - Training Dataset: [Laion-2b](https://huggingface.co/datasets/laion/laion2B-multi-joined-translated-to-en)
33
  - SAE Architecture: topk with k=32
34
  - Layer Location: always the residual stream
35
  - Training Checkpoints: every ~25 million tokens
36
- - Number of features: 65536
37
 
38
  ## Usage
39
 
40
- First install our pypi package and PIL
41
 
42
- `pip install clipscope PIL`
43
 
44
  Then
45
 
@@ -73,9 +77,9 @@ print('latent shape', output['latent'].shape) # (1, 65536)
73
  print('reconstruction shape', output['reconstruction'].shape) # (1, 1024)
74
  ```
75
 
76
- ## Error Formulae
77
 
78
- We calculate MSE as `(batch - reconstruction).pow(2).sum(dim=-1).mean()` i.e. The MSE between the batch and the un-normalized reconstruction, summed across features. We use batch norm to bring all activations into a similar range.
79
 
80
  We calculate Explained Variance as
81
 
@@ -89,9 +93,41 @@ We calculate dead feature proportion as the proportion of features which have no
89
 
90
  ## Subjective Interpretability
91
 
92
- To give an intuitive feel for the interpretability of these models we run 500,000 images from laion2b selected at random through the final trained SAE for each layer and record the latent activations for each. We then winnow down to the first 100 features which activate for at least 9 images. We cherry pick 3 of these and display them in a 3x3 grid for each layer. We do this twice, one for the CLS token and once for token 137 (near the middle of the image). Below are the 6 grids for feature 22. Other grids are available for each layer.
 
 
 
 
 
 
 
 
 
 
93
 
 
94
 
95
  ## Automated Sort EVALs
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  ## Token-wise MSE
 
 
 
 
 
 
10
 
11
  ![](./media/mse.png)
12
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ | Layer | MSE | Explained Variance | Dead Feature Proportion | Active Feature Proportion | Sort Eval Accuracy (CLS token) |
15
+ |-------|---------|--------------------|-------------------------|---------------------------|-------------------------------|
16
+ | 2 | 267.95 | 0.763 | 0.000912 | 0.001 | - |
17
+ | 5 | 354.46 | 0.665 | 0 | 0.0034 | - |
18
+ | 8 | 357.58 | 0.642 | 0 | 0.01074 | - |
19
+ | 11 | 321.23 | 0.674 | 0 | 0.0415 | 0.7334 |
20
+ | 14 | 319.64 | 0.689 | 0 | 0.07866 | 0.7427 |
21
+ | 17 | 261.20 | 0.731 | 0 | 0.1477 | 0.8689 |
22
+ | 20 | 278.06 | 0.706 | 0.0000763 | 0.2036 | 0.9149 |
23
+ | 22 | 299.96 | 0.684 | 0 | 0.1588 | 0.8641 |
24
 
25
+
26
+
27
+ ![](./media/sort-eval.png)
28
+
29
+ Training logs are available [via wandb](https://wandb.ai/lewington/ViT-L-14-laion2B-s32B-b82K/workspace) and training code is available on [github](https://github.com/Lewington-pitsos/vitsae). The training process is heavily reliant on [AWS ECS](https://aws.amazon.com/ecs/) so the weights and biases logs may contain some strange artifacts when a spot instance is killed and the training is resumed by another instance. Some of the code is ripped directly from [Hugo Fry](https://github.com/HugoFry/mats_sae_training_for_ViTs).
30
 
31
  ### Vital Statistics:
32
 
33
  - Number of tokens trained per autoencoder: 1.2 Billion
34
+ - Token type: all 257 image tokens (as opposed to just the CLS token)
35
  - Number of unique images trained per autoencoder: 4.5 Million
36
  - Training Dataset: [Laion-2b](https://huggingface.co/datasets/laion/laion2B-multi-joined-translated-to-en)
37
  - SAE Architecture: topk with k=32
38
  - Layer Location: always the residual stream
39
  - Training Checkpoints: every ~25 million tokens
40
+ - Number of features per autoencoder: 65536 (expansion factor 16)
41
 
42
  ## Usage
43
 
44
+ First install our pypi package and PIL (pillow)
45
 
46
+ `pip install clipscope pillow`
47
 
48
  Then
49
 
 
77
  print('reconstruction shape', output['reconstruction'].shape) # (1, 1024)
78
  ```
79
 
80
+ ## Formulae
81
 
82
+ We calculate MSE as `(batch - reconstruction).pow(2).sum(dim=-1).mean()` i.e. The MSE between the batch and the reconstruction, summed across features. We use batch norm to bring all activations into a similar range. We normalize the input before training using batch norm.
83
 
84
  We calculate Explained Variance as
85
 
 
93
 
94
  ## Subjective Interpretability
95
 
96
+ To give an intuitive feel for the interpretability of these models we run 500,000 images from laion2b selected at random through the final trained SAE for each layer and record the latent activations on the CLS token for each. We then winnow down to the first 100 features which activate for at least 9 images. We cherry pick 3 of these and display them in a 3x3 grid for each layer. Below are the 3 grids for feature 22. Other grids are available for each layer in the `README.md` for that layer.
97
+
98
+ **feature 308**
99
+
100
+ ![](./22_resid/examples/308_grid.png)
101
+
102
+ **feature 464**
103
+
104
+ ![](./22_resid/examples/464_grid.png)
105
+
106
+ **feature 575**
107
 
108
+ ![](./22_resid/examples/575_grid.png)
109
 
110
  ## Automated Sort EVALs
111
 
112
+ We performed automated [Sort Evals](https://transformer-circuits.pub/2024/august-update/index.html) following Anthropic except with the *formatted dataset examples* for both features being replaced by a 3x3 grid of the top 9 activating images and the *formatted query example* being replaced by an image activating for only one of those features but not included in the 3x3s. Our methodology was as follows:
113
+
114
+ 1. pass 500,000 laion images from the training dataset through the SAE, record the first 2048 latent activations for each image (only 2048 to conserve space)
115
+ 2. ignore all features which do not activate for at least 10 of those images
116
+ 3. for the remaining features, if there are more than 100, select the first 100
117
+ 4. if there are fewer than 100, the comparison will not be fair so we cannot proceed (this is why we do not have sort eval scores for layers 8, 5 and 2)
118
+ 5. randomly select 500 pairs from among the 100 features, and perform a sort eval for each pair using gpt-4o
119
+ 4. select 400 samples randomly from among these 500 sort evals 5 times, each time recording the accuracy (n correct / 400) for that subsample of 400
120
+ 5. calculate the mean and standard deviation of these 5 accuracies.
121
+
122
+ The outcomes are plotted below. Active Feature Proportion is the proportion of features which activate for at least 10 images across the 500,000 image dataset for that layer. For the CLS token at layer 2 only 2/2048 features were "active" in this sense.
123
+
124
+ ![](./media/sort-eval.png)
125
+
126
+ ![](./media/active-feature-proportion.png)
127
+
128
  ## Token-wise MSE
129
+
130
+ All layers were trained across all 257 image patches. Below we provide plots demonstrating the reconstruction MSE for each token (other than the CLS token) as training progressed. It seems that throughout training the outer tokens are easier to reconstruct than those in the middle, presumably because these tokens capture more important information (i.e. foreground objects) and are therefore more information rich.
131
+
132
+ ![](./media/layer_22_training_outputs.png)
133
+ ![](./media/layer_22_individually_scaled.png)
media/active-feature-proportion.png ADDED

Git LFS Details

  • SHA256: 368af09165daea07036f644b40b7c3f6dcb7a549dd2f8902e7f944dabe6b180f
  • Pointer size: 130 Bytes
  • Size of remote file: 54.4 kB
media/layer_22_individually_scaled.png ADDED

Git LFS Details

  • SHA256: 4f42629bc6c87f0fa48cc2d753187b637f6e59ae3b48e2b064def339518392e1
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
media/layer_22_training_outputs.png ADDED

Git LFS Details

  • SHA256: a73573ba44742f12233f624c4bfc0ab8d1cb7d2aa7d07fee11beae9e2976c10f
  • Pointer size: 130 Bytes
  • Size of remote file: 66 kB
media/sort-eval.png ADDED

Git LFS Details

  • SHA256: b8a834e44e63ba4510e5d6209f721d50bffa82329f1541b9a30b9a223c4485ad
  • Pointer size: 130 Bytes
  • Size of remote file: 78.6 kB