File size: 6,208 Bytes
d6682b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from copy import deepcopy
from typing import Dict, List, Tuple
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from ...util import nethook
from ...util.generate import generate_fast
from .compute_u import compute_u
from .compute_v import compute_v
from .rome_hparams import ROMEHyperParams
import gradio as gr
CONTEXT_TEMPLATES_CACHE = None
def apply_rome_to_model(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
request: List[Dict],
hparams: ROMEHyperParams,
num_steps: int,
edit_lr: float,
copy=False,
return_orig_weights=False,
keep_original_weight=False,
**kwargs
) -> Tuple[AutoModelForCausalLM, List[str]]:
"""
Returns a model with the desired changes.
:param copy: If true, will preserve the original model while creating a new one to edit.
Note that you are responsible for deallocating the new model's memory to avoid leaks.
:return: (1) the updated model, (2) an original copy of the weights that changed
"""
if copy:
model = deepcopy(model)
weights_copy = {}
hparams.v_num_grad_steps = num_steps // 2
hparams.v_lr = edit_lr
request['subject'] = request['prompt']
deltas = execute_rome(model, tok, request, hparams)
with torch.no_grad():
for w_name, (delta_u, delta_v) in deltas.items():
upd_matrix = delta_u.unsqueeze(1) @ delta_v.unsqueeze(0)
w = nethook.get_parameter(model, w_name)
upd_matrix = upd_matrix_match_shape(upd_matrix, w.shape)
if return_orig_weights and w_name not in weights_copy:
weights_copy[w_name] = w.detach().clone()
w[...] += upd_matrix
print(f"New weights successfully inserted into {list(deltas.keys())}")
if not keep_original_weight:
weights_copy = {}
gr.Info("Completed editing via ROME!")
return model
def execute_rome(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
request: Dict,
hparams: ROMEHyperParams,
) -> Dict[str, Tuple[torch.Tensor]]:
"""
Executes the ROME update algorithm for the specified update at the specified layer
Invariant: model at beginning of function == model at end of function
"""
# Update target and print info
request = deepcopy(request)
if request["target_new"] != " ":
# Space required for correct tokenization
request["target_new"] = " " + request["target_new"]
if '{}' not in request['prompt']:
assert request['subject'] in request['prompt'] or \
print(f"Subject:{request['subject']} do not exist in prompt: {request['prompt']}")
request['prompt'] = request['prompt'].replace(request['subject'], '{}')
print(
f"Executing ROME algorithm for the update: "
f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']}]"
)
# Retrieve weights that user desires to change
weights = {
f"{hparams.rewrite_module_tmp.format(layer)}.weight": nethook.get_parameter(
model, f"{hparams.rewrite_module_tmp.format(layer)}.weight"
)
for layer in hparams.layers
}
# Save old weights for future restoration
weights_copy = {k: v.detach().clone() for k, v in weights.items()}
# Update loop: sequentially intervene at each specified layer
deltas = {}
for layer in sorted(hparams.layers):
# Compute rank-1 update matrix
left_vector: torch.Tensor = compute_u(
model,
tok,
request,
hparams,
layer,
get_context_templates(model, tok, hparams.context_template_length_params),
)
print("Left vector shape:", left_vector.shape)
right_vector: torch.Tensor = compute_v(
model,
tok,
request,
hparams,
layer,
left_vector,
get_context_templates(model, tok, hparams.context_template_length_params),
)
print("Right vector shape:", right_vector.shape)
with torch.no_grad():
# Determine correct transposition of delta matrix
weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight"
upd_matrix = left_vector.unsqueeze(1) @ right_vector.unsqueeze(0)
upd_matrix = upd_matrix_match_shape(upd_matrix, weights[weight_name].shape)
# Update model weights and record desired changes in `delta` variable
weights[weight_name][...] += upd_matrix
deltas[weight_name] = (
left_vector.detach(),
right_vector.detach(),
)
# Restore state of original model
with torch.no_grad():
for k, v in weights.items():
v[...] = weights_copy[k]
print(f"Deltas successfully computed for {list(weights.keys())}")
return deltas
def upd_matrix_match_shape(matrix: torch.Tensor, shape: torch.Size) -> torch.Tensor:
"""
GPT-2 and GPT-J have transposed weight representations.
Returns a matrix that matches the desired shape, else raises a ValueError
"""
if matrix.shape == shape:
return matrix
elif matrix.T.shape == shape:
return matrix.T
else:
raise ValueError(
"Update matrix computed by ROME does not match original weight shape. "
"Check for bugs in the code?"
)
def get_context_templates(model, tok, length_params):
global CONTEXT_TEMPLATES_CACHE
if CONTEXT_TEMPLATES_CACHE is None:
CONTEXT_TEMPLATES_CACHE = ["{}"] + [
x.replace("{", "").replace("}", "") + ". {}"
for x in sum(
(
generate_fast(
model,
tok,
["The", "Therefore", "Because", "I", "You"],
n_gen_per_prompt=n_gen // 5,
max_out_len=length,
)
for length, n_gen in length_params
),
[],
)
]
print(f"Cached context templates {CONTEXT_TEMPLATES_CACHE}")
return CONTEXT_TEMPLATES_CACHE
|