Guy24 commited on
Commit
0108542
·
1 Parent(s): b7e1c46

adding application

Browse files
processor.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class RetrievalProcessor:
6
+ def __init__(self, model, tokenizer, multi_token_kind, num_tokens_to_generate,
7
+ add_context, model_name, whitespace_token='Ġ'):
8
+ self.model = model
9
+ self.tokenizer = tokenizer
10
+ self.multi_token_kind = multi_token_kind
11
+ self.num_tokens_to_generate = num_tokens_to_generate
12
+ self.add_context = add_context
13
+ self.model_name = model_name
14
+ self.whitespace_token = whitespace_token
15
+
16
+ def get_next_word(self, tokens, i, max_length=1000, device='cuda'):
17
+ token_str = self.tokenizer.convert_ids_to_tokens(tokens[i].item())
18
+ j = i + 1
19
+ word_tokens = [tokens[i]]
20
+ if token_str.startswith(self.whitespace_token):
21
+ while j < len(tokens) and (
22
+ self.is_alpha_not_prefix(tokens[j])):
23
+ word_tokens.append(tokens[j])
24
+ j += 1
25
+ word = self.tokenizer.decode(word_tokens)
26
+ original_word = word
27
+ context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
28
+ combined_text = context + word
29
+
30
+ tokenized_combined_text = self.tokenizer(combined_text, return_tensors='pt', truncation=True,
31
+ max_length=max_length).to(device)
32
+ return j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word
33
+
34
+ def get_next_full_word_typo(self, tokens, i, max_length=1000, device='cuda'):
35
+ tokens_str = self.tokenizer.convert_ids_to_tokens(tokens)
36
+ word_tokens = [tokens[i]]
37
+ word = self.tokenizer.decode(word_tokens)
38
+ original_word = word
39
+ if self.is_full_word(tokens_str, i, word, word_tokens):
40
+ word = self.introduce_typo(word)
41
+ word_tokens = self.tokenizer(word, return_tensors='pt', truncation=True, max_length=max_length).input_ids[0][1:]
42
+ context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
43
+ combined_text = context + word
44
+
45
+ tokenized_combined_text = self.tokenizer(combined_text, return_tensors='pt', truncation=True,
46
+ max_length=max_length).to(device)
47
+ j = len(tokenized_combined_text.input_ids[0]) - 1 if self.add_context else len(tokenized_combined_text.input_ids[0]) - 1 + i
48
+ return j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word
49
+
50
+ def get_next_full_word_separated(self, tokens, i, max_length=1000, device='cuda'):
51
+ tokens_str = self.tokenizer.convert_ids_to_tokens(tokens)
52
+ word_tokens = [tokens[i]]
53
+ word = self.tokenizer.decode(word_tokens)
54
+ original_word = word
55
+ if self.is_full_word(tokens_str, i, word, word_tokens):
56
+ word = torch.tensor(self.separate_word(word)).unsqueeze(0)
57
+ else:
58
+ word = word_tokens[0].unsqueeze(0).unsqueeze(0)
59
+ context = self.tokenizer.decode(tokens[:i]) if self.add_context else ""
60
+ tokenized_combined_text = self.tokenizer(context, return_tensors='pt', truncation=True,
61
+ max_length=max_length).to(device)
62
+ print(tokenized_combined_text.input_ids)
63
+ print(word)
64
+ tokenized_combined_text.input_ids = torch.cat((tokenized_combined_text.input_ids, word), dim=1)
65
+ word_tokens = word
66
+ j = i+1
67
+ return j, word_tokens, word, context, tokenized_combined_text, self.tokenizer.decode(tokenized_combined_text.input_ids[0]), original_word
68
+
69
+ def is_alpha_not_prefix(self, token):
70
+ return (not self.tokenizer.convert_ids_to_tokens(token.item()).startswith(self.whitespace_token)
71
+ and self.tokenizer.convert_ids_to_tokens(token.item()).isalpha())
72
+
73
+ def introduce_typo(self, word, typo_type=None):
74
+ letters = 'abcdefghijklmnopqrstuvwxyz'
75
+ if typo_type is None:
76
+ typo_type = random.choice(["substitution", "deletion", "insertion", "transposition"])
77
+
78
+ if typo_type == "substitution":
79
+ position = random.randint(1, len(word) - 1)
80
+ original_char = word[position]
81
+ typo_char = random.choice([c for c in letters if c != original_char])
82
+ return word[:position] + typo_char + word[position + 1:]
83
+ elif typo_type == "deletion":
84
+ position = random.randint(1, len(word) - 1)
85
+ return word[:position] + word[position + 1:]
86
+ elif typo_type == "insertion":
87
+ position = random.randint(1, len(word) - 1)
88
+ typo_char = random.choice(letters)
89
+ return word[:position] + typo_char + word[position:]
90
+ elif typo_type == "transposition":
91
+ position = random.randint(1, len(word) - 2)
92
+ return word[:position] + word[position + 1] + word[position] + word[position + 2:]
93
+ else:
94
+ return word
95
+
96
+ def separate_word(self, word):
97
+ character_tokens = [self.tokenizer.encode(f'\n{char}')[-1] for char in ''.join(word)]
98
+ character_tokens = character_tokens[3:]
99
+ return character_tokens
100
+
101
+ def is_full_word(self, token_str, i, token, word_tokens):
102
+ next_token = self.tokenizer.decode(word_tokens[i + 1]) if i + 1 < len(word_tokens) else ""
103
+ return (token[1:].isalpha() and
104
+ len(token) > 5 and
105
+ token_str[i].startswith(self.whitespace_token) and
106
+ not next_token.isalpha())
requirements.txt CHANGED
@@ -1,213 +1,10 @@
1
- accelerate==1.2.1
2
- aiofiles==23.2.1
3
- aiohappyeyeballs==2.4.4
4
- aiohttp==3.11.11
5
- aiosignal==1.3.2
6
- annotated-types==0.7.0
7
- anyio @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_68kdsx8iyd/croot/anyio_1729121281958/work
8
- appnope @ file:///Users/ktietz/demo/mc3/conda-bld/appnope_1629146036738/work
9
- argon2-cffi @ file:///opt/conda/conda-bld/argon2-cffi_1645000214183/work
10
- argon2-cffi-bindings @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_2ef471wnyf/croot/argon2-cffi-bindings_1736182451265/work
11
- asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
12
- async-lru @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_02efro5ps8/croot/async-lru_1699554529181/work
13
- async-timeout==5.0.1
14
- attrs @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_93pjmt0git/croot/attrs_1734533120523/work
15
- Babel @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_00k1rl2pus/croot/babel_1671781944131/work
16
- backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
17
- beautifulsoup4 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_94rx5n7wo9/croot/beautifulsoup4-split_1718029832430/work
18
- bleach @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_faqg19k8gh/croot/bleach_1732292152791/work
19
- blis==1.2.0
20
- Brotli @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f7i0oxypt6/croot/brotli-split_1736182464088/work
21
- catalogue==2.0.10
22
- certifi @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d8j59rqun5/croot/certifi_1734473289913/work/certifi
23
- cffi @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e4xd9yd9i2/croot/cffi_1736182819442/work
24
- charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work
25
- click==8.1.8
26
- cloudpathlib==0.20.0
27
- cloudpickle==3.1.0
28
- comm @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_3doui0bmzb/croot/comm_1709322861485/work
29
- confection==0.1.5
30
- contourpy==1.3.0
31
- cycler==0.12.1
32
- cymem==2.0.11
33
- dask==2024.8.0
34
- dask-expr==1.1.10
35
- datasets==3.2.0
36
- debugpy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_563_nwtkoc/croot/debugpy_1690905063850/work
37
- decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
38
- defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
39
- -e git+https://github.com/tokeron/diffusers.git@00769b5d64c2ea35201e0df7a082db3513619afe#egg=diffusers&subdirectory=../../../../../../diffusers
40
- dill==0.3.8
41
- distro==1.9.0
42
- docker-pycreds==0.4.0
43
- editdistance==0.8.1
44
- en_core_web_lg @ https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.8.0/en_core_web_lg-3.8.0-py3-none-any.whl#sha256=293e9547a655b25499198ab15a525b05b9407a75f10255e405e8c3854329ab63
45
- en_core_web_md @ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.8.0/en_core_web_md-3.8.0-py3-none-any.whl#sha256=5e6329fe3fecedb1d1a02c3ea2172ee0fede6cea6e4aefb6a02d832dba78a310
46
- en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
47
- eval_type_backport==0.2.2
48
- exceptiongroup @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b2258scr33/croot/exceptiongroup_1706031391815/work
49
- executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
50
- fastapi==0.115.12
51
- fastjsonschema @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_d1wgyi4enb/croot/python-fastjsonschema_1731939426145/work
52
- ffmpy==0.5.0
53
- filelock==3.16.1
54
- fonttools==4.55.3
55
- frozenlist==1.5.0
56
- fsspec==2024.9.0
57
- gitdb==4.0.12
58
- GitPython==3.1.44
59
- gradio==4.44.1
60
- gradio_client==1.3.0
61
- h11 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_110bmw2coo/croot/h11_1706652289620/work
62
- httpcore @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_fcxiho9nv7/croot/httpcore_1706728465004/work
63
- httpx @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_cc4egw1482/croot/httpx_1723474826664/work
64
- huggingface-hub==0.27.1
65
- idna==3.10
66
- importlib_metadata @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_cc4qelzghy/croot/importlib_metadata-suite_1732633706960/work
67
- importlib_resources==6.5.2
68
- ipykernel @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_ddflobe9t3/croot/ipykernel_1728665605034/work
69
- ipython @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_6599f73fa7/croot/ipython_1694181355402/work
70
- jedi @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_38ctoinnl0/croot/jedi_1733987402850/work
71
- Jinja2 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_b15nuwux5r/croot/jinja2_1730902833938/work
72
- jiter==0.8.2
73
- joblib==1.4.2
74
- json5 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b9ww6ewhv3/croot/json5_1730786813588/work
75
- jsonschema @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_7boelfqucq/croot/jsonschema_1728486715888/work
76
- jsonschema-specifications @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_d38pclgu95/croot/jsonschema-specifications_1699032390832/work
77
- jupyter-events @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_db0avcjzq5/croot/jupyter_events_1718738111427/work
78
- jupyter-lsp @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_ae9br5v37x/croot/jupyter-lsp-meta_1699978259353/work
79
- jupyter_client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_58w2siozyz/croot/jupyter_client_1699455907045/work
80
- jupyter_core @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_73nomeum4p/croot/jupyter_core_1718818302815/work
81
- jupyter_server @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d1t69bk94b/croot/jupyter_server_1718827086930/work
82
- jupyter_server_terminals @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e7ryd60iuw/croot/jupyter_server_terminals_1686870731283/work
83
- jupyterlab @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a2d0br6r6g/croot/jupyterlab_1725895226942/work
84
- jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
85
- jupyterlab_server @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f64fg3hglz/croot/jupyterlab_server_1725865356410/work
86
- kiwisolver==1.4.7
87
- langcodes==3.5.0
88
- language_data==1.3.0
89
- locket==1.0.0
90
- marisa-trie==1.2.1
91
- markdown-it-py==3.0.0
92
- MarkupSafe @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a84ni4pci8/croot/markupsafe_1704206002077/work
93
- matplotlib==3.9.4
94
- matplotlib-inline @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_f6fdc0hldi/croots/recipe/matplotlib-inline_1662014472341/work
95
- matplotlib-venn==1.1.2
96
- mdurl==0.1.2
97
- mistune @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_17ya6k1sbs/croots/recipe/mistune_1661496228719/work
98
- mpmath==1.3.0
99
- multidict==6.1.0
100
- multiprocess==0.70.16
101
- murmurhash==1.0.12
102
- nbclient @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_626hpwnurm/croot/nbclient_1698934218848/work
103
- nbconvert @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_f4c1s1qk1f/croot/nbconvert_1728049432295/work
104
- nbformat @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_2cv_qoc1gw/croot/nbformat_1728049423516/work
105
- nest-asyncio @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_310vb5e2a0/croot/nest-asyncio_1708532678212/work
106
- networkx==3.2.1
107
- nltk==3.9.1
108
- notebook @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_539v4hufo2/croot/notebook_1727199149603/work
109
- notebook_shim @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d6_ze10f45/croot/notebook-shim_1699455897525/work
110
- numpy==2.0.2
111
- openai==1.59.7
112
- orjson==3.10.16
113
- overrides @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_70s80guh9g/croot/overrides_1699371144462/work
114
- packaging @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a6_qk3qyg7/croot/packaging_1734472142254/work
115
- pandas==2.2.3
116
- pandocfilters @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work
117
- parso @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_8824a1w4md/croot/parso_1733963320105/work
118
- partd==1.4.2
119
- patsy==1.0.1
120
- pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
121
- pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
122
- pillow==10.4.0
123
- platformdirs @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a8u4fy8k9o/croot/platformdirs_1692205661656/work
124
- plotly==5.24.1
125
- preshed==3.0.9
126
- prometheus_client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_803ymjpv2u/croot/prometheus_client_1731958793251/work
127
- prompt-toolkit @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_c63v4kqjzr/croot/prompt-toolkit_1704404354115/work
128
- propcache==0.2.1
129
- protobuf==5.29.2
130
- psutil @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1310b568-21f4-4cb0-b0e3-2f3d31e39728k9coaga5/croots/recipe/psutil_1656431280844/work
131
- ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
132
- pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
133
- pyarrow==18.1.0
134
- pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
135
- pydantic==2.10.4
136
- pydantic_core==2.27.2
137
- pydub==0.25.1
138
- Pygments @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_29bs9f_dh9/croot/pygments_1684279974747/work
139
- pyparsing==3.2.1
140
- PySocks @ file:///Users/ktietz/Code/oss/ci_pkgs/pysocks_1626781349491/work
141
- python-box==7.3.0
142
- python-dateutil @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_66ud1l42_h/croot/python-dateutil_1716495741162/work
143
- python-json-logger @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_9bjmcmh4nm/croot/python-json-logger_1734370248301/work
144
- python-multipart==0.0.20
145
- pytz @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a4b76c83ik/croot/pytz_1713974318928/work
146
- PyYAML @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_faoex52hrr/croot/pyyaml_1728657970485/work
147
- pyzmq @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_95lsut8ymz/croot/pyzmq_1734709560733/work
148
- referencing @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_5cz64gsx70/croot/referencing_1699012046031/work
149
- regex==2024.11.6
150
- requests @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_ee45nsd33z/croot/requests_1730999134038/work
151
- rfc3339-validator @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_76ae5cu30h/croot/rfc3339-validator_1683077051957/work
152
- rfc3986-validator @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d0l5zd97kt/croot/rfc3986-validator_1683058998431/work
153
- rich==13.9.4
154
- rpds-py @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_93fzmr7v9h/croot/rpds-py_1732228422522/work
155
- ruff==0.11.6
156
- safetensors==0.5.0
157
- scikit-learn==1.6.0
158
- scipy==1.13.1
159
- seaborn==0.13.2
160
- semantic-version==2.10.0
161
- Send2Trash @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5b31f0zzlv/croot/send2trash_1699371144121/work
162
- sentencepiece==0.2.0
163
- sentry-sdk==2.19.2
164
- setproctitle==1.3.4
165
- shellingham==1.5.4
166
- six @ file:///tmp/build/80754af9/six_1644875935023/work
167
- smart-open==7.1.0
168
- smmap==5.0.2
169
- sniffio @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1573pknjrg/croot/sniffio_1705431298885/work
170
- soupsieve @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_9798xzs_03/croot/soupsieve_1696347567192/work
171
- spacy==3.8.3
172
- spacy-legacy==3.0.12
173
- spacy-loggers==1.0.5
174
- srsly==2.5.1
175
- stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
176
- starlette==0.46.2
177
- statsmodels==0.14.4
178
- swifter==1.4.0
179
- sympy==1.13.1
180
- tabulate==0.9.0
181
- tenacity==9.0.0
182
- terminado @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_fcfvyc0an2/croot/terminado_1671751835701/work
183
- thinc==8.3.4
184
- threadpoolctl==3.5.0
185
- tinycss2 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_fcw5_i306t/croot/tinycss2_1668168825117/work
186
- together==1.4.1
187
- tokenizers==0.21.0
188
- tomli @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d0e5ffbf-5cf1-45be-8693-c5dff8108a2awhthtjlq/croots/recipe/tomli_1657175508477/work
189
- tomlkit==0.12.0
190
- toolz==1.0.0
191
- torch==2.5.1
192
- torchvision==0.20.1
193
- tornado @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_0axef5a0m0/croot/tornado_1733960501260/work
194
- tqdm==4.67.1
195
- traitlets @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_500m2_1wyk/croot/traitlets_1718227071952/work
196
- transformers==4.47.1
197
- typer==0.15.1
198
- typing_extensions @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_0b3jpv_f79/croot/typing_extensions_1734714864260/work
199
- tzdata==2024.2
200
- urllib3 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_06_m8gdsy6/croot/urllib3_1727769822458/work
201
- uvicorn==0.34.2
202
- wandb==0.19.1
203
- wasabi==1.1.3
204
- wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
205
- weasel==0.4.1
206
- webencodings==0.5.1
207
- websocket-client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d37u7gqts8/croot/websocket-client_1715878310260/work
208
- websockets==12.0
209
- wordcloud==1.9.4
210
- wrapt==1.17.2
211
- xxhash==3.5.0
212
- yarl==1.18.3
213
- zipp @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_echurpkwug/croot/zipp_1732630743967/work
 
