|
|
|
from huggingface_hub import model_info |
|
import argparse |
|
from copy import deepcopy |
|
import inspect |
|
from logging import warn |
|
from pathlib import Path |
|
from tqdm import tqdm |
|
import json |
|
|
|
from tuned_lens.model_surgery import get_final_norm, get_transformer_layers |
|
from tuned_lens.load_artifacts import load_lens_artifacts |
|
from tuned_lens.nn import TunedLens |
|
from transformers.models.bloom.modeling_bloom import BloomBlock |
|
from transformers import PreTrainedModel, AutoModelForCausalLM |
|
from typing import Optional, Generator, Union |
|
import torch as th |
|
|
|
from tuned_lens.stats.distance import js_divergence |
|
|
|
|
|
def instantiate_layer(model_config, layer_idx: int, model_type: str) -> th.nn.Module: |
|
if model_type == "bloom": |
|
from transformers.models.bloom.modeling_bloom import BloomBlock |
|
|
|
return _BloomBlockWrapper(BloomBlock(model_config)) |
|
if model_type == "gpt_neo": |
|
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoBlock |
|
|
|
return GPTNeoBlock(model_config, layer_idx) |
|
if model_type == "gpt_neox": |
|
from transformers.models.gpt_neox.modeling_gpt_neox import ( |
|
GPTNeoXLayer, |
|
) |
|
|
|
return GPTNeoXLayer(model_config) |
|
if model_type == "gpt2": |
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Block |
|
|
|
return GPT2Block(model_config, layer_idx) |
|
if model_type == "opt": |
|
from transformers.models.opt.modeling_opt import OPTDecoderLayer |
|
|
|
return OPTDecoderLayer(model_config) |
|
else: |
|
raise ValueError(f"Unknown model type '{model_type}'") |
|
|
|
|
|
def maybe_wrap(layer: th.nn.Module) -> th.nn.Module: |
|
return _BloomBlockWrapper(layer) if isinstance(layer, BloomBlock) else layer |
|
|
|
|
|
|
|
|
|
class _BloomBlockWrapper(th.nn.Module): |
|
def __init__(self, block: BloomBlock): |
|
super().__init__() |
|
self.block = block |
|
|
|
def forward(self, x: th.Tensor) -> th.Tensor: |
|
from transformers.models.bloom.modeling_bloom import ( |
|
BloomModel, |
|
build_alibi_tensor, |
|
) |
|
|
|
batch_size, seq_len, _ = x.shape |
|
dummy_mask = x.new_ones([batch_size, seq_len]) |
|
|
|
|
|
|
|
|
|
causal_mask = BloomModel._prepare_attn_mask( |
|
None, dummy_mask, (batch_size, seq_len), 0 |
|
) |
|
alibi = build_alibi_tensor(dummy_mask, self.block.num_heads, x.dtype) |
|
h, *_ = self.block(x, alibi, causal_mask) |
|
return h |
|
|
|
|
|
class TunedLensOld(th.nn.Module): |
|
"""A tuned lens for decoding hidden states into logits.""" |
|
|
|
layer_norm: th.nn.LayerNorm |
|
unembedding: th.nn.Linear |
|
extra_layers: th.nn.Sequential |
|
layer_translators: th.nn.ModuleList |
|
|
|
def __init__( |
|
self, |
|
model: Optional[PreTrainedModel] = None, |
|
*, |
|
bias: bool = True, |
|
extra_layers: int = 0, |
|
include_input: bool = True, |
|
reuse_unembedding: bool = True, |
|
|
|
model_config: Optional[dict] = None, |
|
d_model: Optional[int] = None, |
|
num_layers: Optional[int] = None, |
|
vocab_size: Optional[int] = None, |
|
): |
|
"""Create a TunedLensOld. |
|
|
|
Args: |
|
model : A pertained model from the transformers library you wish to inspect. |
|
bias : Whether to include a bias term in the translator layers. |
|
extra_layers : The number of extra layers to apply to the hidden states |
|
before decoding into logits. |
|
|
|
include_input : Whether to include a lens that decodes the word embeddings. |
|
reuse_unembedding : Weather to reuse the unembedding matrix from the model. |
|
model_config : The config of the model. Used for saving and loading. |
|
d_model : The models hidden size. Used for saving and loading. |
|
num_layers : The number of layers in the model. Used for saving and loading. |
|
vocab_size : The size of the vocabulary. Used for saving and loading. |
|
|
|
Raises: |
|
ValueError: if neither a model or d_model, num_layers, and vocab_size, |
|
are provided. |
|
""" |
|
super().__init__() |
|
|
|
self.extra_layers = th.nn.Sequential() |
|
|
|
if ( |
|
model |
|
is None |
|
== (d_model is None or num_layers is None or vocab_size is None) |
|
): |
|
raise ValueError( |
|
"Must provide either a model or d_model, num_layers, and vocab_size" |
|
) |
|
|
|
|
|
if not model: |
|
assert d_model and num_layers and vocab_size |
|
self.layer_norm = th.nn.LayerNorm(d_model) |
|
self.unembedding = th.nn.Linear(d_model, vocab_size, bias=False) |
|
|
|
|
|
else: |
|
assert not (d_model or num_layers or vocab_size) |
|
d_model = model.config.hidden_size |
|
num_layers = model.config.num_hidden_layers |
|
vocab_size = model.config.vocab_size |
|
assert isinstance(d_model, int) and isinstance(vocab_size, int) |
|
|
|
model_config = model.config.to_dict() |
|
|
|
|
|
self.unembedding = deepcopy(model.get_output_embeddings()).float() |
|
if ln := get_final_norm(model): |
|
self.layer_norm = deepcopy(ln).float() |
|
else: |
|
self.layer_norm = th.nn.Identity() |
|
|
|
if extra_layers: |
|
_, layers = get_transformer_layers(model) |
|
self.extra_layers.extend( |
|
[maybe_wrap(layer) for layer in layers[-extra_layers:]] |
|
) |
|
|
|
|
|
config_keys = set(inspect.getfullargspec(TunedLensOld).kwonlyargs) |
|
self.config = {k: v for k, v in locals().items() if k in config_keys} |
|
del model_config |
|
|
|
|
|
assert d_model and num_layers |
|
self.layer_norm.requires_grad_(False) |
|
self.unembedding.requires_grad_(False) |
|
|
|
out_features = d_model if reuse_unembedding else vocab_size |
|
translator = th.nn.Linear(d_model, out_features, bias=bias) |
|
if not reuse_unembedding: |
|
translator.weight.data = self.unembedding.weight.data.clone() |
|
translator.bias.data.zero_() |
|
else: |
|
translator.weight.data.zero_() |
|
translator.bias.data.zero_() |
|
|
|
self.add_module("input_translator", translator if include_input else None) |
|
|
|
num_layers -= 1 |
|
|
|
self.layer_translators = th.nn.ModuleList( |
|
[deepcopy(translator) for _ in range(num_layers)] |
|
) |
|
|
|
def __getitem__(self, item: int) -> th.nn.Module: |
|
"""Get the probe module at the given index.""" |
|
if isinstance(self.input_translator, th.nn.Module): |
|
if item == 0: |
|
return self.input_translator |
|
else: |
|
item -= 1 |
|
|
|
return self.layer_translators[item] |
|
|
|
def __iter__(self) -> Generator[th.nn.Module, None, None]: |
|
"""Get iterator over the translators within the lens.""" |
|
if isinstance(self.input_translator, th.nn.Module): |
|
yield self.input_translator |
|
|
|
yield from self.layer_translators |
|
|
|
@classmethod |
|
def load(cls, resource_id: str, **kwargs) -> "TunedLensOld": |
|
"""Load a tuned lens from a or hugging face hub. |
|
|
|
Args: |
|
resource_id : The path to the directory containing the config and checkpoint |
|
or the name of the model on the hugging face hub. |
|
**kwargs : Additional arguments to pass to torch.load. |
|
|
|
Returns: |
|
A TunedLensOld instance. |
|
""" |
|
config_path, ckpt_path = load_lens_artifacts(resource_id) |
|
|
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
|
|
|
|
state = th.load(ckpt_path, **kwargs) |
|
|
|
|
|
keys = list(state.keys()) |
|
for key in keys: |
|
for old_key in ["probe", "adapter"]: |
|
if old_key in key: |
|
warn( |
|
f"Loading a checkpoint with a '{old_key}' key. " |
|
"This is deprecated and may be removed in a future version. " |
|
) |
|
new_key = key.replace(old_key, "translator") |
|
state[new_key] = state.pop(key) |
|
|
|
|
|
unrecognized = set(config) - set(inspect.getfullargspec(cls).kwonlyargs) |
|
for key in unrecognized: |
|
warn(f"Ignoring config key '{key}'") |
|
del config[key] |
|
|
|
lens = cls(**config) |
|
|
|
if num_extras := config.get("extra_layers"): |
|
|
|
|
|
from transformers.models.auto import CONFIG_MAPPING |
|
|
|
model_conf_dict = config.get("model_config") |
|
del model_conf_dict["torch_dtype"] |
|
assert model_conf_dict, "Need a 'model_config' entry to load extra layers" |
|
|
|
model_type = model_conf_dict["model_type"] |
|
config_cls = CONFIG_MAPPING[model_type] |
|
model_config = config_cls.from_dict(model_conf_dict) |
|
|
|
lens.extra_layers = th.nn.Sequential( |
|
*[ |
|
instantiate_layer( |
|
model_config, model_config.num_hidden_layers - i - 1, model_type |
|
) |
|
for i in range(num_extras) |
|
] |
|
) |
|
|
|
lens.load_state_dict(state) |
|
return lens |
|
|
|
def save( |
|
self, |
|
path: Union[Path, str], |
|
ckpt: str = "params.pt", |
|
config: str = "config.json", |
|
) -> None: |
|
"""Save the lens to a directory. |
|
|
|
Args: |
|
path : The path to the directory to save the lens to. |
|
ckpt : The name of the checkpoint file to save the parameters to. |
|
config : The name of the config file to save the config to. |
|
""" |
|
path = Path(path) |
|
path.mkdir(exist_ok=True, parents=True) |
|
th.save(self.state_dict(), path / ckpt) |
|
|
|
with open(path / config, "w") as f: |
|
json.dump(self.config, f) |
|
|
|
def normalize_(self): |
|
"""Canonicalize the transforms by centering their weights and biases.""" |
|
for linear in self: |
|
assert isinstance(linear, th.nn.Linear) |
|
|
|
A, b = linear.weight.data, linear.bias.data |
|
A -= A.mean(dim=0, keepdim=True) |
|
b -= b.mean() |
|
|
|
def transform_hidden(self, h: th.Tensor, idx: int) -> th.Tensor: |
|
"""Transform hidden state from layer `idx`.""" |
|
if not self.config["reuse_unembedding"]: |
|
raise RuntimeError("TunedLensOld.transform_hidden requires reuse_unembedding") |
|
|
|
|
|
|
|
|
|
return h + self[idx](h) |
|
|
|
def to_logits(self, h: th.Tensor) -> th.Tensor: |
|
"""Decode a hidden state into logits.""" |
|
h = self.extra_layers(h) |
|
while isinstance(h, tuple): |
|
h, *_ = h |
|
|
|
return self.unembedding(self.layer_norm(h)) |
|
|
|
def forward(self, h: th.Tensor, idx: int) -> th.Tensor: |
|
"""Transform and then decode the hidden states into logits.""" |
|
|
|
|
|
|
|
|
|
|
|
if not self.config["reuse_unembedding"]: |
|
h_ = self.layer_norm(h) |
|
return self[idx](h_) |
|
|
|
h = self.transform_hidden(h, idx) |
|
return self.to_logits(h) |
|
|
|
def __len__(self) -> int: |
|
"""Return the number of layer translators in the lens.""" |
|
N = len(self.layer_translators) |
|
if self.input_translator: |
|
N += 1 |
|
|
|
return N |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model", type=str, default="gpt2") |
|
parser.add_argument("--resource-id", type=str, default="gpt2") |
|
parser.add_argument("--output-dir", type=str, default="lens/gpt2") |
|
args = parser.parse_args() |
|
|
|
model = AutoModelForCausalLM.from_pretrained(args.model) |
|
revision = model_info(args.model).sha |
|
model.eval() |
|
model.requires_grad_(False) |
|
|
|
device = th.device("cuda:0" if th.cuda.is_available() else "cpu") |
|
|
|
print("Loading old lens") |
|
tuned_lens_old = TunedLensOld.load(args.resource_id, map_location=device) |
|
|
|
print("Initializing new lens") |
|
tuned_lens = TunedLens.from_model( |
|
model, bias=tuned_lens_old.config['bias'], revision=revision |
|
) |
|
|
|
for i in tqdm(range(len(tuned_lens_old)), desc="Copying parameters"): |
|
tuned_lens[i].load_state_dict(tuned_lens_old[i].state_dict()) |
|
|
|
|
|
tuned_lens = tuned_lens.to(device) |
|
tuned_lens_old = tuned_lens_old.to(device) |
|
model = model.to(device) |
|
|
|
|
|
with th.no_grad(): |
|
for i in tqdm(range(len(tuned_lens)), desc="Fuzzing layers"): |
|
for _ in range(10): |
|
a = th.randn(1, 1, tuned_lens.config.d_model, device=device) |
|
logits_new = tuned_lens(a, i) |
|
logits_old = tuned_lens_old(a, i) |
|
log_ps_new = logits_new.log_softmax(-1) |
|
log_ps_old = logits_old.log_softmax(-1) |
|
print("js div", js_divergence(log_ps_new, log_ps_old)) |
|
assert (th.allclose(log_ps_new, log_ps_old, atol=1e-7)), (log_ps_new - log_ps_old).abs().max() |
|
print("Saving new lens to", args.output_dir) |
|
tuned_lens.to(th.device("cpu")).save(args.output_dir) |
|
|