File size: 5,031 Bytes
96c7e5d
 
 
 
 
4c461d1
96c7e5d
4c461d1
96c7e5d
4c461d1
621ed53
64a829f
10cba8c
96c7e5d
4c461d1
 
 
 
 
96c7e5d
5d176c5
 
 
 
 
621ed53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3962263
621ed53
3962263
621ed53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96c7e5d
621ed53
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
---
library_name: transformers
tags: []
---

# Model Card for `gpt2_noLN`

This is a gpt2-small model with LayerNorm fine-tuned out.

The model was fine-tuned on OpenWebText for ~500M tokens (1000 iterations of batch size ~488 at 1024 context length) while gradually disableing LayerNorm layers.

There are 5 similar models available (v1 through v5) trained with different fine-tuning schedules. Please refer to the [paper](https://arxiv.org/abs/2409.13710) or [blog post](https://www.lesswrong.com/posts/THzcKKQd4oWkg4dSP/you-can-remove-gpt2-s-layernorm-by-fine-tuning-for-an-hour)
for details; the training code is available [here](https://github.com/ApolloResearch/gpt2_noLN). The best model (v4) is the default as of 6th September 2024 (previously v2 was the default).

The model is a `GPT2LMHeadModel` (to avoid requiring `trust_remote_code`) which technically contains LayerNorm blocks.
However, the epsilon values are all set to 1e12 so that the LayerNorm has no effect. The LN scale is set to 1e6 (to counter the 1e12 epsilon), and the bias to 0.
The final LayerNorm also has 1e12 as epsilon, but non-unity weights and biases. This is because the embed and unembed matrix are tried (and there is no unembed bias),
thus the LN parameters cannot be folded into that matrix. You can completely remove all LNs by simply replacing `ln_1` and `ln_2` modules with identities, and replacing
`ln_f` with modifications to the unembed matrix and unembed bias.

You can load the model with `transformers`, or one of the interpretability libraries listed below.
```python
model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")
```

## TransformerLens loading code
```python
import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer

model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")
hooked_model = HookedTransformer.from_pretrained("gpt2", hf_model=model, fold_ln=False, center_unembed=False).to("cpu")
# Kill the LayerNorms because TransformerLens overwrites eps
for block in hooked_model.blocks:
    block.ln1.eps = 1e12
    block.ln2.eps = 1e12
hooked_model.ln_final.eps = 1e12
```

Or with LNs properly replaced by identities:
```python
import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer

model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")

# Undo my hacky LayerNorm removal
for block in model.transformer.h:
    block.ln_1.weight.data = block.ln_1.weight.data / 1e6
    block.ln_1.eps = 1e-5
    block.ln_2.weight.data = block.ln_2.weight.data / 1e6
    block.ln_2.eps = 1e-5
model.transformer.ln_f.weight.data = model.transformer.ln_f.weight.data / 1e6
model.transformer.ln_f.eps = 1e-5

# Properly replace LayerNorms by Identities
class HookedTransformerNoLN(HookedTransformer):
    def removeLN(self):
        for i in range(len(self.blocks)):
            self.blocks[i].ln1 = torch.nn.Identity()
            self.blocks[i].ln2 = torch.nn.Identity()
        self.ln_final = torch.nn.Identity()

hooked_model = HookedTransformerNoLN.from_pretrained("gpt2", hf_model=model, fold_ln=True, center_unembed=False).to("cpu")
hooked_model.removeLN()
hooked_model.cfg.normalization_type = None
```
Edit: Added the last line updating `hooked_model.cfg.normalization_type` so that `ActivationCache.apply_ln_to_stack()` still works, thanks to [Quiche Eater](https://www.lesswrong.com/posts/THzcKKQd4oWkg4dSP/you-can-remove-gpt2-s-layernorm-by-fine-tuning-for-an-hour?commentId=s3zf5okYgPBnB2vMF) for pointing this out.

## NNSight loading code
Copy-pasted from [Logan Riggs' comment](https://www.lesswrong.com/posts/THzcKKQd4oWkg4dSP/you-can-remove-gpt2-s-layernorm-by-fine-tuning-for-an-hour?commentId=Gcq8wic9WmdnqM2Fm), based on code by Caden.
```python
import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer
from nnsight.models.UnifiedTransformer import UnifiedTransformer


model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")

# Undo my hacky LayerNorm removal
for block in model.transformer.h:
    block.ln_1.weight.data = block.ln_1.weight.data / 1e6
    block.ln_1.eps = 1e-5
    block.ln_2.weight.data = block.ln_2.weight.data / 1e6
    block.ln_2.eps = 1e-5
model.transformer.ln_f.weight.data = model.transformer.ln_f.weight.data / 1e6
model.transformer.ln_f.eps = 1e-5

# Properly replace LayerNorms by Identities
def removeLN(transformer_lens_model):
    for i in range(len(transformer_lens_model.blocks)):
        transformer_lens_model.blocks[i].ln1 = torch.nn.Identity()
        transformer_lens_model.blocks[i].ln2 = torch.nn.Identity()
    transformer_lens_model.ln_final = torch.nn.Identity()

hooked_model = HookedTransformer.from_pretrained("gpt2", hf_model=model, fold_ln=True, center_unembed=False).to("cpu")
removeLN(hooked_model)

model_nnsight = UnifiedTransformer(model="gpt2", hf_model=model, fold_ln=True, center_unembed=False).to("cpu")
removeLN(model_nnsight)
```