1
+ torch
2
+ transformers
3
+ pandas
4
+ functools
5
+ tqdm
6
+ abc
7
+ enum
8
+ typing
9
+ scikit-learn
10
+ gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__init__.py ADDED
File without changes
utils/calibration_utils.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import nn
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator
8
+ from transformers import get_scheduler
9
+ from accelerate import Accelerator
10
+ from accelerate.utils import set_seed
11
+ from collections import defaultdict
12
+ from torch.utils.data import DataLoader
13
+ import torch.optim as optim
14
+
15
+ from ..utils.data_utils import load_lm_dataset, extract_new_words_from_dataset, get_group_texts_func, get_tokenize_func
16
+
17
+
18
+ class EmbeddingCalibrator(nn.Module):
19
+ def __init__(self, hidden_size, lora_r=None, lora_alpha=None, dtype=torch.bfloat16):
20
+ super().__init__()
21
+ self.use_lora = lora_r is not None
22
+
23
+ if not self.use_lora:
24
+ self.weight = nn.Parameter(torch.zeros(hidden_size, hidden_size, dtype=dtype))
25
+ else:
26
+ self.lora_scaling = lora_alpha / lora_r if lora_alpha is not None else 1.0
27
+ self.lora_A = nn.Parameter(torch.randn(lora_rank, hidden_size, dtype=dtype) * (1/lora_r))
28
+ self.lora_B = nn.Parameter(torch.zeros(hidden_size, lora_rank, dtype=dtype))
29
+
30
+ def forward(self, x):
31
+ if not self.use_lora:
32
+ return x + torch.matmul(x, self.weight.t())
33
+ else:
34
+ # Low-rank adaptation
35
+ lora_out = torch.matmul(x, self.lora_A.t())
36
+ lora_out = torch.matmul(lora_out, self.lora_B.t())
37
+ return x + self.lora_scaling * lora_out
38
+
39
+
40
+ class CalibrationModel(nn.Module):
41
+ def __init__(
42
+ self,
43
+ base_model, lm_head, original_vocab_size, num_new_tokens,
44
+ calibrate_embedding=True, calibrate_lm_head=True, empty_init=False,
45
+ lora_alpha=None, lora_r=None,
46
+ target_loss_weight=0.15, subsequent_loss_weight=0.15,
47
+ ):
48
+ super().__init__()
49
+ self.base_model = base_model
50
+ self.lm_head = lm_head
51
+ self.new_tokens_start = original_vocab_size
52
+ self.new_tokens_end = original_vocab_size + num_new_tokens
53
+
54
+ self.calibrate_lm_head = calibrate_lm_head
55
+ self.calibrate_embedding = calibrate_embedding
56
+ if not empty_init:
57
+ self.lm_head_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha)
58
+ self.embedding_calibrator = EmbeddingCalibrator(base_model.config.hidden_size, lora_r, lora_alpha)
59
+
60
+ self.loss_fct = nn.CrossEntropyLoss(reduction="none")
61
+ self.subsequent_tokens_loss_alpha = subsequent_loss_weight
62
+ self.new_tokens_loss_alpha = target_loss_weight
63
+ self.original_tokens_loss_alpha = 1 - self.new_tokens_loss_alpha - self.subsequent_tokens_loss_alpha
64
+
65
+ def forward(self, input_ids, labels, attention_mask=None):
66
+ # shift labels by 1 for CLM
67
+ labels = labels[:, 1:].contiguous()
68
+ input_ids = input_ids[:, :-1].contiguous()
69
+
70
+ if self.calibrate_embedding:
71
+ E_weights = self.base_model.get_input_embeddings().weight.data
72
+ E_weights = torch.cat((E_weights[:self.new_tokens_start], self.embedding_calibrator(E_weights[self.new_tokens_start:])))
73
+ input_embeddings = E_weights[input_ids]
74
+ if attention_mask is None:
75
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
76
+ outputs = self.base_model(inputs_embeds=input_embeddings, attention_mask=attention_mask)
77
+ else:
78
+ with torch.no_grad():
79
+ # Forward pass through the base model
80
+ outputs = self.base_model(input_ids, attention_mask=attention_mask)
81
+
82
+ if self.calibrate_lm_head:
83
+ with torch.no_grad():
84
+ lm_head_weights = self.lm_head.weight
85
+ normed_weights = lm_head_weights.clone()
86
+ normed_weights[self.new_tokens_start:self.new_tokens_end] = self.lm_head_calibrator(lm_head_weights[self.new_tokens_start:self.new_tokens_end])
87
+ logits = torch.matmul(outputs['last_hidden_state'], normed_weights.T)
88
+ else:
89
+ if self.calibrate_embedding:
90
+ logits = self.lm_head(outputs['last_hidden_state'])
91
+ else:
92
+ with torch.no_grad():
93
+ logits = self.lm_head(outputs['last_hidden_state'])
94
+
95
+ per_example_loss = self.loss_fct(logits.transpose(1,2), labels)
96
+ original_tokens_mask = labels < self.new_tokens_start
97
+ new_tokens_mask = ~original_tokens_mask
98
+ loss = 0.0
99
+ if self.original_tokens_loss_alpha > 0.0:
100
+ loss += self.original_tokens_loss_alpha * per_example_loss[original_tokens_mask].mean()
101
+ if self.new_tokens_loss_alpha > 0.0:
102
+ loss += self.new_tokens_loss_alpha * per_example_loss[new_tokens_mask].mean()
103
+ if self.subsequent_tokens_loss_alpha > 0.0:
104
+ subsequent_tokens_mask = torch.zeros_like(original_tokens_mask, dtype=torch.bool)
105
+ subsequent_tokens_mask[:, 1:][new_tokens_mask[:, :-1]] = True
106
+ loss += self.subsequent_tokens_loss_alpha * per_example_loss[subsequent_tokens_mask].mean()
107
+
108
+ return {'loss': loss, 'logits': logits}
109
+
110
+ def get_calibrators(self):
111
+ embedding_calibrator = self.embedding_calibrator if self.calibrate_embedding else None
112
+ lm_head_calibrator = self.lm_head_calibrator if self.calibrate_lm_head else None
113
+ return {
114
+ "embedding_calibrator": embedding_calibrator,
115
+ "lm_head_calibrator": lm_head_calibrator,
116
+ "new_tokens_start": self.new_tokens_start,
117
+ "new_tokens_end": self.new_tokens_end,
118
+ }
119
+
120
+ def set_calibrators(self, embedding_calibrator=None, lm_head_calibrator=None):
121
+ self.embedding_calibrator = embedding_calibrator
122
+ self.lm_head_calibrator = lm_head_calibrator
123
+
124
+ def save_calibrators(self, save_dir):
125
+ os.makedirs(save_dir, exist_ok=True)
126
+ if self.calibrate_embedding:
127
+ torch.save(self.embedding_calibrator, os.path.join(save_dir, "embedding_calibrator.pt"))
128
+ if self.calibrate_lm_head:
129
+ torch.save(self.lm_head_calibrator, os.path.join(save_dir, "lm_head_calibrator.pt"))
130
+
131
+ def load_calibrators(self, load_dir, fail_ok=False):
132
+ """Loads the model's state dictionary from a file."""
133
+ try:
134
+ if self.calibrate_embedding:
135
+ self.embedding_calibrator = torch.load(os.path.join(load_dir, "embedding_calibrator.pt"))
136
+ if self.calibrate_lm_head:
137
+ self.lm_head_calibrator = torch.load(os.path.join(load_dir, "lm_head_calibrator.pt"))
138
+ return True
139
+ except:
140
+ if fail_ok:
141
+ return False
142
+ raise FileNotFoundError(f"Loading calibrators from '{load_dir}' failed")
143
+
144
+
145
+ def get_calibration_model(model, original_vocab_size, num_new_tokens, target_loss_weight=0.15, subsequent_loss_weight=0.15):
146
+ calibrated_model = CalibrationModel(model.model, model.lm_head, original_vocab_size, num_new_tokens, target_loss_weight=target_loss_weight, subsequent_loss_weight=subsequent_loss_weight)
147
+ calibrated_model.base_model.eval()
148
+ calibrated_model.lm_head.eval()
149
+
150
+ for param in calibrated_model.base_model.parameters():
151
+ param.requires_grad = False
152
+ for param in calibrated_model.lm_head.parameters():
153
+ param.requires_grad = False
154
+ for param in calibrated_model.lm_head_calibrator.parameters():
155
+ param.requires_grad = True
156
+ for param in calibrated_model.embedding_calibrator.parameters():
157
+ param.requires_grad = True
158
+
159
+ return calibrated_model
160
+
161
+
162
+ def train_calibration_model(calibrated_model: CalibrationModel, tokenizer, dataset, save_dir=None, max_samples=None, filter_examples_without_new_tokens=True, lr=1e-4, lr_schedule="linear", num_epochs=1, batch_size=8, max_length=256, n_warmup_steps=0, text_col_name="text", clip_grad_norm=1.0, mixed_precision=None):
163
+ accelerator = Accelerator(mixed_precision=mixed_precision)
164
+ # Optimizer
165
+ optimizer = optim.AdamW(calibrated_model.parameters(), lr=lr)
166
+
167
+ # Tokenize data
168
+ if tokenizer.bos_token is not None and max_length:
169
+ add_start_token = True
170
+ # leave room for <BOS> token to be added:
171
+ max_tokenized_len = max_length - 1
172
+ else:
173
+ add_start_token = False
174
+ max_tokenized_len = max_length
175
+
176
+ def _add_start_token(batch):
177
+ bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device)
178
+ batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1)
179
+ batch["attention_mask"] = torch.cat(
180
+ [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1)
181
+ return batch
182
+
183
+ tokenize_function = get_tokenize_func(tokenizer, text_col_name)
184
+
185
+ column_names = dataset.column_names
186
+
187
+ with accelerator.main_process_first():
188
+ tokenized_dataset = dataset.map(
189
+ tokenize_function,
190
+ batched=True,
191
+ remove_columns=column_names,
192
+ load_from_cache_file=False,
193
+ desc="Running tokenizer on dataset",
194
+ )
195
+ group_texts = get_group_texts_func(block_size=max_tokenized_len)
196
+ lm_dataset = tokenized_dataset.map(
197
+ group_texts,
198
+ batched=True,
199
+ )
200
+
201
+ if filter_examples_without_new_tokens:
202
+ examples_w_new_token = np.arange(len(lm_dataset))[np.any(np.array(lm_dataset['input_ids']) >= calibrated_model.new_tokens_start, axis=1)]
203
+ lm_dataset = lm_dataset.select(examples_w_new_token)
204
+
205
+ if max_samples is not None:
206
+ lm_dataset = lm_dataset.select(np.arange(max_samples))
207
+
208
+ data_collator = default_data_collator
209
+
210
+ # Create data loaders
211
+ dataloader = DataLoader(
212
+ lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=True, shuffle=True,
213
+ )
214
+
215
+ # Learning rate scheduler
216
+ if isinstance(n_warmup_steps, float):
217
+ n_warmup_steps = n_warmup_steps * len(dataloader)
218
+ scheduler = get_scheduler(lr_schedule, optimizer=optimizer, num_warmup_steps=n_warmup_steps, num_training_steps=len(dataloader) * num_epochs)
219
+
220
+ calibrated_model, dataloader = accelerator.prepare(calibrated_model, dataloader)
221
+
222
+ # Freeze the original lm_head weights
223
+ for param in calibrated_model.lm_head.parameters():
224
+ param.requires_grad = False
225
+
226
+ calibrated_model.train()
227
+ for epoch in tqdm(range(num_epochs), unit="epochs", desc="Fitting calibration"):
228
+ total_loss = 0.0
229
+ for step, batch in tqdm(enumerate(dataloader), total=len(dataloader), miniters=10, unit="batches"):
230
+ if add_start_token:
231
+ batch = _add_start_token(batch)
232
+ batch["labels"] = batch["input_ids"]
233
+ optimizer.zero_grad()
234
+ outputs = calibrated_model(**batch)
235
+ loss = outputs['loss']
236
+ loss.backward()
237
+ torch.nn.utils.clip_grad_norm_(calibrated_model.parameters(), max_norm=clip_grad_norm)
238
+ optimizer.step()
239
+ scheduler.step()
240
+
241
+ total_loss += loss.item()
242
+
243
+ # # Log loss
244
+ # if step % 10 == 0:
245
+ # print(f"Epoch {epoch + 1}, Step {step}, Loss: {loss.item()}")
246
+
247
+ avg_loss = total_loss / len(dataloader)
248
+ print(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss}")
249
+
250
+ if save_dir is not None:
251
+ calibrated_model.save_calibrators(save_dir)
252
+
253
+ return calibrated_model
254
+
255
+
256
+ def merge_calibrators_to_hf_model(hf_model, new_tokens_start, new_tokens_end=None, embedding_calibrator=None, lm_head_calibrator=None):
257
+ embedding_calibrator.to(hf_model.device)
258
+ lm_head_calibrator.to(hf_model.device)
259
+ if embedding_calibrator is not None:
260
+ embedding_weights = hf_model.get_input_embeddings().weight
261
+ with torch.no_grad():
262
+ calibrated_weights = embedding_calibrator(embedding_weights[new_tokens_start:new_tokens_end])
263
+ hf_model.model.embed_tokens.weight.data[
264
+ new_tokens_start:new_tokens_end] = calibrated_weights
265
+
266
+ if lm_head_calibrator is not None:
267
+ lm_head_weights = hf_model.get_output_embeddings().weight
268
+ with torch.no_grad():
269
+ calibrated_weights = lm_head_calibrator(lm_head_weights[new_tokens_start:new_tokens_end])
270
+ hf_model.lm_head.weight.data[new_tokens_start:new_tokens_end] = calibrated_weights
271
+
272
+ return hf_model
273
+
274
+
275
+ def merge_calibration_model_to_hf_model(hf_model, calibrated_model):
276
+ calibrated_model.to(hf_model.device)
277
+ if calibrated_model.calibrate_lm_head:
278
+ lm_head_weights = calibrated_model.lm_head.weight
279
+ normed_weights = calibrated_model.lm_head_calibrator(lm_head_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end])
280
+ with torch.no_grad():
281
+ hf_model.lm_head.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights
282
+ if calibrated_model.calibrate_embedding:
283
+ embedding_weights = calibrated_model.base_model.get_input_embeddings().weight
284
+ normed_weights = calibrated_model.embedding_calibrator(embedding_weights[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end])
285
+ with torch.no_grad():
286
+ hf_model.model.embed_tokens.weight.data[calibrated_model.new_tokens_start:calibrated_model.new_tokens_end] = normed_weights
287
+ return hf_model
288
+
utils/data_utils.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from datasets import load_dataset, Dataset, DatasetDict
3
+ from itertools import chain
4
+ from tqdm import tqdm
5
+ from collections import Counter
6
+ from accelerate import Accelerator
7
+
8
+ LANGUAGES_TO_DECODE_FROM_BYTES = ["he", "fr", "uk"]
9
+ STREAMING_DATASETS = ["fineweb-edu"]
10
+
11
+
12
+ def load_pg19_val_and_test():
13
+ # Load the dataset in streaming mode
14
+ streaming_dataset = load_dataset("deepmind/pg19", split=None, streaming=True)
15
+
16
+ # Extract test and validation splits
17
+ test_split = list(streaming_dataset["test"])
18
+ validation_split = list(streaming_dataset["validation"])
19
+
20
+ # Convert them into regular datasets
21
+ test_dataset = Dataset.from_list(test_split)
22
+ validation_dataset = Dataset.from_list(validation_split)
23
+
24
+ # validation_dataset = load_dataset("deepmind/pg19", split="validation")
25
+ # test_dataset = load_dataset("deepmind/pg19", split="test")
26
+
27
+ return DatasetDict({"validation": validation_dataset, "test": test_dataset})
28
+
29
+
30
+ def load_pubmed(n_samples=10000):
31
+ # Load the dataset in streaming mode
32
+ streaming_dataset = load_dataset("MedRAG/pubmed", streaming=True)
33
+
34
+ # Extract test and validation splits
35
+ data = list(streaming_dataset["train"].take(n_samples*4))
36
+ train = data[:2*n_samples]
37
+ validation = data[2*n_samples:3*n_samples]
38
+ test = data[3*n_samples:]
39
+ # Convert them into regular datasets
40
+ train = Dataset.from_list(train)
41
+ validation = Dataset.from_list(validation)
42
+ test = Dataset.from_list(test)
43
+ dataset = DatasetDict({"train": train, 'validation': validation, 'test': test})
44
+ dataset = dataset.rename_column('content', 'text')
45
+ return dataset
46
+
47
+
48
+ def load_lm_dataset(dataset_name, language="en", split=None):
49
+ """
50
+ Loads a popular pretraining or perplexity evaluation dataset by name and language.
51
+
52
+ Args:
53
+ dataset_name (str): The name of the dataset to load. Options include:
54
+ - 'wikitext' (wikitext-2, smaller WikiText dataset)
55
+ - 'wikitext-103' (larger WikiText dataset)
56
+ - 'pg19' (Project Gutenberg dataset for long-context modeling)
57
+ - 'c4' (Common Crawl-based English corpus)
58
+ - 'wiki40b' (Wikipedia dataset in multiple languages)
59
+ - 'mc4' (Multilingual C4 dataset in various languages)
60
+ language (str): Language code for datasets that support multilingual options (e.g., 'en' for English).
61
+ Defaults to 'en'.
62
+
63
+ Returns:
64
+ Dataset: Loaded Hugging Face dataset.
65
+ """
66
+ if dataset_name.lower() == 'wikitext':
67
+ return load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split=split)
68
+ elif dataset_name.lower() == 'fineweb-edu':
69
+ return load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT")
70
+ elif dataset_name.lower() == 'wikitext-103':
71
+ return load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split=split)
72
+ elif dataset_name.lower() == 'cord19':
73
+ return load_dataset("allenai/cord19", "fulltext", trust_remote_code=True)
74
+ elif dataset_name.lower() == 'pubmed':
75
+ return load_pubmed()
76
+ elif dataset_name.lower() == 'wikilingua':
77
+ dataset = load_dataset("GEM/wiki_lingua", trust_remote_code=True)
78
+ dataset = dataset.filter(lambda ex: (ex['source_language'] == "en") & (ex['target_language'] == "en"))
79
+ dataset = dataset.rename_column("source", "text")
80
+ dataset = dataset.rename_column("target", "summary")
81
+ return dataset
82
+ elif dataset_name.lower() == 'xsum':
83
+ dataset = load_dataset("EdinburghNLP/xsum")
84
+ dataset = dataset.rename_column("document", "text")
85
+ return dataset
86
+ elif dataset_name.lower() == 'cnn':
87
+ dataset = load_dataset("abisee/cnn_dailymail", "3.0.0")
88
+ dataset = dataset.rename_column("article", "text")
89
+ dataset = dataset.rename_column("highlights", "summary")
90
+ dataset = dataset.map(lambda example: {"text": example["text"].replace("(CNN)", "")})
91
+ return dataset
92
+ elif dataset_name.lower() == 'pg19':
93
+ return load_pg19_val_and_test()
94
+ elif dataset_name.lower() == 'wiki40b':
95
+ dataset = load_dataset("google/wiki40b", language, split=split)
96
+ if language in LANGUAGES_TO_DECODE_FROM_BYTES:
97
+ dataset = dataset.map(lambda x: {
98
+ "text": bytes(x["text"][2:-1], "utf-8").decode("unicode_escape").encode("latin1").decode("utf-8").replace("_NEWLINE_", "\n")
99
+ })
100
+ return dataset
101
+ else:
102
+ raise ValueError(
103
+ "Dataset not recognized. Available options: 'wikitext-2', 'wikitext-103', 'pg19', 'c4', 'wiki40b', 'mc4'.")
104
+
105
+
106
+ def extract_new_words_from_dataset(
107
+ dataset: Dataset, tokenizer, text_column: str = "text", max_samples: int = None, filter_func=(lambda word, token_count: True)):
108
+ """
109
+ Loads a Hugging Face dataset and extracts all unique words from the specified text column.
110
+
111
+ Args:
112
+ dataset (Dataset): Name of the dataset to load.
113
+ split (str): Dataset split to use, typically 'train' for training data. Defaults to 'train'.
114
+ text_column (str): The column in the dataset containing text. Defaults to 'text'.
115
+ max_samples (int): Number of samples from the dataset to go over.
116
+
117
+ Returns:
118
+ set: A set of unique words in the dataset.
119
+ """
120
+ if max_samples:
121
+ dataset = dataset.select(range(max_samples))
122
+
123
+ # Regular expression to split text into words (adjust as needed for specific languages)
124
+ # word_pattern = re.compile(r"\b\w+\b")
125
+ word_pattern = re.compile(r"\b\w+(?:[-']\w+)*\b")
126
+
127
+ # Iterate over each entry in the dataset and extract unique words
128
+ all_words = list()
129
+ new_words = list()
130
+ for record in tqdm(dataset, total=len(dataset), miniters=10, desc="Extracting all words from dataset...", unit="examples"):
131
+ text = record.get(text_column, "")
132
+ words = word_pattern.findall(text)
133
+ all_words += words
134
+
135
+ # all_words = list(dict.fromkeys(all_words))
136
+ word_frequencies = Counter(all_words)
137
+ all_words = list(word_frequencies.keys())
138
+ token_counts = [len(x) for x in tokenizer(all_words, add_special_tokens=False)["input_ids"]]
139
+ w_whitespace_token_counts = [len(x) for x in tokenizer([f" {w}" for w in all_words], add_special_tokens=False)["input_ids"]]
140
+
141
+ new_words = [word for word, count, w_whitespace_count in zip(all_words, token_counts, w_whitespace_token_counts) if ((count > 1) and (w_whitespace_count > 1) and filter_func(word, count))]
142
+ new_words_freq = {word: word_frequencies[word] for word in new_words}
143
+ # for word, token_count in tqdm(all_words, total=len(all_words), miniters=10, desc="Finding new words...", unit="words"):
144
+ # if (not tokenizer.vocab.get(word, False)) and :
145
+ # new_words.append(word)
146
+
147
+ # remove duplicates and return
148
+ return new_words, new_words_freq
149
+
150
+
151
+ def get_group_texts_func(block_size=1024):
152
+ def group_texts(examples):
153
+ # Concatenate all texts.
154
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
155
+
156
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
157
+ # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
158
+ # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
159
+ total_length = (total_length // block_size) * block_size
160
+ # Split by chunks of max_len.
161
+ result = {
162
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
163
+ for k, t in concatenated_examples.items()
164
+ }
165
+ result["labels"] = result["input_ids"].copy()
166
+ return result
167
+ return group_texts
168
+
169
+
170
+ def get_tokenize_func(tokenizer, text_col_name):
171
+ def _tokenize(examples):
172
+ output = tokenizer(
173
+ examples[text_col_name],
174
+ return_token_type_ids=False,
175
+ add_special_tokens=False,
176
+ )
177
+ return output
178
+ return _tokenize
179
+
180
+
181
+ def tokenize_and_prepare_dataset(
182
+ dataset, tokenizer, accelerator=None,
183
+ text_col_name: str = "text",
184
+ max_length: int = 256,
185
+ eval_max_samples: int = None,
186
+ ):
187
+
188
+ if tokenizer.bos_token is not None and max_length:
189
+ # leave room for <BOS> token to be added:
190
+ max_tokenized_len = max_length - 1
191
+ else:
192
+ max_tokenized_len = max_length
193
+
194
+ tokenize_function = get_tokenize_func(tokenizer, text_col_name)
195
+
196
+ column_names = dataset.column_names
197
+
198
+ tokenized_dataset = dataset.map(
199
+ tokenize_function,
200
+ batched=True,
201
+ remove_columns=column_names,
202
+ load_from_cache_file=False,
203
+ desc="Running tokenizer on dataset",
204
+ )
205
+ group_texts = get_group_texts_func(block_size=max_tokenized_len)
206
+ lm_dataset = tokenized_dataset.map(
207
+ group_texts,
208
+ batched=True,
209
+ )
210
+
211
+ if eval_max_samples:
212
+ lm_dataset = lm_dataset.select(range(eval_max_samples))
213
+
214
+ return lm_dataset
utils/enums.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class RetrievalTechniques(Enum):
5
+ ReverseLogitLens = 1
6
+ LogitLens = 2
7
+ Patchscopes = 3
8
+
9
+
10
+ class MultiTokenKind(Enum):
11
+ Split = 1
12
+ Typo = 2
13
+ Natural = 3
utils/eval_utils.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torch.utils.data import DataLoader
5
+ from accelerate import Accelerator
6
+ from transformers import default_data_collator
7
+ from collections import defaultdict
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+
11
+
12
+ def is_not_number(s):
13
+ try:
14
+ float(s) # Try converting the string to a float
15
+ return False # If conversion is successful, it's a number
16
+ except ValueError:
17
+ return True # If conversion fails, it's not a number
18
+
19
+
20
+ def get_contexts_ending_with_word(word, dataset):
21
+ result_contexts = []
22
+ word_len = len(word)
23
+
24
+ # Iterate over the dataset
25
+ for example in dataset:
26
+ text = example["text"]
27
+
28
+ # Find all occurrences of the word in the text
29
+ start = 0
30
+ while True:
31
+ idx = text.find(word, start)
32
+ if idx == -1:
33
+ break
34
+
35
+ # Ensure that the word is isolated (not a substring of another word)
36
+ if (idx == 0 or not text[idx - 1].isalnum()) and (
37
+ idx + word_len == len(text) or not text[idx + word_len].isalnum()):
38
+ # Text ends with the word
39
+ result_contexts.append(text[:idx + word_len].strip())
40
+ start = idx + word_len
41
+
42
+ return result_contexts
43
+
44
+
45
+ def get_texts_containing_word(words, dataset):
46
+ result_texts = []
47
+ words_set = set(words)
48
+
49
+ # Iterate over the dataset
50
+ for example in dataset:
51
+ if words_set.intersection(set(example["text"].split())):
52
+ result_texts.append(example["text"])
53
+
54
+ return result_texts
55
+
56
+
57
+ def compute_topk_token_rank(logits, labels, k=1000):
58
+ # Get the top-k predicted logits and their indices
59
+ topk_logits, topk_indices = torch.topk(logits, k, dim=-1)
60
+
61
+ # Expand the labels for comparison
62
+ labels_expanded = labels.unsqueeze(-1).expand_as(topk_indices)
63
+
64
+ # Check if the label token is within the top-k predictions
65
+ rank_in_topk = (topk_indices == labels_expanded).nonzero(as_tuple=False)
66
+
67
+ # Create a rank tensor initialized with k (max rank is k)
68
+ ranks = torch.full(labels.shape, k, dtype=torch.long, device=logits.device)
69
+
70
+ # For labels in top-k, set the rank accordingly
71
+ ranks[rank_in_topk[:, 0], rank_in_topk[:, 1]] = rank_in_topk[:, 2] + 1
72
+
73
+ return ranks
74
+
75
+
76
+ def count_tokens_in_dataset(dataset, tokenizer, text_column='text'):
77
+ def tokenize_and_count(examples):
78
+ return {'num_tokens': [len(tokenizer(ex).input_ids) for ex in examples[text_column]]}
79
+
80
+ tokenized_dataset = dataset.map(tokenize_and_count, batched=True, remove_columns=dataset.column_names)
81
+
82
+ total_tokens = sum(tokenized_dataset['num_tokens'])
83
+ return total_tokens
84
+
85
+
86
+ def filter_single_token_words(array, tokenizer, add_space_prefix_for_lower=True):
87
+ def _is_multi_token(word):
88
+ if add_space_prefix_for_lower and word[0].islower():
89
+ word = " " + word
90
+ return len(tokenizer.encode(word, add_special_tokens=False))
91
+ token_counts = array.apply(_is_multi_token)
92
+ mask = token_counts > 1
93
+ return array[mask], token_counts
94
+
95
+
96
+ # TODO make clearer what's its use
97
+ def get_last_zero_in_every_seq_mask(tensor):
98
+ # Find where consecutive zeros end
99
+ zero_mask = (tensor == 0)
100
+ diff = torch.diff(zero_mask.int(), dim=1)
101
+ last_zero_mask = torch.cat([diff, torch.ones(tensor.size(0), 1, dtype=diff.dtype).to(tensor.device)], dim=1) == -1
102
+
103
+ # Create the output
104
+ output = 1 - tensor
105
+ output[zero_mask & ~last_zero_mask] = 0
106
+ return output
107
+
108
+
109
+ def get_first_zero_in_every_seq_mask(tensor):
110
+ # Identify where consecutive zeros begin
111
+ zero_mask = (tensor == 0)
112
+ diff = torch.diff(zero_mask.int(), dim=1, prepend=torch.zeros(tensor.size(0), 1, dtype=torch.int).to(tensor.device))
113
+ first_zero_mask = diff == 1 # Marks the beginning of each sequence of zeros
114
+
115
+ # Create the output
116
+ output = 1 - tensor
117
+ output[zero_mask & ~first_zero_mask] = 0
118
+ return output
119
+
120
+
121
+ def _add_start_token(batch, tokenizer):
122
+ bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * batch["input_ids"].size(dim=0)).to(batch["input_ids"].device)
123
+ batch["input_ids"] = torch.cat([bos_tokens_tensor, batch["input_ids"]], dim=1)
124
+ batch["attention_mask"] = torch.cat(
125
+ [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(batch["attention_mask"].device), batch["attention_mask"]], dim=1)
126
+ return batch
127
+
128
+
129
+ def _ignore_new_words_in_attention_mask(shift_attention_mask_batch, shift_labels, new_token_ids=None, replaced_token_seqs_by_len=None):
130
+ # Ignore token_ids of new vocabulary words in shift_labels and shift_logits
131
+ if new_token_ids is not None:
132
+ ignore_mask = torch.isin(shift_labels, new_token_ids)
133
+ shift_attention_mask_batch = shift_attention_mask_batch * (~ignore_mask).long()
134
+
135
+ # Ignore multi-token sequences of that were replaced with a single token
136
+ if replaced_token_seqs_by_len is not None:
137
+ # Create a mask that will be updated where sequences match
138
+ ignore_mask = shift_attention_mask_batch.clone() # Clone the attention mask to modify it
139
+ # Loop over sequences in skip_token_seqs
140
+ for seq_len, seqs in replaced_token_seqs_by_len.items():
141
+ # Create a sliding window of the same size as the skip_seq and check for matches
142
+ for i in range(shift_labels.size(1) - seq_len + 1):
143
+ # Check if the sequence matches at position i
144
+ window = shift_labels[:, i:i + seq_len]
145
+ curr_mask = torch.all(window.unsqueeze(1) == seqs.unsqueeze(0), dim=-1)
146
+ if curr_mask.any():
147
+ # Zero out the ignore mask for the length of the sequence
148
+ ignore_mask[curr_mask.any(dim=-1), i:i + seq_len] = 0
149
+ # Apply the ignore mask to the attention mask
150
+ shift_attention_mask_batch *= ignore_mask
151
+
152
+ return shift_attention_mask_batch, ignore_mask
153
+
154
+
155
+ # TODO consider not aggregating results here, to enable metrics for specific words
156
+ def compute_metrics(
157
+ logits, labels, attention_mask,
158
+ compute_target_metrics=True, compute_subsequent_metrics=True, compute_perplexity=False,
159
+ return_successful_targets=False,
160
+ original_labels=None, original_logits=None,
161
+ debug=False):
162
+ target_results = dict() # will hold metrics for all the new words we add or their original tokenization
163
+ background_results = dict() # will hold metrics for all background tokens, i.e., not the ones we add or replace
164
+ overall_results = dict() # will hold metrics for all tokens
165
+ successful_targets = None # will hold list of target tokens successfully predicted
166
+ if compute_subsequent_metrics:
167
+ # prepare labels and attentions masks for computing metrics only for the 1st tokens following the new words
168
+ subsequent_labels = labels[:, 1:]
169
+ subsequent_attention_mask = get_last_zero_in_every_seq_mask(attention_mask[..., :-1].contiguous())
170
+ subsequent_attention_mask_bool = subsequent_attention_mask == 1
171
+ attention_mask_bool = attention_mask == 1
172
+ overall_mask_bool = attention_mask_bool
173
+
174
+ if compute_target_metrics:
175
+ target_mask = get_first_zero_in_every_seq_mask(attention_mask)
176
+ target_mask_bool = target_mask == 1
177
+ overall_mask_bool = attention_mask_bool | target_mask_bool
178
+
179
+ if compute_perplexity:
180
+ background_results["perplexity"] = torch.exp(
181
+ (F.cross_entropy(logits.transpose(1, 2), labels, reduction="none") * attention_mask).sum(1)
182
+ / attention_mask.sum(1)
183
+ ).mean().detach().cpu().numpy()
184
+
185
+ top1 = logits.argmax(dim=-1)
186
+ if original_logits is not None:
187
+ orig_top1 = original_logits.argmax(dim=-1)
188
+
189
+ if compute_target_metrics:
190
+ target_results["top1_acc"] = ((labels == top1)[target_mask_bool]).detach().cpu().numpy()
191
+ if original_labels is not None:
192
+ target_results["sum_top1_acc"] = (
193
+ ((original_labels == top1) | (labels == top1))[target_mask_bool]).detach().cpu().numpy()
194
+ if original_logits is not None:
195
+ target_results["orig_top1_acc"] = (
196
+ (original_labels == orig_top1)[target_mask_bool]).detach().cpu().numpy()
197
+
198
+ if return_successful_targets:
199
+ successful_targets = (labels[(labels == top1) & target_mask_bool]).detach().cpu().numpy()
200
+
201
+ background_results["top1_acc"] = ((
202
+ labels == top1)[attention_mask_bool]).detach().cpu().numpy()
203
+ if compute_subsequent_metrics:
204
+ background_results["subsequent_top1_acc"] = ((subsequent_labels == top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy()
205
+ if original_logits is not None:
206
+ background_results["orig_top1_acc"] = (
207
+ (original_labels == orig_top1)[attention_mask_bool]).detach().cpu().numpy()
208
+ if compute_subsequent_metrics:
209
+ background_results["orig_subsequent_top1_acc"] = (
210
+ (subsequent_labels == orig_top1[:, 1:])[subsequent_attention_mask_bool]).detach().cpu().numpy()
211
+
212
+ overall_results["top1_acc"] = ((labels == top1))[overall_mask_bool].detach().cpu().numpy()
213
+ if original_labels is not None:
214
+ overall_results["sum_top1_acc"] = (
215
+ ((original_labels == top1) | (labels == top1)))[overall_mask_bool].detach().cpu().numpy()
216
+ if original_logits is not None:
217
+ overall_results["orig_top1_acc"] = (
218
+ (original_labels == orig_top1)[overall_mask_bool]).detach().cpu().numpy()
219
+
220
+ if debug:
221
+ import pdb; pdb.set_trace()
222
+ return background_results, target_results, overall_results, successful_targets
223
+
224
+
225
+ def eval_next_word_prediction(
226
+ model, tokenizer, lm_dataset, accelerator=None,
227
+ batch_size: int = 4,
228
+ new_token_ids=None, replaced_token_seqs_by_len=None,
229
+ new_token_to_original_first_token=None,
230
+ max_length: int = 256,
231
+ drop_last: bool = True,
232
+ eval_max_samples: int = None,
233
+ eval_shuffle_samples: bool = False,
234
+ reduction="none",
235
+ ):
236
+ if accelerator is None:
237
+ accelerator = Accelerator()
238
+ model.eval()
239
+ if tokenizer.bos_token is not None and max_length:
240
+ add_start_token = True
241
+ else:
242
+ add_start_token = False
243
+
244
+ data_collator = default_data_collator
245
+
246
+ if eval_max_samples:
247
+ eval_idx = range(len(lm_dataset), min(eval_max_samples, len(lm_dataset)))
248
+ if eval_shuffle_samples:
249
+ eval_idx = np.random.choice(len(lm_dataset), min(eval_max_samples, len(lm_dataset)))
250
+ lm_dataset = lm_dataset.select(eval_idx)
251
+
252
+ # Create data loaders
253
+ eval_dataloader = DataLoader(
254
+ lm_dataset, collate_fn=data_collator, batch_size=batch_size, drop_last=drop_last, shuffle=False,
255
+ )
256
+ eval_dataloader = accelerator.prepare(eval_dataloader)
257
+
258
+ model.eval()
259
+
260
+ if new_token_ids is not None:
261
+ new_token_ids = torch.tensor(new_token_ids).to(model.device)
262
+ if replaced_token_seqs_by_len is not None:
263
+ replaced_token_seqs_by_len = {token_length: torch.tensor(skip_token_seqs).to(model.device) for token_length, skip_token_seqs in replaced_token_seqs_by_len.items() if len(skip_token_seqs) > 0}
264
+ if new_token_to_original_first_token is not None:
265
+ # Convert the mapping into a tensor for efficient indexing, create a mapping tensor that defaults to identity
266
+ new_token_to_orig_first_mapping_tensor = torch.arange(len(tokenizer), device=model.device)
267
+ new_token_to_orig_first_mapping_tensor[torch.tensor(list(new_token_to_original_first_token.keys()), device=model.device)] = \
268
+ torch.tensor(list(new_token_to_original_first_token.values()), device=model.device)
269
+
270
+ target_metrics = defaultdict(list)
271
+ background_metrics = defaultdict(list)
272
+ overall_metrics = defaultdict(list)
273
+
274
+ # run eval and compute metrics
275
+ for batch_i, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader), miniters=10, desc="Evaluating vocabulary..."):
276
+ if add_start_token:
277
+ batch = _add_start_token(batch, tokenizer)
278
+
279
+ labels = batch["input_ids"]
280
+ attn_mask = batch["attention_mask"]
281
+ batch.pop("labels")
282
+ with torch.no_grad():
283
+ outputs = model(**batch)
284
+ out_logits = outputs.logits
285
+
286
+ shift_logits = out_logits[..., :-1, :].contiguous()
287
+ shift_labels = labels[..., 1:].contiguous()
288
+ shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
289
+
290
+ shift_attention_mask_batch, ignore_mask = \
291
+ _ignore_new_words_in_attention_mask(
292
+ shift_attention_mask_batch, shift_labels, new_token_ids, replaced_token_seqs_by_len)
293
+ original_labels = None if new_token_to_original_first_token is None \
294
+ else new_token_to_orig_first_mapping_tensor[shift_labels]
295
+ original_logits = None if new_token_ids is None else torch.cat([shift_logits[:, :, :min(new_token_ids)], shift_logits[:, :, max(new_token_ids)+1:]], dim=-1)
296
+
297
+ background_results, target_results, overall_results, successful_targets = \
298
+ compute_metrics(
299
+ shift_logits, shift_labels, shift_attention_mask_batch,
300
+ original_labels=original_labels, original_logits=original_logits, compute_perplexity=True)
301
+
302
+ for metric_name, metric_value in target_results.items():
303
+ target_metrics[metric_name].append(np.array(metric_value))
304
+ for metric_name, metric_value in background_results.items():
305
+ background_metrics[metric_name].append(metric_value)
306
+ for metric_name, metric_value in overall_results.items():
307
+ overall_metrics[metric_name].append(metric_value)
308
+
309
+ eval_dataloader = accelerator.free_memory(eval_dataloader)
310
+
311
+ def _concat_func(x):
312
+ if isinstance(x, np.ndarray) and len(x.shape) > 1:
313
+ x = np.concat(x)
314
+ elif isinstance(x, (list, tuple)) and len(x) > 1:
315
+ if isinstance(x[0], np.ndarray) and len(x[0].shape) == 0:
316
+ x = np.array(x)
317
+ else:
318
+ x = np.concat(x)
319
+ return x
320
+
321
+ # apply reduction
322
+ reduce_func = _concat_func
323
+ if reduction == 'mean':
324
+ reduce_func = lambda x: np.mean(_concat_func(x)).item()
325
+
326
+ for metric_name, metric_value in target_metrics.items():
327
+ target_metrics[metric_name] = reduce_func(metric_value)
328
+ for metric_name, metric_value in background_metrics.items():
329
+ background_metrics[metric_name] = reduce_func(metric_value)
330
+ for metric_name, metric_value in overall_metrics.items():
331
+ overall_metrics[metric_name] = reduce_func(metric_value)
332
+ return background_metrics, target_metrics, overall_metrics
333
+
334
+
utils/file_utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+
5
+
6
+ def save_df_to_dir(results_df, base_dir, sub_dirs, file_name_format, add_context, model_name):
7
+ # Get the root directory of the project
8
+ root_dir = os.path.dirname(os.path.abspath(__file__))
9
+
10
+ # Construct the output directory path
11
+ output_dir = os.path.join(root_dir, base_dir, *sub_dirs)
12
+ os.makedirs(output_dir, exist_ok=True)
13
+
14
+ # Construct the file name
15
+ file_name = file_name_format.format(model_name=model_name,
16
+ context="with_context" if add_context else "without_context")
17
+
18
+ # Construct the full file path
19
+ file_path = os.path.join(output_dir, file_name)
20
+
21
+ # Save the DataFrame to CSV
22
+ results_df.to_csv(file_path, index=False)
23
+
24
+
25
+ def merge_dfs(base_dir, exp_name, part_format="part_{i}_", output_dir=None,
26
+ filename="patchscopes_results.parquet", output_filename="patchscopes_results.parquet"):
27
+ """
28
+ Merges DataFrames from directories matching the part format into a single DataFrame,
29
+ and optionally saves the result to a file.
30
+
31
+ Args:
32
+ base_dir (str): The base directory containing the data.
33
+ exp_name (str): The experiment name to look for within part directories.
34
+ part_format (str): The general format for identifying parts (e.g., "part_{i}_").
35
+ output_dir (str, optional): Directory to save the merged DataFrame. Default is None.
36
+ filename (str): The filename of the Parquet file to read in each part directory.
37
+ output_filename (str): Name of the output file if saving is enabled.
38
+
39
+ Returns:
40
+ pd.DataFrame: A single DataFrame containing data from all parts.
41
+ """
42
+ dataframes = []
43
+ part_regex = part_format.replace("{i}", r"\d+")
44
+
45
+ # List all directories in base_dir
46
+ for dir_name in os.listdir(base_dir):
47
+ if os.path.isdir(os.path.join(base_dir, dir_name)) and re.match(part_regex, dir_name) and (dir_name.endswith(exp_name)):
48
+ part_dir = os.path.join(base_dir, dir_name)
49
+ file_path = os.path.join(part_dir, filename)
50
+
51
+ if os.path.exists(file_path):
52
+ # Read the DataFrame and add it to the list
53
+ df = pd.read_parquet(file_path)
54
+ dataframes.append(df)
55
+
56
+ # Concatenate all DataFrames into a single DataFrame
57
+ merged_df = pd.concat(dataframes, axis=1)
58
+
59
+ # Save the result to file if output_dir is given
60
+ if output_dir:
61
+ os.makedirs(output_dir, exist_ok=True)
62
+ output_path = os.path.join(output_dir, output_filename)
63
+ merged_df.to_parquet(output_path, index=False)
64
+
65
+ return merged_df, dataframes
66
+
67
+
68
+ def parse_string_list_from_file(file_path, delimiter=None):
69
+ """
70
+ Parses a list of strings from a file, handling various list formats.
71
+
72
+ Args:
73
+ file_path (str): Path to the file containing the list.
74
+
75
+ Returns:
76
+ list: A list of parsed strings.
77
+ """
78
+ with open(file_path, 'r') as file:
79
+ content = file.read()
80
+
81
+ if delimiter is None:
82
+ # Remove newlines and excess whitespace
83
+ content = re.sub(r'\s+', ' ', content.strip())
84
+
85
+ # Handle different delimiters and list formats
86
+ # Removes common list notations like commas, brackets, quotes, etc.
87
+ items = re.split(r'[,\[\]\(\)\{\}"\'\s]+', content)
88
+ else:
89
+ if delimiter == "newline": # TODO fix this
90
+ delimiter = "\n"
91
+ items = [item.strip() for item in content.split(delimiter)]
92
+
93
+ # Filter out any empty strings from the list
94
+ return [item for item in items if item]
utils/logit_lens.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Provides a class for mapping transformer hidden states to logits (and vice versa).
2
+ Example:
3
+
4
+ from standalone_logit_lens import LogitLens, ReverseLogitLens
5
+
6
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device).to(dtype)
7
+ lens = LogitLens.from_model(model).to(device).to(dtype)
8
+ reverse_lens = ReverseLogitLens.from_model(model).to(device).to(dtype)
9
+
10
+ hidden_state = ...
11
+ result = lens(hidden_state, layer_index) # layer_index is not really used, you can pass whatever
12
+ """
13
+
14
+ import abc
15
+ import logging
16
+
17
+ import copy
18
+ from typing import Union
19
+
20
+ import torch
21
+ from torch import nn
22
+ import torch.nn.functional as F
23
+
24
+ import transformers
25
+ from transformers import models
26
+ from transformers import PreTrainedModel
27
+
28
+
29
+ Model = Union[PreTrainedModel]
30
+ Norm = Union[
31
+ nn.LayerNorm,
32
+ models.llama.modeling_llama.LlamaRMSNorm,
33
+ models.gemma.modeling_gemma.GemmaRMSNorm,
34
+ models.gemma2.modeling_gemma2.Gemma2RMSNorm,
35
+ nn.Module,
36
+ ]
37
+
38
+
39
+ def get_unembedding_matrix(model: Model) -> nn.Linear:
40
+ """The final linear tranformation from the model hidden state to the output."""
41
+ if isinstance(model, PreTrainedModel):
42
+ unembed = model.get_output_embeddings()
43
+ if not isinstance(unembed, nn.Linear):
44
+ raise ValueError("We currently only support linear unemebdings")
45
+ return unembed
46
+ else:
47
+ raise ValueError(f"Model class {type(model)} not recognized!")
48
+
49
+
50
+ def get_embedding_matrix(model: nn.Module) -> nn.Embedding:
51
+ """The initial embedding matrix from the input tokens to the model hidden state."""
52
+ if isinstance(model, PreTrainedModel):
53
+ embed = model.get_input_embeddings()
54
+ if not isinstance(embed, nn.Embedding):
55
+ raise ValueError("We currently only support embedding matrices")
56
+ return embed
57
+ else:
58
+ raise ValueError(f"Model class {type(model)} not recognized!")
59
+
60
+
61
+ def get_final_norm(model: Model) -> Norm:
62
+ """Get the final norm from a model.
63
+
64
+ This isn't standardized across models, so this will need to be updated as
65
+ we add new models.
66
+ """
67
+
68
+ if not hasattr(model, "base_model"):
69
+ raise ValueError("Model does not have a `base_model` attribute.")
70
+
71
+ base_model = model.base_model
72
+ if isinstance(base_model, models.opt.modeling_opt.OPTModel):
73
+ final_layer_norm = base_model.decoder.final_layer_norm
74
+ elif isinstance(base_model, models.gpt_neox.modeling_gpt_neox.GPTNeoXModel):
75
+ final_layer_norm = base_model.final_layer_norm
76
+ elif isinstance(
77
+ base_model,
78
+ (
79
+ models.bloom.modeling_bloom.BloomModel,
80
+ models.gpt2.modeling_gpt2.GPT2Model,
81
+ models.gpt_neo.modeling_gpt_neo.GPTNeoModel,
82
+ models.gptj.modeling_gptj.GPTJModel,
83
+ ),
84
+ ):
85
+ final_layer_norm = base_model.ln_f
86
+ elif isinstance(base_model, models.llama.modeling_llama.LlamaModel):
87
+ final_layer_norm = base_model.norm
88
+ elif isinstance(base_model, models.mistral.modeling_mistral.MistralModel):
89
+ final_layer_norm = base_model.norm
90
+ elif isinstance(base_model, models.t5.modeling_t5.T5ForConditionalGeneration):
91
+ # For T5, use the LayerNorm from the last decoder block, before the feed-forward layer.
92
+ final_layer_norm = base_model.decoder.block[-1].layer[1].layer_norm
93
+ else:
94
+ raise NotImplementedError(f"Unknown model type {type(base_model)}")
95
+
96
+ if final_layer_norm is None:
97
+ raise ValueError("Model does not have a final layer norm.")
98
+
99
+ assert isinstance(final_layer_norm, Norm.__args__) # type: ignore
100
+
101
+ return final_layer_norm
102
+
103
+
104
+ class Unembed(nn.Module):
105
+ """Module that maps transformer hidden states to logits (and vice versa)."""
106
+
107
+ final_norm: Norm
108
+ unembedding: nn.Linear
109
+
110
+ def __init__(
111
+ self,
112
+ model: Model,
113
+ ):
114
+ """Initialize unmebed.
115
+
116
+ Args:
117
+ model: A HuggingFace model from which to extract the unembedding matrix.
118
+ """
119
+ super().__init__()
120
+ final_norm = get_final_norm(model)
121
+ unembedding_matrix = get_unembedding_matrix(model)
122
+
123
+ self.final_norm = copy.deepcopy(final_norm)
124
+ self.unembedding = copy.deepcopy(unembedding_matrix)
125
+
126
+ # In general we don't want to finetune the unembed operation.
127
+ self.requires_grad_(False)
128
+
129
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
130
+ """Convert hidden states into logits."""
131
+ return self.unembedding(self.final_norm(h))
132
+
133
+
134
+ class Reembed(nn.Module):
135
+ """Module that maps transformer hidden states to logits (and vice versa)."""
136
+ embedding: torch.Tensor
137
+
138
+ def __init__(
139
+ self,
140
+ model: Model,
141
+ distance_metric: str = "logits",
142
+ ):
143
+ """Initialize unmebed.
144
+
145
+ Args:
146
+ model: A HuggingFace model from which to extract the unembedding matrix.
147
+ """
148
+ super().__init__()
149
+ embedding_matrix = get_embedding_matrix(model)
150
+
151
+ self.embedding = copy.deepcopy(embedding_matrix.weight.data)
152
+
153
+ self.distance_metric = distance_metric
154
+
155
+ # In general we don't want to finetune the unembed operation.
156
+ self.requires_grad_(False)
157
+
158
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
159
+ """Convert hidden states into logits."""
160
+
161
+ if self.distance_metric == 'logits':
162
+ logits = torch.matmul(hidden_state, self.embedding.T).squeeze(0)
163
+
164
+ elif self.distance_metric == 'cosine':
165
+ # Normalize E and h
166
+ E_normalized = F.normalize(self.embedding, p=2, dim=-1)
167
+ h_normalized = F.normalize(hidden_state, p=2, dim=-1)
168
+
169
+ # Compute cosine similarity
170
+ logits = torch.matmul(h_normalized, E_normalized.T).squeeze(0)
171
+
172
+ elif self.distance_metric == 'euclidean':
173
+ # Compute Euclidean distance
174
+ distances = torch.cdist(hidden_state, self.embedding, p=2).squeeze(0)
175
+
176
+ # Convert distances to logits (negative distance for logits-like values)
177
+ logits = -distances
178
+
179
+ else: # Compute regular dot-product as a similarity measure
180
+ logits = torch.matmul(hidden_state, self.embedding.T).squeeze(0)
181
+ return logits
182
+
183
+
184
+ class ReverseLens(abc.ABC, nn.Module):
185
+ """Abstract base class for all Lens."""
186
+
187
+ reembed: Reembed
188
+
189
+ def __init__(self, reembed: Reembed):
190
+ """Create a Lens.
191
+
192
+ Args:
193
+ unembed: The unembed operation to use.
194
+ """
195
+ super().__init__()
196
+
197
+ self.reembed = reembed
198
+
199
+ @abc.abstractmethod
200
+ def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor:
201
+ """Decode hidden states into logits."""
202
+ ...
203
+
204
+
205
+ class ReverseLogitLens(ReverseLens):
206
+ """Reembeds the residual stream into logits."""
207
+
208
+ reembed: Reembed
209
+
210
+ def __init__(
211
+ self,
212
+ reembed: Reembed,
213
+ ):
214
+ """Create a Reverse Logit Lens.
215
+
216
+ Args:
217
+ reembed: The reembed operation to use.
218
+ """
219
+ super().__init__(reembed)
220
+
221
+ @classmethod
222
+ def from_model(
223
+ cls,
224
+ model: PreTrainedModel,
225
+ ) -> "ReverseLogitLens":
226
+ """Create a ReverseLogitLens from a pretrained model.
227
+
228
+ Args:
229
+ model: A pretrained model from the transformers library you wish to inspect.
230
+ """
231
+ reembed = Reembed(model)
232
+ return cls(reembed)
233
+
234
+ def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor:
235
+ """Decode a hidden state into logits.
236
+
237
+ Args:
238
+ h: The hidden state to decode.
239
+ idx: the layer of the transformer these hidden states come from.
240
+ """
241
+ del idx
242
+ return self.reembed.forward(h)
243
+
244
+
245
+ class Lens(abc.ABC, nn.Module):
246
+ """Abstract base class for all Lens."""
247
+
248
+ unembed: Unembed
249
+
250
+ def __init__(self, unembed: Unembed):
251
+ """Create a Lens.
252
+
253
+ Args:
254
+ unembed: The unembed operation to use.
255
+ """
256
+ super().__init__()
257
+
258
+ self.unembed = unembed
259
+
260
+ @abc.abstractmethod
261
+ def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor:
262
+ """Decode hidden states into logits."""
263
+ ...
264
+
265
+
266
+ class LogitLens(Lens):
267
+ """Unembeds the residual stream into logits."""
268
+
269
+ unembed: Unembed
270
+
271
+ def __init__(
272
+ self,
273
+ unembed: Unembed,
274
+ ):
275
+ """Create a Logit Lens.
276
+
277
+ Args:
278
+ unembed: The unembed operation to use.
279
+ """
280
+ super().__init__(unembed)
281
+
282
+ @classmethod
283
+ def from_model(
284
+ cls,
285
+ model: PreTrainedModel,
286
+ ) -> "LogitLens":
287
+ """Create a LogitLens from a pretrained model.
288
+
289
+ Args:
290
+ model: A pretrained model from the transformers library you wish to inspect.
291
+ """
292
+ unembed = Unembed(model)
293
+ return cls(unembed)
294
+
295
+ def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor:
296
+ """Decode a hidden state into logits.
297
+
298
+ Args:
299
+ h: The hidden state to decode.
300
+ idx: the layer of the transformer these hidden states come from.
301
+ """
302
+ del idx
303
+ return self.unembed.forward(h)
304
+
utils/model_utils.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from typing import Iterable, List, Union
3
+ from transformers import PreTrainedModel, PreTrainedTokenizer
4
+ import torch
5
+ from torch import nn
6
+ from sklearn.linear_model import LinearRegression
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader, TensorDataset
9
+
10
+
11
+ def extract_token_i_hidden_states(
12
+ model: PreTrainedModel,
13
+ tokenizer: PreTrainedTokenizer,
14
+ inputs: Union[str, List[str]],
15
+ token_idx_to_extract: int = -1,
16
+ batch_size: int = 1,
17
+ layers_to_extract: List[int] = None,
18
+ return_dict: bool = True,
19
+ verbose: bool = True,
20
+ ) -> torch.Tensor:
21
+ device = model.device
22
+ model.eval()
23
+
24
+ if isinstance(inputs, str):
25
+ inputs = [inputs]
26
+
27
+ if layers_to_extract is None:
28
+ layers_to_extract = list(range(1, model.config.num_hidden_layers + 1)) # extract all but initial embeddings
29
+ all_hidden_states = {layer: [] for layer in layers_to_extract}
30
+
31
+ with torch.no_grad():
32
+ for i in tqdm(range(0, len(inputs), batch_size), desc="Extracting hidden states", unit="batch", disable=not verbose):
33
+ input_ids = tokenizer(inputs[i:i+batch_size], return_tensors="pt", return_attention_mask=False)['input_ids']
34
+ try:
35
+ outputs = model(input_ids.to(device), output_hidden_states=True)
36
+ except:
37
+ import pdb; pdb.set_trace()
38
+ # from transformers import AutoModelForCausalLM
39
+ # model2 = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", torch_dtype=torch.bfloat16).to(device)
40
+ for input_i in range(len(input_ids)):
41
+ for layer in layers_to_extract:
42
+ hidden_states = outputs.hidden_states[layer]
43
+ all_hidden_states[layer].append(hidden_states[:, token_idx_to_extract, :].detach().cpu())
44
+ for layer in all_hidden_states:
45
+ all_hidden_states[layer] = torch.concat(all_hidden_states[layer], dim=0)
46
+
47
+ if not return_dict:
48
+ all_hidden_states = torch.concat([all_hidden_states[layer] for layer in layers_to_extract], dim=0)
49
+
50
+ return all_hidden_states
51
+
52
+
53
+ def extract_vocab_hidden_states(
54
+ model: PreTrainedModel,
55
+ tokenizer: PreTrainedTokenizer,
56
+ tokens_ids_to_extract: Iterable[int] = None,
57
+ prompt: str = "{target}",
58
+ prompt_target: str = "{target}",
59
+ batch_size: int = 128,
60
+ layers_to_extract: List[int] = None
61
+ ) -> torch.Tensor:
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+ model.to(device)
64
+ model.eval()
65
+
66
+ if layers_to_extract is None:
67
+ layers_to_extract = list(range(1, model.config.num_hidden_layers + 1)) # extract all but initial embeddings
68
+ all_hidden_states = {layer: [] for layer in layers_to_extract}
69
+ tokens_ids_to_extract = tokens_ids_to_extract if tokens_ids_to_extract is not None else range(tokenizer.vocab_size)
70
+ tokens_to_extract = [tokenizer.decode(tok_id) for tok_id in tokens_ids_to_extract]
71
+
72
+ # add pad token if necessary
73
+ if tokenizer.pad_token is None:
74
+ tokenizer.pad_token = tokenizer.eos_token
75
+
76
+ with torch.no_grad():
77
+ for i in tqdm(range(0, len(tokens_to_extract), batch_size), desc="Extracting hidden states", unit="batch"):
78
+ prompts = [prompt.replace(prompt_target, target) for target in tokens_to_extract[i:i+batch_size]]
79
+ input_ids = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")["input_ids"]
80
+ # input_ids = tokenizer(prompts, return_tensors="pt")["input_ids"]
81
+ outputs = model(input_ids.to(device), output_hidden_states=True)
82
+ for layer in layers_to_extract:
83
+ hidden_states = outputs.hidden_states[layer]
84
+ all_hidden_states[layer].append(hidden_states[:, -1, :].detach().cpu())
85
+
86
+ for layer in all_hidden_states:
87
+ all_hidden_states[layer] = torch.concat(all_hidden_states[layer], dim=0)
88
+
89
+ return all_hidden_states
90
+
91
+
92
+ def get_vocab_tokens(tokenizer: PreTrainedTokenizer, min_word_len: int = None):
93
+ vocab_size = tokenizer.vocab_size
94
+ tokens = list(range(vocab_size))
95
+ if min_word_len:
96
+ tokens_str = [tokenizer.decode(i) for i in tokens]
97
+ tokens_len = [len(x) for x in tokens_str]
98
+ tokens = [tok for tok, tok_len in zip(tokens, tokens_len) if tok_len >= min_word_len]
99
+ return tokens
100
+
101
+
102
+ def learn_linear_map(X: torch.Tensor, Y: torch.Tensor, fit_intercept=False):
103
+ input_dtype = X.dtype
104
+ linear_reg = LinearRegression(fit_intercept=fit_intercept).fit(X.cpu().to(float).numpy(), Y.cpu().to(float).numpy())
105
+ linear_map = nn.Linear(X.size(1), Y.size(1), bias=fit_intercept)
106
+ with torch.no_grad():
107
+ linear_map.weight.data = torch.Tensor(linear_reg.coef_.T)
108
+ if fit_intercept:
109
+ linear_map.bias.data = torch.Tensor(linear_reg.intercept_)
110
+ linear_map = linear_map.to(input_dtype)
111
+ return linear_map
112
+
113
+
114
+ def train_model(
115
+ model,
116
+ dataloader,
117
+ optimizer,
118
+ loss_func="mse",
119
+ scheduler=None,
120
+ num_epochs=5,
121
+ gradient_accumulation_steps=1,
122
+ max_grads_norm=1.0,
123
+ ):
124
+ """
125
+ Trains a two-layer MLP to map hidden states from X to Y.
126
+
127
+ Parameters:
128
+ X (torch.Tensor): Input tensor of shape (N, D).
129
+ Y (torch.Tensor): Target tensor of shape (N, D).
130
+ activation_func (nn.Module): Activation function for the hidden layer. Default is SiLU.
131
+ lr (float): Learning rate. Default is 0.001.
132
+ weight_decay (float): Weight decay for the optimizer. Default is 0.0.
133
+ loss_func (str): Loss function to use ('mse', 'huber', 'cosine'). Default is 'mse'.
134
+ lr_schedule (str): Learning rate schedule. Default is 'linear'.
135
+ num_epochs (int): Number of training epochs. Default is 20.
136
+ batch_size (int): Batch size for DataLoader. Default is 32.
137
+ gradient_accumulation_steps (int): Number of steps to accumulate gradients. Default is 1.
138
+
139
+ Returns:
140
+ nn.Module: Trained MLP model.
141
+ """
142
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
+
144
+ # Select loss function
145
+ if loss_func == "mse":
146
+ criterion = nn.MSELoss()
147
+ elif loss_func == "huber":
148
+ criterion = nn.HuberLoss()
149
+ elif loss_func == "cosine":
150
+ criterion = nn.CosineEmbeddingLoss()
151
+ else:
152
+ raise ValueError("Unsupported loss function. Choose from 'mse', 'huber', or 'cosine'.")
153
+
154
+ # Training loop
155
+ model.train()
156
+ for epoch in range(num_epochs):
157
+ epoch_loss = 0.0
158
+ for i, (x_batch, y_batch) in enumerate(dataloader):
159
+ outputs = model(x_batch.to(device))
160
+ if loss_func == "cosine":
161
+ # Cosine loss requires an additional target tensor of 1s
162
+ loss = criterion(outputs, y_batch.to(device), torch.ones(x_batch.size(0)))
163
+ else:
164
+ loss = criterion(outputs, y_batch.to(device))
165
+
166
+ loss = loss / gradient_accumulation_steps
167
+ loss.backward()
168
+
169
+ if max_grads_norm is not None:
170
+ nn.utils.clip_grad_norm_(model.parameters(), max_grads_norm)
171
+
172
+ if (i + 1) % gradient_accumulation_steps == 0 or (i + 1) == len(dataloader):
173
+ optimizer.step()
174
+ optimizer.zero_grad()
175
+ if scheduler:
176
+ scheduler.step()
177
+
178
+ epoch_loss += loss.item() * gradient_accumulation_steps
179
+
180
+ print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(dataloader):.6f}")
181
+
182
+ return model.cpu()
183
+
184
+
185
+ def learn_mlp(
186
+ X: torch.Tensor, Y: torch.Tensor,
187
+ activation_func=nn.SiLU,
188
+ batch_size=128,
189
+ lr=0.001,
190
+ weight_decay=0.0,
191
+ loss_func="mse",
192
+ lr_schedule="linear",
193
+ expansion_alpha=1.0,
194
+ num_epochs=5,
195
+ gradient_accumulation_steps=1,
196
+ max_grads_norm=1.0,
197
+ ):
198
+ """
199
+ Trains a two-layer MLP to map hidden states from X to Y.
200
+
201
+ Parameters:
202
+ X (torch.Tensor): Input tensor of shape (N, D).
203
+ Y (torch.Tensor): Target tensor of shape (N, D).
204
+ activation_func (nn.Module): Activation function for the hidden layer. Default is SiLU.
205
+ lr (float): Learning rate. Default is 0.001.
206
+ weight_decay (float): Weight decay for the optimizer. Default is 0.0.
207
+ loss_func (str): Loss function to use ('mse', 'huber', 'cosine'). Default is 'mse'.
208
+ lr_schedule (str): Learning rate schedule. Default is 'linear'.
209
+ num_epochs (int): Number of training epochs. Default is 20.
210
+ batch_size (int): Batch size for DataLoader. Default is 32.
211
+ gradient_accumulation_steps (int): Number of steps to accumulate gradients. Default is 1.
212
+
213
+ Returns:
214
+ nn.Module: Trained MLP model.
215
+ """
216
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
217
+ input_dim = X.shape[1]
218
+ hidden_dim = int(input_dim * expansion_alpha)
219
+ output_dim = Y.shape[1]
220
+ model = nn.Sequential(
221
+ nn.Linear(input_dim, hidden_dim),
222
+ activation_func(),
223
+ nn.Linear(hidden_dim, output_dim)
224
+ ).to(device)
225
+
226
+ # Optimizer
227
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
228
+
229
+ # DataLoader setup
230
+ dataset = TensorDataset(X, Y)
231
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
232
+
233
+ # Learning rate scheduler
234
+ if lr_schedule == "linear":
235
+ total_steps = (len(dataloader) * num_epochs) // gradient_accumulation_steps
236
+ scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: 1 - step / total_steps)
237
+ else:
238
+ scheduler = None
239
+
240
+ return train_model(
241
+ model,
242
+ dataloader,
243
+ optimizer,
244
+ loss_func=loss_func,
245
+ scheduler=scheduler,
246
+ num_epochs=num_epochs,
247
+ gradient_accumulation_steps=gradient_accumulation_steps,
248
+ max_grads_norm=max_grads_norm,
249
+ )
250
+
251
+
252
+ class FFN(nn.Module):
253
+ def __init__(self, input_dim):
254
+ super(FFN, self).__init__()
255
+ self.gate_proj = nn.Linear(input_dim, input_dim)
256
+ self.activation = nn.SiLU()
257
+ self.map_proj = nn.Linear(input_dim, input_dim)
258
+
259
+ def forward(self, x):
260
+ return (self.activation(self.gate_proj(x)) * x) + self.map_proj(x)
261
+
262
+
263
+ def learn_ffn(
264
+ X: torch.Tensor, Y: torch.Tensor,
265
+ activation_func=nn.SiLU,
266
+ batch_size=128,
267
+ lr=0.001,
268
+ weight_decay=0.0,
269
+ loss_func="mse",
270
+ lr_schedule="linear",
271
+ num_epochs=5,
272
+ gradient_accumulation_steps=1,
273
+ max_grads_norm=1.0,
274
+ ):
275
+ """
276
+ Trains a two-layer MLP to map hidden states from X to Y.
277
+
278
+ Parameters:
279
+ X (torch.Tensor): Input tensor of shape (N, D).
280
+ Y (torch.Tensor): Target tensor of shape (N, D).
281
+ activation_func (nn.Module): Activation function for the hidden layer. Default is SiLU.
282
+ lr (float): Learning rate. Default is 0.001.
283
+ weight_decay (float): Weight decay for the optimizer. Default is 0.0.
284
+ loss_func (str): Loss function to use ('mse', 'huber', 'cosine'). Default is 'mse'.
285
+ lr_schedule (str): Learning rate schedule. Default is 'linear'.
286
+ num_epochs (int): Number of training epochs. Default is 20.
287
+ batch_size (int): Batch size for DataLoader. Default is 32.
288
+ gradient_accumulation_steps (int): Number of steps to accumulate gradients. Default is 1.
289
+
290
+ Returns:
291
+ nn.Module: Trained MLP model.
292
+ """
293
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
294
+ input_dim = X.shape[1]
295
+ model = FFN(input_dim).to(device)
296
+
297
+ # Optimizer
298
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
299
+
300
+ # DataLoader setup
301
+ dataset = TensorDataset(X, Y)
302
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
303
+
304
+ # Learning rate scheduler
305
+ if lr_schedule == "linear":
306
+ total_steps = (len(dataloader) * num_epochs) // gradient_accumulation_steps
307
+ scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: 1 - step / total_steps)
308
+ else:
309
+ scheduler = None
310
+
311
+ return train_model(
312
+ model,
313
+ dataloader,
314
+ optimizer,
315
+ loss_func=loss_func,
316
+ scheduler=scheduler,
317
+ num_epochs=num_epochs,
318
+ gradient_accumulation_steps=gradient_accumulation_steps,
319
+ max_grads_norm=max_grads_norm,
320
+ )
utils/procrustes/__init__.py ADDED
File without changes
utils/procrustes/orthogonal.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # The Procrustes library provides a set of functions for transforming
3
+ # a matrix to make it as similar as possible to a target matrix.
4
+ #
5
+ # Copyright (C) 2017-2022 The QC-Devs Community
6
+ #
7
+ # This file is part of Procrustes.
8
+ #
9
+ # Procrustes is free software; you can redistribute it and/or
10
+ # modify it under the terms of the GNU General Public License
11
+ # as published by the Free Software Foundation; either version 3
12
+ # of the License, or (at your option) any later version.
13
+ #
14
+ # Procrustes is distributed in the hope that it will be useful,
15
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
16
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17
+ # GNU General Public License for more details.
18
+ #
19
+ # You should have received a copy of the GNU General Public License
20
+ # along with this program; if not, see <http://www.gnu.org/licenses/>
21
+ #
22
+ # --
23
+ """Orthogonal Procrustes Module."""
24
+
25
+ # import warnings
26
+
27
+ from typing import Optional
28
+
29
+ import numpy as np
30
+ from .utils import compute_error, ProcrustesResult, setup_input_arrays
31
+ import scipy
32
+
33
+
34
+ __all__ = [
35
+ "orthogonal",
36
+ "orthogonal_2sided",
37
+ ]
38
+
39
+
40
+ def orthogonal(
41
+ a: np.ndarray,
42
+ b: np.ndarray,
43
+ pad: bool = True,
44
+ translate: bool = False,
45
+ scale: bool = False,
46
+ unpad_col: bool = False,
47
+ unpad_row: bool = False,
48
+ check_finite: bool = True,
49
+ weight: Optional[np.ndarray] = None,
50
+ lapack_driver: str = "gesvd",
51
+ ) -> ProcrustesResult:
52
+ r"""Perform orthogonal Procrustes.
53
+
54
+ Given a matrix :math:`\mathbf{A}_{m \times n}` and a reference matrix :math:`\mathbf{B}_{m
55
+ \times n}`, find the orthogonal transformation matrix :math:`\mathbf{Q}_{n
56
+ \times n}` that makes :math:`\mathbf{AQ}` as close as possible to :math:`\mathbf{B}`.
57
+ In other words,
58
+
59
+ .. math::
60
+ \underbrace{\min}_{\left\{\mathbf{Q} | \mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger \right\}}
61
+ \|\mathbf{A}\mathbf{Q} - \mathbf{B}\|_{F}^2
62
+
63
+ This Procrustes method requires the :math:`\mathbf{A}` and :math:`\mathbf{B}` matrices to
64
+ have the same shape, which is gauranteed with the default ``pad`` argument for any given
65
+ :math:`\mathbf{A}` and :math:`\mathbf{B}` matrices. In preparing the :math:`\mathbf{A}` and
66
+ :math:`\mathbf{B}` matrices, the (optional) order of operations is: **1)** unpad zero
67
+ rows/columns, **2)** translate the matrices to the origin, **3)** weight entries of
68
+ :math:`\mathbf{A}`, **4)** scale the matrices to have unit norm, **5)** pad matrices with zero
69
+ rows/columns so they have the same shape.
70
+
71
+ Parameters
72
+ ----------
73
+ a : ndarray
74
+ The 2D-array :math:`\mathbf{A}` which is going to be transformed.
75
+ b : ndarray
76
+ The 2D-array :math:`\mathbf{B}` representing the reference matrix.
77
+ pad : bool, optional
78
+ Add zero rows (at the bottom) and/or columns (to the right-hand side) of matrices
79
+ :math:`\mathbf{A}` and :math:`\mathbf{B}` so that they have the same shape.
80
+ translate : bool, optional
81
+ If True, both arrays are centered at origin (columns of the arrays will have mean zero).
82
+ scale : bool, optional
83
+ If True, both arrays are normalized with respect to the Frobenius norm, i.e.,
84
+ :math:`\text{Tr}\left[\mathbf{A}^\dagger\mathbf{A}\right] = 1` and
85
+ :math:`\text{Tr}\left[\mathbf{B}^\dagger\mathbf{B}\right] = 1`.
86
+ unpad_col : bool, optional
87
+ If True, zero columns (with values less than 1.0e-8) on the right-hand side of the intial
88
+ :math:`\mathbf{A}` and :math:`\mathbf{B}` matrices are removed.
89
+ unpad_row : bool, optional
90
+ If True, zero rows (with values less than 1.0e-8) at the bottom of the intial
91
+ :math:`\mathbf{A}` and :math:`\mathbf{B}` matrices are removed.
92
+ check_finite : bool, optional
93
+ If True, convert the input to an array, checking for NaNs or Infs.
94
+ weight : ndarray, optional
95
+ The 1D-array representing the weights of each row of :math:`\mathbf{A}`. This defines the
96
+ elements of the diagonal matrix :math:`\mathbf{W}` that is multiplied by :math:`\mathbf{A}`
97
+ matrix, i.e., :math:`\mathbf{A} \rightarrow \mathbf{WA}`.
98
+ lapack_driver : {'gesvd', 'gesdd'}, optional
99
+ Whether to use the more efficient divide-and-conquer approach ('gesdd') or the more robust
100
+ general rectangular approach ('gesvd') to compute the singular-value decomposition with
101
+ `scipy.linalg.svd`.
102
+
103
+ Returns
104
+ -------
105
+ res : ProcrustesResult
106
+ The Procrustes result represented as a class:`utils.ProcrustesResult` object.
107
+
108
+ Notes
109
+ -----
110
+ The optimal orthogonal matrix is obtained by,
111
+
112
+ .. math::
113
+ \mathbf{Q}^{\text{opt}} =
114
+ \arg \underbrace{\min}_{\left\{\mathbf{Q} \left| {\mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger}
115
+ \right. \right\}} \|\mathbf{A}\mathbf{Q} - \mathbf{B}\|_{F}^2 =
116
+ \arg \underbrace{\max}_{\left\{\mathbf{Q} \left| {\mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger}
117
+ \right. \right\}} \text{Tr}\left[\mathbf{Q^\dagger}\mathbf{A^\dagger}\mathbf{B}\right]
118
+
119
+ The solution is obtained using the singular value decomposition (SVD) of the
120
+ :math:`\mathbf{A}^\dagger \mathbf{B}` matrix,
121
+
122
+ .. math::
123
+ \mathbf{A}^\dagger \mathbf{B} &= \tilde{\mathbf{U}} \tilde{\mathbf{\Sigma}}
124
+ \tilde{\mathbf{V}}^{\dagger} \\
125
+ \mathbf{Q}^{\text{opt}} &= \tilde{\mathbf{U}} \tilde{\mathbf{V}}^{\dagger}
126
+
127
+ The singular values are always listed in decreasing order, with the smallest singular
128
+ value in the bottom-right-hand corner of :math:`\tilde{\mathbf{\Sigma}}`.
129
+
130
+ Examples
131
+ --------
132
+ >>> import numpy as np
133
+ >>> from scipy.stats import ortho_group
134
+ >>> from procrustes import orthogonal
135
+ >>> a = np.random.rand(5, 3) # random input matrix
136
+ >>> q = ortho_group.rvs(3) # random orthogonal transformation
137
+ >>> b = np.dot(a, q) + np.random.rand(1, 3) # random target matrix
138
+ >>> result = orthogonal(a, b, translate=True, scale=False)
139
+ >>> print(result.error) # error (should be zero)
140
+ >>> print(result.t) # transformation matrix (same as q)
141
+ >>> print(result.new_a) # translated array a
142
+ >>> print(result.new_b) # translated array b
143
+
144
+ """
145
+ # check inputs
146
+ new_a, new_b = setup_input_arrays(
147
+ a,
148
+ b,
149
+ unpad_col,
150
+ unpad_row,
151
+ pad,
152
+ translate,
153
+ scale,
154
+ check_finite,
155
+ weight,
156
+ )
157
+ if new_a.shape != new_b.shape:
158
+ raise ValueError(
159
+ f"Shape of A and B does not match: {new_a.shape} != {new_b.shape} "
160
+ "Check pad, unpad_col, and unpad_row arguments."
161
+ )
162
+ # calculate SVD of A.T * B
163
+ u, _, vt = scipy.linalg.svd(np.dot(new_a.T, new_b), lapack_driver=lapack_driver)
164
+ # compute optimal orthogonal transformation
165
+ u_opt = np.dot(u, vt)
166
+ # compute one-sided error
167
+ error = compute_error(new_a, new_b, u_opt)
168
+
169
+ return ProcrustesResult(error=error, new_a=new_a, new_b=new_b, t=u_opt, s=None)
170
+
171
+
172
+ def orthogonal_2sided(
173
+ a: np.ndarray,
174
+ b: np.ndarray,
175
+ single: bool = True,
176
+ pad: bool = True,
177
+ translate: bool = False,
178
+ scale: bool = False,
179
+ unpad_col: bool = False,
180
+ unpad_row: bool = False,
181
+ check_finite: bool = True,
182
+ weight: Optional[np.ndarray] = None,
183
+ lapack_driver: str = "gesvd",
184
+ ) -> ProcrustesResult:
185
+ r"""Perform two-sided orthogonal Procrustes with one- or two-transformations.
186
+
187
+ **Two Transformations:** Given a matrix :math:`\mathbf{A}_{m \times n}` and a reference matrix
188
+ :math:`\mathbf{B}_{m \times n}`, find two :math:`n \times n` orthogonal
189
+ transformation matrices :math:`\mathbf{Q}_1^\dagger` and :math:`\mathbf{Q}_2` that makes
190
+ :math:`\mathbf{Q}_1^\dagger\mathbf{A}\mathbf{Q}_2` as close as possible to :math:`\mathbf{B}`.
191
+ In other words,
192
+
193
+ .. math::
194
+ \underbrace{\text{min}}_{\left\{ {\mathbf{Q}_1 \atop \mathbf{Q}_2} \left|
195
+ {\mathbf{Q}_1^{-1} = \mathbf{Q}_1^\dagger \atop \mathbf{Q}_2^{-1} =
196
+ \mathbf{Q}_2^\dagger} \right. \right\}}
197
+ \|\mathbf{Q}_1^\dagger \mathbf{A} \mathbf{Q}_2 - \mathbf{B}\|_{F}^2
198
+
199
+ **Single Transformations:** Given a **symmetric** matrix :math:`\mathbf{A}_{n \times n}` and
200
+ a reference :math:`\mathbf{B}_{n \times n}`, find one orthogonal transformation
201
+ matrix :math:`\mathbf{Q}_{n \times n}` that makes :math:`\mathbf{A}` as close as possible to
202
+ :math:`\mathbf{B}`. In other words,
203
+
204
+ .. math::
205
+ \underbrace{\min}_{\left\{\mathbf{Q} | \mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger \right\}}
206
+ \|\mathbf{Q}^\dagger\mathbf{A}\mathbf{Q} - \mathbf{B}\|_{F}^2
207
+
208
+ This Procrustes method requires the :math:`\mathbf{A}` and :math:`\mathbf{B}` matrices to
209
+ have the same shape, which is gauranteed with the default ``pad`` argument for any given
210
+ :math:`\mathbf{A}` and :math:`\mathbf{B}` matrices. In preparing the :math:`\mathbf{A}` and
211
+ :math:`\mathbf{B}` matrices, the (optional) order of operations is: **1)** unpad zero
212
+ rows/columns, **2)** translate the matrices to the origin, **3)** weight entries of
213
+ :math:`\mathbf{A}`, **4)** scale the matrices to have unit norm, **5)** pad matrices with zero
214
+ rows/columns so they have the same shape.
215
+
216
+ Parameters
217
+ ----------
218
+ a : ndarray
219
+ The 2D-array :math:`\mathbf{A}` which is going to be transformed.
220
+ b : ndarray
221
+ The 2D-array :math:`\mathbf{B}` representing the reference matrix.
222
+ single : bool, optional
223
+ If True, single transformation is used (i.e., :math:`\mathbf{Q}_1=\mathbf{Q}_2=\mathbf{Q}`),
224
+ otherwise, two transformations are used.
225
+ pad : bool, optional
226
+ Add zero rows (at the bottom) and/or columns (to the right-hand side) of matrices
227
+ :math:`\mathbf{A}` and :math:`\mathbf{B}` so that they have the same shape.
228
+ translate : bool, optional
229
+ If True, both arrays are centered at origin (columns of the arrays will have mean zero).
230
+ scale : bool, optional
231
+ If True, both arrays are normalized with respect to the Frobenius norm, i.e.,
232
+ :math:`\text{Tr}\left[\mathbf{A}^\dagger\mathbf{A}\right] = 1` and
233
+ :math:`\text{Tr}\left[\mathbf{B}^\dagger\mathbf{B}\right] = 1`.
234
+ unpad_col : bool, optional
235
+ If True, zero columns (with values less than 1.0e-8) on the right-hand side of the intial
236
+ :math:`\mathbf{A}` and :math:`\mathbf{B}` matrices are removed.
237
+ unpad_row : bool, optional
238
+ If True, zero rows (with values less than 1.0e-8) at the bottom of the intial
239
+ :math:`\mathbf{A}` and :math:`\mathbf{B}` matrices are removed.
240
+ check_finite : bool, optional
241
+ If True, convert the input to an array, checking for NaNs or Infs.
242
+ weight : ndarray, optional
243
+ The 1D-array representing the weights of each row of :math:`\mathbf{A}`. This defines the
244
+ elements of the diagonal matrix :math:`\mathbf{W}` that is multiplied by :math:`\mathbf{A}`
245
+ matrix, i.e., :math:`\mathbf{A} \rightarrow \mathbf{WA}`.
246
+ lapack_driver : {"gesvd", "gesdd"}, optional
247
+ Used in the singular value decomposition function from SciPy. Only allowed two options,
248
+ with "gesvd" being less-efficient than "gesdd" but is more robust. Default is "gesvd".
249
+
250
+ Returns
251
+ -------
252
+ res : ProcrustesResult
253
+ The Procrustes result represented as a class:`utils.ProcrustesResult` object.
254
+
255
+ Notes
256
+ -----
257
+ **Two-Sided Orthogonal Procrustes with Two Transformations:**
258
+ The optimal orthogonal transformations are obtained by:
259
+
260
+ .. math::
261
+ \mathbf{Q}_{1}^{\text{opt}}, \mathbf{Q}_{2}^{\text{opt}} = \arg
262
+ \underbrace{\text{min}}_{\left\{ {\mathbf{Q}_1 \atop \mathbf{Q}_2} \left|
263
+ {\mathbf{Q}_1^{-1} = \mathbf{Q}_1^\dagger \atop \mathbf{Q}_2^{-1} =
264
+ \mathbf{Q}_2^\dagger} \right. \right\}}
265
+ \|\mathbf{Q}_1^\dagger \mathbf{A} \mathbf{Q}_2 - \mathbf{B}\|_{F}^2 = \arg
266
+ \underbrace{\text{max}}_{\left\{ {\mathbf{Q}_1 \atop \mathbf{Q}_2} \left|
267
+ {\mathbf{Q}_1^{-1} = \mathbf{Q}_1^\dagger \atop \mathbf{Q}_2^{-1} =
268
+ \mathbf{Q}_2^\dagger} \right. \right\}}
269
+ \text{Tr}\left[\mathbf{Q}_2^\dagger\mathbf{A}^\dagger\mathbf{Q}_1\mathbf{B} \right]
270
+
271
+ This is solved by taking the singular value decomposition (SVD) of :math:`\mathbf{A}` and
272
+ :math:`\mathbf{B}`,
273
+
274
+ .. math::
275
+ \mathbf{A} = \mathbf{U}_A \mathbf{\Sigma}_A \mathbf{V}_A^\dagger \\
276
+ \mathbf{B} = \mathbf{U}_B \mathbf{\Sigma}_B \mathbf{V}_B^\dagger
277
+
278
+ Then the two optimal orthogonal matrices are given by,
279
+
280
+ .. math::
281
+ \mathbf{Q}_1^{\text{opt}} = \mathbf{U}_A \mathbf{U}_B^\dagger \\
282
+ \mathbf{Q}_2^{\text{opt}} = \mathbf{V}_A \mathbf{V}_B^\dagger
283
+
284
+ **Two-Sided Orthogonal Procrustes with Single-Transformation:**
285
+ The optimal orthogonal transformation is obtained by:
286
+
287
+ .. math::
288
+ \mathbf{Q}^{\text{opt}} = \arg
289
+ \underbrace{\min}_{\left\{\mathbf{Q} | \mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger \right\}}
290
+ \|\mathbf{Q}^\dagger\mathbf{A}\mathbf{Q} - \mathbf{B}\|_{F}^2 = \arg
291
+ \underbrace{\text{max}}_{\left\{\mathbf{Q} | \mathbf{Q}^{-1} = {\mathbf{Q}}^\dagger\right\}}
292
+ \text{Tr}\left[\mathbf{Q}^\dagger\mathbf{A}^\dagger\mathbf{Q}\mathbf{B} \right]
293
+
294
+ Using the singular value decomposition (SVD) of :math:`\mathbf{A}` and :math:`\mathbf{B}`,
295
+
296
+ .. math::
297
+ \mathbf{A} = \mathbf{U}_A \mathbf{\Lambda}_A \mathbf{U}_A^\dagger \\
298
+ \mathbf{B} = \mathbf{U}_B \mathbf{\Lambda}_B \mathbf{U}_B^\dagger
299
+
300
+ The optimal orthogonal matrix :math:`\mathbf{Q}^\text{opt}` is obtained through,
301
+
302
+ .. math::
303
+ \mathbf{Q}^\text{opt} = \mathbf{U}_A \mathbf{S} \mathbf{U}_B^\dagger
304
+
305
+ where :math:`\mathbf{S}` is a diagonal matrix with :math:`\pm{1}` elements,
306
+
307
+ .. math::
308
+ \mathbf{S} =
309
+ \begin{bmatrix}
310
+ { \pm 1} & 0 &\cdots &0 \\
311
+ 0 &{ \pm 1} &\ddots &\vdots \\
312
+ \vdots &\ddots &\ddots &0\\
313
+ 0 &\cdots &0 &{ \pm 1}
314
+ \end{bmatrix}
315
+
316
+ The matrix :math:`\mathbf{S}` is chosen to be the identity matrix.
317
+
318
+ Examples
319
+ --------
320
+ >>> import numpy as np
321
+ >>> a = np.array([[30, 33, 20], [33, 53, 43], [20, 43, 46]])
322
+ >>> b = np.array([[ 22.78131838, -0.58896768,-43.00635291, 0., 0.],
323
+ ... [ -0.58896768, 16.77132475, 0.24289990, 0., 0.],
324
+ ... [-43.00635291, 0.2428999 , 89.44735687, 0., 0.],
325
+ ... [ 0. , 0. , 0. , 0., 0.]])
326
+ >>> res = orthogonal_2sided(a, b, single=True, pad=True, unpad_col=True)
327
+ >>> res.t
328
+ array([[ 0.25116633, 0.76371527, 0.59468855],
329
+ [-0.95144277, 0.08183302, 0.29674906],
330
+ [ 0.17796663, -0.64034549, 0.74718507]])
331
+ >>> res.error
332
+ 1.9646186414076689e-26
333
+
334
+ """
335
+ # if translate:
336
+ # warnings.warn(
337
+ # "The translation matrix was not well defined. \
338
+ # Two sided rotation and translation don't commute.",
339
+ # stacklevel=2,
340
+ # )
341
+
342
+ # Check inputs
343
+ new_a, new_b = setup_input_arrays(
344
+ a,
345
+ b,
346
+ unpad_col,
347
+ unpad_row,
348
+ pad,
349
+ translate,
350
+ scale,
351
+ check_finite,
352
+ weight,
353
+ )
354
+
355
+ # check symmetry if single_transform=True
356
+ if single:
357
+ if not np.allclose(new_a.T, new_a):
358
+ raise ValueError(
359
+ f"Array A with {new_a.shape} shape is not symmetric. "
360
+ "Check pad, unpad_col, and unpad_row arguments."
361
+ )
362
+ if not np.allclose(new_b.T, new_b):
363
+ raise ValueError(
364
+ f"Array B with {new_b.shape} shape is not symmetric. "
365
+ "Check pad, unpad_col, and unpad_row arguments."
366
+ )
367
+
368
+ # two-sided orthogonal Procrustes with one-transformations
369
+ if single:
370
+ _, ua = np.linalg.eigh(new_a)
371
+ _, ub = np.linalg.eigh(new_b)
372
+ u_opt = np.dot(ua, ub.T)
373
+ # compute one-sided error
374
+ error = compute_error(new_a, new_b, u_opt, u_opt.T)
375
+ return ProcrustesResult(error=error, new_a=new_a, new_b=new_b, t=u_opt, s=u_opt.T)
376
+
377
+ # two-sided orthogonal Procrustes with two-transformations
378
+ ua, _, vta = scipy.linalg.svd(new_a, lapack_driver=lapack_driver)
379
+ ub, _, vtb = scipy.linalg.svd(new_b, lapack_driver=lapack_driver)
380
+ u_opt1 = np.dot(ua, ub.T)
381
+ u_opt2 = np.dot(vta.T, vtb)
382
+ error = compute_error(new_a, new_b, u_opt2, u_opt1.T)
383
+ return ProcrustesResult(error=error, new_a=new_a, new_b=new_b, t=u_opt2, s=u_opt1.T)
utils/procrustes/utils.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # The Procrustes library provides a set of functions for transforming
3
+ # a matrix to make it as similar as possible to a target matrix.
4
+ #
5
+ # Copyright (C) 2017-2022 The QC-Devs Community
6
+ #
7
+ # This file is part of Procrustes.
8
+ #
9
+ # Procrustes is free software; you can redistribute it and/or
10
+ # modify it under the terms of the GNU General Public License
11
+ # as published by the Free Software Foundation; either version 3
12
+ # of the License, or (at your option) any later version.
13
+ #
14
+ # Procrustes is distributed in the hope that it will be useful,
15
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
16
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17
+ # GNU General Public License for more details.
18
+ #
19
+ # You should have received a copy of the GNU General Public License
20
+ # along with this program; if not, see <http://www.gnu.org/licenses/>
21
+ #
22
+ # --
23
+ """Utility Module."""
24
+ from typing import List, Optional, Tuple
25
+
26
+ import numpy as np
27
+
28
+ __all__ = [
29
+ "compute_error",
30
+ "setup_input_arrays",
31
+ "ProcrustesResult",
32
+ ]
33
+
34
+
35
+ def _zero_padding(
36
+ array_a: np.ndarray, array_b: np.ndarray, pad_mode: str = "row-col"
37
+ ) -> Tuple[np.ndarray, np.ndarray]:
38
+ r"""
39
+ Return arrays padded with rows and/or columns of zero.
40
+
41
+ Parameters
42
+ ----------
43
+ array_a : ndarray
44
+ The 2D-array :math:`\mathbf{A}_{n_a \times m_a}`.
45
+ array_b : ndarray
46
+ The 2D-array :math:`\mathbf{B}_{n_b \times m_b}`.
47
+ pad_mode : str
48
+ Specifying how to pad the arrays. Should be one of
49
+ - "row"
50
+ The array with fewer rows is padded with zero rows so that both have the same
51
+ number of rows.
52
+ - "col"
53
+ The array with fewer columns is padded with zero columns so that both have the
54
+ same number of columns.
55
+ - "row-col"
56
+ The array with fewer rows is padded with zero rows, and the array with fewer
57
+ columns is padded with zero columns, so that both have the same dimensions.
58
+ This does not necessarily result in square arrays.
59
+ - "square"
60
+ The arrays are padded with zero rows and zero columns so that they are both
61
+ squared arrays. The dimension of square array is specified based on the highest
62
+ dimension, i.e. :math:`\text{max}(n_a, m_a, n_b, m_b)`.
63
+
64
+ Returns
65
+ -------
66
+ padded_a : ndarray
67
+ Padded array_a.
68
+ padded_b : ndarray
69
+ Padded array_b.
70
+
71
+ """
72
+ # sanity checks
73
+ if not isinstance(array_a, np.ndarray) or not isinstance(array_b, np.ndarray):
74
+ raise ValueError("Arguments array_a & array_b should be numpy arrays.")
75
+ if array_a.ndim != 2 or array_b.ndim != 2:
76
+ raise ValueError("Arguments array_a & array_b should be 2D arrays.")
77
+
78
+ if array_a.shape == array_b.shape and array_a.shape[0] == array_a.shape[1]:
79
+ # special case of square arrays, mode is set to None so that array_a & array_b are returned.
80
+ pad_mode = None
81
+
82
+ if pad_mode == "square":
83
+ # calculate desired dimension of square array
84
+ (a_n1, a_m1), (a_n2, a_m2) = array_a.shape, array_b.shape
85
+ dim = max(a_n1, a_n2, a_m1, a_m2)
86
+ # padding rows to have both arrays have dim rows
87
+ if a_n1 < dim:
88
+ array_a = np.pad(array_a, [[0, dim - a_n1], [0, 0]], "constant", constant_values=0)
89
+ if a_n2 < dim:
90
+ array_b = np.pad(array_b, [[0, dim - a_n2], [0, 0]], "constant", constant_values=0)
91
+ # padding columns to have both arrays have dim columns
92
+ if a_m1 < dim:
93
+ array_a = np.pad(array_a, [[0, 0], [0, dim - a_m1]], "constant", constant_values=0)
94
+ if a_m2 < dim:
95
+ array_b = np.pad(array_b, [[0, 0], [0, dim - a_m2]], "constant", constant_values=0)
96
+
97
+ if pad_mode in ["row", "row-col"]:
98
+ # padding rows to have both arrays have the same number of rows
99
+ diff = array_a.shape[0] - array_b.shape[0]
100
+ if diff < 0:
101
+ array_a = np.pad(array_a, [[0, -diff], [0, 0]], "constant", constant_values=0)
102
+ else:
103
+ array_b = np.pad(array_b, [[0, diff], [0, 0]], "constant", constant_values=0)
104
+
105
+ if pad_mode in ["col", "row-col"]:
106
+ # padding columns to have both arrays have the same number of columns
107
+ diff = array_a.shape[1] - array_b.shape[1]
108
+ if diff < 0:
109
+ array_a = np.pad(array_a, [[0, 0], [0, -diff]], "constant", constant_values=0)
110
+ else:
111
+ array_b = np.pad(array_b, [[0, 0], [0, diff]], "constant", constant_values=0)
112
+
113
+ return array_a, array_b
114
+
115
+
116
+ def _translate_array(
117
+ array_a: np.ndarray, array_b: Optional[np.ndarray] = None, weight: Optional[np.ndarray] = None
118
+ ) -> Tuple[np.ndarray, float]:
119
+ """
120
+ Return translated array_a and translation vector.
121
+
122
+ Columns of both arrays will have mean zero.
123
+
124
+ Parameters
125
+ ----------
126
+ array_a : ndarray
127
+ The 2D-array to translate.
128
+ array_b : ndarray, optional
129
+ The 2D-array to translate array_a based on.
130
+ weight : ndarray, optional
131
+ The weight vector.
132
+
133
+ Returns
134
+ -------
135
+ array_a : ndarray
136
+ If array_b is None, array_a is translated to origin using its centroid.
137
+ If array_b is given, array_a is translated to centroid of array_b (the centroid of
138
+ translated array_a will centroid with the centroid array_b).
139
+ centroid : float
140
+ If array_b is given, the centroid is returned.
141
+
142
+ """
143
+ # The mean is strongly affected by outliers and is not a robust estimator for central location
144
+ # see https://docs.python.org/3.6/library/statistics.html?highlight=mean#statistics.mean
145
+ if weight is not None:
146
+ if weight.ndim != 1:
147
+ raise ValueError("The weight should be a 1d row vector.")
148
+ if not (weight >= 0).all():
149
+ raise ValueError("The elements of the weight should be non-negative.")
150
+
151
+ centroid_a = np.average(array_a, axis=0, weights=weight)
152
+ if array_b is not None:
153
+ # translation vector to b centroid
154
+ centroid_a -= np.average(array_b, axis=0, weights=weight)
155
+ return array_a - centroid_a, -1 * centroid_a
156
+
157
+
158
+ def _scale_array(array_a, array_b=None) -> Tuple[np.ndarray, float]:
159
+ """
160
+ Return scaled/normalized array_a and scaling vector.
161
+
162
+ Parameters
163
+ ----------
164
+ array_a : ndarray
165
+ The 2D-array to scale
166
+ array_b : ndarray, default=None
167
+ The 2D-array to scale array_a based on.
168
+
169
+ Returns
170
+ -------
171
+ scaled_a, ndarray
172
+ If array_b is None, array_a is normalized using the Frobenius norm.
173
+ If array_b is given, array_a is scaled to match array_b"s norm (the norm of array_a
174
+ will be equal norm of array_b).
175
+ scale : float
176
+ The scaling factor to match array_b norm.
177
+
178
+ """
179
+ # scaling factor to match unit sphere
180
+ scale = 1.0 / np.linalg.norm(array_a)
181
+ if array_b is not None:
182
+ # scaling factor to match array_b norm
183
+ scale *= np.linalg.norm(array_b)
184
+ return array_a * scale, scale
185
+
186
+
187
+ def _hide_zero_padding(
188
+ array_a: np.ndarray,
189
+ remove_zero_col: bool = True,
190
+ remove_zero_row: bool = True,
191
+ tol: float = 1.0e-8,
192
+ ) -> np.ndarray:
193
+ r"""
194
+ Return array with zero-padded rows (bottom) and columns (right) removed.
195
+
196
+ Parameters
197
+ ----------
198
+ array_a : ndarray
199
+ The initial array.
200
+ remove_zero_col : bool, optional
201
+ If True, zero columns (values less than 1e-8) on the right side will be removed.
202
+ remove_zero_row : bool, optional
203
+ If True, zero rows (values less than 1e-8) on the bottom will be removed.
204
+ tol : float, optional
205
+ Tolerance value.
206
+
207
+ Returns
208
+ -------
209
+ new_A : ndarray
210
+ Array, with either near zero columns and/or zero rows are removed.
211
+
212
+ """
213
+ # Input checking
214
+ if array_a.ndim > 2:
215
+ raise TypeError("Matrix inputs must be 1- or 2- dimensional arrays")
216
+ # Check zero rows from bottom to top
217
+ if remove_zero_row:
218
+ num_row = array_a.shape[0]
219
+ tmp_a = array_a[..., np.newaxis] if array_a.ndim == 1 else array_a
220
+ for array_v in tmp_a[::-1]:
221
+ if any(abs(i) > tol for i in array_v):
222
+ break
223
+ num_row -= 1
224
+ array_a = array_a[:num_row]
225
+ # Cut off zero rows
226
+ if remove_zero_col:
227
+ if array_a.ndim == 2:
228
+ # Check zero columns from right to left
229
+ col_m = array_a.shape[1]
230
+ for array_v in array_a.T[::-1]:
231
+ if any(abs(i) > tol for i in array_v):
232
+ break
233
+ col_m -= 1
234
+ # Cut off zero columns
235
+ array_a = array_a[:, :col_m]
236
+ return array_a
237
+
238
+
239
+ def compute_error(
240
+ a: np.ndarray, b: np.ndarray, t: np.ndarray, s: Optional[np.ndarray] = None
241
+ ) -> float:
242
+ r"""Return the one- or two-sided Procrustes (squared Frobenius norm) error.
243
+
244
+ The double-sided Procrustes error is defined as
245
+
246
+ .. math::
247
+ \|\mathbf{S}\mathbf{A}\mathbf{T} - \mathbf{B}\|_{F}^2 =
248
+ \text{Tr}\left[
249
+ \left(\mathbf{S}\mathbf{A}\mathbf{T} - \mathbf{B}\right)^\dagger
250
+ \left(\mathbf{S}\mathbf{A}\mathbf{T} - \mathbf{B}\right)\right]
251
+
252
+ when :math:`\mathbf{S}` is the identity matrix :math:`\mathbf{I}`, this is called the one-sided
253
+ Procrustes error.
254
+
255
+ Parameters
256
+ ----------
257
+ a : ndarray
258
+ The 2D-array :math:`\mathbf{A}_{m \times n}` which is going to be transformed.
259
+ b : ndarray
260
+ The 2D-array :math:`\mathbf{B}_{m \times n}` representing the reference matrix.
261
+ t : ndarray
262
+ The 2D-array :math:`\mathbf{T}_{n \times n}` representing the right-hand-side transformation
263
+ matrix.
264
+ s : ndarray, optional
265
+ The 2D-array :math:`\mathbf{S}_{m \times m}` representing the left-hand-side transformation
266
+ matrix. If set to `None`, the one-sided Procrustes error is computed.
267
+
268
+ Returns
269
+ -------
270
+ error : float
271
+ The squared Frobenius norm of difference between the transformed array, :math:`\mathbf{S}
272
+ \mathbf{A}\mathbf{T}`, and the reference array, :math:`\mathbf{B}`.
273
+
274
+ """
275
+ # transform matrix A to either AT or SAT
276
+ a_trans = np.dot(a, t) if s is None else np.dot(np.dot(s, a), t)
277
+ # subtract matrix B and compute Frobenius norm squared
278
+ return np.linalg.norm(a_trans - b, ord=None) ** 2
279
+
280
+
281
+ def setup_input_arrays(
282
+ array_a: np.ndarray,
283
+ array_b: np.ndarray,
284
+ remove_zero_col: bool,
285
+ remove_zero_row: bool,
286
+ pad: bool,
287
+ translate: bool,
288
+ scale: bool,
289
+ check_finite: bool,
290
+ weight: Optional[np.ndarray] = None,
291
+ ) -> Tuple[np.ndarray, np.ndarray]:
292
+ r"""
293
+ Check and process array inputs for the Procrustes transformation routines.
294
+
295
+ Usually, the precursor step before all Procrustes methods.
296
+
297
+ Parameters
298
+ ----------
299
+ array_a : npdarray
300
+ The 2D array :math:`A` being transformed.
301
+ array_b : npdarray
302
+ The 2D reference array :math:`B`.
303
+ remove_zero_col : bool
304
+ If True, zero columns (values less than 1e-8) on the right side will be removed.
305
+ remove_zero_row : bool
306
+ If True, zero rows (values less than 1e-8) on the bottom will be removed.
307
+ pad : bool
308
+ Add zero rows (at the bottom) and/or columns (to the right-hand side) of matrices
309
+ :math:`\mathbf{A}` and :math:`\mathbf{B}` so that they have the same shape.
310
+ translate : bool
311
+ If true, then translate both arrays :math:`A, B` to the origin, ie columns of the arrays
312
+ will have mean zero.
313
+ scale :
314
+ If True, both arrays are normalized to one with respect to the Frobenius norm, ie
315
+ :math:`Tr(A^T A) = 1`.
316
+ check_finite : bool
317
+ If true, then checks if both arrays :math:`A, B` are numpy arrays and two-dimensional.
318
+ weight : A list of ndarray or ndarray
319
+ A list of the weight arrays or one numpy array. When only on numpy array provided,
320
+ it is assumed that the two arrays :math:`A` and :math:`B` share the same weight matrix.
321
+
322
+ Returns
323
+ -------
324
+ (ndarray, ndarray) :
325
+ Returns the padded arrays, in that they have the same matrix dimensions.
326
+
327
+ """
328
+ array_a = _setup_input_array_lower(
329
+ array_a, None, remove_zero_col, remove_zero_row, translate, scale, check_finite, weight
330
+ )
331
+ array_b = _setup_input_array_lower(
332
+ array_b, None, remove_zero_col, remove_zero_row, translate, scale, check_finite, weight
333
+ )
334
+ if pad:
335
+ array_a, array_b = _zero_padding(array_a, array_b, pad_mode="row-col")
336
+ return array_a, array_b
337
+
338
+
339
+ def setup_input_arrays_multi(
340
+ array_list: List[np.ndarray],
341
+ array_ref: np.ndarray,
342
+ remove_zero_col: bool,
343
+ remove_zero_row: bool,
344
+ pad_mode: str,
345
+ translate: bool,
346
+ scale: bool,
347
+ check_finite: bool,
348
+ weight: Optional[np.ndarray] = None,
349
+ ) -> List[np.ndarray]:
350
+ r"""
351
+ Check and process array inputs for the Procrustes transformation routines.
352
+
353
+ Parameters
354
+ ----------
355
+ array_list : List
356
+ A list of 2D arrays that being transformed.
357
+ array_ref : ndarray
358
+ The 2D reference array :math:`B`.
359
+ remove_zero_col : bool
360
+ If True, zero columns (values less than 1e-8) on the right side will be removed.
361
+ remove_zero_row : bool
362
+ If True, zero rows (values less than 1e-8) on the bottom will be removed.
363
+ pad_mode : str
364
+ Specifying how to pad the arrays. Should be one of
365
+ - "row"
366
+ The array with fewer rows is padded with zero rows so that both have the same
367
+ number of rows.
368
+ - "col"
369
+ The array with fewer columns is padded with zero columns so that both have the
370
+ same number of columns.
371
+ - "row-col"
372
+ The array with fewer rows is padded with zero rows, and the array with fewer
373
+ columns is padded with zero columns, so that both have the same dimensions.
374
+ This does not necessarily result in square arrays.
375
+ - "square"
376
+ The arrays are padded with zero rows and zero columns so that they are both
377
+ squared arrays. The dimension of square array is specified based on the highest
378
+ dimension, i.e. :math:`\text{max}(n_a, m_a, n_b, m_b)`.
379
+ translate : bool
380
+ If true, then translate both arrays :math:`A, B` to the origin, ie columns of the arrays
381
+ will have mean zero.
382
+ scale :
383
+ If True, both arrays are normalized to one with respect to the Frobenius norm, ie
384
+ :math:`Tr(A^T A) = 1`.
385
+ check_finite : bool
386
+ If true, then checks if both arrays :math:`A, B` are numpy arrays and two-dimensional.
387
+ weight : A list of ndarray or ndarray, optional
388
+ A list of the weight arrays or one numpy array. When only on numpy array provided,
389
+ it is assumed that the two arrays :math:`A` and :math:`B` share the same weight matrix.
390
+
391
+ Returns
392
+ -------
393
+ List of arrays :
394
+ Returns the padded arrays, in that they have the same matrix dimensions.
395
+ """
396
+ array_list_new = [
397
+ _setup_input_array_lower(
398
+ array_a=arr,
399
+ array_ref=array_ref,
400
+ remove_zero_col=remove_zero_col,
401
+ remove_zero_row=remove_zero_row,
402
+ translate=translate,
403
+ scale=scale,
404
+ check_finite=check_finite,
405
+ weight=weight,
406
+ )
407
+ for arr in array_list
408
+ ]
409
+ arr_shape = np.array([arr.shape for arr in array_list_new])
410
+ array_b = np.ones(np.max(arr_shape, axis=0), dtype=int)
411
+ array_list_new = [_zero_padding(arr, array_b, pad_mode=pad_mode) for arr in array_list_new]
412
+ return array_list_new
413
+
414
+
415
+ def _setup_input_array_lower(
416
+ array_a: np.ndarray,
417
+ array_ref: np.ndarray,
418
+ remove_zero_col: np.ndarray,
419
+ remove_zero_row: np.ndarray,
420
+ translate: bool,
421
+ scale: bool,
422
+ check_finite: bool,
423
+ weight: Optional[np.ndarray] = None,
424
+ ) -> np.ndarray:
425
+ """Pre-processing the matrices with translation, scaling."""
426
+ _check_arraytypes(array_a)
427
+ if check_finite:
428
+ array_a = np.asarray_chkfinite(array_a)
429
+ # Sometimes arrays already have zero padding that messes up zero padding below.
430
+ array_a = _hide_zero_padding(array_a, remove_zero_col, remove_zero_row)
431
+ if translate:
432
+ array_a, _ = _translate_array(array_a, array_ref, weight)
433
+ # scale the matrix when translate is False, but weight is True
434
+ else:
435
+ if weight is not None:
436
+ array_a = np.dot(np.diag(weight), array_a)
437
+
438
+ if scale:
439
+ array_a, _ = _scale_array(array_a, array_ref)
440
+ return array_a
441
+
442
+
443
+ def _check_arraytypes(*args) -> None:
444
+ r"""Check array input types to Procrustes transformation routines."""
445
+ if any(not isinstance(arr_x, np.ndarray) for arr_x in args):
446
+ raise TypeError("Matrix inputs must be NumPy arrays")
447
+ if any(x.ndim != 2 for x in args):
448
+ raise TypeError("Matrix inputs must be 2-dimensional arrays")
449
+
450
+
451
+ class ProcrustesResult(dict):
452
+ r"""Represents the Procrustes analysis result.
453
+
454
+ Attributes
455
+ ----------
456
+ error : float
457
+ The Procrustes (squared Frobenius norm) error.
458
+ new_a : ndarray
459
+ The translated/scaled numpy ndarray :math:`\mathbf{A}`.
460
+ new_b : ndarray
461
+ The translated/scaled numpy ndarray :math:`\mathbf{B}`.
462
+ t : ndarray
463
+ The 2D-array :math:`\mathbf{T}` representing the right-hand-side transformation matrix.
464
+ s : ndarray
465
+ The 2D-array :math:`\mathbf{S}` representing the left-hand-side transformation
466
+ matrix. If set to `None`, the one-sided Procrustes was performed.
467
+
468
+ """
469
+
470
+ # modification on https://github.com/scipy/scipy/blob/v1.4.1/scipy/optimize/optimize.py#L77-L132
471
+ def __getattr__(self, name: str):
472
+ """Deal with attributes which it doesn't explicitly manage."""
473
+ try:
474
+ return self[name]
475
+ # Not using raise from makes the traceback inaccurate, because the message implies there
476
+ # is a bug in the exception-handling code itself, which is a separate situation than
477
+ # wrapping an exception
478
+ # W0707 from http://pylint.pycqa.org/en/latest/technical_reference/features.html
479
+ except KeyError as ke_info:
480
+ raise AttributeError(name) from ke_info
481
+
482
+ __setattr__ = dict.__setitem__
483
+ __delattr__ = dict.__delitem__
484
+
485
+ def __repr__(self):
486
+ """Return a human friendly representation."""
487
+ if self.keys():
488
+ max_len = max(map(len, list(self.keys()))) + 1
489
+ return "\n".join([k.rjust(max_len) + ": " + repr(v) for k, v in sorted(self.items())])
490
+ else:
491
+ return self.__class__.__name__ + "()"
492
+
493
+ def __dir__(self):
494
+ """Provide basic customization of module attribute access with a list."""
495
+ return list(self.keys())