Spaces:
Sleeping
Sleeping
adding application
Browse files- processor.py +106 -0
- requirements.txt +10 -213
- utils/__init__.py +0 -0
- utils/calibration_utils.py +288 -0
- utils/data_utils.py +214 -0
- utils/enums.py +13 -0
- utils/eval_utils.py +334 -0
- utils/file_utils.py +94 -0
- utils/logit_lens.py +304 -0
- utils/model_utils.py +320 -0
- utils/procrustes/__init__.py +0 -0
- utils/procrustes/orthogonal.py +383 -0
- utils/procrustes/utils.py +495 -0
processor.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class RetrievalProcessor:
|
6 |
+
def __init__(self, model, tokenizer, multi_token_kind, num_tokens_to_generate,
|
7 |
+
add_context, model_name, whitespace_token='Ġ'):
|
8 |
+
self.model = model
|
9 |
+
self.tokenizer = tokenizer
|
10 |
+
self.multi_token_kind = multi_token_kind
|
11 |
+
self.num_tokens_to_generate = num_tokens_to_generate
|
12 |
+
self.add_context = add_context
|
13 |
+
self.model_name = model_name
|
14 |
+
self.whitespace_token = whitespace_token
|
15 |
+
|
16 |
+
def get_next_word(self, tokens, i, max_length=1000, device='cuda'):
|
17 |
+
token_str = self.tokenizer.convert_ids_to_tokens(tokens[i].item())
|
18 |
+
j = i + 1
|
19 |
+
word_tokens = [tokens[i]]
|
20 |
+
if token_str.startswith(self.whitespace_token):
|
21 |
+
while j < len(tokens) and (
|
22 |
+
self.is_alpha_not_prefix(tokens[j])):
|
23 |
+
word_tokens.append(tokens[j])
|
24 |
+
j += 1
|
25 |
+
word = self.tokenizer.decode(word_tokens)
|
26 |
+
original_word = word
|
27 |
+
context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
|
28 |
+
combined_text = context + word
|
29 |
+
|
30 |
+
tokenized_combined_text = self.tokenizer(combined_text, return_tensors='pt', truncation=True,
|
31 |
+
max_length=max_length).to(device)
|
32 |
+
return j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word
|
33 |
+
|
34 |
+
def get_next_full_word_typo(self, tokens, i, max_length=1000, device='cuda'):
|
35 |
+
tokens_str = self.tokenizer.convert_ids_to_tokens(tokens)
|
36 |
+
word_tokens = [tokens[i]]
|
37 |
+
word = self.tokenizer.decode(word_tokens)
|
38 |
+
original_word = word
|
39 |
+
if self.is_full_word(tokens_str, i, word, word_tokens):
|
40 |
+
word = self.introduce_typo(word)
|
41 |
+
word_tokens = self.tokenizer(word, return_tensors='pt', truncation=True, max_length=max_length).input_ids[0][1:]
|
42 |
+
context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
|
43 |
+
combined_text = context + word
|
44 |
+
|
45 |
+
tokenized_combined_text = self.tokenizer(combined_text, return_tensors='pt', truncation=True,
|
46 |
+
max_length=max_length).to(device)
|
47 |
+
j = len(tokenized_combined_text.input_ids[0]) - 1 if self.add_context else len(tokenized_combined_text.input_ids[0]) - 1 + i
|
48 |
+
return j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word
|
49 |
+
|
50 |
+
def get_next_full_word_separated(self, tokens, i, max_length=1000, device='cuda'):
|
51 |
+
tokens_str = self.tokenizer.convert_ids_to_tokens(tokens)
|
52 |
+
word_tokens = [tokens[i]]
|
53 |
+
word = self.tokenizer.decode(word_tokens)
|
54 |
+
original_word = word
|
55 |
+
if self.is_full_word(tokens_str, i, word, word_tokens):
|
56 |
+
word = torch.tensor(self.separate_word(word)).unsqueeze(0)
|
57 |
+
else:
|
58 |
+
word = word_tokens[0].unsqueeze(0).unsqueeze(0)
|
59 |
+
context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
|
60 |
+
tokenized_combined_text = self.tokenizer(context, return_tensors='pt', truncation=True,
|
61 |
+
max_length=max_length).to(device)
|
62 |
+
print(tokenized_combined_text.input_ids)
|
63 |
+
print(word)
|
64 |
+
tokenized_combined_text.input_ids = torch.cat((tokenized_combined_text.input_ids, word), dim=1)
|
65 |
+
word_tokens = word
|
66 |
+
j = i+1
|
67 |
+
return j, word_tokens, word, context, tokenized_combined_text, self.tokenizer.decode(tokenized_combined_text.input_ids[0]), original_word
|
68 |
+
|
69 |
+
def is_alpha_not_prefix(self, token):
|
70 |
+
return (not self.tokenizer.convert_ids_to_tokens(token.item()).startswith(self.whitespace_token)
|
71 |
+
and self.tokenizer.convert_ids_to_tokens(token.item()).isalpha())
|
72 |
+
|
73 |
+
def introduce_typo(self, word, typo_type=None):
|
74 |
+
letters = 'abcdefghijklmnopqrstuvwxyz'
|
75 |
+
if typo_type is None:
|
76 |
+
typo_type = random.choice(["substitution", "deletion", "insertion", "transposition"])
|
77 |
+
|
78 |
+
if typo_type == "substitution":
|
79 |
+
position = random.randint(1, len(word) - 1)
|
80 |
+
original_char = word[position]
|
81 |
+
typo_char = random.choice([c for c in letters if c != original_char])
|
82 |
+
return word[:position] + typo_char + word[position + 1:]
|
83 |
+
elif typo_type == "deletion":
|
84 |
+
position = random.randint(1, len(word) - 1)
|
85 |
+
return word[:position] + word[position + 1:]
|
86 |
+
elif typo_type == "insertion":
|
87 |
+
position = random.randint(1, len(word) - 1)
|
88 |
+
typo_char = random.choice(letters)
|
89 |
+
return word[:position] + typo_char + word[position:]
|
90 |
+
elif typo_type == "transposition":
|
91 |
+
position = random.randint(1, len(word) - 2)
|
92 |
+
return word[:position] + word[position + 1] + word[position] + word[position + 2:]
|
93 |
+
else:
|
94 |
+
return word
|
95 |
+
|
96 |
+
def separate_word(self, word):
|
97 |
+
character_tokens = [self.tokenizer.encode(f'\n{char}')[-1] for char in ''.join(word)]
|
98 |
+
character_tokens = character_tokens[3:]
|
99 |
+
return character_tokens
|
100 |
+
|
101 |
+
def is_full_word(self, token_str, i, token, word_tokens):
|
102 |
+
next_token = self.tokenizer.decode(word_tokens[i + 1]) if i + 1 < len(word_tokens) else ""
|
103 |
+
return (token[1:].isalpha() and
|
104 |
+
len(token) > 5 and
|
105 |
+
token_str[i].startswith(self.whitespace_token) and
|
106 |
+
not next_token.isalpha())
|
requirements.txt
CHANGED
@@ -1,213 +1,10 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
|
12 |
-
async-lru @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_02efro5ps8/croot/async-lru_1699554529181/work
|
13 |
-
async-timeout==5.0.1
|
14 |
-
attrs @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_93pjmt0git/croot/attrs_1734533120523/work
|
15 |
-
Babel @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_00k1rl2pus/croot/babel_1671781944131/work
|
16 |
-
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
|
17 |
-
beautifulsoup4 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_94rx5n7wo9/croot/beautifulsoup4-split_1718029832430/work
|
18 |
-
bleach @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_faqg19k8gh/croot/bleach_1732292152791/work
|
19 |
-
blis==1.2.0
|
20 |
-
Brotli @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f7i0oxypt6/croot/brotli-split_1736182464088/work
|
21 |
-
catalogue==2.0.10
|
22 |
-
certifi @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d8j59rqun5/croot/certifi_1734473289913/work/certifi
|
23 |
-
cffi @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e4xd9yd9i2/croot/cffi_1736182819442/work
|
24 |
-
charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work
|
25 |
-
click==8.1.8
|
26 |
-
cloudpathlib==0.20.0
|
27 |
-
cloudpickle==3.1.0
|
28 |
-
comm @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_3doui0bmzb/croot/comm_1709322861485/work
|
29 |
-
confection==0.1.5
|
30 |
-
contourpy==1.3.0
|
31 |
-
cycler==0.12.1
|
32 |
-
cymem==2.0.11
|
33 |
-
dask==2024.8.0
|
34 |
-
dask-expr==1.1.10
|
35 |
-
datasets==3.2.0
|
36 |
-
debugpy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_563_nwtkoc/croot/debugpy_1690905063850/work
|
37 |
-
decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
|
38 |
-
defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
|
39 |
-
-e git+https://github.com/tokeron/diffusers.git@00769b5d64c2ea35201e0df7a082db3513619afe#egg=diffusers&subdirectory=../../../../../../diffusers
|
40 |
-
dill==0.3.8
|
41 |
-
distro==1.9.0
|
42 |
-
docker-pycreds==0.4.0
|
43 |
-
editdistance==0.8.1
|
44 |
-
en_core_web_lg @ https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.8.0/en_core_web_lg-3.8.0-py3-none-any.whl#sha256=293e9547a655b25499198ab15a525b05b9407a75f10255e405e8c3854329ab63
|
45 |
-
en_core_web_md @ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.8.0/en_core_web_md-3.8.0-py3-none-any.whl#sha256=5e6329fe3fecedb1d1a02c3ea2172ee0fede6cea6e4aefb6a02d832dba78a310
|
46 |
-
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
|
47 |
-
eval_type_backport==0.2.2
|
48 |
-
exceptiongroup @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b2258scr33/croot/exceptiongroup_1706031391815/work
|
49 |
-
executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
|
50 |
-
fastapi==0.115.12
|
51 |
-
fastjsonschema @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_d1wgyi4enb/croot/python-fastjsonschema_1731939426145/work
|
52 |
-
ffmpy==0.5.0
|
53 |
-
filelock==3.16.1
|
54 |
-
fonttools==4.55.3
|
55 |
-
frozenlist==1.5.0
|
56 |
-
fsspec==2024.9.0
|
57 |
-
gitdb==4.0.12
|
58 |
-
GitPython==3.1.44
|
59 |
-
gradio==4.44.1
|
60 |
-
gradio_client==1.3.0
|
61 |
-
h11 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_110bmw2coo/croot/h11_1706652289620/work
|
62 |
-
httpcore @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_fcxiho9nv7/croot/httpcore_1706728465004/work
|
63 |
-
httpx @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_cc4egw1482/croot/httpx_1723474826664/work
|
64 |
-
huggingface-hub==0.27.1
|
65 |
-
idna==3.10
|
66 |
-
importlib_metadata @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_cc4qelzghy/croot/importlib_metadata-suite_1732633706960/work
|
67 |
-
importlib_resources==6.5.2
|
68 |
-
ipykernel @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_ddflobe9t3/croot/ipykernel_1728665605034/work
|
69 |
-
ipython @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_6599f73fa7/croot/ipython_1694181355402/work
|
70 |
-
jedi @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_38ctoinnl0/croot/jedi_1733987402850/work
|
71 |
-
Jinja2 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_b15nuwux5r/croot/jinja2_1730902833938/work
|
72 |
-
jiter==0.8.2
|
73 |
-
joblib==1.4.2
|
74 |
-
json5 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b9ww6ewhv3/croot/json5_1730786813588/work
|
75 |
-
jsonschema @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_7boelfqucq/croot/jsonschema_1728486715888/work
|
76 |
-
jsonschema-specifications @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_d38pclgu95/croot/jsonschema-specifications_1699032390832/work
|
77 |
-
jupyter-events @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_db0avcjzq5/croot/jupyter_events_1718738111427/work
|
78 |
-
jupyter-lsp @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_ae9br5v37x/croot/jupyter-lsp-meta_1699978259353/work
|
79 |
-
jupyter_client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_58w2siozyz/croot/jupyter_client_1699455907045/work
|
80 |
-
jupyter_core @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_73nomeum4p/croot/jupyter_core_1718818302815/work
|
81 |
-
jupyter_server @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d1t69bk94b/croot/jupyter_server_1718827086930/work
|
82 |
-
jupyter_server_terminals @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e7ryd60iuw/croot/jupyter_server_terminals_1686870731283/work
|
83 |
-
jupyterlab @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a2d0br6r6g/croot/jupyterlab_1725895226942/work
|
84 |
-
jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
|
85 |
-
jupyterlab_server @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f64fg3hglz/croot/jupyterlab_server_1725865356410/work
|
86 |
-
kiwisolver==1.4.7
|
87 |
-
langcodes==3.5.0
|
88 |
-
language_data==1.3.0
|
89 |
-
locket==1.0.0
|
90 |
-
marisa-trie==1.2.1
|
91 |
-
markdown-it-py==3.0.0
|
92 |
-
MarkupSafe @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a84ni4pci8/croot/markupsafe_1704206002077/work
|
93 |
-
matplotlib==3.9.4
|
94 |
-
matplotlib-inline @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_f6fdc0hldi/croots/recipe/matplotlib-inline_1662014472341/work
|
95 |
-
matplotlib-venn==1.1.2
|
96 |
-
mdurl==0.1.2
|
97 |
-
mistune @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_17ya6k1sbs/croots/recipe/mistune_1661496228719/work
|
98 |
-
mpmath==1.3.0
|
99 |
-
multidict==6.1.0
|
100 |
-
multiprocess==0.70.16
|
101 |
-
murmurhash==1.0.12
|
102 |
-
nbclient @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_626hpwnurm/croot/nbclient_1698934218848/work
|
103 |
-
nbconvert @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_f4c1s1qk1f/croot/nbconvert_1728049432295/work
|
104 |
-
nbformat @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_2cv_qoc1gw/croot/nbformat_1728049423516/work
|
105 |
-
nest-asyncio @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_310vb5e2a0/croot/nest-asyncio_1708532678212/work
|
106 |
-
networkx==3.2.1
|
107 |
-
nltk==3.9.1
|
108 |
-
notebook @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_539v4hufo2/croot/notebook_1727199149603/work
|
109 |
-
notebook_shim @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d6_ze10f45/croot/notebook-shim_1699455897525/work
|
110 |
-
numpy==2.0.2
|
111 |
-
openai==1.59.7
|
112 |
-
orjson==3.10.16
|
113 |
-
overrides @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_70s80guh9g/croot/overrides_1699371144462/work
|
114 |
-
packaging @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a6_qk3qyg7/croot/packaging_1734472142254/work
|
115 |
-
pandas==2.2.3
|
116 |
-
pandocfilters @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work
|
117 |
-
parso @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_8824a1w4md/croot/parso_1733963320105/work
|
118 |
-
partd==1.4.2
|
119 |
-
patsy==1.0.1
|
120 |
-
pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
|
121 |
-
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
|
122 |
-
pillow==10.4.0
|
123 |
-
platformdirs @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a8u4fy8k9o/croot/platformdirs_1692205661656/work
|
124 |
-
plotly==5.24.1
|
125 |
-
preshed==3.0.9
|
126 |
-
prometheus_client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_803ymjpv2u/croot/prometheus_client_1731958793251/work
|
127 |
-
prompt-toolkit @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_c63v4kqjzr/croot/prompt-toolkit_1704404354115/work
|
128 |
-
propcache==0.2.1
|
129 |
-
protobuf==5.29.2
|
130 |
-
psutil @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1310b568-21f4-4cb0-b0e3-2f3d31e39728k9coaga5/croots/recipe/psutil_1656431280844/work
|
131 |
-
ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
|
132 |
-
pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
|
133 |
-
pyarrow==18.1.0
|
134 |
-
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
|
135 |
-
pydantic==2.10.4
|
136 |
-
pydantic_core==2.27.2
|
137 |
-
pydub==0.25.1
|
138 |
-
Pygments @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_29bs9f_dh9/croot/pygments_1684279974747/work
|
139 |
-
pyparsing==3.2.1
|
140 |
-
PySocks @ file:///Users/ktietz/Code/oss/ci_pkgs/pysocks_1626781349491/work
|
141 |
-
python-box==7.3.0
|
142 |
-
python-dateutil @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_66ud1l42_h/croot/python-dateutil_1716495741162/work
|
143 |
-
python-json-logger @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_9bjmcmh4nm/croot/python-json-logger_1734370248301/work
|
144 |
-
python-multipart==0.0.20
|
145 |
-
pytz @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a4b76c83ik/croot/pytz_1713974318928/work
|
146 |
-
PyYAML @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_faoex52hrr/croot/pyyaml_1728657970485/work
|
147 |
-
pyzmq @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_95lsut8ymz/croot/pyzmq_1734709560733/work
|
148 |
-
referencing @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_5cz64gsx70/croot/referencing_1699012046031/work
|
149 |
-
regex==2024.11.6
|
150 |
-
requests @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_ee45nsd33z/croot/requests_1730999134038/work
|
151 |
-
rfc3339-validator @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_76ae5cu30h/croot/rfc3339-validator_1683077051957/work
|
152 |
-
rfc3986-validator @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d0l5zd97kt/croot/rfc3986-validator_1683058998431/work
|
153 |
-
rich==13.9.4
|
154 |
-
rpds-py @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_93fzmr7v9h/croot/rpds-py_1732228422522/work
|
155 |
-
ruff==0.11.6
|
156 |
-
safetensors==0.5.0
|
157 |
-
scikit-learn==1.6.0
|
158 |
-
scipy==1.13.1
|
159 |
-
seaborn==0.13.2
|
160 |
-
semantic-version==2.10.0
|
161 |
-
Send2Trash @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5b31f0zzlv/croot/send2trash_1699371144121/work
|
162 |
-
sentencepiece==0.2.0
|
163 |
-
sentry-sdk==2.19.2
|
164 |
-
setproctitle==1.3.4
|
165 |
-
shellingham==1.5.4
|
166 |
-
six @ file:///tmp/build/80754af9/six_1644875935023/work
|
167 |
-
smart-open==7.1.0
|
168 |
-
smmap==5.0.2
|
169 |
-
sniffio @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1573pknjrg/croot/sniffio_1705431298885/work
|
170 |
-
soupsieve @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_9798xzs_03/croot/soupsieve_1696347567192/work
|
171 |
-
spacy==3.8.3
|
172 |
-
spacy-legacy==3.0.12
|
173 |
-
spacy-loggers==1.0.5
|
174 |
-
srsly==2.5.1
|
175 |
-
stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
|
176 |
-
starlette==0.46.2
|
177 |
-
statsmodels==0.14.4
|
178 |
-
swifter==1.4.0
|
179 |
-
sympy==1.13.1
|
180 |
-
tabulate==0.9.0
|
181 |
-
tenacity==9.0.0
|
182 |
-
terminado @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_fcfvyc0an2/croot/terminado_1671751835701/work
|
183 |
-
thinc==8.3.4
|
184 |
-
threadpoolctl==3.5.0
|
185 |
-
tinycss2 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_fcw5_i306t/croot/tinycss2_1668168825117/work
|
186 |
-
together==1.4.1
|
187 |
-
tokenizers==0.21.0
|
188 |
-
tomli @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d0e5ffbf-5cf1-45be-8693-c5dff8108a2awhthtjlq/croots/recipe/tomli_1657175508477/work
|
189 |
-
tomlkit==0.12.0
|
190 |
-
toolz==1.0.0
|
191 |
-
torch==2.5.1
|
192 |
-
torchvision==0.20.1
|
193 |
-
tornado @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_0axef5a0m0/croot/tornado_1733960501260/work
|
194 |
-
tqdm==4.67.1
|
195 |
-
traitlets @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_500m2_1wyk/croot/traitlets_1718227071952/work
|
196 |
-
transformers==4.47.1
|
197 |
-
typer==0.15.1
|
198 |
-
typing_extensions @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_0b3jpv_f79/croot/typing_extensions_1734714864260/work
|
199 |
-
tzdata==2024.2
|
200 |
-
urllib3 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_06_m8gdsy6/croot/urllib3_1727769822458/work
|
201 |
-
uvicorn==0.34.2
|
202 |
-
wandb==0.19.1
|
203 |
-
wasabi==1.1.3
|
204 |
-
wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
|
205 |
-
weasel==0.4.1
|
206 |
-
webencodings==0.5.1
|
207 |
-
websocket-client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d37u7gqts8/croot/websocket-client_1715878310260/work
|
208 |
-
websockets==12.0
|
209 |
-
wordcloud==1.9.4
|
210 |
-
wrapt==1.17.2
|
211 |
-
xxhash==3.5.0
|
212 |
-
yarl==1.18.3
|
213 |
-
zipp @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_echurpkwug/croot/zipp_1732630743967/work
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
pandas
|
4 |
+
functools
|
5 |
+
tqdm
|
6 |
+
abc
|
7 |
+
enum
|
8 |
+
typing
|
9 |
+
scikit-learn
|
10 |
+
gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/__init__.py
ADDED
File without changes
|
utils/calibration_utils.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator
|
8 |
+
from transformers import get_scheduler
|
9 |
+
from accelerate import Accelerator
|
10 |
+
from accelerate.utils import set_seed
|
11 |
+
from collections import defaultdict
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
import torch.optim as optim
|
14 |
+
|
15 |
+
from ..utils.data_utils import load_lm_dataset, extract_new_words_from_dataset, get_group_texts_func, get_tokenize_func
|
16 |
+
|
17 |
+
|
18 |
+
class EmbeddingCalibrator(nn.Module):
|
19 |
+
def __init__(self, hidden_size, lora_r=None, lora_alpha=None, dtype=torch.bfloat16):
|
20 |
+
super().__init__()
|
21 |
+
self.use_lora = lora_r is not None
|
22 |
+
|
23 |
+
if not self.use_lora:
|
24 |
+
self.weight = nn.Parameter(torch.zeros(hidden_size, hidden_size, dtype=dtype))
|
25 |
+
else:
|
26 |
+
self.lora_scaling = lora_alpha / lora_r if lora_alpha is not None else 1.0
|
27 |
+
self.lora_A = nn.Parameter(torch.randn(lora_rank, hidden_size, dtype=dtype) * (1/lora_r))
|
28 |
+
self.lora_B = nn.Parameter(torch.zeros(hidden_size, lora_rank, dtype=dtype))
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
if not self.use_lora:
|
32 |
+
return x + torch.matmul(x, self.weight.t())
|
33 |
+
else:
|
34 |
+
# Low-rank adaptation
|
35 |
+
lora_out = torch.matmul(x, self.lora_A.t())
|
36 |
+
lora_out = torch.matmul(lora_out, self.lora_B.t())
|
37 |
+
return x + self.lora_scaling * lora_out
|
38 |
+
|
39 |
+
|
40 |
+
class CalibrationModel(nn.Module):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
base_model, lm_head, original_vocab_size, num_new_tokens,
|
44 |
+
calibrate_embedding=True, calibrate_lm_head=True, empty_init=False,
|
45 |
+
lora_alpha=None, lora_r=None,
|
46 |
+
target_loss_weight=0.15, subsequent_loss_weight=0.15,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.base_model = base_model
|
50 |
+
self.lm_head = lm_head
|
51 |
+
self.new_tokens_start = original_vocab_size
|
52 |
+
self.new_tokens_end = original_vocab_size + num_new_tokens
|
53 |
+
|
54 |
+
self.calibrate_lm_head = calibrate_lm_head
|
55 |
+
self.calibrate_embedding = calibrate_embedding
|
56 |
+
if not empty_init:
|
57 |
+
self.lm_head_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha)
|
58 |
+
self.embedding_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha)
|
59 |
+
|
60 |
+
self.loss_fct = nn.CrossEntropyLoss(reduction="none")
|
61 |
+
self.subsequent_tokens_loss_alpha = subsequent_loss_weight
|
62 |
+
self.new_tokens_loss_alpha = target_loss_weight
|
63 |
+
self.original_tokens_loss_alpha = 1 - self.new_tokens_loss_alpha - self.subsequent_tokens_loss_alpha
|
64 |
+
|
65 |
+
def forward(self, input_ids, labels, attention_mask=None):
|
66 |
+
# shift labels by 1 for CLM
|
67 |
+
labels = labels[:, 1:].contiguous()
|
68 |
+
input_ids = input_ids[:, :-1].contiguous()
|
69 |
+
|
70 |
+
if self.calibrate_embedding:
|
71 |
+
E_weights = self.base_model.get_input_embeddings().weight.data
|
72 |
+
E_weights = torch.cat((E_weights[:self.new_tokens_start], self.embedding_calibrator(E_weights[self.new_tokens_start:])))
|
73 |
+
input_embeddings = E_weights[input_ids]
|
74 |
+
if attention_mask is None:
|
75 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
76 |
+
outputs = self.base_model(inputs_embeds=input_embeddings, attention_mask=attention_mask)
|
77 |
+
else:
|
78 |
+
with torch.no_grad():
|
79 |
+
# Forward pass through the base model
|
80 |
+
outputs = self.base_model(input_ids, attention_mask=attention_mask)
|
81 |
+
|
82 |
+
if self.calibrate_lm_head:
|
83 |
+
with torch.no_grad():
|
84 |
+
lm_head_weights = self.lm_head.weight
|
85 |
+
normed_weights = lm_head_weights.clone()
|
86 |
+
normed_weights[self.new_tokens_start:self.new_tokens_end] = self.lm_head_calibrator(lm_head_weights[self.new_tokens_start:self.new_tokens_end])
|
87 |
+
logits = torch.matmul(outputs['last_hidden_state'], normed_weights.T)
|
88 |
+
else:
|
89 |
+
if self.calibrate_embedding:
|
90 |
+
logits = self.lm_head(outputs['last_hidden_state'])
|
91 |
+
else:
|
92 |
+
with torch.no_grad():
|
93 |
+
logits = self.lm_head(outputs['last_hidden_state'])
|
94 |
+
|
95 |
+
per_example_loss = self.loss_fct(logits.transpose(1,2), labels)
|
96 |
+
original_tokens_mask = labels < self.new_tokens_start
|
97 |
+
new_tokens_mask = ~original_tokens_mask
|
98 |
+
loss = 0.0
|
99 |
+
if self.original_tokens_loss_alpha > 0.0:
|
100 |
+
loss += self.original_tokens_loss_alpha * per_example_loss[original_tokens_mask].mean()
|
101 |
+
if self.new_tokens_loss_alpha > 0.0:
|
102 |
+
loss += self.new_tokens_loss_alpha * per_example_loss[new_tokens_mask].mean()
|
103 |
+
if self.subsequent_tokens_loss_alpha > 0.0:
|
104 |
+
subsequent_tokens_mask = torch.zeros_like(original_tokens_mask, dtype=torch.bool)
|
105 |
+
subsequent_tokens_mask[:, 1:][new_tokens_mask[:, :-1]] = True
|
106 |
+
loss += self.subsequent_tokens_loss_alpha * per_example_loss[subsequent_tokens_mask].mean()
|
107 |
+
|
108 |
+
return {'loss': loss, 'logits': logits}
|
109 |
+
|
110 |
+
def get_calibrators(self):
|
111 |
+
embedding_calibrator = self.embedding_calibrator if self.calibrate_embedding else None
|
112 |
+
lm_head_calibrator = self.lm_head_calibrator if self.calibrate_lm_head else None
|
113 |
+
return {
|
114 |
+
"embedding_calibrator": embedding_calibrator,
|
115 |
+
"lm_head_calibrator": lm_head_calibrator,
|
116 |
+
"new_tokens_start": self.new_tokens_start,
|
117 |
+
"new_tokens_end": self.new_tokens_end,
|
118 |
+
}
|
119 |
+
|
120 |
+
def set_calibrators(self, embedding_calibrator=None, lm_head_calibrator=None):
|
121 |
+
self.embedding_calibrator = embedding_calibrator
|
122 |
+
self.lm_head_calibrator = lm_head_calibrator
|
123 |
+
|
124 |
+
def save_calibrators(self, save_dir):
|
125 |
+
os.makedirs(save_dir, exist_ok=True)
|
126 |
+
if self.calibrate_embedding:
|
127 |
+
torch.save(self.embedding_calibrator, os.path.join(save_dir, "embedding_calibrator.pt"))
|
128 |
+
if self.calibrate_lm_head:
|
129 |
+
torch.save(self.lm_head_calibrator, os.path.join(save_dir, "lm_head_calibrator.pt"))
|
130 |
+
|
131 |
+
def load_calibrators(self, load_dir, fail_ok=False):
|
132 |
+
"""Loads the model's state dictionary from a file."""
|
133 |
+
try:
|
134 |
+
if self.calibrate_embedding:
|
135 |
+
self.embedding_calibrator = torch.load(os.path.join(load_dir, "embedding_calibrator.pt"))
|
136 |
+
if self.calibrate_lm_head:
|
137 |
+
self.lm_head_calibrator = torch.load(os.path.join(load_dir, "lm_head_calibrator.pt"))
|
138 |
+
return True
|
139 |
+
except:
|
140 |
+
if fail_ok:
|
141 |
+
return False
|
142 |
+
raise FileNotFoundError(f"Loading calibrators from '{load_dir}' failed")
|
143 |
+
|
144 |
+
|
145 |
+
def get_calibration_model(model, original_vocab_size, num_new_tokens, target_loss_weight=0.15, subsequent_loss_weight=0.15):
|
146 |
+
calibrated_model = CalibrationModel(model.model, model.lm_head, original_vocab_size, num_new_tokens, target_loss_weight=target_loss_weight, subsequent_loss_weight=subsequent_loss_weight)
|
147 |
+
calibrated_model.base_model.eval()
|
148 |
+
calibrated_model.lm_head.eval()
|
149 |
+
|
150 |
+
for param in calibrated_model.base_model.parameters():
|
151 |
+
param.requires_grad = False
|
152 |
+
for param in calibrated_model.lm_head.parameters():
|
153 |
+
param.requires_grad = False
|
154 |
+
for param in calibrated_model.lm_head_calibrator.parameters():
|
155 |
+
param.requires_grad = True
|
156 |
+
for param in calibrated_model.embedding_calibrator.parameters():
|
157 |
+
param.requires_grad = True
|
158 |
+
|
159 |
+
return calibrated_model
|
160 |
+
|
161 |
+
|
162 |
+
def train_calibration_model(calibrated_model: CalibrationModel, tokenizer, dataset, save_dir=None, max_samples=None, filter_examples_without_new_tokens=True, lr=1e-4, lr_schedule="linear", num_epochs=1, batch_size=8, max_length=256, n_warmup_steps=0, text_col_name="text", clip_grad_norm=1.0, mixed_precision=None):
|
163 |
+
accelerator = Accelerator(mixed_precision=mixed_precision)
|
164 |
+
# Optimizer
|
165 |
+
optimizer = optim.AdamW(calibrated_model.parameters(), lr=lr)
|
166 |
+
|
167 |
+
# Tokenize data
|
168 |
+
if tokenizer.bos_token is not None and max_length:
|
169 |
+
add_start_token = True
|
170 |
+
# leave room for <BOS> token to be added:
|
171 |
+
max_tokenized_len = max_length - 1
|
172 |
+
else:
|
173 |
+
add_start_token = False
|
174 |
+
max_tokenized_len = max_length
|
175 |
+
|
176 |
+
def _add_start_token(batch):
|
177 |
+
bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device)
|
178 |
+
batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1)
|
179 |
+
batch["attention_mask"] = torch.cat(
|
180 |
+
[torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1)
|
181 |
+
return batch
|
182 |
+
|
183 |
+
tokenize_function = get_tokenize_func(tokenizer, text_col_name)
|
184 |
+
|
185 |
+
column_names = dataset.column_names
|
186 |
+
|
187 |
+
with accelerator.main_process_first():
|
188 |
+
tokenized_dataset = dataset.map(
|
189 |
+
tokenize_function,
|
190 |
+
batched=True,
|
191 |
+
remove_columns=column_names,
|
192 |
+
load_from_cache_file=False,
|
193 |
+
desc="Running tokenizer on dataset",
|
194 |
+
)
|
195 |
+
group_texts = get_group_texts_func(block_size=max_tokenized_len)
|
196 |
+
lm_dataset = tokenized_dataset.map(
|
197 |
+
group_texts,
|
198 |
+
batched=True,
|
199 |
+
)
|
200 |
+
|
201 |
+
if filter_examples_without_new_tokens:
|
202 |
+
examples_w_new_token = np.arange(len(lm_dataset))[np.any(np.array(lm_dataset['input_ids']) >= calibrated_model.new_tokens_start, axis=1)]
|
203 |
+
lm_dataset = lm_dataset.select(examples_w_new_token)
|
204 |
+
|
205 |
+
if max_samples is not None:
|
206 |
+
lm_dataset = lm_dataset.select(np.arange(max_samples))
|
207 |
+
|
208 |
+
data_collator = default_data_collator
|
209 |
+
|
210 |
+
# Create data loaders
|
211 |
+
dataloader = DataLoader(
|
212 |
+
lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=True, shuffle=True,
|
213 |
+
)
|
214 |
+
|
215 |
+
# Learning rate scheduler
|
216 |
+
if isinstance(n_warmup_steps, float):
|
217 |
+
n_warmup_steps = n_warmup_steps * len(dataloader)
|
218 |
+
scheduler = get_scheduler(lr_schedule, optimizer=optimizer, num_warmup_steps=n_warmup_steps, num_training_steps=len(dataloader) * num_epochs)
|
219 |
+
|
220 |
+
calibrated_model, dataloader = accelerator.prepare(calibrated_model, dataloader)
|
221 |
+
|
222 |
+
# Freeze the original lm_head weights
|
223 |
+
for param in calibrated_model.lm_head.parameters():
|
224 |
+
param.requires_grad = False
|
225 |
+
|
226 |
+
calibrated_model.train()
|
227 |
+
for epoch in tqdm(range(num_epochs), unit="epochs", desc="Fitting calibration"):
|
228 |
+
total_loss = 0.0
|
229 |
+
for step, batch in tqdm(enumerate(dataloader), total=len(dataloader), miniters=10, unit="batches"):
|
230 |
+
if add_start_token:
|
231 |
+
batch = _add_start_token(batch)
|
232 |
+
batch["labels"] = batch["input_ids"]
|
233 |
+
optimizer.zero_grad()
|
234 |
+
outputs = calibrated_model(**batch)
|
235 |
+
loss = outputs['loss']
|
236 |
+
loss.backward()
|
237 |
+
torch.nn.utils.clip_grad_norm_(calibrated_model.parameters(), max_norm=clip_grad_norm)
|
238 |
+
optimizer.step()
|
239 |
+
scheduler.step()
|
240 |
+
|
241 |
+
total_loss += loss.item()
|
242 |
+
|
243 |
+
# # Log loss
|
244 |
+
# if step % 10 == 0:
|
245 |
+
# print(f"Epoch {epoch + 1}, Step {step}, Loss: {loss.item()}")
|
246 |
+
|
247 |
+
avg_loss = total_loss / len(dataloader)
|
248 |
+
print(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss}")
|
249 |
+
|
250 |
+
if save_dir is not None:
|
251 |
+
calibrated_model.save_calibrators(save_dir)
|
252 |
+
|
253 |
+
return calibrated_model
|
254 |
+
|
255 |
+
|
256 |
+
def merge_calibrators_to_hf_model(hf_model, new_tokens_start, new_tokens_end=None, embedding_calibrator=None, lm_head_calibrator=None):
|
257 |
+
embedding_calibrator.to(hf_model.device)
|
258 |
+
lm_head_calibrator.to(hf_model.device)
|
259 |
+
if embedding_calibrator is not None:
|
260 |
+
embedding_weights = hf_model.get_input_embeddings().weight
|
261 |
+
with torch.no_grad():
|
262 |
+
calibrated_weights = embedding_calibrator(embedding_weights[new_tokens_start:new_tokens_end])
|
263 |
+
hf_model.model.embed_tokens.weight.data[
|
264 |
+
new_tokens_start:new_tokens_end] = calibrated_weights
|
265 |
+
|
266 |
+
if lm_head_calibrator is not None:
|
267 |
+
lm_head_weights = hf_model.get_output_embeddings().weight
|
268 |
+
with torch.no_grad():
|
269 |
+
calibrated_weights = lm_head_calibrator(lm_head_weights[new_tokens_start:new_tokens_end])
|
270 |
+
hf_model.lm_head.weight.data[new_tokens_start:new_tokens_end] = calibrated_weights
|
271 |
+
|
272 |
+
return hf_model
|
273 |
+
|
274 |
+
|
275 |
+
def merge_calibration_model_to_hf_model(hf_model, calibrated_model):
|
276 |
+
calibrated_model.to(hf_model.device)
|
277 |
+
if calibrated_model.calibrate_lm_head:
|
278 |
+
lm_head_weights = calibrated_model.lm_head.weight
|
279 |
+
normed_weights = calibrated_model.lm_head_calibrator(lm_head_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end])
|
280 |
+
with torch.no_grad():
|
281 |
+
hf_model.lm_head.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights
|
282 |
+
if calibrated_model.calibrate_embedding:
|
283 |
+
embedding_weights = calibrated_model.base_model.get_input_embeddings().weight
|
284 |
+
normed_weights = calibrated_model.embedding_calibrator(embedding_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end])
|
285 |
+
with torch.no_grad():
|
286 |
+
hf_model.model.embed_tokens.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights
|
287 |
+
return hf_model
|
288 |
+
|
utils/data_utils.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from datasets import load_dataset, Dataset, DatasetDict
|
3 |
+
from itertools import chain
|
4 |
+
from tqdm import tqdm
|
5 |
+
from collections import Counter
|
6 |
+
from accelerate import Accelerator
|
7 |
+
|
8 |
+
LANGUAGES_TO_DECODE_FROM_BYTES = ["he", "fr", "uk"]
|
9 |
+
STREAMING_DATASETS = ["fineweb-edu"]
|
10 |
+
|
11 |
+
|
12 |
+
def load_pg19_val_and_test():
|
13 |
+
# Load the dataset in streaming mode
|
14 |
+
streaming_dataset = load_dataset("deepmind/pg19", split=None, streaming=True)
|
15 |
+
|
16 |
+
# Extract test and validation splits
|
17 |
+
test_split = list(streaming_dataset["test"])
|
18 |
+
validation_split = list(streaming_dataset["validation"])
|
19 |
+
|
20 |
+
# Convert them into regular datasets
|
21 |
+
test_dataset = Dataset.from_list(test_split)
|
22 |
+
validation_dataset = Dataset.from_list(validation_split)
|
23 |
+
|
24 |
+
# validation_dataset = load_dataset("deepmind/pg19", split="validation")
|
25 |
+
# test_dataset = load_dataset("deepmind/pg19", split="test")
|
26 |
+
|
27 |
+
return DatasetDict({"validation": validation_dataset, "test": test_dataset})
|
28 |
+
|
29 |
+
|
30 |
+
def load_pubmed(n_samples=10000):
|
31 |
+
# Load the dataset in streaming mode
|
32 |
+
streaming_dataset = load_dataset("MedRAG/pubmed", streaming=True)
|
33 |
+
|
34 |
+
# Extract test and validation splits
|
35 |
+
data = list(streaming_dataset["train"].take(n_samples*4))
|
36 |
+
train = data[:2*n_samples]
|
37 |
+
validation = data[2*n_samples:3*n_samples]
|
38 |
+
test = data[3*n_samples:]
|
39 |
+
# Convert them into regular datasets
|
40 |
+
train = Dataset.from_list(train)
|
41 |
+
validation = Dataset.from_list(validation)
|
42 |
+
test = Dataset.from_list(test)
|
43 |
+
dataset = DatasetDict({"train": train, 'validation': validation, 'test': test})
|
44 |
+
dataset = dataset.rename_column('content', 'text')
|
45 |
+
return dataset
|
46 |
+
|
47 |
+
|
48 |
+
def load_lm_dataset(dataset_name, language="en", split=None):
|
49 |
+
"""
|
50 |
+
Loads a popular pretraining or perplexity evaluation dataset by name and language.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
dataset_name (str): The name of the dataset to load. Options include:
|
54 |
+
- 'wikitext' (wikitext-2, smaller WikiText dataset)
|
55 |
+
- 'wikitext-103' (larger WikiText dataset)
|
56 |
+
- 'pg19' (Project Gutenberg dataset for long-context modeling)
|
57 |
+
- 'c4' (Common Crawl-based English corpus)
|
58 |
+
- 'wiki40b' (Wikipedia dataset in multiple languages)
|
59 |
+
- 'mc4' (Multilingual C4 dataset in various languages)
|
60 |
+
language (str): Language code for datasets that support multilingual options (e.g., 'en' for English).
|
61 |
+
Defaults to 'en'.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Dataset: Loaded Hugging Face dataset.
|
65 |
+
"""
|
66 |
+
if dataset_name.lower() == 'wikitext':
|
67 |
+
return load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=split)
|
68 |
+
elif dataset_name.lower() == 'fineweb-edu':
|
69 |
+
return load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT")
|
70 |
+
elif dataset_name.lower() == 'wikitext-103':
|
71 |
+
return load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split=split)
|
72 |
+
elif dataset_name.lower() == 'cord19':
|
73 |
+
return load_dataset("allenai/cord19", "fulltext", trust_remote_code=True)
|
74 |
+
elif dataset_name.lower() == 'pubmed':
|
75 |
+
return load_pubmed()
|
76 |
+
elif dataset_name.lower() == 'wikilingua':
|
77 |
+
dataset = load_dataset("GEM/wiki_lingua", trust_remote_code=True)
|
78 |
+
dataset = dataset.filter(lambda ex: (ex['source_language'] == "en") & (ex['target_language'] == "en"))
|
79 |
+
dataset = dataset.rename_column("source", "text")
|
80 |
+
dataset = dataset.rename_column("target", "summary")
|
81 |
+
return dataset
|
82 |
+
elif dataset_name.lower() == 'xsum':
|
83 |
+
dataset = load_dataset("EdinburghNLP/xsum")
|
84 |
+
dataset = dataset.rename_column("document", "text")
|
85 |
+
return dataset
|
86 |
+
elif dataset_name.lower() == 'cnn':
|
87 |
+
dataset = load_dataset("abisee/cnn_dailymail", "3.0.0")
|
88 |
+
dataset = dataset.rename_column("article", "text")
|
89 |
+
dataset = dataset.rename_column("highlights", "summary")
|
90 |
+
dataset = dataset.map(lambda example: {"text": example["text"].replace("(CNN)", "")})
|
91 |
+
return dataset
|
92 |
+
elif dataset_name.lower() == 'pg19':
|
93 |
+
return load_pg19_val_and_test()
|
94 |
+
elif dataset_name.lower() == 'wiki40b':
|
95 |
+
dataset = load_dataset("google/wiki40b", language, split=split)
|
96 |
+
if language in LANGUAGES_TO_DECODE_FROM_BYTES:
|
97 |
+
dataset = dataset.map(lambda x: {
|
98 |
+
"text": bytes(x["text"][2:-1], "utf-8").decode("unicode_escape").encode("latin1").decode("utf-8").replace("_NEWLINE_", "\n")
|
99 |
+
})
|
100 |
+
return dataset
|
101 |
+
else:
|
102 |
+
raise ValueError(
|
103 |
+
"Dataset not recognized. Available options: 'wikitext-2', 'wikitext-103', 'pg19', 'c4', 'wiki40b', 'mc4'.")
|
104 |
+
|
105 |
+
|
106 |
+
def extract_new_words_from_dataset(
|
107 |
+
dataset: Dataset, tokenizer, text_column: str = "text", max_samples: int = None, filter_func=(lambda word, token_count: True)):
|
108 |
+
"""
|
109 |
+
Loads a Hugging Face dataset and extracts all unique words from the specified text column.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
dataset (Dataset): Name of the dataset to load.
|
113 |
+
split (str): Dataset split to use, typically 'train' for training data. Defaults to 'train'.
|
114 |
+
text_column (str): The column in the dataset containing text. Defaults to 'text'.
|
115 |
+
max_samples (int): Number of samples from the dataset to go over.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
set: A set of unique words in the dataset.
|
119 |
+
"""
|
120 |
+
if max_samples:
|
121 |
+
dataset = dataset.select(range(max_samples))
|
122 |
+
|
123 |
+
# Regular expression to split text into words (adjust as needed for specific languages)
|
124 |
+
# word_pattern = re.compile(r"\b\w+\b")
|
125 |
+
word_pattern = re.compile(r"\b\w+(?:[-']\w+)*\b")
|
126 |
+
|
127 |
+
# Iterate over each entry in the dataset and extract unique words
|
128 |
+
all_words = list()
|
129 |
+
new_words = list()
|
130 |
+
for record in tqdm(dataset, total=len(dataset), miniters=10, desc="Extracting all words from dataset...", unit="examples"):
|
131 |
+
text = record.get(text_column, "")
|
132 |
+
words = word_pattern.findall(text)
|
133 |
+
all_words += words
|
134 |
+
|
135 |
+
# all_words = list(dict.fromkeys(all_words))
|
136 |
+
word_frequencies = Counter(all_words)
|
137 |
+
all_words = list(word_frequencies.keys())
|
138 |
+
token_counts = [len(x) for x in tokenizer(all_words, add_special_tokens=False)["input_ids"]]
|
139 |
+
w_whitespace_token_counts = [len(x) for x in tokenizer([f" {w}" for w in all_words], add_special_tokens=False)["input_ids"]]
|
140 |
+
|
141 |
+
new_words = [word for word, count, w_whitespace_count in zip(all_words, token_counts, w_whitespace_token_counts) if ((count > 1) and (w_whitespace_count > 1) and filter_func(word, count))]
|
142 |
+
new_words_freq = {word: word_frequencies[word] for word in new_words}
|
143 |
+
# for word, token_count in tqdm(all_words, total=len(all_words), miniters=10, desc="Finding new words...", unit="words"):
|
144 |
+
# if (not tokenizer.vocab.get(word, False)) and :
|
145 |
+
# new_words.append(word)
|
146 |
+
|
147 |
+
# remove duplicates and return
|
148 |
+
return new_words, new_words_freq
|
149 |
+
|
150 |
+
|
151 |
+
def get_group_texts_func(block_size=1024):
|
152 |
+
def group_texts(examples):
|
153 |
+
# Concatenate all texts.
|
154 |
+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
155 |
+
|
156 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
157 |
+
# We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
|
158 |
+
# We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
|
159 |
+
total_length = (total_length // block_size) * block_size
|
160 |
+
# Split by chunks of max_len.
|
161 |
+
result = {
|
162 |
+
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
163 |
+
for k, t in concatenated_examples.items()
|
164 |
+
}
|
165 |
+
result["labels"] = result["input_ids"].copy()
|
166 |
+
return result
|
167 |
+
return group_texts
|
168 |
+
|
169 |
+
|
170 |
+
def get_tokenize_func(tokenizer, text_col_name):
|
171 |
+
def _tokenize(examples):
|
172 |
+
output = tokenizer(
|
173 |
+
examples[text_col_name],
|
174 |
+
return_token_type_ids=False,
|
175 |
+
add_special_tokens=False,
|
176 |
+
)
|
177 |
+
return output
|
178 |
+
return _tokenize
|
179 |
+
|
180 |
+
|
181 |
+
def tokenize_and_prepare_dataset(
|
182 |
+
dataset, tokenizer, accelerator=None,
|
183 |
+
text_col_name: str = "text",
|
184 |
+
max_length: int = 256,
|
185 |
+
eval_max_samples: int = None,
|
186 |
+
):
|
187 |
+
|
188 |
+
if tokenizer.bos_token is not None and max_length:
|
189 |
+
# leave room for <BOS> token to be added:
|
190 |
+
max_tokenized_len = max_length - 1
|
191 |
+
else:
|
192 |
+
max_tokenized_len = max_length
|
193 |
+
|
194 |
+
tokenize_function = get_tokenize_func(tokenizer, text_col_name)
|
195 |
+
|
196 |
+
column_names = dataset.column_names
|
197 |
+
|
198 |
+
tokenized_dataset = dataset.map(
|
199 |
+
tokenize_function,
|
200 |
+
batched=True,
|
201 |
+
remove_columns=column_names,
|
202 |
+
load_from_cache_file=False,
|
203 |
+
desc="Running tokenizer on dataset",
|
204 |
+
)
|
205 |
+
group_texts = get_group_texts_func(block_size=max_tokenized_len)
|
206 |
+
lm_dataset = tokenized_dataset.map(
|
207 |
+
group_texts,
|
208 |
+
batched=True,
|
209 |
+
)
|
210 |
+
|
211 |
+
if eval_max_samples:
|
212 |
+
lm_dataset = lm_dataset.select(range(eval_max_samples))
|
213 |
+
|
214 |
+
return lm_dataset
|
utils/enums.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class RetrievalTechniques(Enum):
|
5 |
+
ReverseLogitLens = 1
|
6 |
+
LogitLens = 2
|
7 |
+
Patchscopes = 3
|
8 |
+
|
9 |
+
|
10 |
+
class MultiTokenKind(Enum):
|
11 |
+
Split = 1
|
12 |
+
Typo = 2
|
13 |
+
Natural = 3
|
utils/eval_utils.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from accelerate import Accelerator
|
6 |
+
from transformers import default_data_collator
|
7 |
+
from collections import defaultdict
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def is_not_number(s):
|
13 |
+
try:
|
14 |
+
float(s) # Try converting the string to a float
|
15 |
+
return False # If conversion is successful, it's a number
|
16 |
+
except ValueError:
|
17 |
+
return True # If conversion fails, it's not a number
|
18 |
+
|
19 |
+
|
20 |
+
def get_contexts_ending_with_word(word, dataset):
|
21 |
+
result_contexts = []
|
22 |
+
word_len = len(word)
|
23 |
+
|
24 |
+
# Iterate over the dataset
|
25 |
+
for example in dataset:
|
26 |
+
text = example["text"]
|
27 |
+
|
28 |
+
# Find all occurrences of the word in the text
|
29 |
+
start = 0
|
30 |
+
while True:
|
31 |
+
idx = text.find(word, start)
|
32 |
+
if idx == -1:
|
33 |
+
break
|
34 |
+
|
35 |
+
# Ensure that the word is isolated (not a substring of another word)
|
36 |
+
if (idx == 0 or not text[idx - 1].isalnum()) and (
|
37 |
+
idx + word_len == len(text) or not text[idx + word_len].isalnum()):
|
38 |
+
# Text ends with the word
|
39 |
+
result_contexts.append(text[:idx + word_len].strip())
|
40 |
+
start = idx + word_len
|
41 |
+
|
42 |
+
return result_contexts
|
43 |
+
|
44 |
+
|
45 |
+
def get_texts_containing_word(words, dataset):
|
46 |
+
result_texts = []
|
47 |
+
words_set = set(words)
|
48 |
+
|
49 |
+
# Iterate over the dataset
|
50 |
+
for example in dataset:
|
51 |
+
if words_set.intersection(set(example["text"].split())):
|
52 |
+
result_texts.append(example["text"])
|
53 |
+
|
54 |
+
return result_texts
|
55 |
+
|
56 |
+
|
57 |
+
def compute_topk_token_rank(logits, labels, k=1000):
|
58 |
+
# Get the top-k predicted logits and their indices
|
59 |
+
topk_logits, topk_indices = torch.topk(logits, k, dim=-1)
|
60 |
+
|
61 |
+
# Expand the labels for comparison
|
62 |
+
labels_expanded = labels.unsqueeze(-1).expand_as(topk_indices)
|
63 |
+
|
64 |
+
# Check if the label token is within the top-k predictions
|
65 |
+
rank_in_topk = (topk_indices == labels_expanded).nonzero(as_tuple=False)
|
66 |
+
|
67 |
+
# Create a rank tensor initialized with k (max rank is k)
|
68 |
+
ranks = torch.full(labels.shape, k, dtype=torch.long, device=logits.device)
|
69 |
+
|
70 |
+
# For labels in top-k, set the rank accordingly
|
71 |
+
ranks[rank_in_topk[:, 0], rank_in_topk[:, 1]] = rank_in_topk[:, 2] + 1
|
72 |
+
|
73 |
+
return ranks
|
74 |
+
|
75 |
+
|
76 |
+
def count_tokens_in_dataset(dataset, tokenizer, text_column='text'):
|
77 |
+
def tokenize_and_count(examples):
|
78 |
+
return {'num_tokens': [len(tokenizer(ex).input_ids) for ex in examples[text_column]]}
|
79 |
+
|
80 |
+
tokenized_dataset = dataset.map(tokenize_and_count, batched=True, remove_columns=dataset.column_names)
|
81 |
+
|
82 |
+
total_tokens = sum(tokenized_dataset['num_tokens'])
|
83 |
+
return total_tokens
|
84 |
+
|
85 |
+
|
86 |
+
def filter_single_token_words(array, tokenizer, add_space_prefix_for_lower=True):
|
87 |
+
def _is_multi_token(word):
|
88 |
+
if add_space_prefix_for_lower and word[0].islower():
|
89 |
+
word = " " + word
|
90 |
+
return len(tokenizer.encode(word, add_special_tokens=False))
|
91 |
+
token_counts = array.apply(_is_multi_token)
|
92 |
+
mask = token_counts > 1
|
93 |
+
return array[mask], token_counts
|
94 |
+
|
95 |
+
|
96 |
+
# TODO make clearer what's its use
|
97 |
+
def get_last_zero_in_every_seq_mask(tensor):
|
98 |
+
# Find where consecutive zeros end
|
99 |
+
zero_mask = (tensor == 0)
|
100 |
+
diff = torch.diff(zero_mask.int(), dim=1)
|
101 |
+
last_zero_mask = torch.cat([diff, torch.ones(tensor.size(0), 1, dtype=diff.dtype).to(tensor.device)], dim=1) == -1
|
102 |
+
|
103 |
+
# Create the output
|
104 |
+
output = 1 - tensor
|
105 |
+
output[zero_mask & ~last_zero_mask] = 0
|
106 |
+
return output
|
107 |
+
|
108 |
+
|
109 |
+
def get_first_zero_in_every_seq_mask(tensor):
|
110 |
+
# Identify where consecutive zeros begin
|
111 |
+
zero_mask = (tensor == 0)
|
112 |
+
diff = torch.diff(zero_mask.int(), dim=1, prepend=torch.zeros(tensor.size(0), 1, dtype=torch.int).to(tensor.device))
|
113 |
+
first_zero_mask = diff == 1 # Marks the beginning of each sequence of zeros
|
114 |
+
|
115 |
+
# Create the output
|
116 |
+
output = 1 - tensor
|
117 |
+
output[zero_mask & ~first_zero_mask] = 0
|
118 |
+
return output
|
119 |
+
|
120 |
+
|
121 |
+
def _add_start_token(batch, tokenizer):
|
122 |
+
bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device)
|
123 |
+
batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1)
|
124 |
+
batch["attention_mask"] = torch.cat(
|
125 |
+
[torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1)
|
126 |
+
return batch
|
127 |
+
|
128 |
+
|
129 |
+
def _ignore_new_words_in_attention_mask(shift_attention_mask_batch, shift_labels, new_token_ids=None, replaced_token_seqs_by_len=None):
|
130 |
+
# Ignore token_ids of new vocabulary words in shift_labels and shift_logits
|
131 |
+
if new_token_ids is not None:
|
132 |
+
ignore_mask = torch.isin(shift_labels, new_token_ids)
|
133 |
+
shift_attention_mask_batch = shift_attention_mask_batch * (~ignore_mask).long()
|
134 |
+
|
135 |
+
# Ignore multi-token sequences of that were replaced with a single token
|
136 |
+
if replaced_token_seqs_by_len is not None:
|
137 |
+
# Create a mask that will be updated where sequences match
|
138 |
+
ignore_mask = shift_attention_mask_batch.clone() # Clone the attention mask to modify it
|
139 |
+
# Loop over sequences in skip_token_seqs
|
140 |
+
for seq_len, seqs in replaced_token_seqs_by_len.items():
|
141 |
+
# Create a sliding window of the same size as the skip_seq and check for matches
|
142 |
+
for i in range(shift_labels.size(1) - seq_len + 1):
|
143 |
+
# Check if the sequence matches at position i
|
144 |
+
window = shift_labels[:, i:i + seq_len]
|
145 |
+
curr_mask = torch.all(window.unsqueeze(1) == seqs.unsqueeze(0), dim=-1)
|
146 |
+
if curr_mask.any():
|
147 |
+
# Zero out the ignore mask for the length of the sequence
|
148 |
+
ignore_mask[curr_mask.any(dim=-1), i:i + seq_len] = 0
|
149 |
+
# Apply the ignore mask to the attention mask
|
150 |
+
shift_attention_mask_batch *= ignore_mask
|
151 |
+
|
152 |
+
return shift_attention_mask_batch, ignore_mask
|
153 |
+
|
154 |
+
|
155 |
+
# TODO consider not aggregating results here, to enable metrics for specific words
|
156 |
+
def compute_metrics(
|
157 |
+
logits, labels, attention_mask,
|
158 |
+
compute_target_metrics=True, compute_subsequent_metrics=True, compute_perplexity=False,
|
159 |
+
return_successful_targets=False,
|
160 |
+
original_labels=None, original_logits=None,
|
161 |
+
debug=False):
|
162 |
+
target_results = dict() # will hold metrics for all the new words we add or their original tokenization
|
163 |
+
background_results = dict() # will hold metrics for all background tokens, i.e., not the ones we add or replace
|
164 |
+
overall_results = dict() # will hold metrics for all tokens
|
165 |
+
successful_targets = None # will hold list of target tokens successfully predicted
|
166 |
+
if compute_subsequent_metrics:
|
167 |
+
# prepare labels and attentions masks for computing metrics only for the 1st tokens following the new words
|
168 |
+
subsequent_labels = labels[:, 1:]
|
169 |
+
subsequent_attention_mask = get_last_zero_in_every_seq_mask(attention_mask[..., :-1].contiguous())
|
170 |
+
subsequent_attention_mask_bool = subsequent_attention_mask == 1
|
171 |
+
attention_mask_bool = attention_mask == 1
|
172 |
+
overall_mask_bool = attention_mask_bool
|
173 |
+
|
174 |
+
if compute_target_metrics:
|
175 |
+
target_mask = get_first_zero_in_every_seq_mask(attention_mask)
|
176 |
+
target_mask_bool = target_mask == 1
|
177 |
+
overall_mask_bool = attention_mask_bool | target_mask_bool
|
178 |
+
|
179 |
+
if compute_perplexity:
|
180 |
+
background_results["perplexity"] = torch.exp(
|
181 |
+
(F.cross_entropy(logits.transpose(1, 2), labels, reduction="none") * attention_mask).sum(1)
|
182 |
+
/ attention_mask.sum(1)
|
183 |
+
).mean().detach().cpu().numpy()
|
184 |
+
|
185 |
+
top1 = logits.argmax(dim=-1)
|
186 |
+
if original_logits is not None:
|
187 |
+
orig_top1 = original_logits.argmax(dim=-1)
|
188 |
+
|
189 |
+
if compute_target_metrics:
|
190 |
+
target_results["top1_acc"] = ((labels == top1)[target_mask_bool]).detach().cpu().numpy()
|
191 |
+
if original_labels is not None:
|
192 |
+
target_results["sum_top1_acc"] = (
|
193 |
+
((original_labels == top1) | (labels == top1))[target_mask_bool]).detach().cpu().numpy()
|
194 |
+
if original_logits is not None:
|
195 |
+
target_results["orig_top1_acc"] = (
|
196 |
+
(original_labels == orig_top1)[target_mask_bool]).detach().cpu().numpy()
|
197 |
+
|
198 |
+
if return_successful_targets:
|
199 |
+
successful_targets = (labels[(labels == top1) & target_mask_bool]).detach().cpu().numpy()
|
200 |
+
|
201 |
+
background_results["top1_acc"] = ((
|
202 |
+
labels == top1)[attention_mask_bool]).detach().cpu().numpy()
|
203 |
+
if compute_subsequent_metrics:
|
204 |
+
background_results["subsequent_top1_acc"] = ((subsequent_labels == top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy()
|
205 |
+
if original_logits is not None:
|
206 |
+
background_results["orig_top1_acc"] = (
|
207 |
+
(original_labels == orig_top1)[attention_mask_bool]).detach().cpu().numpy()
|
208 |
+
if compute_subsequent_metrics:
|
209 |
+
background_results["orig_subsequent_top1_acc"] = (
|
210 |
+
(subsequent_labels == orig_top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy()
|
211 |
+
|
212 |
+
overall_results["top1_acc"] = ((labels == top1))[overall_mask_bool].detach().cpu().numpy()
|
213 |
+
if original_labels is not None:
|
214 |
+
overall_results["sum_top1_acc"] = (
|
215 |
+
((original_labels == top1) | (labels == top1)))[overall_mask_bool].detach().cpu().numpy()
|
216 |
+
if original_logits is not None:
|
217 |
+
overall_results["orig_top1_acc"] = (
|
218 |
+
(original_labels == orig_top1)[overall_mask_bool]).detach().cpu().numpy()
|
219 |
+
|
220 |
+
if debug:
|
221 |
+
import pdb; pdb.set_trace()
|
222 |
+
return background_results, target_results, overall_results, successful_targets
|
223 |
+
|
224 |
+
|
225 |
+
def eval_next_word_prediction(
|
226 |
+
model, tokenizer, lm_dataset, accelerator=None,
|
227 |
+
batch_size: int = 4,
|
228 |
+
new_token_ids=None, replaced_token_seqs_by_len=None,
|
229 |
+
new_token_to_original_first_token=None,
|
230 |
+
max_length: int = 256,
|
231 |
+
drop_last: bool = True,
|
232 |
+
eval_max_samples: int = None,
|
233 |
+
eval_shuffle_samples: bool = False,
|
234 |
+
reduction="none",
|
235 |
+
):
|
236 |
+
if accelerator is None:
|
237 |
+
accelerator = Accelerator()
|
238 |
+
model.eval()
|
239 |
+
if tokenizer.bos_token is not None and max_length:
|
240 |
+
add_start_token = True
|
241 |
+
else:
|
242 |
+
add_start_token = False
|
243 |
+
|
244 |
+
data_collator = default_data_collator
|
245 |
+
|
246 |
+
if eval_max_samples:
|
247 |
+
eval_idx = range(len(lm_dataset), min(eval_max_samples, len(lm_dataset)))
|
248 |
+
if eval_shuffle_samples:
|
249 |
+
eval_idx = np.random.choice(len(lm_dataset), min(eval_max_samples, len(lm_dataset)))
|
250 |
+
lm_dataset = lm_dataset.select(eval_idx)
|
251 |
+
|
252 |
+
# Create data loaders
|
253 |
+
eval_dataloader = DataLoader(
|
254 |
+
lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=drop_last, shuffle=False,
|
255 |
+
)
|
256 |
+
eval_dataloader = accelerator.prepare(eval_dataloader)
|
257 |
+
|
258 |
+
model.eval()
|
259 |
+
|
260 |
+
if new_token_ids is not None:
|
261 |
+
new_token_ids = torch.tensor(new_token_ids).to(model.device)
|
262 |
+
if replaced_token_seqs_by_len is not None:
|
263 |
+
replaced_token_seqs_by_len = {token_length: torch.tensor(skip_token_seqs).to(model.device) for token_length, skip_token_seqs in replaced_token_seqs_by_len.items() if len(skip_token_seqs) > 0}
|
264 |
+
if new_token_to_original_first_token is not None:
|
265 |
+
# Convert the mapping into a tensor for efficient indexing, create a mapping tensor that defaults to identity
|
266 |
+
new_token_to_orig_first_mapping_tensor = torch.arange(len(tokenizer), device=model.device)
|
267 |
+
new_token_to_orig_first_mapping_tensor[torch.tensor(list(new_token_to_original_first_token.keys()), device=model.device)] = \
|
268 |
+
torch.tensor(list(new_token_to_original_first_token.values()), device=model.device)
|
269 |
+
|
270 |
+
target_metrics = defaultdict(list)
|
271 |
+
background_metrics = defaultdict(list)
|
272 |
+
overall_metrics = defaultdict(list)
|
273 |
+
|
274 |
+
# run eval and compute metrics
|
275 |
+
for batch_i, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader), miniters=10, desc="Evaluating vocabulary..."):
|
276 |
+
if add_start_token:
|
277 |
+
batch = _add_start_token(batch, tokenizer)
|
278 |
+
|
279 |
+
labels = batch["input_ids"]
|
280 |
+
attn_mask = batch["attention_mask"]
|
281 |
+
batch.pop("labels")
|
282 |
+
with torch.no_grad():
|
283 |
+
outputs = model(**batch)
|
284 |
+
out_logits = outputs.logits
|
285 |
+
|
286 |
+
shift_logits = out_logits[..., :-1, :].contiguous()
|
287 |
+
shift_labels = labels[..., 1:].contiguous()
|
288 |
+
shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
|
289 |
+
|
290 |
+
shift_attention_mask_batch, ignore_mask = \
|
291 |
+
_ignore_new_words_in_attention_mask(
|
292 |
+
shift_attention_mask_batch, shift_labels, new_token_ids, replaced_token_seqs_by_len)
|
293 |
+
original_labels = None if new_token_to_original_first_token is None \
|
294 |
+
else new_token_to_orig_first_mapping_tensor[shift_labels]
|
295 |
+
original_logits = None if new_token_ids is None else torch.cat([shift_logits[:, :, :min(new_token_ids)], shift_logits[:, :, max(new_token_ids)+1:]], dim=-1)
|
296 |
+
|
297 |
+
background_results, target_results, overall_results, successful_targets = \
|
298 |
+
compute_metrics(
|
299 |
+
shift_logits, shift_labels, shift_attention_mask_batch,
|
300 |
+
original_labels=original_labels, original_logits=original_logits, compute_perplexity=True)
|
301 |
+
|
302 |
+
for metric_name, metric_value in target_results.items():
|
303 |
+
target_metrics[metric_name].append(np.array(metric_value))
|
304 |
+
for metric_name, metric_value in background_results.items():
|
305 |
+
background_metrics[metric_name].append(metric_value)
|
306 |
+
for metric_name, metric_value in overall_results.items():
|
307 |
+
overall_metrics[metric_name].append(metric_value)
|
308 |
+
|
309 |
+
eval_dataloader = accelerator.free_memory(eval_dataloader)
|
310 |
+
|
311 |
+
def _concat_func(x):
|
312 |
+
if isinstance(x, np.ndarray) and len(x.shape) > 1:
|
313 |
+
x = np.concat(x)
|
314 |
+
elif isinstance(x, (list, tuple)) and len(x) > 1:
|
315 |
+
if isinstance(x[0], np.ndarray) and len(x[0].shape) == 0:
|
316 |
+
x = np.array(x)
|
317 |
+
else:
|
318 |
+
x = np.concat(x)
|
319 |
+
return x
|
320 |
+
|
321 |
+
# apply reduction
|
322 |
+
reduce_func = _concat_func
|
323 |
+
if reduction == 'mean':
|
324 |
+
reduce_func = lambda x: np.mean(_concat_func(x)).item()
|
325 |
+
|
326 |
+
for metric_name, metric_value in target_metrics.items():
|
327 |
+
target_metrics[metric_name] = reduce_func(metric_value)
|
328 |
+
for metric_name, metric_value in background_metrics.items():
|
329 |
+
background_metrics[metric_name] = reduce_func(metric_value)
|
330 |
+
for metric_name, metric_value in overall_metrics.items():
|
331 |
+
overall_metrics[metric_name] = reduce_func(metric_value)
|
332 |
+
return background_metrics, target_metrics, overall_metrics
|
333 |
+
|
334 |
+
|
utils/file_utils.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
|
6 |
+
def save_df_to_dir(results_df, base_dir, sub_dirs, file_name_format, add_context, model_name):
|
7 |
+
# Get the root directory of the project
|
8 |
+
root_dir = os.path.dirname(os.path.abspath(__file__))
|
9 |
+
|
10 |
+
# Construct the output directory path
|
11 |
+
output_dir = os.path.join(root_dir, base_dir, *sub_dirs)
|
12 |
+
os.makedirs(output_dir, exist_ok=True)
|
13 |
+
|
14 |
+
# Construct the file name
|
15 |
+
file_name = file_name_format.format(model_name=model_name,
|
16 |
+
context="with_context" if add_context else "without_context")
|
17 |
+
|
18 |
+
# Construct the full file path
|
19 |
+
file_path = os.path.join(output_dir, file_name)
|
20 |
+
|
21 |
+
# Save the DataFrame to CSV
|
22 |
+
results_df.to_csv(file_path, index=False)
|
23 |
+
|
24 |
+
|
25 |
+
def merge_dfs(base_dir, exp_name, part_format="part_{i}_", output_dir=None,
|
26 |
+
filename="patchscopes_results.parquet", output_filename="patchscopes_results.parquet"):
|
27 |
+
"""
|
28 |
+
Merges DataFrames from directories matching the part format into a single DataFrame,
|
29 |
+
and optionally saves the result to a file.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
base_dir (str): The base directory containing the data.
|
33 |
+
exp_name (str): The experiment name to look for within part directories.
|
34 |
+
part_format (str): The general format for identifying parts (e.g., "part_{i}_").
|
35 |
+
output_dir (str, optional): Directory to save the merged DataFrame. Default is None.
|
36 |
+
filename (str): The filename of the Parquet file to read in each part directory.
|
37 |
+
output_filename (str): Name of the output file if saving is enabled.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
pd.DataFrame: A single DataFrame containing data from all parts.
|
41 |
+
"""
|
42 |
+
dataframes = []
|
43 |
+
part_regex = part_format.replace("{i}", r"\d+")
|
44 |
+
|
45 |
+
# List all directories in base_dir
|
46 |
+
for dir_name in os.listdir(base_dir):
|
47 |
+
if os.path.isdir(os.path.join(base_dir, dir_name)) and re.match(part_regex, dir_name) and (dir_name.endswith(exp_name)):
|
48 |
+
part_dir = os.path.join(base_dir, dir_name)
|
49 |
+
file_path = os.path.join(part_dir, filename)
|
50 |
+
|
51 |
+
if os.path.exists(file_path):
|
52 |
+
# Read the DataFrame and add it to the list
|
53 |
+
df = pd.read_parquet(file_path)
|
54 |
+
dataframes.append(df)
|
55 |
+
|
56 |
+
# Concatenate all DataFrames into a single DataFrame
|
57 |
+
merged_df = pd.concat(dataframes, axis=1)
|
58 |
+
|
59 |
+
# Save the result to file if output_dir is given
|
60 |
+
if output_dir:
|
61 |
+
os.makedirs(output_dir, exist_ok=True)
|
62 |
+
output_path = os.path.join(output_dir, output_filename)
|
63 |
+
merged_df.to_parquet(output_path, index=False)
|
64 |
+
|
65 |
+
return merged_df, dataframes
|
66 |
+
|
67 |
+
|
68 |
+
def parse_string_list_from_file(file_path, delimiter=None):
|
69 |
+
"""
|
70 |
+
Parses a list of strings from a file, handling various list formats.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
file_path (str): Path to the file containing the list.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
list: A list of parsed strings.
|
77 |
+
"""
|
78 |
+
with open(file_path, 'r') as file:
|
79 |
+
content = file.read()
|
80 |
+
|
81 |
+
if delimiter is None:
|
82 |
+
# Remove newlines and excess whitespace
|
83 |
+
content = re.sub(r'\s+', ' ', content.strip())
|
84 |
+
|
85 |
+
# Handle different delimiters and list formats
|
86 |
+
# Removes common list notations like commas, brackets, quotes, etc.
|
87 |
+
items = re.split(r'[,\[\]\(\)\{\}"\'\s]+', content)
|
88 |
+
else:
|
89 |
+
if delimiter == "newline": # TODO fix this
|
90 |
+
delimiter = "\n"
|
91 |
+
items = [item.strip() for item in content.split(delimiter)]
|
92 |
+
|
93 |
+
# Filter out any empty strings from the list
|
94 |
+
return [item for item in items if item]
|
utils/logit_lens.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Provides a class for mapping transformer hidden states to logits (and vice versa).
|
2 |
+
Example:
|
3 |
+
|
4 |
+
from standalone_logit_lens import LogitLens, ReverseLogitLens
|
5 |
+
|
6 |
+
model = AutoModelForCausalLM.from_pretrained(model_name).to(device).to(dtype)
|
7 |
+
lens = LogitLens.from_model(model).to(device).to(dtype)
|
8 |
+
reverse_lens = ReverseLogitLens.from_model(model).to(device).to(dtype)
|
9 |
+
|
10 |
+
hidden_state = ...
|
11 |
+
result = lens(hidden_state, layer_index) # layer_index is not really used, you can pass whatever
|
12 |
+
"""
|
13 |
+
|
14 |
+
import abc
|
15 |
+
import logging
|
16 |
+
|
17 |
+
import copy
|
18 |
+
from typing import Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
import transformers
|
25 |
+
from transformers import models
|
26 |
+
from transformers import PreTrainedModel
|
27 |
+
|
28 |
+
|
29 |
+
Model = Union[PreTrainedModel]
|
30 |
+
Norm = Union[
|
31 |
+
nn.LayerNorm,
|
32 |
+
models.llama.modeling_llama.LlamaRMSNorm,
|
33 |
+
models.gemma.modeling_gemma.GemmaRMSNorm,
|
34 |
+
models.gemma2.modeling_gemma2.Gemma2RMSNorm,
|
35 |
+
nn.Module,
|
36 |
+
]
|
37 |
+
|
38 |
+
|
39 |
+
def get_unembedding_matrix(model: Model) -> nn.Linear:
|
40 |
+
"""The final linear tranformation from the model hidden state to the output."""
|
41 |
+
if isinstance(model, PreTrainedModel):
|
42 |
+
unembed = model.get_output_embeddings()
|
43 |
+
if not isinstance(unembed, nn.Linear):
|
44 |
+
raise ValueError("We currently only support linear unemebdings")
|
45 |
+
return unembed
|
46 |
+
else:
|
47 |
+
raise ValueError(f"Model class {type(model)} not recognized!")
|
48 |
+
|
49 |
+
|
50 |
+
def get_embedding_matrix(model: nn.Module) -> nn.Embedding:
|
51 |
+
"""The initial embedding matrix from the input tokens to the model hidden state."""
|
52 |
+
if isinstance(model, PreTrainedModel):
|
53 |
+
embed = model.get_input_embeddings()
|
54 |
+
if not isinstance(embed, nn.Embedding):
|
55 |
+
raise ValueError("We currently only support embedding matrices")
|
56 |
+
return embed
|
57 |
+
else:
|
58 |
+
raise ValueError(f"Model class {type(model)} not recognized!")
|
59 |
+
|
60 |
+
|
61 |
+
def get_final_norm(model: Model) -> Norm:
|
62 |
+
"""Get the final norm from a model.
|
63 |
+
|
64 |
+
This isn't standardized across models, so this will need to be updated as
|
65 |
+
we add new models.
|
66 |
+
"""
|
67 |
+
|
68 |
+
if not hasattr(model, "base_model"):
|
69 |
+
raise ValueError("Model does not have a `base_model` attribute.")
|
70 |
+
|
71 |
+
base_model = model.base_model
|
72 |
+
if isinstance(base_model, models.opt.modeling_opt.OPTModel):
|
73 |
+
final_layer_norm = base_model.decoder.final_layer_norm
|
74 |
+
elif isinstance(base_model, models.gpt_neox.modeling_gpt_neox.GPTNeoXModel):
|
75 |
+
final_layer_norm = base_model.final_layer_norm
|
76 |
+
elif isinstance(
|
77 |
+
base_model,
|
78 |
+
(
|
79 |
+
models.bloom.modeling_bloom.BloomModel,
|
80 |
+
models.gpt2.modeling_gpt2.GPT2Model,
|
81 |
+
models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
|
82 |
+
models.gptj.modeling_gptj.GPTJModel,
|
83 |
+
),
|
84 |
+
):
|
85 |
+
final_layer_norm = base_model.ln_f
|
86 |
+
elif isinstance(base_model, models.llama.modeling_llama.LlamaModel):
|
87 |
+
final_layer_norm = base_model.norm
|
88 |
+
elif isinstance(base_model, models.mistral.modeling_mistral.MistralModel):
|
89 |
+
final_layer_norm = base_model.norm
|
90 |
+
elif isinstance(base_model, models.t5.modeling_t5.T5ForConditionalGeneration):
|
91 |
+
# For T5, use the LayerNorm from the last decoder block, before the feed-forward layer.
|
92 |
+
final_layer_norm = base_model.decoder.block[-1].layer[1].layer_norm
|
93 |
+
else:
|
94 |
+
raise NotImplementedError(f"Unknown model type {type(base_model)}")
|
95 |
+
|
96 |
+
if final_layer_norm is None:
|
97 |
+
raise ValueError("Model does not have a final layer norm.")
|
98 |
+
|
99 |
+
assert isinstance(final_layer_norm, Norm.__args__) # type: ignore
|
100 |
+
|
101 |
+
return final_layer_norm
|
102 |
+
|
103 |
+
|
104 |
+
class Unembed(nn.Module):
|
105 |
+
"""Module that maps transformer hidden states to logits (and vice versa)."""
|
106 |
+
|
107 |
+
final_norm: Norm
|
108 |
+
unembedding: nn.Linear
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
model: Model,
|
113 |
+
):
|
114 |
+
"""Initialize unmebed.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
model: A HuggingFace model from which to extract the unembedding matrix.
|
118 |
+
"""
|
119 |
+
super().__init__()
|
120 |
+
final_norm = get_final_norm(model)
|
121 |
+
unembedding_matrix = get_unembedding_matrix(model)
|
122 |
+
|
123 |
+
self.final_norm = copy.deepcopy(final_norm)
|
124 |
+
self.unembedding = copy.deepcopy(unembedding_matrix)
|
125 |
+
|
126 |
+
# In general we don't want to finetune the unembed operation.
|
127 |
+
self.requires_grad_(False)
|
128 |
+
|
129 |
+
def forward(self, h: torch.Tensor) -> torch.Tensor:
|
130 |
+
"""Convert hidden states into logits."""
|
131 |
+
return self.unembedding(self.final_norm(h))
|
132 |
+
|
133 |
+
|
134 |
+
class Reembed(nn.Module):
|
135 |
+
"""Module that maps transformer hidden states to logits (and vice versa)."""
|
136 |
+
embedding: torch.Tensor
|
137 |
+
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
model: Model,
|
141 |
+
distance_metric: str = "logits",
|
142 |
+
):
|
143 |
+
"""Initialize unmebed.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
model: A HuggingFace model from which to extract the unembedding matrix.
|
147 |
+
"""
|
148 |
+
super().__init__()
|
149 |
+
embedding_matrix = get_embedding_matrix(model)
|
150 |
+
|
151 |
+
self.embedding = copy.deepcopy(embedding_matrix.weight.data)
|
152 |
+
|
153 |
+
self.distance_metric = distance_metric
|
154 |
+
|
155 |
+
# In general we don't want to finetune the unembed operation.
|
156 |
+
self.requires_grad_(False)
|
157 |
+
|
158 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
159 |
+
"""Convert hidden states into logits."""
|
160 |
+
|
161 |
+
if self.distance_metric == 'logits':
|
162 |
+
logits = torch.matmul(hidden_state, self.embedding.T).squeeze(0)
|
163 |
+
|
164 |
+
elif self.distance_metric == 'cosine':
|
165 |
+
# Normalize E and h
|
166 |
+
E_normalized = F.normalize(self.embedding, p=2, dim=-1)
|
167 |
+
h_normalized = F.normalize(hidden_state, p=2, dim=-1)
|
168 |
+
|
169 |
+
# Compute cosine similarity
|
170 |
+
logits = torch.matmul(h_normalized, E_normalized.T).squeeze(0)
|
171 |
+
|
172 |
+
elif self.distance_metric == 'euclidean':
|
173 |
+
# Compute Euclidean distance
|
174 |
+
distances = torch.cdist(hidden_state, self.embedding, p=2).squeeze(0)
|
175 |
+
|
176 |
+
# Convert distances to logits (negative distance for logits-like values)
|
177 |
+
logits = -distances
|
178 |
+
|
179 |
+
else: # Compute regular dot-product as a similarity measure
|
180 |
+
logits = torch.matmul(hidden_state, self.embedding.T).squeeze(0)
|
181 |
+
return logits
|
182 |
+
|
183 |
+
|
184 |
+
class ReverseLens(abc.ABC, nn.Module):
|
185 |
+
"""Abstract base class for all Lens."""
|
186 |
+
|
187 |
+
reembed: Reembed
|
188 |
+
|
189 |
+
def __init__(self, reembed: Reembed):
|
190 |
+
"""Create a Lens.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
unembed: The unembed operation to use.
|
194 |
+
"""
|
195 |
+
super().__init__()
|
196 |
+
|
197 |
+
self.reembed = reembed
|
198 |
+
|
199 |
+
@abc.abstractmethod
|
200 |
+
def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor:
|
201 |
+
"""Decode hidden states into logits."""
|
202 |
+
...
|
203 |
+
|
204 |
+
|
205 |
+
class ReverseLogitLens(ReverseLens):
|
206 |
+
"""Reembeds the residual stream into logits."""
|
207 |
+
|
208 |
+
reembed: Reembed
|
209 |
+
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
reembed: Reembed,
|
213 |
+
):
|
214 |
+
"""Create a Reverse Logit Lens.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
reembed: The reembed operation to use.
|
218 |
+
"""
|
219 |
+
super().__init__(reembed)
|
220 |
+
|
221 |
+
@classmethod
|
222 |
+
def from_model(
|
223 |
+
cls,
|
224 |
+
model: PreTrainedModel,
|
225 |
+
) -> "ReverseLogitLens":
|
226 |
+
"""Create a ReverseLogitLens from a pretrained model.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
model: A pretrained model from the transformers library you wish to inspect.
|
230 |
+
"""
|
231 |
+
reembed = Reembed(model)
|
232 |
+
return cls(reembed)
|
233 |
+
|
234 |
+
def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor:
|
235 |
+
"""Decode a hidden state into logits.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
h: The hidden state to decode.
|
239 |
+
idx: the layer of the transformer these hidden states come from.
|
240 |
+
"""
|
241 |
+
del idx
|
242 |
+
return self.reembed.forward(h)
|
243 |
+
|
244 |
+
|
245 |
+
class Lens(abc.ABC, nn.Module):
|
246 |
+
"""Abstract base class for all Lens."""
|
247 |
+
|
248 |
+
unembed: Unembed
|
249 |
+
|
250 |
+
def __init__(self, unembed: Unembed):
|
251 |
+
"""Create a Lens.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
unembed: The unembed operation to use.
|
255 |
+
"""
|
256 |
+
super().__init__()
|
257 |
+
|
258 |
+
self.unembed = unembed
|
259 |
+
|
260 |
+
@abc.abstractmethod
|
261 |
+
def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor:
|
262 |
+
"""Decode hidden states into logits."""
|
263 |
+
...
|
264 |
+
|
265 |
+
|
266 |
+
class LogitLens(Lens):
|
267 |
+
"""Unembeds the residual stream into logits."""
|
268 |
+
|
269 |
+
unembed: Unembed
|
270 |
+
|
271 |
+
def __init__(
|
272 |
+
self,
|
273 |
+
unembed: Unembed,
|
274 |
+
):
|
275 |
+
"""Create a Logit Lens.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
unembed: The unembed operation to use.
|
279 |
+
"""
|
280 |
+
super().__init__(unembed)
|
281 |
+
|
282 |
+
@classmethod
|
283 |
+
def from_model(
|
284 |
+
cls,
|
285 |
+
model: PreTrainedModel,
|
286 |
+
) -> "LogitLens":
|
287 |
+
"""Create a LogitLens from a pretrained model.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
model: A pretrained model from the transformers library you wish to inspect.
|
291 |
+
"""
|
292 |
+
unembed = Unembed(model)
|
293 |
+
return cls(unembed)
|
294 |
+
|
295 |
+
def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor:
|
296 |
+
"""Decode a hidden state into logits.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
h: The hidden state to decode.
|
300 |
+
idx: the layer of the transformer these hidden states come from.
|
301 |
+
"""
|
302 |
+
del idx
|
303 |
+
return self.unembed.forward(h)
|
304 |
+
|
utils/model_utils.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from typing import Iterable, List, Union
|
3 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from sklearn.linear_model import LinearRegression
|
7 |
+
import torch.optim as optim
|
8 |
+
from torch.utils.data import DataLoader, TensorDataset
|
9 |
+
|
10 |
+
|
11 |
+
def extract_token_i_hidden_states(
|
12 |
+
model: PreTrainedModel,
|
13 |
+
tokenizer: PreTrainedTokenizer,
|
14 |
+
inputs: Union[str, List[str]],
|
15 |
+
token_idx_to_extract: int = -1,
|
16 |
+
batch_size: int = 1,
|
17 |
+
layers_to_extract: List[int] = None,
|
18 |
+
return_dict: bool = True,
|
19 |
+
verbose: bool = True,
|
20 |
+
) -> torch.Tensor:
|
21 |
+
device = model.device
|
22 |
+
model.eval()
|
23 |
+
|
24 |
+
if isinstance(inputs, str):
|
25 |
+
inputs = [inputs]
|
26 |
+
|
27 |
+
if layers_to_extract is None:
|
28 |
+
layers_to_extract = list(range(1, model.config.num_hidden_layers + 1)) # extract all but initial embeddings
|
29 |
+
all_hidden_states = {layer: [] for layer in layers_to_extract}
|
30 |
+
|
31 |
+
with torch.no_grad():
|
32 |
+
for i in tqdm(range(0, len(inputs), batch_size), desc="Extracting hidden states", unit="batch", disable=not verbose):
|
33 |
+
input_ids = tokenizer(inputs[i:i+batch_size], return_tensors="pt", return_attention_mask=False)['input_ids']
|
34 |
+
try:
|
35 |
+
outputs = model(input_ids.to(device), output_hidden_states=True)
|
36 |
+
except:
|
37 |
+
import pdb; pdb.set_trace()
|
38 |
+
# from transformers import AutoModelForCausalLM
|
39 |
+
# model2 = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", torch_dtype=torch.bfloat16).to(device)
|
40 |
+
for input_i in range(len(input_ids)):
|
41 |
+
for layer in layers_to_extract:
|
42 |
+
hidden_states = outputs.hidden_states[layer]
|
43 |
+
all_hidden_states[layer].append(hidden_states[:, token_idx_to_extract, :].detach().cpu())
|
44 |
+
for layer in all_hidden_states:
|
45 |
+
all_hidden_states[layer] = torch.concat(all_hidden_states[layer], dim=0)
|
46 |
+
|
47 |
+
if not return_dict:
|
48 |
+
all_hidden_states = torch.concat([all_hidden_states[layer] for layer in layers_to_extract], dim=0)
|
49 |
+
|
50 |
+
return all_hidden_states
|
51 |
+
|
52 |
+
|
53 |
+
def extract_vocab_hidden_states(
|
54 |
+
model: PreTrainedModel,
|
55 |
+
tokenizer: PreTrainedTokenizer,
|
56 |
+
tokens_ids_to_extract: Iterable[int] = None,
|
57 |
+
prompt: str = "{target}",
|
58 |
+
prompt_target: str = "{target}",
|
59 |
+
batch_size: int = 128,
|
60 |
+
layers_to_extract: List[int] = None
|
61 |
+
) -> torch.Tensor:
|
62 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
63 |
+
model.to(device)
|
64 |
+
model.eval()
|
65 |
+
|
66 |
+
if layers_to_extract is None:
|
67 |
+
layers_to_extract = list(range(1, model.config.num_hidden_layers + 1)) # extract all but initial embeddings
|
68 |
+
all_hidden_states = {layer: [] for layer in layers_to_extract}
|
69 |
+
tokens_ids_to_extract = tokens_ids_to_extract if tokens_ids_to_extract is not None else range(tokenizer.vocab_size)
|
70 |
+
tokens_to_extract = [tokenizer.decode(tok_id) for tok_id in tokens_ids_to_extract]
|
71 |
+
|
72 |
+
# add pad token if necessary
|
73 |
+
if tokenizer.pad_token is None:
|
74 |
+
tokenizer.pad_token = tokenizer.eos_token
|
75 |
+
|
76 |
+
with torch.no_grad():
|
77 |
+
for i in tqdm(range(0, len(tokens_to_extract), batch_size), desc="Extracting hidden states", unit="batch"):
|
78 |
+
prompts = [prompt.replace(prompt_target, target) for target in tokens_to_extract[i:i+batch_size]]
|
79 |
+
input_ids = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")["input_ids"]
|
80 |
+
# input_ids = tokenizer(prompts, return_tensors="pt")["input_ids"]
|
81 |
+
outputs = model(input_ids.to(device), output_hidden_states=True)
|
82 |
+
for layer in layers_to_extract:
|
83 |
+
hidden_states = outputs.hidden_states[layer]
|
84 |
+
all_hidden_states[layer].append(hidden_states[:, -1, :].detach().cpu())
|
85 |
+
|
86 |
+
for layer in all_hidden_states:
|
87 |
+
all_hidden_states[layer] = torch.concat(all_hidden_states[layer], dim=0)
|
88 |
+
|
89 |
+
return all_hidden_states
|
90 |
+
|
91 |
+
|
92 |
+
def get_vocab_tokens(tokenizer: PreTrainedTokenizer, min_word_len: int = None):
|
93 |
+
vocab_size = tokenizer.vocab_size
|
94 |
+
tokens = list(range(vocab_size))
|
95 |
+
if min_word_len:
|
96 |
+
tokens_str = [tokenizer.decode(i) for i in tokens]
|
97 |
+
tokens_len = [len(x) for x in tokens_str]
|
98 |
+
tokens = [tok for tok, tok_len in zip(tokens, tokens_len) if tok_len >= min_word_len]
|
99 |
+
return tokens
|
100 |
+
|
101 |
+
|
102 |
+
def learn_linear_map(X: torch.Tensor, Y: torch.Tensor, fit_intercept=False):
|
103 |
+
input_dtype = X.dtype
|
104 |
+
linear_reg = LinearRegression(fit_intercept=fit_intercept).fit(X.cpu().to(float).numpy(), Y.cpu().to(float).numpy())
|
105 |
+
linear_map = nn.Linear(X.size(1), Y.size(1), bias=fit_intercept)
|
106 |
+
with torch.no_grad():
|
107 |
+
linear_map.weight.data = torch.Tensor(linear_reg.coef_.T)
|
108 |
+
if fit_intercept:
|
109 |
+
linear_map.bias.data = torch.Tensor(linear_reg.intercept_)
|
110 |
+
linear_map = linear_map.to(input_dtype)
|
111 |
+
return linear_map
|
112 |
+
|
113 |
+
|
114 |
+
def train_model(
|
115 |
+
model,
|
116 |
+
dataloader,
|
117 |
+
optimizer,
|
118 |
+
loss_func="mse",
|
119 |
+
scheduler=None,
|
120 |
+
num_epochs=5,
|
121 |
+
gradient_accumulation_steps=1,
|
122 |
+
max_grads_norm=1.0,
|
123 |
+
):
|
124 |
+
"""
|
125 |
+
Trains a two-layer MLP to map hidden states from X to Y.
|
126 |
+
|
127 |
+
Parameters:
|
128 |
+
X (torch.Tensor): Input tensor of shape (N, D).
|
129 |
+
Y (torch.Tensor): Target tensor of shape (N, D).
|
130 |
+
activation_func (nn.Module): Activation function for the hidden layer. Default is SiLU.
|
131 |
+
lr (float): Learning rate. Default is 0.001.
|
132 |
+
weight_decay (float): Weight decay for the optimizer. Default is 0.0.
|
133 |
+
loss_func (str): Loss function to use ('mse', 'huber', 'cosine'). Default is 'mse'.
|
134 |
+
lr_schedule (str): Learning rate schedule. Default is 'linear'.
|
135 |
+
num_epochs (int): Number of training epochs. Default is 20.
|
136 |
+
batch_size (int): Batch size for DataLoader. Default is 32.
|
137 |
+
gradient_accumulation_steps (int): Number of steps to accumulate gradients. Default is 1.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
nn.Module: Trained MLP model.
|
141 |
+
"""
|
142 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
143 |
+
|
144 |
+
# Select loss function
|
145 |
+
if loss_func == "mse":
|
146 |
+
criterion = nn.MSELoss()
|
147 |
+
elif loss_func == "huber":
|
148 |
+
criterion = nn.HuberLoss()
|
149 |
+
elif loss_func == "cosine":
|
150 |
+
criterion = nn.CosineEmbeddingLoss()
|
151 |
+
else:
|
152 |
+
raise ValueError("Unsupported loss function. Choose from 'mse', 'huber', or 'cosine'.")
|
153 |
+
|
154 |
+
# Training loop
|
155 |
+
model.train()
|
156 |
+
for epoch in range(num_epochs):
|
157 |
+
epoch_loss = 0.0
|
158 |
+
for i, (x_batch, y_batch) in enumerate(dataloader):
|
159 |
+
outputs = model(x_batch.to(device))
|
160 |
+
if loss_func == "cosine":
|
161 |
+
# Cosine loss requires an additional target tensor of 1s
|
162 |
+
loss = criterion(outputs, y_batch.to(device), torch.ones(x_batch.size(0)))
|
163 |
+
else:
|
164 |
+
loss = criterion(outputs, y_batch.to(device))
|
165 |
+
|
166 |
+
loss = loss / gradient_accumulation_steps
|
167 |
+
loss.backward()
|
168 |
+
|
169 |
+
if max_grads_norm is not None:
|
170 |
+
nn.utils.clip_grad_norm_(model.parameters(), max_grads_norm)
|
171 |
+
|
172 |
+
if (i + 1) % gradient_accumulation_steps == 0 or (i + 1) == len(dataloader):
|
173 |
+
optimizer.step()
|
174 |
+
optimizer.zero_grad()
|
175 |
+
if scheduler:
|
176 |
+
scheduler.step()
|
177 |
+
|
178 |
+
epoch_loss += loss.item() * gradient_accumulation_steps
|
179 |
+
|
180 |
+
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(dataloader):.6f}")
|
181 |
+
|
182 |
+
return model.cpu()
|
183 |
+
|
184 |
+
|
185 |
+
def learn_mlp(
|
186 |
+
X: torch.Tensor, Y: torch.Tensor,
|
187 |
+
activation_func=nn.SiLU,
|
188 |
+
batch_size=128,
|
189 |
+
lr=0.001,
|
190 |
+
weight_decay=0.0,
|
191 |
+
loss_func="mse",
|
192 |
+
lr_schedule="linear",
|
193 |
+
expansion_alpha=1.0,
|
194 |
+
num_epochs=5,
|
195 |
+
gradient_accumulation_steps=1,
|
196 |
+
max_grads_norm=1.0,
|
197 |
+
):
|
198 |
+
"""
|
199 |
+
Trains a two-layer MLP to map hidden states from X to Y.
|
200 |
+
|
201 |
+
Parameters:
|
202 |
+
X (torch.Tensor): Input tensor of shape (N, D).
|
203 |
+
Y (torch.Tensor): Target tensor of shape (N, D).
|
204 |
+
activation_func (nn.Module): Activation function for the hidden layer. Default is SiLU.
|
205 |
+
lr (float): Learning rate. Default is 0.001.
|
206 |
+
weight_decay (float): Weight decay for the optimizer. Default is 0.0.
|
207 |
+
loss_func (str): Loss function to use ('mse', 'huber', 'cosine'). Default is 'mse'.
|
208 |
+
lr_schedule (str): Learning rate schedule. Default is 'linear'.
|
209 |
+
num_epochs (int): Number of training epochs. Default is 20.
|
210 |
+
batch_size (int): Batch size for DataLoader. Default is 32.
|
211 |
+
gradient_accumulation_steps (int): Number of steps to accumulate gradients. Default is 1.
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
nn.Module: Trained MLP model.
|
215 |
+
"""
|
216 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
217 |
+
input_dim = X.shape[1]
|
218 |
+
hidden_dim = int(input_dim * expansion_alpha)
|
219 |
+
output_dim = Y.shape[1]
|
220 |
+
model = nn.Sequential(
|
221 |
+
nn.Linear(input_dim, hidden_dim),
|
222 |
+
activation_func(),
|
223 |
+
nn.Linear(hidden_dim, output_dim)
|
224 |
+
).to(device)
|
225 |
+
|
226 |
+
# Optimizer
|
227 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
228 |
+
|
229 |
+
# DataLoader setup
|
230 |
+
dataset = TensorDataset(X, Y)
|
231 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
232 |
+
|
233 |
+
# Learning rate scheduler
|
234 |
+
if lr_schedule == "linear":
|
235 |
+
total_steps = (len(dataloader) * num_epochs) // gradient_accumulation_steps
|
236 |
+
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: 1 - step / total_steps)
|
237 |
+
else:
|
238 |
+
scheduler = None
|
239 |
+
|
240 |
+
return train_model(
|
241 |
+
model,
|
242 |
+
dataloader,
|
243 |
+
optimizer,
|
244 |
+
loss_func=loss_func,
|
245 |
+
scheduler=scheduler,
|
246 |
+
num_epochs=num_epochs,
|
247 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
248 |
+
max_grads_norm=max_grads_norm,
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
class FFN(nn.Module):
|
253 |
+
def __init__(self, input_dim):
|
254 |
+
super(FFN, self).__init__()
|
255 |
+
self.gate_proj = nn.Linear(input_dim, input_dim)
|
256 |
+
self.activation = nn.SiLU()
|
257 |
+
self.map_proj = nn.Linear(input_dim, input_dim)
|
258 |
+
|
259 |
+
def forward(self, x):
|
260 |
+
return (self.activation(self.gate_proj(x)) * x) + self.map_proj(x)
|
261 |
+
|
262 |
+
|
263 |
+
def learn_ffn(
|
264 |
+
X: torch.Tensor, Y: torch.Tensor,
|
265 |
+
activation_func=nn.SiLU,
|
266 |
+
batch_size=128,
|
267 |
+
lr=0.001,
|
268 |
+
weight_decay=0.0,
|
269 |
+
loss_func="mse",
|
270 |
+
lr_schedule="linear",
|
271 |
+
num_epochs=5,
|
272 |
+
gradient_accumulation_steps=1,
|
273 |
+
max_grads_norm=1.0,
|
274 |
+
):
|
275 |
+
"""
|
276 |
+
Trains a two-layer MLP to map hidden states from X to Y.
|
277 |
+
|
278 |
+
Parameters:
|
279 |
+
X (torch.Tensor): Input tensor of shape (N, D).
|
280 |
+
Y (torch.Tensor): Target tensor of shape (N, D).
|
281 |
+
activation_func (nn.Module): Activation function for the hidden layer. Default is SiLU.
|
282 |
+
lr (float): Learning rate. Default is 0.001.
|
283 |
+
weight_decay (float): Weight decay for the optimizer. Default is 0.0.
|
284 |
+
loss_func (str): Loss function to use ('mse', 'huber', 'cosine'). Default is 'mse'.
|
285 |
+
lr_schedule (str): Learning rate schedule. Default is 'linear'.
|
286 |
+
num_epochs (int): Number of training epochs. Default is 20.
|
287 |
+
batch_size (int): Batch size for DataLoader. Default is 32.
|
288 |
+
gradient_accumulation_steps (int): Number of steps to accumulate gradients. Default is 1.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
nn.Module: Trained MLP model.
|
292 |
+
"""
|
293 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
294 |
+
input_dim = X.shape[1]
|
295 |
+
model = FFN(input_dim).to(device)
|
296 |
+
|
297 |
+
# Optimizer
|
298 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
299 |
+
|
300 |
+
# DataLoader setup
|
301 |
+
dataset = TensorDataset(X, Y)
|
302 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
303 |
+
|
304 |
+
# Learning rate scheduler
|
305 |
+
if lr_schedule == "linear":
|
306 |
+
total_steps = (len(dataloader) * num_epochs) // gradient_accumulation_steps
|
307 |
+
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: 1 - step / total_steps)
|
308 |
+
else:
|
309 |
+
scheduler = None
|
310 |
+
|
311 |
+
return train_model(
|
312 |
+
model,
|
313 |
+
dataloader,
|
314 |
+
optimizer,
|
315 |
+
loss_func=loss_func,
|
316 |
+
scheduler=scheduler,
|
317 |
+
num_epochs=num_epochs,
|
318 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
319 |
+
max_grads_norm=max_grads_norm,
|
320 |
+
)
|
utils/procrustes/__init__.py
ADDED
File without changes
|
utils/procrustes/orthogonal.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# The Procrustes library provides a set of functions for transforming
|
3 |
+
# a matrix to make it as similar as possible to a target matrix.
|
4 |
+
#
|
5 |
+
# Copyright (C) 2017-2022 The QC-Devs Community
|
6 |
+
#
|
7 |
+
# This file is part of Procrustes.
|
8 |
+
#
|
9 |
+
# Procrustes is free software; you can redistribute it and/or
|
10 |
+
# modify it under the terms of the GNU General Public License
|
11 |
+
# as published by the Free Software Foundation; either version 3
|
12 |
+
# of the License, or (at your option) any later version.
|
13 |
+
#
|
14 |
+
# Procrustes is distributed in the hope that it will be useful,
|
15 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
16 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
17 |
+
# GNU General Public License for more details.
|
18 |
+
#
|
19 |
+
# You should have received a copy of the GNU General Public License
|
20 |
+
# along with this program; if not, see <http://www.gnu.org/licenses/>
|
21 |
+
#
|
22 |
+
# --
|
23 |
+
"""Orthogonal Procrustes Module."""
|
24 |
+
|
25 |
+
# import warnings
|
26 |
+
|
27 |
+
from typing import Optional
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
from .utils import compute_error, ProcrustesResult, setup_input_arrays
|
31 |
+
import scipy
|
32 |
+
|
33 |
+
|
34 |
+
__all__ = [
|
35 |
+
"orthogonal",
|
36 |
+
"orthogonal_2sided",
|
37 |
+
]
|
38 |
+
|
39 |
+
|
40 |
+
def orthogonal(
|
41 |
+
a: np.ndarray,
|
42 |
+
b: np.ndarray,
|
43 |
+
pad: bool = True,
|
44 |
+
translate: bool = False,
|
45 |
+
scale: bool = False,
|
46 |
+
unpad_col: bool = False,
|
47 |
+
unpad_row: bool = False,
|
48 |
+
check_finite: bool = True,
|
49 |
+
weight: Optional[np.ndarray] = None,
|
50 |
+
lapack_driver: str = "gesvd",
|
51 |
+
) -> ProcrustesResult:
|
52 |
+
r"""Perform orthogonal Procrustes.
|
53 |
+
|
54 |
+
Given a matrix :math:`\mathbf{A}_{m \times n}` and a reference matrix :math:`\mathbf{B}_{m
|
55 |
+
\times n}`, find the orthogonal transformation matrix :math:`\mathbf{Q}_{n
|
56 |
+
\times n}` that makes :math:`\mathbf{AQ}` as close as possible to :math:`\mathbf{B}`.
|
57 |
+
In other words,
|
58 |
+
|
59 |
+
.. math::
|
60 |
+
\underbrace{\min}_{\left\{\mathbf{Q} | \mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger \right\}}
|
61 |
+
\|\mathbf{A}\mathbf{Q} - \mathbf{B}\|_{F}^2
|
62 |
+
|
63 |
+
This Procrustes method requires the :math:`\mathbf{A}` and :math:`\mathbf{B}` matrices to
|
64 |
+
have the same shape, which is gauranteed with the default ``pad`` argument for any given
|
65 |
+
:math:`\mathbf{A}` and :math:`\mathbf{B}` matrices. In preparing the :math:`\mathbf{A}` and
|
66 |
+
:math:`\mathbf{B}` matrices, the (optional) order of operations is: **1)** unpad zero
|
67 |
+
rows/columns, **2)** translate the matrices to the origin, **3)** weight entries of
|
68 |
+
:math:`\mathbf{A}`, **4)** scale the matrices to have unit norm, **5)** pad matrices with zero
|
69 |
+
rows/columns so they have the same shape.
|
70 |
+
|
71 |
+
Parameters
|
72 |
+
----------
|
73 |
+
a : ndarray
|
74 |
+
The 2D-array :math:`\mathbf{A}` which is going to be transformed.
|
75 |
+
b : ndarray
|
76 |
+
The 2D-array :math:`\mathbf{B}` representing the reference matrix.
|
77 |
+
pad : bool, optional
|
78 |
+
Add zero rows (at the bottom) and/or columns (to the right-hand side) of matrices
|
79 |
+
:math:`\mathbf{A}` and :math:`\mathbf{B}` so that they have the same shape.
|
80 |
+
translate : bool, optional
|
81 |
+
If True, both arrays are centered at origin (columns of the arrays will have mean zero).
|
82 |
+
scale : bool, optional
|
83 |
+
If True, both arrays are normalized with respect to the Frobenius norm, i.e.,
|
84 |
+
:math:`\text{Tr}\left[\mathbf{A}^\dagger\mathbf{A}\right] = 1` and
|
85 |
+
:math:`\text{Tr}\left[\mathbf{B}^\dagger\mathbf{B}\right] = 1`.
|
86 |
+
unpad_col : bool, optional
|
87 |
+
If True, zero columns (with values less than 1.0e-8) on the right-hand side of the intial
|
88 |
+
:math:`\mathbf{A}` and :math:`\mathbf{B}` matrices are removed.
|
89 |
+
unpad_row : bool, optional
|
90 |
+
If True, zero rows (with values less than 1.0e-8) at the bottom of the intial
|
91 |
+
:math:`\mathbf{A}` and :math:`\mathbf{B}` matrices are removed.
|
92 |
+
check_finite : bool, optional
|
93 |
+
If True, convert the input to an array, checking for NaNs or Infs.
|
94 |
+
weight : ndarray, optional
|
95 |
+
The 1D-array representing the weights of each row of :math:`\mathbf{A}`. This defines the
|
96 |
+
elements of the diagonal matrix :math:`\mathbf{W}` that is multiplied by :math:`\mathbf{A}`
|
97 |
+
matrix, i.e., :math:`\mathbf{A} \rightarrow \mathbf{WA}`.
|
98 |
+
lapack_driver : {'gesvd', 'gesdd'}, optional
|
99 |
+
Whether to use the more efficient divide-and-conquer approach ('gesdd') or the more robust
|
100 |
+
general rectangular approach ('gesvd') to compute the singular-value decomposition with
|
101 |
+
`scipy.linalg.svd`.
|
102 |
+
|
103 |
+
Returns
|
104 |
+
-------
|
105 |
+
res : ProcrustesResult
|
106 |
+
The Procrustes result represented as a class:`utils.ProcrustesResult` object.
|
107 |
+
|
108 |
+
Notes
|
109 |
+
-----
|
110 |
+
The optimal orthogonal matrix is obtained by,
|
111 |
+
|
112 |
+
.. math::
|
113 |
+
\mathbf{Q}^{\text{opt}} =
|
114 |
+
\arg \underbrace{\min}_{\left\{\mathbf{Q} \left| {\mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger}
|
115 |
+
\right. \right\}} \|\mathbf{A}\mathbf{Q} - \mathbf{B}\|_{F}^2 =
|
116 |
+
\arg \underbrace{\max}_{\left\{\mathbf{Q} \left| {\mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger}
|
117 |
+
\right. \right\}} \text{Tr}\left[\mathbf{Q^\dagger}\mathbf{A^\dagger}\mathbf{B}\right]
|
118 |
+
|
119 |
+
The solution is obtained using the singular value decomposition (SVD) of the
|
120 |
+
:math:`\mathbf{A}^\dagger \mathbf{B}` matrix,
|
121 |
+
|
122 |
+
.. math::
|
123 |
+
\mathbf{A}^\dagger \mathbf{B} &= \tilde{\mathbf{U}} \tilde{\mathbf{\Sigma}}
|
124 |
+
\tilde{\mathbf{V}}^{\dagger} \\
|
125 |
+
\mathbf{Q}^{\text{opt}} &= \tilde{\mathbf{U}} \tilde{\mathbf{V}}^{\dagger}
|
126 |
+
|
127 |
+
The singular values are always listed in decreasing order, with the smallest singular
|
128 |
+
value in the bottom-right-hand corner of :math:`\tilde{\mathbf{\Sigma}}`.
|
129 |
+
|
130 |
+
Examples
|
131 |
+
--------
|
132 |
+
>>> import numpy as np
|
133 |
+
>>> from scipy.stats import ortho_group
|
134 |
+
>>> from procrustes import orthogonal
|
135 |
+
>>> a = np.random.rand(5, 3) # random input matrix
|
136 |
+
>>> q = ortho_group.rvs(3) # random orthogonal transformation
|
137 |
+
>>> b = np.dot(a, q) + np.random.rand(1, 3) # random target matrix
|
138 |
+
>>> result = orthogonal(a, b, translate=True, scale=False)
|
139 |
+
>>> print(result.error) # error (should be zero)
|
140 |
+
>>> print(result.t) # transformation matrix (same as q)
|
141 |
+
>>> print(result.new_a) # translated array a
|
142 |
+
>>> print(result.new_b) # translated array b
|
143 |
+
|
144 |
+
"""
|
145 |
+
# check inputs
|
146 |
+
new_a, new_b = setup_input_arrays(
|
147 |
+
a,
|
148 |
+
b,
|
149 |
+
unpad_col,
|
150 |
+
unpad_row,
|
151 |
+
pad,
|
152 |
+
translate,
|
153 |
+
scale,
|
154 |
+
check_finite,
|
155 |
+
weight,
|
156 |
+
)
|
157 |
+
if new_a.shape != new_b.shape:
|
158 |
+
raise ValueError(
|
159 |
+
f"Shape of A and B does not match: {new_a.shape} != {new_b.shape} "
|
160 |
+
"Check pad, unpad_col, and unpad_row arguments."
|
161 |
+
)
|
162 |
+
# calculate SVD of A.T * B
|
163 |
+
u, _, vt = scipy.linalg.svd(np.dot(new_a.T, new_b), lapack_driver=lapack_driver)
|
164 |
+
# compute optimal orthogonal transformation
|
165 |
+
u_opt = np.dot(u, vt)
|
166 |
+
# compute one-sided error
|
167 |
+
error = compute_error(new_a, new_b, u_opt)
|
168 |
+
|
169 |
+
return ProcrustesResult(error=error, new_a=new_a, new_b=new_b, t=u_opt, s=None)
|
170 |
+
|
171 |
+
|
172 |
+
def orthogonal_2sided(
|
173 |
+
a: np.ndarray,
|
174 |
+
b: np.ndarray,
|
175 |
+
single: bool = True,
|
176 |
+
pad: bool = True,
|
177 |
+
translate: bool = False,
|
178 |
+
scale: bool = False,
|
179 |
+
unpad_col: bool = False,
|
180 |
+
unpad_row: bool = False,
|
181 |
+
check_finite: bool = True,
|
182 |
+
weight: Optional[np.ndarray] = None,
|
183 |
+
lapack_driver: str = "gesvd",
|
184 |
+
) -> ProcrustesResult:
|
185 |
+
r"""Perform two-sided orthogonal Procrustes with one- or two-transformations.
|
186 |
+
|
187 |
+
**Two Transformations:** Given a matrix :math:`\mathbf{A}_{m \times n}` and a reference matrix
|
188 |
+
:math:`\mathbf{B}_{m \times n}`, find two :math:`n \times n` orthogonal
|
189 |
+
transformation matrices :math:`\mathbf{Q}_1^\dagger` and :math:`\mathbf{Q}_2` that makes
|
190 |
+
:math:`\mathbf{Q}_1^\dagger\mathbf{A}\mathbf{Q}_2` as close as possible to :math:`\mathbf{B}`.
|
191 |
+
In other words,
|
192 |
+
|
193 |
+
.. math::
|
194 |
+
\underbrace{\text{min}}_{\left\{ {\mathbf{Q}_1 \atop \mathbf{Q}_2} \left|
|
195 |
+
{\mathbf{Q}_1^{-1} = \mathbf{Q}_1^\dagger \atop \mathbf{Q}_2^{-1} =
|
196 |
+
\mathbf{Q}_2^\dagger} \right. \right\}}
|
197 |
+
\|\mathbf{Q}_1^\dagger \mathbf{A} \mathbf{Q}_2 - \mathbf{B}\|_{F}^2
|
198 |
+
|
199 |
+
**Single Transformations:** Given a **symmetric** matrix :math:`\mathbf{A}_{n \times n}` and
|
200 |
+
a reference :math:`\mathbf{B}_{n \times n}`, find one orthogonal transformation
|
201 |
+
matrix :math:`\mathbf{Q}_{n \times n}` that makes :math:`\mathbf{A}` as close as possible to
|
202 |
+
:math:`\mathbf{B}`. In other words,
|
203 |
+
|
204 |
+
.. math::
|
205 |
+
\underbrace{\min}_{\left\{\mathbf{Q} | \mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger \right\}}
|
206 |
+
\|\mathbf{Q}^\dagger\mathbf{A}\mathbf{Q} - \mathbf{B}\|_{F}^2
|
207 |
+
|
208 |
+
This Procrustes method requires the :math:`\mathbf{A}` and :math:`\mathbf{B}` matrices to
|
209 |
+
have the same shape, which is gauranteed with the default ``pad`` argument for any given
|
210 |
+
:math:`\mathbf{A}` and :math:`\mathbf{B}` matrices. In preparing the :math:`\mathbf{A}` and
|
211 |
+
:math:`\mathbf{B}` matrices, the (optional) order of operations is: **1)** unpad zero
|
212 |
+
rows/columns, **2)** translate the matrices to the origin, **3)** weight entries of
|
213 |
+
:math:`\mathbf{A}`, **4)** scale the matrices to have unit norm, **5)** pad matrices with zero
|
214 |
+
rows/columns so they have the same shape.
|
215 |
+
|
216 |
+
Parameters
|
217 |
+
----------
|
218 |
+
a : ndarray
|
219 |
+
The 2D-array :math:`\mathbf{A}` which is going to be transformed.
|
220 |
+
b : ndarray
|
221 |
+
The 2D-array :math:`\mathbf{B}` representing the reference matrix.
|
222 |
+
single : bool, optional
|
223 |
+
If True, single transformation is used (i.e., :math:`\mathbf{Q}_1=\mathbf{Q}_2=\mathbf{Q}`),
|
224 |
+
otherwise, two transformations are used.
|
225 |
+
pad : bool, optional
|
226 |
+
Add zero rows (at the bottom) and/or columns (to the right-hand side) of matrices
|
227 |
+
:math:`\mathbf{A}` and :math:`\mathbf{B}` so that they have the same shape.
|
228 |
+
translate : bool, optional
|
229 |
+
If True, both arrays are centered at origin (columns of the arrays will have mean zero).
|
230 |
+
scale : bool, optional
|
231 |
+
If True, both arrays are normalized with respect to the Frobenius norm, i.e.,
|
232 |
+
:math:`\text{Tr}\left[\mathbf{A}^\dagger\mathbf{A}\right] = 1` and
|
233 |
+
:math:`\text{Tr}\left[\mathbf{B}^\dagger\mathbf{B}\right] = 1`.
|
234 |
+
unpad_col : bool, optional
|
235 |
+
If True, zero columns (with values less than 1.0e-8) on the right-hand side of the intial
|
236 |
+
:math:`\mathbf{A}` and :math:`\mathbf{B}` matrices are removed.
|
237 |
+
unpad_row : bool, optional
|
238 |
+
If True, zero rows (with values less than 1.0e-8) at the bottom of the intial
|
239 |
+
:math:`\mathbf{A}` and :math:`\mathbf{B}` matrices are removed.
|
240 |
+
check_finite : bool, optional
|
241 |
+
If True, convert the input to an array, checking for NaNs or Infs.
|
242 |
+
weight : ndarray, optional
|
243 |
+
The 1D-array representing the weights of each row of :math:`\mathbf{A}`. This defines the
|
244 |
+
elements of the diagonal matrix :math:`\mathbf{W}` that is multiplied by :math:`\mathbf{A}`
|
245 |
+
matrix, i.e., :math:`\mathbf{A} \rightarrow \mathbf{WA}`.
|
246 |
+
lapack_driver : {"gesvd", "gesdd"}, optional
|
247 |
+
Used in the singular value decomposition function from SciPy. Only allowed two options,
|
248 |
+
with "gesvd" being less-efficient than "gesdd" but is more robust. Default is "gesvd".
|
249 |
+
|
250 |
+
Returns
|
251 |
+
-------
|
252 |
+
res : ProcrustesResult
|
253 |
+
The Procrustes result represented as a class:`utils.ProcrustesResult` object.
|
254 |
+
|
255 |
+
Notes
|
256 |
+
-----
|
257 |
+
**Two-Sided Orthogonal Procrustes with Two Transformations:**
|
258 |
+
The optimal orthogonal transformations are obtained by:
|
259 |
+
|
260 |
+
.. math::
|
261 |
+
\mathbf{Q}_{1}^{\text{opt}}, \mathbf{Q}_{2}^{\text{opt}} = \arg
|
262 |
+
\underbrace{\text{min}}_{\left\{ {\mathbf{Q}_1 \atop \mathbf{Q}_2} \left|
|
263 |
+
{\mathbf{Q}_1^{-1} = \mathbf{Q}_1^\dagger \atop \mathbf{Q}_2^{-1} =
|
264 |
+
\mathbf{Q}_2^\dagger} \right. \right\}}
|
265 |
+
\|\mathbf{Q}_1^\dagger \mathbf{A} \mathbf{Q}_2 - \mathbf{B}\|_{F}^2 = \arg
|
266 |
+
\underbrace{\text{max}}_{\left\{ {\mathbf{Q}_1 \atop \mathbf{Q}_2} \left|
|
267 |
+
{\mathbf{Q}_1^{-1} = \mathbf{Q}_1^\dagger \atop \mathbf{Q}_2^{-1} =
|
268 |
+
\mathbf{Q}_2^\dagger} \right. \right\}}
|
269 |
+
\text{Tr}\left[\mathbf{Q}_2^\dagger\mathbf{A}^\dagger\mathbf{Q}_1\mathbf{B} \right]
|
270 |
+
|
271 |
+
This is solved by taking the singular value decomposition (SVD) of :math:`\mathbf{A}` and
|
272 |
+
:math:`\mathbf{B}`,
|
273 |
+
|
274 |
+
.. math::
|
275 |
+
\mathbf{A} = \mathbf{U}_A \mathbf{\Sigma}_A \mathbf{V}_A^\dagger \\
|
276 |
+
\mathbf{B} = \mathbf{U}_B \mathbf{\Sigma}_B \mathbf{V}_B^\dagger
|
277 |
+
|
278 |
+
Then the two optimal orthogonal matrices are given by,
|
279 |
+
|
280 |
+
.. math::
|
281 |
+
\mathbf{Q}_1^{\text{opt}} = \mathbf{U}_A \mathbf{U}_B^\dagger \\
|
282 |
+
\mathbf{Q}_2^{\text{opt}} = \mathbf{V}_A \mathbf{V}_B^\dagger
|
283 |
+
|
284 |
+
**Two-Sided Orthogonal Procrustes with Single-Transformation:**
|
285 |
+
The optimal orthogonal transformation is obtained by:
|
286 |
+
|
287 |
+
.. math::
|
288 |
+
\mathbf{Q}^{\text{opt}} = \arg
|
289 |
+
\underbrace{\min}_{\left\{\mathbf{Q} | \mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger \right\}}
|
290 |
+
\|\mathbf{Q}^\dagger\mathbf{A}\mathbf{Q} - \mathbf{B}\|_{F}^2 = \arg
|
291 |
+
\underbrace{\text{max}}_{\left\{\mathbf{Q} | \mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger\right\}}
|
292 |
+
\text{Tr}\left[\mathbf{Q}^\dagger\mathbf{A}^\dagger\mathbf{Q}\mathbf{B} \right]
|
293 |
+
|
294 |
+
Using the singular value decomposition (SVD) of :math:`\mathbf{A}` and :math:`\mathbf{B}`,
|
295 |
+
|
296 |
+
.. math::
|
297 |
+
\mathbf{A} = \mathbf{U}_A \mathbf{\Lambda}_A \mathbf{U}_A^\dagger \\
|
298 |
+
\mathbf{B} = \mathbf{U}_B \mathbf{\Lambda}_B \mathbf{U}_B^\dagger
|
299 |
+
|
300 |
+
The optimal orthogonal matrix :math:`\mathbf{Q}^\text{opt}` is obtained through,
|
301 |
+
|
302 |
+
.. math::
|
303 |
+
\mathbf{Q}^\text{opt} = \mathbf{U}_A \mathbf{S} \mathbf{U}_B^\dagger
|
304 |
+
|
305 |
+
where :math:`\mathbf{S}` is a diagonal matrix with :math:`\pm{1}` elements,
|
306 |
+
|
307 |
+
.. math::
|
308 |
+
\mathbf{S} =
|
309 |
+
\begin{bmatrix}
|
310 |
+
{ \pm 1} & 0 &\cdots &0 \\
|
311 |
+
0 &{ \pm 1} &\ddots &\vdots \\
|
312 |
+
\vdots &\ddots &\ddots &0\\
|
313 |
+
0 &\cdots &0 &{ \pm 1}
|
314 |
+
\end{bmatrix}
|
315 |
+
|
316 |
+
The matrix :math:`\mathbf{S}` is chosen to be the identity matrix.
|
317 |
+
|
318 |
+
Examples
|
319 |
+
--------
|
320 |
+
>>> import numpy as np
|
321 |
+
>>> a = np.array([[30, 33, 20], [33, 53, 43], [20, 43, 46]])
|
322 |
+
>>> b = np.array([[ 22.78131838, -0.58896768,-43.00635291, 0., 0.],
|
323 |
+
... [ -0.58896768, 16.77132475, 0.24289990, 0., 0.],
|
324 |
+
... [-43.00635291, 0.2428999 , 89.44735687, 0., 0.],
|
325 |
+
... [ 0. , 0. , 0. , 0., 0.]])
|
326 |
+
>>> res = orthogonal_2sided(a, b, single=True, pad=True, unpad_col=True)
|
327 |
+
>>> res.t
|
328 |
+
array([[ 0.25116633, 0.76371527, 0.59468855],
|
329 |
+
[-0.95144277, 0.08183302, 0.29674906],
|
330 |
+
[ 0.17796663, -0.64034549, 0.74718507]])
|
331 |
+
>>> res.error
|
332 |
+
1.9646186414076689e-26
|
333 |
+
|
334 |
+
"""
|
335 |
+
# if translate:
|
336 |
+
# warnings.warn(
|
337 |
+
# "The translation matrix was not well defined. \
|
338 |
+
# Two sided rotation and translation don't commute.",
|
339 |
+
# stacklevel=2,
|
340 |
+
# )
|
341 |
+
|
342 |
+
# Check inputs
|
343 |
+
new_a, new_b = setup_input_arrays(
|
344 |
+
a,
|
345 |
+
b,
|
346 |
+
unpad_col,
|
347 |
+
unpad_row,
|
348 |
+
pad,
|
349 |
+
translate,
|
350 |
+
scale,
|
351 |
+
check_finite,
|
352 |
+
weight,
|
353 |
+
)
|
354 |
+
|
355 |
+
# check symmetry if single_transform=True
|
356 |
+
if single:
|
357 |
+
if not np.allclose(new_a.T, new_a):
|
358 |
+
raise ValueError(
|
359 |
+
f"Array A with {new_a.shape} shape is not symmetric. "
|
360 |
+
"Check pad, unpad_col, and unpad_row arguments."
|
361 |
+
)
|
362 |
+
if not np.allclose(new_b.T, new_b):
|
363 |
+
raise ValueError(
|
364 |
+
f"Array B with {new_b.shape} shape is not symmetric. "
|
365 |
+
"Check pad, unpad_col, and unpad_row arguments."
|
366 |
+
)
|
367 |
+
|
368 |
+
# two-sided orthogonal Procrustes with one-transformations
|
369 |
+
if single:
|
370 |
+
_, ua = np.linalg.eigh(new_a)
|
371 |
+
_, ub = np.linalg.eigh(new_b)
|
372 |
+
u_opt = np.dot(ua, ub.T)
|
373 |
+
# compute one-sided error
|
374 |
+
error = compute_error(new_a, new_b, u_opt, u_opt.T)
|
375 |
+
return ProcrustesResult(error=error, new_a=new_a, new_b=new_b, t=u_opt, s=u_opt.T)
|
376 |
+
|
377 |
+
# two-sided orthogonal Procrustes with two-transformations
|
378 |
+
ua, _, vta = scipy.linalg.svd(new_a, lapack_driver=lapack_driver)
|
379 |
+
ub, _, vtb = scipy.linalg.svd(new_b, lapack_driver=lapack_driver)
|
380 |
+
u_opt1 = np.dot(ua, ub.T)
|
381 |
+
u_opt2 = np.dot(vta.T, vtb)
|
382 |
+
error = compute_error(new_a, new_b, u_opt2, u_opt1.T)
|
383 |
+
return ProcrustesResult(error=error, new_a=new_a, new_b=new_b, t=u_opt2, s=u_opt1.T)
|
utils/procrustes/utils.py
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# The Procrustes library provides a set of functions for transforming
|
3 |
+
# a matrix to make it as similar as possible to a target matrix.
|
4 |
+
#
|
5 |
+
# Copyright (C) 2017-2022 The QC-Devs Community
|
6 |
+
#
|
7 |
+
# This file is part of Procrustes.
|
8 |
+
#
|
9 |
+
# Procrustes is free software; you can redistribute it and/or
|
10 |
+
# modify it under the terms of the GNU General Public License
|
11 |
+
# as published by the Free Software Foundation; either version 3
|
12 |
+
# of the License, or (at your option) any later version.
|
13 |
+
#
|
14 |
+
# Procrustes is distributed in the hope that it will be useful,
|
15 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
16 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
17 |
+
# GNU General Public License for more details.
|
18 |
+
#
|
19 |
+
# You should have received a copy of the GNU General Public License
|
20 |
+
# along with this program; if not, see <http://www.gnu.org/licenses/>
|
21 |
+
#
|
22 |
+
# --
|
23 |
+
"""Utility Module."""
|
24 |
+
from typing import List, Optional, Tuple
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
__all__ = [
|
29 |
+
"compute_error",
|
30 |
+
"setup_input_arrays",
|
31 |
+
"ProcrustesResult",
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
def _zero_padding(
|
36 |
+
array_a: np.ndarray, array_b: np.ndarray, pad_mode: str = "row-col"
|
37 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
38 |
+
r"""
|
39 |
+
Return arrays padded with rows and/or columns of zero.
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
array_a : ndarray
|
44 |
+
The 2D-array :math:`\mathbf{A}_{n_a \times m_a}`.
|
45 |
+
array_b : ndarray
|
46 |
+
The 2D-array :math:`\mathbf{B}_{n_b \times m_b}`.
|
47 |
+
pad_mode : str
|
48 |
+
Specifying how to pad the arrays. Should be one of
|
49 |
+
- "row"
|
50 |
+
The array with fewer rows is padded with zero rows so that both have the same
|
51 |
+
number of rows.
|
52 |
+
- "col"
|
53 |
+
The array with fewer columns is padded with zero columns so that both have the
|
54 |
+
same number of columns.
|
55 |
+
- "row-col"
|
56 |
+
The array with fewer rows is padded with zero rows, and the array with fewer
|
57 |
+
columns is padded with zero columns, so that both have the same dimensions.
|
58 |
+
This does not necessarily result in square arrays.
|
59 |
+
- "square"
|
60 |
+
The arrays are padded with zero rows and zero columns so that they are both
|
61 |
+
squared arrays. The dimension of square array is specified based on the highest
|
62 |
+
dimension, i.e. :math:`\text{max}(n_a, m_a, n_b, m_b)`.
|
63 |
+
|
64 |
+
Returns
|
65 |
+
-------
|
66 |
+
padded_a : ndarray
|
67 |
+
Padded array_a.
|
68 |
+
padded_b : ndarray
|
69 |
+
Padded array_b.
|
70 |
+
|
71 |
+
"""
|
72 |
+
# sanity checks
|
73 |
+
if not isinstance(array_a, np.ndarray) or not isinstance(array_b, np.ndarray):
|
74 |
+
raise ValueError("Arguments array_a & array_b should be numpy arrays.")
|
75 |
+
if array_a.ndim != 2 or array_b.ndim != 2:
|
76 |
+
raise ValueError("Arguments array_a & array_b should be 2D arrays.")
|
77 |
+
|
78 |
+
if array_a.shape == array_b.shape and array_a.shape[0] == array_a.shape[1]:
|
79 |
+
# special case of square arrays, mode is set to None so that array_a & array_b are returned.
|
80 |
+
pad_mode = None
|
81 |
+
|
82 |
+
if pad_mode == "square":
|
83 |
+
# calculate desired dimension of square array
|
84 |
+
(a_n1, a_m1), (a_n2, a_m2) = array_a.shape, array_b.shape
|
85 |
+
dim = max(a_n1, a_n2, a_m1, a_m2)
|
86 |
+
# padding rows to have both arrays have dim rows
|
87 |
+
if a_n1 < dim:
|
88 |
+
array_a = np.pad(array_a, [[0, dim - a_n1], [0, 0]], "constant", constant_values=0)
|
89 |
+
if a_n2 < dim:
|
90 |
+
array_b = np.pad(array_b, [[0, dim - a_n2], [0, 0]], "constant", constant_values=0)
|
91 |
+
# padding columns to have both arrays have dim columns
|
92 |
+
if a_m1 < dim:
|
93 |
+
array_a = np.pad(array_a, [[0, 0], [0, dim - a_m1]], "constant", constant_values=0)
|
94 |
+
if a_m2 < dim:
|
95 |
+
array_b = np.pad(array_b, [[0, 0], [0, dim - a_m2]], "constant", constant_values=0)
|
96 |
+
|
97 |
+
if pad_mode in ["row", "row-col"]:
|
98 |
+
# padding rows to have both arrays have the same number of rows
|
99 |
+
diff = array_a.shape[0] - array_b.shape[0]
|
100 |
+
if diff < 0:
|
101 |
+
array_a = np.pad(array_a, [[0, -diff], [0, 0]], "constant", constant_values=0)
|
102 |
+
else:
|
103 |
+
array_b = np.pad(array_b, [[0, diff], [0, 0]], "constant", constant_values=0)
|
104 |
+
|
105 |
+
if pad_mode in ["col", "row-col"]:
|
106 |
+
# padding columns to have both arrays have the same number of columns
|
107 |
+
diff = array_a.shape[1] - array_b.shape[1]
|
108 |
+
if diff < 0:
|
109 |
+
array_a = np.pad(array_a, [[0, 0], [0, -diff]], "constant", constant_values=0)
|
110 |
+
else:
|
111 |
+
array_b = np.pad(array_b, [[0, 0], [0, diff]], "constant", constant_values=0)
|
112 |
+
|
113 |
+
return array_a, array_b
|
114 |
+
|
115 |
+
|
116 |
+
def _translate_array(
|
117 |
+
array_a: np.ndarray, array_b: Optional[np.ndarray] = None, weight: Optional[np.ndarray] = None
|
118 |
+
) -> Tuple[np.ndarray, float]:
|
119 |
+
"""
|
120 |
+
Return translated array_a and translation vector.
|
121 |
+
|
122 |
+
Columns of both arrays will have mean zero.
|
123 |
+
|
124 |
+
Parameters
|
125 |
+
----------
|
126 |
+
array_a : ndarray
|
127 |
+
The 2D-array to translate.
|
128 |
+
array_b : ndarray, optional
|
129 |
+
The 2D-array to translate array_a based on.
|
130 |
+
weight : ndarray, optional
|
131 |
+
The weight vector.
|
132 |
+
|
133 |
+
Returns
|
134 |
+
-------
|
135 |
+
array_a : ndarray
|
136 |
+
If array_b is None, array_a is translated to origin using its centroid.
|
137 |
+
If array_b is given, array_a is translated to centroid of array_b (the centroid of
|
138 |
+
translated array_a will centroid with the centroid array_b).
|
139 |
+
centroid : float
|
140 |
+
If array_b is given, the centroid is returned.
|
141 |
+
|
142 |
+
"""
|
143 |
+
# The mean is strongly affected by outliers and is not a robust estimator for central location
|
144 |
+
# see https://docs.python.org/3.6/library/statistics.html?highlight=mean#statistics.mean
|
145 |
+
if weight is not None:
|
146 |
+
if weight.ndim != 1:
|
147 |
+
raise ValueError("The weight should be a 1d row vector.")
|
148 |
+
if not (weight >= 0).all():
|
149 |
+
raise ValueError("The elements of the weight should be non-negative.")
|
150 |
+
|
151 |
+
centroid_a = np.average(array_a, axis=0, weights=weight)
|
152 |
+
if array_b is not None:
|
153 |
+
# translation vector to b centroid
|
154 |
+
centroid_a -= np.average(array_b, axis=0, weights=weight)
|
155 |
+
return array_a - centroid_a, -1 * centroid_a
|
156 |
+
|
157 |
+
|
158 |
+
def _scale_array(array_a, array_b=None) -> Tuple[np.ndarray, float]:
|
159 |
+
"""
|
160 |
+
Return scaled/normalized array_a and scaling vector.
|
161 |
+
|
162 |
+
Parameters
|
163 |
+
----------
|
164 |
+
array_a : ndarray
|
165 |
+
The 2D-array to scale
|
166 |
+
array_b : ndarray, default=None
|
167 |
+
The 2D-array to scale array_a based on.
|
168 |
+
|
169 |
+
Returns
|
170 |
+
-------
|
171 |
+
scaled_a, ndarray
|
172 |
+
If array_b is None, array_a is normalized using the Frobenius norm.
|
173 |
+
If array_b is given, array_a is scaled to match array_b"s norm (the norm of array_a
|
174 |
+
will be equal norm of array_b).
|
175 |
+
scale : float
|
176 |
+
The scaling factor to match array_b norm.
|
177 |
+
|
178 |
+
"""
|
179 |
+
# scaling factor to match unit sphere
|
180 |
+
scale = 1.0 / np.linalg.norm(array_a)
|
181 |
+
if array_b is not None:
|
182 |
+
# scaling factor to match array_b norm
|
183 |
+
scale *= np.linalg.norm(array_b)
|
184 |
+
return array_a * scale, scale
|
185 |
+
|
186 |
+
|
187 |
+
def _hide_zero_padding(
|
188 |
+
array_a: np.ndarray,
|
189 |
+
remove_zero_col: bool = True,
|
190 |
+
remove_zero_row: bool = True,
|
191 |
+
tol: float = 1.0e-8,
|
192 |
+
) -> np.ndarray:
|
193 |
+
r"""
|
194 |
+
Return array with zero-padded rows (bottom) and columns (right) removed.
|
195 |
+
|
196 |
+
Parameters
|
197 |
+
----------
|
198 |
+
array_a : ndarray
|
199 |
+
The initial array.
|
200 |
+
remove_zero_col : bool, optional
|
201 |
+
If True, zero columns (values less than 1e-8) on the right side will be removed.
|
202 |
+
remove_zero_row : bool, optional
|
203 |
+
If True, zero rows (values less than 1e-8) on the bottom will be removed.
|
204 |
+
tol : float, optional
|
205 |
+
Tolerance value.
|
206 |
+
|
207 |
+
Returns
|
208 |
+
-------
|
209 |
+
new_A : ndarray
|
210 |
+
Array, with either near zero columns and/or zero rows are removed.
|
211 |
+
|
212 |
+
"""
|
213 |
+
# Input checking
|
214 |
+
if array_a.ndim > 2:
|
215 |
+
raise TypeError("Matrix inputs must be 1- or 2- dimensional arrays")
|
216 |
+
# Check zero rows from bottom to top
|
217 |
+
if remove_zero_row:
|
218 |
+
num_row = array_a.shape[0]
|
219 |
+
tmp_a = array_a[..., np.newaxis] if array_a.ndim == 1 else array_a
|
220 |
+
for array_v in tmp_a[::-1]:
|
221 |
+
if any(abs(i) > tol for i in array_v):
|
222 |
+
break
|
223 |
+
num_row -= 1
|
224 |
+
array_a = array_a[:num_row]
|
225 |
+
# Cut off zero rows
|
226 |
+
if remove_zero_col:
|
227 |
+
if array_a.ndim == 2:
|
228 |
+
# Check zero columns from right to left
|
229 |
+
col_m = array_a.shape[1]
|
230 |
+
for array_v in array_a.T[::-1]:
|
231 |
+
if any(abs(i) > tol for i in array_v):
|
232 |
+
break
|
233 |
+
col_m -= 1
|
234 |
+
# Cut off zero columns
|
235 |
+
array_a = array_a[:, :col_m]
|
236 |
+
return array_a
|
237 |
+
|
238 |
+
|
239 |
+
def compute_error(
|
240 |
+
a: np.ndarray, b: np.ndarray, t: np.ndarray, s: Optional[np.ndarray] = None
|
241 |
+
) -> float:
|
242 |
+
r"""Return the one- or two-sided Procrustes (squared Frobenius norm) error.
|
243 |
+
|
244 |
+
The double-sided Procrustes error is defined as
|
245 |
+
|
246 |
+
.. math::
|
247 |
+
\|\mathbf{S}\mathbf{A}\mathbf{T} - \mathbf{B}\|_{F}^2 =
|
248 |
+
\text{Tr}\left[
|
249 |
+
\left(\mathbf{S}\mathbf{A}\mathbf{T} - \mathbf{B}\right)^\dagger
|
250 |
+
\left(\mathbf{S}\mathbf{A}\mathbf{T} - \mathbf{B}\right)\right]
|
251 |
+
|
252 |
+
when :math:`\mathbf{S}` is the identity matrix :math:`\mathbf{I}`, this is called the one-sided
|
253 |
+
Procrustes error.
|
254 |
+
|
255 |
+
Parameters
|
256 |
+
----------
|
257 |
+
a : ndarray
|
258 |
+
The 2D-array :math:`\mathbf{A}_{m \times n}` which is going to be transformed.
|
259 |
+
b : ndarray
|
260 |
+
The 2D-array :math:`\mathbf{B}_{m \times n}` representing the reference matrix.
|
261 |
+
t : ndarray
|
262 |
+
The 2D-array :math:`\mathbf{T}_{n \times n}` representing the right-hand-side transformation
|
263 |
+
matrix.
|
264 |
+
s : ndarray, optional
|
265 |
+
The 2D-array :math:`\mathbf{S}_{m \times m}` representing the left-hand-side transformation
|
266 |
+
matrix. If set to `None`, the one-sided Procrustes error is computed.
|
267 |
+
|
268 |
+
Returns
|
269 |
+
-------
|
270 |
+
error : float
|
271 |
+
The squared Frobenius norm of difference between the transformed array, :math:`\mathbf{S}
|
272 |
+
\mathbf{A}\mathbf{T}`, and the reference array, :math:`\mathbf{B}`.
|
273 |
+
|
274 |
+
"""
|
275 |
+
# transform matrix A to either AT or SAT
|
276 |
+
a_trans = np.dot(a, t) if s is None else np.dot(np.dot(s, a), t)
|
277 |
+
# subtract matrix B and compute Frobenius norm squared
|
278 |
+
return np.linalg.norm(a_trans - b, ord=None) ** 2
|
279 |
+
|
280 |
+
|
281 |
+
def setup_input_arrays(
|
282 |
+
array_a: np.ndarray,
|
283 |
+
array_b: np.ndarray,
|
284 |
+
remove_zero_col: bool,
|
285 |
+
remove_zero_row: bool,
|
286 |
+
pad: bool,
|
287 |
+
translate: bool,
|
288 |
+
scale: bool,
|
289 |
+
check_finite: bool,
|
290 |
+
weight: Optional[np.ndarray] = None,
|
291 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
292 |
+
r"""
|
293 |
+
Check and process array inputs for the Procrustes transformation routines.
|
294 |
+
|
295 |
+
Usually, the precursor step before all Procrustes methods.
|
296 |
+
|
297 |
+
Parameters
|
298 |
+
----------
|
299 |
+
array_a : npdarray
|
300 |
+
The 2D array :math:`A` being transformed.
|
301 |
+
array_b : npdarray
|
302 |
+
The 2D reference array :math:`B`.
|
303 |
+
remove_zero_col : bool
|
304 |
+
If True, zero columns (values less than 1e-8) on the right side will be removed.
|
305 |
+
remove_zero_row : bool
|
306 |
+
If True, zero rows (values less than 1e-8) on the bottom will be removed.
|
307 |
+
pad : bool
|
308 |
+
Add zero rows (at the bottom) and/or columns (to the right-hand side) of matrices
|
309 |
+
:math:`\mathbf{A}` and :math:`\mathbf{B}` so that they have the same shape.
|
310 |
+
translate : bool
|
311 |
+
If true, then translate both arrays :math:`A, B` to the origin, ie columns of the arrays
|
312 |
+
will have mean zero.
|
313 |
+
scale :
|
314 |
+
If True, both arrays are normalized to one with respect to the Frobenius norm, ie
|
315 |
+
:math:`Tr(A^T A) = 1`.
|
316 |
+
check_finite : bool
|
317 |
+
If true, then checks if both arrays :math:`A, B` are numpy arrays and two-dimensional.
|
318 |
+
weight : A list of ndarray or ndarray
|
319 |
+
A list of the weight arrays or one numpy array. When only on numpy array provided,
|
320 |
+
it is assumed that the two arrays :math:`A` and :math:`B` share the same weight matrix.
|
321 |
+
|
322 |
+
Returns
|
323 |
+
-------
|
324 |
+
(ndarray, ndarray) :
|
325 |
+
Returns the padded arrays, in that they have the same matrix dimensions.
|
326 |
+
|
327 |
+
"""
|
328 |
+
array_a = _setup_input_array_lower(
|
329 |
+
array_a, None, remove_zero_col, remove_zero_row, translate, scale, check_finite, weight
|
330 |
+
)
|
331 |
+
array_b = _setup_input_array_lower(
|
332 |
+
array_b, None, remove_zero_col, remove_zero_row, translate, scale, check_finite, weight
|
333 |
+
)
|
334 |
+
if pad:
|
335 |
+
array_a, array_b = _zero_padding(array_a, array_b, pad_mode="row-col")
|
336 |
+
return array_a, array_b
|
337 |
+
|
338 |
+
|
339 |
+
def setup_input_arrays_multi(
|
340 |
+
array_list: List[np.ndarray],
|
341 |
+
array_ref: np.ndarray,
|
342 |
+
remove_zero_col: bool,
|
343 |
+
remove_zero_row: bool,
|
344 |
+
pad_mode: str,
|
345 |
+
translate: bool,
|
346 |
+
scale: bool,
|
347 |
+
check_finite: bool,
|
348 |
+
weight: Optional[np.ndarray] = None,
|
349 |
+
) -> List[np.ndarray]:
|
350 |
+
r"""
|
351 |
+
Check and process array inputs for the Procrustes transformation routines.
|
352 |
+
|
353 |
+
Parameters
|
354 |
+
----------
|
355 |
+
array_list : List
|
356 |
+
A list of 2D arrays that being transformed.
|
357 |
+
array_ref : ndarray
|
358 |
+
The 2D reference array :math:`B`.
|
359 |
+
remove_zero_col : bool
|
360 |
+
If True, zero columns (values less than 1e-8) on the right side will be removed.
|
361 |
+
remove_zero_row : bool
|
362 |
+
If True, zero rows (values less than 1e-8) on the bottom will be removed.
|
363 |
+
pad_mode : str
|
364 |
+
Specifying how to pad the arrays. Should be one of
|
365 |
+
- "row"
|
366 |
+
The array with fewer rows is padded with zero rows so that both have the same
|
367 |
+
number of rows.
|
368 |
+
- "col"
|
369 |
+
The array with fewer columns is padded with zero columns so that both have the
|
370 |
+
same number of columns.
|
371 |
+
- "row-col"
|
372 |
+
The array with fewer rows is padded with zero rows, and the array with fewer
|
373 |
+
columns is padded with zero columns, so that both have the same dimensions.
|
374 |
+
This does not necessarily result in square arrays.
|
375 |
+
- "square"
|
376 |
+
The arrays are padded with zero rows and zero columns so that they are both
|
377 |
+
squared arrays. The dimension of square array is specified based on the highest
|
378 |
+
dimension, i.e. :math:`\text{max}(n_a, m_a, n_b, m_b)`.
|
379 |
+
translate : bool
|
380 |
+
If true, then translate both arrays :math:`A, B` to the origin, ie columns of the arrays
|
381 |
+
will have mean zero.
|
382 |
+
scale :
|
383 |
+
If True, both arrays are normalized to one with respect to the Frobenius norm, ie
|
384 |
+
:math:`Tr(A^T A) = 1`.
|
385 |
+
check_finite : bool
|
386 |
+
If true, then checks if both arrays :math:`A, B` are numpy arrays and two-dimensional.
|
387 |
+
weight : A list of ndarray or ndarray, optional
|
388 |
+
A list of the weight arrays or one numpy array. When only on numpy array provided,
|
389 |
+
it is assumed that the two arrays :math:`A` and :math:`B` share the same weight matrix.
|
390 |
+
|
391 |
+
Returns
|
392 |
+
-------
|
393 |
+
List of arrays :
|
394 |
+
Returns the padded arrays, in that they have the same matrix dimensions.
|
395 |
+
"""
|
396 |
+
array_list_new = [
|
397 |
+
_setup_input_array_lower(
|
398 |
+
array_a=arr,
|
399 |
+
array_ref=array_ref,
|
400 |
+
remove_zero_col=remove_zero_col,
|
401 |
+
remove_zero_row=remove_zero_row,
|
402 |
+
translate=translate,
|
403 |
+
scale=scale,
|
404 |
+
check_finite=check_finite,
|
405 |
+
weight=weight,
|
406 |
+
)
|
407 |
+
for arr in array_list
|
408 |
+
]
|
409 |
+
arr_shape = np.array([arr.shape for arr in array_list_new])
|
410 |
+
array_b = np.ones(np.max(arr_shape, axis=0), dtype=int)
|
411 |
+
array_list_new = [_zero_padding(arr, array_b, pad_mode=pad_mode) for arr in array_list_new]
|
412 |
+
return array_list_new
|
413 |
+
|
414 |
+
|
415 |
+
def _setup_input_array_lower(
|
416 |
+
array_a: np.ndarray,
|
417 |
+
array_ref: np.ndarray,
|
418 |
+
remove_zero_col: np.ndarray,
|
419 |
+
remove_zero_row: np.ndarray,
|
420 |
+
translate: bool,
|
421 |
+
scale: bool,
|
422 |
+
check_finite: bool,
|
423 |
+
weight: Optional[np.ndarray] = None,
|
424 |
+
) -> np.ndarray:
|
425 |
+
"""Pre-processing the matrices with translation, scaling."""
|
426 |
+
_check_arraytypes(array_a)
|
427 |
+
if check_finite:
|
428 |
+
array_a = np.asarray_chkfinite(array_a)
|
429 |
+
# Sometimes arrays already have zero padding that messes up zero padding below.
|
430 |
+
array_a = _hide_zero_padding(array_a, remove_zero_col, remove_zero_row)
|
431 |
+
if translate:
|
432 |
+
array_a, _ = _translate_array(array_a, array_ref, weight)
|
433 |
+
# scale the matrix when translate is False, but weight is True
|
434 |
+
else:
|
435 |
+
if weight is not None:
|
436 |
+
array_a = np.dot(np.diag(weight), array_a)
|
437 |
+
|
438 |
+
if scale:
|
439 |
+
array_a, _ = _scale_array(array_a, array_ref)
|
440 |
+
return array_a
|
441 |
+
|
442 |
+
|
443 |
+
def _check_arraytypes(*args) -> None:
|
444 |
+
r"""Check array input types to Procrustes transformation routines."""
|
445 |
+
if any(not isinstance(arr_x, np.ndarray) for arr_x in args):
|
446 |
+
raise TypeError("Matrix inputs must be NumPy arrays")
|
447 |
+
if any(x.ndim != 2 for x in args):
|
448 |
+
raise TypeError("Matrix inputs must be 2-dimensional arrays")
|
449 |
+
|
450 |
+
|
451 |
+
class ProcrustesResult(dict):
|
452 |
+
r"""Represents the Procrustes analysis result.
|
453 |
+
|
454 |
+
Attributes
|
455 |
+
----------
|
456 |
+
error : float
|
457 |
+
The Procrustes (squared Frobenius norm) error.
|
458 |
+
new_a : ndarray
|
459 |
+
The translated/scaled numpy ndarray :math:`\mathbf{A}`.
|
460 |
+
new_b : ndarray
|
461 |
+
The translated/scaled numpy ndarray :math:`\mathbf{B}`.
|
462 |
+
t : ndarray
|
463 |
+
The 2D-array :math:`\mathbf{T}` representing the right-hand-side transformation matrix.
|
464 |
+
s : ndarray
|
465 |
+
The 2D-array :math:`\mathbf{S}` representing the left-hand-side transformation
|
466 |
+
matrix. If set to `None`, the one-sided Procrustes was performed.
|
467 |
+
|
468 |
+
"""
|
469 |
+
|
470 |
+
# modification on https://github.com/scipy/scipy/blob/v1.4.1/scipy/optimize/optimize.py#L77-L132
|
471 |
+
def __getattr__(self, name: str):
|
472 |
+
"""Deal with attributes which it doesn't explicitly manage."""
|
473 |
+
try:
|
474 |
+
return self[name]
|
475 |
+
# Not using raise from makes the traceback inaccurate, because the message implies there
|
476 |
+
# is a bug in the exception-handling code itself, which is a separate situation than
|
477 |
+
# wrapping an exception
|
478 |
+
# W0707 from http://pylint.pycqa.org/en/latest/technical_reference/features.html
|
479 |
+
except KeyError as ke_info:
|
480 |
+
raise AttributeError(name) from ke_info
|
481 |
+
|
482 |
+
__setattr__ = dict.__setitem__
|
483 |
+
__delattr__ = dict.__delitem__
|
484 |
+
|
485 |
+
def __repr__(self):
|
486 |
+
"""Return a human friendly representation."""
|
487 |
+
if self.keys():
|
488 |
+
max_len = max(map(len, list(self.keys()))) + 1
|
489 |
+
return "\n".join([k.rjust(max_len) + ": " + repr(v) for k, v in sorted(self.items())])
|
490 |
+
else:
|
491 |
+
return self.__class__.__name__ + "()"
|
492 |
+
|
493 |
+
def __dir__(self):
|
494 |
+
"""Provide basic customization of module attribute access with a list."""
|
495 |
+
return list(self.keys())
|