jbloom's picture
Update README.md
f923b8e verified
|
raw
history blame
3.27 kB
metadata
license: mit

Gemma 2b Residual Stream SAEs.

This is a "quick and dirty" SAE release to unblock researchers. These SAEs have not been extensively studied or characterized. However, I will try to update the readme here when I add SAEs here to reflect what I know about them.

These SAEs were trained with SAE Lens and the library version is stored in the cfg.json.

All training hyperparameters are specified in cfg.json.

They are loadable using SAE via a few methods. The preferred method is to use the following:

EDIT: This chunk is out of date. Please see SAE Lens tutorials for up to date syntax for loading pretrained SAEs.

import torch
from transformer_lens import HookedTransformer
from sae_lens import SparseAutoencoder, ActivationsStore

torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained("gemma-2b")
sparse_autoencoder = SparseAutoencoder.from_pretrained(
  "gemma-2b-res-jb", # to see the list of available releases, go to: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
  "blocks.0.hook_resid_post" # change this to another specific SAE ID in the release if desired. 
)
activation_store = ActivationsStore.from_config(model, sparse_autoencoder.cfg)

Resid Post 0

Stats:

  • 16384 Features (expansion factor 8)
  • CE Loss score of 99.1% (2.647 without SAE, 2.732 with the SAE)
  • Mean L0 54 (in practice L0 is log normal distributed and is heavily right tailed).
  • Dead Features: We think this SAE may have ~2.5k dead features.

Notes:

  • This SAE was trained with methods from the Anthropic April Update excepting activation normalization.
  • It is likely under-trained.

Resid Post 6

Stats:

  • 16384 Features (expansion factor 8) achieving a CE Loss score of
  • CE Loss score of 95.33% (2.647 without SAE, 3.103 with the SAE)
  • Mean L0 53 (in practice L0 is log normal distributed and is heavily right tailed).
  • Dead Features: We think this SAE may have up to 7k dead features.

Notes:

  • This SAE was trained with methods from the Anthropic April Update
    • Excepting activation normalization.
    • We increased the learning rate here by one order of magnitude in order to explore whether this resulted in faster training (in particular, a lower L0 more quickly)
      • We find in practice that the drop in L0 is accelerated but this results is significantly more dead features (likely causing worse reconstruction)
  • As above, it is likely under-trained.

Resid Post 12

Stats:

  • 16384 Features (expansion factor 8) achieving a CE Loss score of
  • CE Loss score of 95.99% (2.563 without SAE, 2.96 with the SAE)
  • Mean L0 52 (in practice L0 is log normal distributed and is heavily right tailed).
  • Dead Features: Less than 200 dead features.

Notes:

  • This SAE was trained with methods from the Anthropic April Update
    • With activation normalization. This means that activations should be multiplied by a constant such that E(|X|) = sqrt(2048)
  • As above, it is likely under-trained.