Lev McKinney commited on
Commit
e49cdfa
·
1 Parent(s): 37317f0

added migration utils

Browse files
Files changed (2) hide show
  1. lens_migration.py +381 -0
  2. migrate.sh +10 -0
lens_migration.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from huggingface_hub import model_info
3
+ import argparse
4
+ from copy import deepcopy
5
+ import inspect
6
+ from logging import warn
7
+ from pathlib import Path
8
+ import json
9
+
10
+ from tuned_lens.model_surgery import get_final_layer_norm, get_transformer_layers
11
+ from tuned_lens.load_artifacts import load_lens_artifacts
12
+ from tuned_lens.nn import TunedLens
13
+ from transformers.models.bloom.modeling_bloom import BloomBlock
14
+ from transformers import PreTrainedModel, AutoModelForCausalLM
15
+ from typing import Optional, Generator, Union
16
+ import torch as th
17
+
18
+ from tuned_lens.stats.distance import js_divergence
19
+
20
+
21
+ def instantiate_layer(model_config, layer_idx: int, model_type: str) -> th.nn.Module:
22
+ if model_type == "bloom":
23
+ from transformers.models.bloom.modeling_bloom import BloomBlock
24
+
25
+ return _BloomBlockWrapper(BloomBlock(model_config)) # type: ignore[arg-type]
26
+ if model_type == "gpt_neo":
27
+ from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoBlock
28
+
29
+ return GPTNeoBlock(model_config, layer_idx)
30
+ if model_type == "gpt_neox":
31
+ from transformers.models.gpt_neox.modeling_gpt_neox import (
32
+ GPTNeoXLayer,
33
+ )
34
+
35
+ return GPTNeoXLayer(model_config) # type: ignore[arg-type]
36
+ if model_type == "gpt2":
37
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Block
38
+
39
+ return GPT2Block(model_config, layer_idx) # type: ignore[arg-type]
40
+ if model_type == "opt":
41
+ from transformers.models.opt.modeling_opt import OPTDecoderLayer
42
+
43
+ return OPTDecoderLayer(model_config) # type: ignore[arg-type]
44
+ else:
45
+ raise ValueError(f"Unknown model type '{model_type}'")
46
+
47
+
48
+ def maybe_wrap(layer: th.nn.Module) -> th.nn.Module:
49
+ return _BloomBlockWrapper(layer) if isinstance(layer, BloomBlock) else layer
50
+
51
+
52
+ # Very annoying that we have to do this. See https://bit.ly/3XSQ7W6 for context on
53
+ # what we're doing here.
54
+ class _BloomBlockWrapper(th.nn.Module):
55
+ def __init__(self, block: BloomBlock):
56
+ super().__init__()
57
+ self.block = block
58
+
59
+ def forward(self, x: th.Tensor) -> th.Tensor:
60
+ from transformers.models.bloom.modeling_bloom import (
61
+ BloomModel,
62
+ build_alibi_tensor,
63
+ )
64
+
65
+ batch_size, seq_len, _ = x.shape
66
+ dummy_mask = x.new_ones([batch_size, seq_len])
67
+
68
+ # Causal mask isn't created inside the block itself, so we have to do it here.
69
+ # Weirdly _prepare_attn_mask doesn't depend on `self` at all but is still an
70
+ # instance method for some reason, so we pass `None` as the first argument.
71
+ causal_mask = BloomModel._prepare_attn_mask(
72
+ None, dummy_mask, (batch_size, seq_len), 0 # type: ignore[arg-type]
73
+ )
74
+ alibi = build_alibi_tensor(dummy_mask, self.block.num_heads, x.dtype)
75
+ h, *_ = self.block(x, alibi, causal_mask)
76
+ return h
77
+
78
+
79
+ class TunedLensOld(th.nn.Module):
80
+ """A tuned lens for decoding hidden states into logits."""
81
+
82
+ layer_norm: th.nn.LayerNorm
83
+ unembedding: th.nn.Linear
84
+ extra_layers: th.nn.Sequential
85
+ layer_translators: th.nn.ModuleList
86
+
87
+ def __init__(
88
+ self,
89
+ model: Optional[PreTrainedModel] = None,
90
+ *,
91
+ bias: bool = True,
92
+ extra_layers: int = 0,
93
+ include_input: bool = True,
94
+ reuse_unembedding: bool = True,
95
+ # Used when saving and loading the lens
96
+ model_config: Optional[dict] = None,
97
+ d_model: Optional[int] = None,
98
+ num_layers: Optional[int] = None,
99
+ vocab_size: Optional[int] = None,
100
+ ):
101
+ """Create a TunedLensOld.
102
+
103
+ Args:
104
+ model : A pertained model from the transformers library you wish to inspect.
105
+ bias : Whether to include a bias term in the translator layers.
106
+ extra_layers : The number of extra layers to apply to the hidden states
107
+ before decoding into logits.
108
+
109
+ include_input : Whether to include a lens that decodes the word embeddings.
110
+ reuse_unembedding : Weather to reuse the unembedding matrix from the model.
111
+ model_config : The config of the model. Used for saving and loading.
112
+ d_model : The models hidden size. Used for saving and loading.
113
+ num_layers : The number of layers in the model. Used for saving and loading.
114
+ vocab_size : The size of the vocabulary. Used for saving and loading.
115
+
116
+ Raises:
117
+ ValueError: if neither a model or d_model, num_layers, and vocab_size,
118
+ are provided.
119
+ """
120
+ super().__init__()
121
+
122
+ self.extra_layers = th.nn.Sequential()
123
+
124
+ if (
125
+ model
126
+ is None
127
+ == (d_model is None or num_layers is None or vocab_size is None)
128
+ ):
129
+ raise ValueError(
130
+ "Must provide either a model or d_model, num_layers, and vocab_size"
131
+ )
132
+
133
+ # Initializing from scratch without a model
134
+ if not model:
135
+ assert d_model and num_layers and vocab_size
136
+ self.layer_norm = th.nn.LayerNorm(d_model)
137
+ self.unembedding = th.nn.Linear(d_model, vocab_size, bias=False)
138
+
139
+ # Use HuggingFace methods to get decoder layers
140
+ else:
141
+ assert not (d_model or num_layers or vocab_size)
142
+ d_model = model.config.hidden_size
143
+ num_layers = model.config.num_hidden_layers
144
+ vocab_size = model.config.vocab_size
145
+ assert isinstance(d_model, int) and isinstance(vocab_size, int)
146
+
147
+ model_config = model.config.to_dict() # type: ignore[F841]
148
+
149
+ # Currently we convert the decoder to full precision
150
+ self.unembedding = deepcopy(model.get_output_embeddings()).float()
151
+ if ln := get_final_layer_norm(model):
152
+ self.layer_norm = deepcopy(ln).float()
153
+ else:
154
+ self.layer_norm = th.nn.Identity()
155
+
156
+ if extra_layers:
157
+ _, layers = get_transformer_layers(model)
158
+ self.extra_layers.extend(
159
+ [maybe_wrap(layer) for layer in layers[-extra_layers:]]
160
+ )
161
+
162
+ # Save config for later
163
+ config_keys = set(inspect.getfullargspec(TunedLensOld).kwonlyargs)
164
+ self.config = {k: v for k, v in locals().items() if k in config_keys}
165
+ del model_config
166
+
167
+ # Try to prevent finetuning the decoder
168
+ assert d_model and num_layers
169
+ self.layer_norm.requires_grad_(False)
170
+ self.unembedding.requires_grad_(False)
171
+
172
+ out_features = d_model if reuse_unembedding else vocab_size
173
+ translator = th.nn.Linear(d_model, out_features, bias=bias)
174
+ if not reuse_unembedding:
175
+ translator.weight.data = self.unembedding.weight.data.clone()
176
+ translator.bias.data.zero_()
177
+ else:
178
+ translator.weight.data.zero_()
179
+ translator.bias.data.zero_()
180
+
181
+ self.add_module("input_translator", translator if include_input else None)
182
+ # Don't include the final layer
183
+ num_layers -= 1
184
+
185
+ self.layer_translators = th.nn.ModuleList(
186
+ [deepcopy(translator) for _ in range(num_layers)]
187
+ )
188
+
189
+ def __getitem__(self, item: int) -> th.nn.Module:
190
+ """Get the probe module at the given index."""
191
+ if isinstance(self.input_translator, th.nn.Module):
192
+ if item == 0:
193
+ return self.input_translator
194
+ else:
195
+ item -= 1
196
+
197
+ return self.layer_translators[item]
198
+
199
+ def __iter__(self) -> Generator[th.nn.Module, None, None]:
200
+ """Get iterator over the translators within the lens."""
201
+ if isinstance(self.input_translator, th.nn.Module):
202
+ yield self.input_translator
203
+
204
+ yield from self.layer_translators
205
+
206
+ @classmethod
207
+ def load(cls, resource_id: str, **kwargs) -> "TunedLensOld":
208
+ """Load a tuned lens from a or hugging face hub.
209
+
210
+ Args:
211
+ resource_id : The path to the directory containing the config and checkpoint
212
+ or the name of the model on the hugging face hub.
213
+ **kwargs : Additional arguments to pass to torch.load.
214
+
215
+ Returns:
216
+ A TunedLensOld instance.
217
+ """
218
+ config_path, ckpt_path = load_lens_artifacts(resource_id)
219
+ # Load config
220
+ with open(config_path, "r") as f:
221
+ config = json.load(f)
222
+
223
+ # Load parameters
224
+ state = th.load(ckpt_path, **kwargs)
225
+
226
+ # Backwards compatibility we really need to stop renaming things
227
+ keys = list(state.keys())
228
+ for key in keys:
229
+ for old_key in ["probe", "adapter"]:
230
+ if old_key in key:
231
+ warn(
232
+ f"Loading a checkpoint with a '{old_key}' key. "
233
+ "This is deprecated and may be removed in a future version. "
234
+ )
235
+ new_key = key.replace(old_key, "translator")
236
+ state[new_key] = state.pop(key)
237
+
238
+ # Drop unrecognized config keys
239
+ unrecognized = set(config) - set(inspect.getfullargspec(cls).kwonlyargs)
240
+ for key in unrecognized:
241
+ warn(f"Ignoring config key '{key}'")
242
+ del config[key]
243
+
244
+ lens = cls(**config)
245
+
246
+ if num_extras := config.get("extra_layers"):
247
+ # This is sort of a hack but AutoConfig doesn't appear to have a from_dict
248
+ # for some reason.
249
+ from transformers.models.auto import CONFIG_MAPPING
250
+
251
+ model_conf_dict = config.get("model_config")
252
+ del model_conf_dict["torch_dtype"]
253
+ assert model_conf_dict, "Need a 'model_config' entry to load extra layers"
254
+
255
+ model_type = model_conf_dict["model_type"]
256
+ config_cls = CONFIG_MAPPING[model_type]
257
+ model_config = config_cls.from_dict(model_conf_dict)
258
+
259
+ lens.extra_layers = th.nn.Sequential(
260
+ *[
261
+ instantiate_layer(
262
+ model_config, model_config.num_hidden_layers - i - 1, model_type
263
+ )
264
+ for i in range(num_extras)
265
+ ]
266
+ )
267
+
268
+ lens.load_state_dict(state)
269
+ return lens
270
+
271
+ def save(
272
+ self,
273
+ path: Union[Path, str],
274
+ ckpt: str = "params.pt",
275
+ config: str = "config.json",
276
+ ) -> None:
277
+ """Save the lens to a directory.
278
+
279
+ Args:
280
+ path : The path to the directory to save the lens to.
281
+ ckpt : The name of the checkpoint file to save the parameters to.
282
+ config : The name of the config file to save the config to.
283
+ """
284
+ path = Path(path)
285
+ path.mkdir(exist_ok=True, parents=True)
286
+ th.save(self.state_dict(), path / ckpt)
287
+
288
+ with open(path / config, "w") as f:
289
+ json.dump(self.config, f)
290
+
291
+ def normalize_(self):
292
+ """Canonicalize the transforms by centering their weights and biases."""
293
+ for linear in self:
294
+ assert isinstance(linear, th.nn.Linear)
295
+
296
+ A, b = linear.weight.data, linear.bias.data
297
+ A -= A.mean(dim=0, keepdim=True)
298
+ b -= b.mean()
299
+
300
+ def transform_hidden(self, h: th.Tensor, idx: int) -> th.Tensor:
301
+ """Transform hidden state from layer `idx`."""
302
+ if not self.config["reuse_unembedding"]:
303
+ raise RuntimeError("TunedLensOld.transform_hidden requires reuse_unembedding")
304
+
305
+ # Note that we add the translator output residually, in contrast to the formula
306
+ # in the paper. By parametrizing it this way we ensure that weight decay
307
+ # regularizes the transform toward the identity, not the zero transformation.
308
+ return h + self[idx](h)
309
+
310
+ def to_logits(self, h: th.Tensor) -> th.Tensor:
311
+ """Decode a hidden state into logits."""
312
+ h = self.extra_layers(h)
313
+ while isinstance(h, tuple):
314
+ h, *_ = h
315
+
316
+ return self.unembedding(self.layer_norm(h))
317
+
318
+ def forward(self, h: th.Tensor, idx: int) -> th.Tensor:
319
+ """Transform and then decode the hidden states into logits."""
320
+ # Sanity check to make sure we don't finetune the decoder
321
+ # if any(p.requires_grad for p in self.parameters(recurse=False)):
322
+ # raise RuntimeError("Make sure to freeze the decoder")
323
+
324
+ # We're learning a separate unembedding for each layer
325
+ if not self.config["reuse_unembedding"]:
326
+ h_ = self.layer_norm(h)
327
+ return self[idx](h_)
328
+
329
+ h = self.transform_hidden(h, idx)
330
+ return self.to_logits(h)
331
+
332
+ def __len__(self) -> int:
333
+ """Return the number of layer translators in the lens."""
334
+ N = len(self.layer_translators)
335
+ if self.input_translator:
336
+ N += 1
337
+
338
+ return N
339
+
340
+
341
+ if __name__ == "__main__":
342
+ parser = argparse.ArgumentParser()
343
+ parser.add_argument("--model", type=str, default="gpt2")
344
+ parser.add_argument("--resource-id", type=str, default="gpt2")
345
+ parser.add_argument("--output-dir", type=str, default="lens/gpt2")
346
+ args = parser.parse_args()
347
+
348
+ model = AutoModelForCausalLM.from_pretrained(args.model)
349
+ revision = model_info(args.model).sha
350
+ model.eval()
351
+ model.requires_grad_(False)
352
+
353
+ device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
354
+
355
+ tuned_lens_old = TunedLensOld.load(args.resource_id, map_location=device)
356
+
357
+ tuned_lens = TunedLens.init_from_model(
358
+ model, bias=tuned_lens_old.config['bias'], revision=revision
359
+ )
360
+
361
+ for i in range(len(tuned_lens_old)):
362
+ tuned_lens[i].load_state_dict(tuned_lens_old[i].state_dict())
363
+
364
+
365
+ tuned_lens = tuned_lens.to(device)
366
+ tuned_lens_old = tuned_lens_old.to(device)
367
+ model = model.to(device)
368
+
369
+ # Fuzz the new lens against the old one's
370
+ with th.no_grad():
371
+ for i in range(len(tuned_lens)):
372
+ for _ in range(10):
373
+ a = th.randn(1, 1, tuned_lens.config.d_model, device=device)
374
+ logits_new = tuned_lens(a, i)
375
+ logits_old = tuned_lens_old(a, i)
376
+ log_ps_new = logits_new.log_softmax(-1)
377
+ log_ps_old = logits_old.log_softmax(-1)
378
+ assert (th.allclose(log_ps_new, log_ps_old))
379
+ print("js div", js_divergence(log_ps_new, log_ps_old))
380
+
381
+ tuned_lens.to(th.device("cpu")).save(args.output_dir)
migrate.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ for i in gpt2,gpt2 pythia-160m-deduped-v0,EleutherAI/pythia-160m-deduped-v0 gpt2-large,gpt2-large gpt2-xl,gpt2-xl opt-125m,facebook/opt-125m opt-6.7b,facebook/opt-6.7b pythia-1.4b-deduped-v0,EleutherAI/pythia-1.4b-deduped-v0 pythia-1b-deduped-v0,EleutherAI/pythia-1b-deduped-v0 pythia-6.9b-deduped-v0,EleutherAI/pythia-6.9b-deduped-v0 opt-1.3b,facebook/opt-1.3b pythia-410m-deduped-v0,EleutherAI/pythia-410m-deduped-v0 pythia-12b-deduped-v0,EleutherAI/pythia-12b-deduped-v0 gpt-neox-20b,EleutherAI/gpt-neox-20b
4
+ do
5
+ IFS=","
6
+ set -- $i
7
+ echo "migrating $2"
8
+ CUDA_VISIBLE_DEVICES=-1 python lens_migration.py --model $2 --resource-id $1 --output lens/$1
9
+ git commit -am "$1 migrated"
10
+ done