skar0 commited on
Commit
4e46e20
·
1 Parent(s): 24dc1da

Removed dataclasses from requirements

Browse files
attention_replication.py CHANGED
@@ -1,9 +1,9 @@
1
  # %%
2
  import torch as t
3
  import torch.nn as nn
4
- from typing import Union, List
5
  from fancy_einsum import einsum
6
- from einops import repeat, rearrange, reduce
7
  import numpy as np
8
  #%%
9
  def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
 
1
  # %%
2
  import torch as t
3
  import torch.nn as nn
4
+ from typing import Union
5
  from fancy_einsum import einsum
6
+ from einops import repeat, rearrange
7
  import numpy as np
8
  #%%
9
  def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
requirements.txt CHANGED
@@ -1,168 +1,161 @@
1
- aiofiles @ file:///home/conda/feedstock_root/build_artifacts/aiofiles_1676402724025/work
2
- aiohttp @ file:///home/conda/feedstock_root/build_artifacts/aiohttp_1676292661248/work
3
- aiosignal @ file:///home/conda/feedstock_root/build_artifacts/aiosignal_1667935791922/work
4
- altair @ file:///home/conda/feedstock_root/build_artifacts/altair_1675180856922/work
5
- anyio @ file:///home/conda/feedstock_root/build_artifacts/anyio_1666191106763/work/dist
6
- argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1640817743617/work
7
- argon2-cffi-bindings @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi-bindings_1666850768662/work
8
- asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work
9
- async-timeout @ file:///home/conda/feedstock_root/build_artifacts/async-timeout_1640026696943/work
10
- attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1671632566681/work
11
- backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
12
- backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work
13
- beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1675252249248/work
14
- bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1674535352125/work
15
- brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1666764671472/work
16
  certifi==2022.12.7
17
- cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1671179353105/work
18
- charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1661170624537/work
19
- click @ file:///home/conda/feedstock_root/build_artifacts/click_1666798198223/work
20
- colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
21
- comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1670575068857/work
22
- contourpy @ file:///home/conda/feedstock_root/build_artifacts/contourpy_1673633665736/work
23
- cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography-split_1675828607645/work
24
- cycler @ file:///home/conda/feedstock_root/build_artifacts/cycler_1635519461629/work
25
- dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work
26
- datasets @ file:///home/conda/feedstock_root/build_artifacts/datasets_1674838636692/work
27
- debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1674522362098/work
28
- decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
29
- defusedxml @ file:///home/conda/feedstock_root/build_artifacts/defusedxml_1615232257335/work
30
- dill @ file:///home/conda/feedstock_root/build_artifacts/dill_1666603105584/work
31
- einops @ file:///home/conda/feedstock_root/build_artifacts/einops_1670600230829/work
32
- entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
33
- executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work
34
  fancy-einsum==0.0.3
35
- fastapi @ file:///home/conda/feedstock_root/build_artifacts/fastapi_1676407540585/work
36
- fastjsonschema @ file:///home/conda/feedstock_root/build_artifacts/python-fastjsonschema_1663619548554/work/dist
37
- ffmpy @ file:///home/conda/feedstock_root/build_artifacts/ffmpy_1659474992694/work
38
- filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1672354931606/work
39
- Flask @ file:///home/conda/feedstock_root/build_artifacts/flask_1676592993069/work
40
- flit_core @ file:///home/conda/feedstock_root/build_artifacts/flit-core_1667734568827/work/source/flit_core
41
- fonttools @ file:///home/conda/feedstock_root/build_artifacts/fonttools_1666827107856/work
42
- frozenlist @ file:///home/conda/feedstock_root/build_artifacts/frozenlist_1667935435842/work
43
- fsspec @ file:///home/conda/feedstock_root/build_artifacts/fsspec_1674184942191/work
44
- gradio @ file:///home/conda/feedstock_root/build_artifacts/gradio_1676897693557/work
45
- h11 @ file:///home/conda/feedstock_root/build_artifacts/h11_1664132893548/work
46
- h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1634280454336/work
47
  hpack==4.0.0
48
- httpcore @ file:///home/conda/feedstock_root/build_artifacts/httpcore_1671551055614/work
49
- httpx @ file:///home/conda/feedstock_root/build_artifacts/httpx_1672850625594/work
50
- huggingface-hub @ file:///home/conda/feedstock_root/build_artifacts/huggingface_hub_1676642337813/work
51
- hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1619110129307/work
52
- idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
53
- importlib-metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1672612343532/work
54
- importlib-resources @ file:///home/conda/feedstock_root/build_artifacts/importlib_resources_1676919000169/work
55
- ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1676322140253/work
56
- ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1676047456691/work
57
  ipython-genutils==0.2.0
58
- ipywidgets @ file:///home/conda/feedstock_root/build_artifacts/ipywidgets_1671720089366/work
59
- itsdangerous @ file:///home/conda/feedstock_root/build_artifacts/itsdangerous_1648147185463/work
60
- jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work
61
- Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
62
- joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1663332044897/work
63
- jsonschema @ file:///home/conda/feedstock_root/build_artifacts/jsonschema-meta_1669810440410/work
64
- jupyter @ file:///home/conda/feedstock_root/build_artifacts/jupyter_1670249595582/work
65
- jupyter-console @ file:///home/conda/feedstock_root/build_artifacts/jupyter_console_1676328545892/work
66
- jupyter-events @ file:///home/conda/feedstock_root/build_artifacts/jupyter_events_1673559782596/work
67
- jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1676579893731/work
68
- jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1675109846004/work
69
- jupyter_server @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_1676476189852/work
70
- jupyter_server_terminals @ file:///home/conda/feedstock_root/build_artifacts/jupyter_server_terminals_1673491454549/work
71
- jupyterlab-pygments @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_pygments_1649936611996/work
72
- jupyterlab-widgets @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_widgets_1671722028097/work
73
- kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/kiwisolver_1666805701884/work
74
- linkify-it-py @ file:///home/conda/feedstock_root/build_artifacts/linkify-it-py_1651923627081/work
75
- markdown-it-py @ file:///home/conda/feedstock_root/build_artifacts/markdown-it-py_1650305363826/work
76
- MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1674135787083/work
77
- matplotlib @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-suite_1676406361850/work
78
- matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
79
- mdit-py-plugins @ file:///home/conda/feedstock_root/build_artifacts/mdit-py-plugins_1670348296204/work
80
- mdurl @ file:///home/conda/feedstock_root/build_artifacts/mdurl_1639515908913/work
81
- mistune @ file:///home/conda/feedstock_root/build_artifacts/mistune_1675771498296/work
82
- multidict @ file:///home/conda/feedstock_root/build_artifacts/multidict_1672339403932/work
83
- multiprocess @ file:///home/conda/feedstock_root/build_artifacts/multiprocess_1666932878376/work
84
  munkres==1.1.4
85
- nbclassic @ file:///home/conda/feedstock_root/build_artifacts/nbclassic_1676729186918/work
86
- nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1669795076334/work
87
- nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert-meta_1674590374792/work
88
- nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1673560067442/work
89
- nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work
90
- notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1667565639349/work
91
- notebook_shim @ file:///home/conda/feedstock_root/build_artifacts/notebook-shim_1667478401171/work
92
- numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1675642512762/work
93
- orjson @ file:///home/conda/feedstock_root/build_artifacts/orjson_1673484660945/work/target/wheels/orjson-3.8.5-cp310-cp310-linux_x86_64.whl
94
- packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1673482170163/work
95
  pandas==1.5.3
96
- pandocfilters @ file:///home/conda/feedstock_root/build_artifacts/pandocfilters_1631603243851/work
97
- parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
98
- pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
99
- pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
100
- Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1675487172403/work
101
- pkgutil_resolve_name @ file:///home/conda/feedstock_root/build_artifacts/pkgutil-resolve-name_1633981968097/work
102
- platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1675735718929/work
 
103
  ply==3.11
104
- prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1674535637125/work
105
- prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1670414775770/work
106
- psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1667885877572/work
107
- ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
108
- pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
109
- pyarrow==11.0.0
110
- pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
111
- pycryptodome @ file:///home/conda/feedstock_root/build_artifacts/pycryptodome_1669581639515/work
112
- pydantic @ file:///home/conda/feedstock_root/build_artifacts/pydantic_1676531650626/work
113
- pydub @ file:///home/conda/feedstock_root/build_artifacts/pydub_1615612442567/work
114
- Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1672682006896/work
115
- pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1672659226110/work
116
- pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1652235407899/work
117
  PyQt5==5.15.7
118
  PyQt5-sip==12.11.0
119
- pyrsistent @ file:///home/conda/feedstock_root/build_artifacts/pyrsistent_1672681463845/work
120
- PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
121
- python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
122
- python-json-logger @ file:///home/conda/feedstock_root/build_artifacts/python-json-logger_1676401516590/work
123
  python-multipart==0.0.5
124
- pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1673864280276/work
125
- PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1666772395347/work
126
- pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1673612669255/work
127
- qtconsole @ file:///home/conda/feedstock_root/build_artifacts/qtconsole-base_1667404144336/work
128
- QtPy @ file:///home/conda/feedstock_root/build_artifacts/qtpy_1667873092748/work
129
- regex @ file:///home/conda/feedstock_root/build_artifacts/regex_1667265033016/work
130
- requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1673863902341/work
131
- responses @ file:///home/conda/feedstock_root/build_artifacts/responses_1643839609465/work
132
- rfc3339-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3339-validator_1638811747357/work
133
- rfc3986 @ file:///home/conda/feedstock_root/build_artifacts/rfc3986_1620442452971/work
134
- rfc3986-validator @ file:///home/conda/feedstock_root/build_artifacts/rfc3986-validator_1598024191506/work
135
- sacremoses @ file:///home/conda/feedstock_root/build_artifacts/sacremoses_1651557636210/work
136
- Send2Trash @ file:///home/conda/feedstock_root/build_artifacts/send2trash_1628511208346/work
137
- sip @ file:///home/conda/feedstock_root/build_artifacts/sip_1675696581052/work
138
- six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
139
- sniffio @ file:///home/conda/feedstock_root/build_artifacts/sniffio_1662051266223/work
140
- soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
141
- stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
142
- starlette @ file:///home/conda/feedstock_root/build_artifacts/starlette-recipe_1676402644778/work
143
- terminado @ file:///home/conda/feedstock_root/build_artifacts/terminado_1670253674810/work
144
- tinycss2 @ file:///home/conda/feedstock_root/build_artifacts/tinycss2_1666100256010/work
145
- tokenizers @ file:///home/conda/feedstock_root/build_artifacts/tokenizers_1674690844352/work/bindings/python
146
- toml @ file:///home/conda/feedstock_root/build_artifacts/toml_1604308577558/work
147
- toolz @ file:///home/conda/feedstock_root/build_artifacts/toolz_1657485559105/work
148
  torch==1.13.1
149
  torchaudio==0.13.1
150
  torchvision==0.14.1
151
- tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1666788589303/work
152
- tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1662214488106/work
153
- traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work
154
- transformers @ file:///home/conda/feedstock_root/build_artifacts/transformers_1676091074773/work
155
- typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1665144421445/work
156
- uc-micro-py @ file:///home/conda/feedstock_root/build_artifacts/uc-micro-py_1608058642472/work
157
- unicodedata2 @ file:///home/conda/feedstock_root/build_artifacts/unicodedata2_1667239886688/work
158
- urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1673452138552/work
159
- uvicorn @ file:///home/conda/feedstock_root/build_artifacts/uvicorn-split_1669234664979/work
160
- wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
161
  webencodings==0.5.1
162
- websocket-client @ file:///home/conda/feedstock_root/build_artifacts/websocket-client_1675567828044/work
163
- websockets @ file:///home/conda/feedstock_root/build_artifacts/websockets_1666806213473/work
164
- Werkzeug @ file:///home/conda/feedstock_root/build_artifacts/werkzeug_1676411946679/work
165
- widgetsnbextension @ file:///home/conda/feedstock_root/build_artifacts/widgetsnbextension_1672066693230/work
166
- xxhash @ file:///home/conda/feedstock_root/build_artifacts/python-xxhash_1672695020159/work
167
- yarl @ file:///home/conda/feedstock_root/build_artifacts/yarl_1672340954791/work
168
- zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1676708471276/work
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ anyio==3.6.2
6
+ argon2-cffi==21.3.0
7
+ argon2-cffi-bindings==21.2.0
8
+ asttokens==2.2.1
9
+ async-timeout==4.0.2
10
+ attrs==22.2.0
11
+ backcall==0.2.0
12
+ backports.functools-lru-cache==1.6.4
13
+ beautifulsoup4==4.11.2
14
+ bleach==6.0.0
15
+ brotlipy==0.7.0
16
  certifi==2022.12.7
17
+ cffi==1.15.1
18
+ charset-normalizer==2.1.1
19
+ click==8.1.3
20
+ comm==0.1.2
21
+ contourpy==1.0.7
22
+ cryptography==39.0.1
23
+ cycler==0.11.0
24
+ debugpy==1.6.6
25
+ decorator==5.1.1
26
+ defusedxml==0.7.1
27
+ einops==0.6.0
28
+ entrypoints==0.4
29
+ executing==1.2.0
 
 
 
 
30
  fancy-einsum==0.0.3
31
+ fastapi==0.92.0
32
+ fastjsonschema==2.16.2
33
+ ffmpy==0.3.0
34
+ filelock==3.9.0
35
+ Flask==2.2.3
36
+ flit_core==3.8.0
37
+ fonttools==4.38.0
38
+ frozenlist==1.3.3
39
+ fsspec==2023.1.0
40
+ gradio==3.19.1
41
+ h11==0.14.0
42
+ h2==4.1.0
43
  hpack==4.0.0
44
+ httpcore==0.16.3
45
+ httpx==0.23.3
46
+ huggingface-hub==0.12.1
47
+ hyperframe==6.0.1
48
+ idna==3.4
49
+ importlib-metadata==6.0.0
50
+ importlib-resources==5.12.0
51
+ ipykernel==6.21.2
52
+ ipython==8.10.0
53
  ipython-genutils==0.2.0
54
+ ipywidgets==8.0.4
55
+ itsdangerous==2.1.2
56
+ jedi==0.18.2
57
+ Jinja2==3.1.2
58
+ jsonschema==4.17.3
59
+ jupyter==1.0.0
60
+ jupyter_client==8.0.3
61
+ jupyter-console==6.5.1
62
+ jupyter_core==5.2.0
63
+ jupyter-events==0.6.3
64
+ jupyter_server==2.3.0
65
+ jupyter_server_terminals==0.4.4
66
+ jupyterlab-pygments==0.2.2
67
+ jupyterlab-widgets==3.0.5
68
+ kiwisolver==1.4.4
69
+ linkify-it-py==2.0.0
70
+ markdown-it-py==2.1.0
71
+ MarkupSafe==2.1.2
72
+ matplotlib==3.7.0
73
+ matplotlib-inline==0.1.6
74
+ mdit-py-plugins==0.3.3
75
+ mdurl==0.1.0
76
+ mistune==2.0.5
77
+ multidict==6.0.4
 
 
78
  munkres==1.1.4
79
+ nbclassic==0.5.2
80
+ nbclient==0.7.2
81
+ nbconvert==7.2.9
82
+ nbformat==5.7.3
83
+ nest-asyncio==1.5.6
84
+ notebook==6.5.2
85
+ notebook_shim==0.2.2
86
+ numpy==1.24.2
87
+ orjson==3.8.5
88
+ packaging==23.0
89
  pandas==1.5.3
90
+ pandocfilters==1.5.0
91
+ parso==0.8.3
92
+ pexpect==4.8.0
93
+ pickleshare==0.7.5
94
+ Pillow==9.4.0
95
+ pip==23.0.1
96
+ pkgutil_resolve_name==1.3.10
97
+ platformdirs==3.0.0
98
  ply==3.11
99
+ prometheus-client==0.16.0
100
+ prompt-toolkit==3.0.36
101
+ psutil==5.9.4
102
+ ptyprocess==0.7.0
103
+ pure-eval==0.2.2
104
+ pycparser==2.21
105
+ pycryptodome==3.16.0
106
+ pydantic==1.10.5
107
+ pydub==0.25.1
108
+ Pygments==2.14.0
109
+ pyOpenSSL==23.0.0
110
+ pyparsing==3.0.9
 
111
  PyQt5==5.15.7
112
  PyQt5-sip==12.11.0
113
+ pyrsistent==0.19.3
114
+ PySocks==1.7.1
115
+ python-dateutil==2.8.2
116
+ python-json-logger==2.0.6
117
  python-multipart==0.0.5
118
+ pytz==2022.7.1
119
+ PyYAML==6.0
120
+ pyzmq==25.0.0
121
+ qtconsole==5.4.0
122
+ QtPy==2.3.0
123
+ regex==2022.10.31
124
+ requests==2.28.2
125
+ rfc3339-validator==0.1.4
126
+ rfc3986==1.5.0
127
+ rfc3986-validator==0.1.1
128
+ Send2Trash==1.8.0
129
+ setuptools==67.3.2
130
+ sip==6.7.7
131
+ six==1.16.0
132
+ sniffio==1.3.0
133
+ soupsieve==2.3.2.post1
134
+ stack-data==0.6.2
135
+ starlette==0.25.0
136
+ terminado==0.17.1
137
+ tinycss2==1.2.1
138
+ tokenizers==0.13.2
139
+ toml==0.10.2
140
+ toolz==0.12.0
 
141
  torch==1.13.1
142
  torchaudio==0.13.1
143
  torchvision==0.14.1
144
+ tornado==6.2
145
+ tqdm==4.64.1
146
+ traitlets==5.9.0
147
+ transformers==4.26.1
148
+ typing_extensions==4.4.0
149
+ uc-micro-py==1.0.1
150
+ unicodedata2==15.0.0
151
+ urllib3==1.26.14
152
+ uvicorn==0.20.0
153
+ wcwidth==0.2.6
154
  webencodings==0.5.1
155
+ websocket-client==1.5.1
156
+ websockets==10.4
157
+ Werkzeug==2.2.3
158
+ wheel==0.38.4
159
+ widgetsnbextension==4.0.5
160
+ yarl==1.8.2
161
+ zipp==3.14.0
transformer_replication.py CHANGED
@@ -5,13 +5,7 @@ import torch.nn as nn
5
  from typing import Union, List
6
  from fancy_einsum import einsum
7
  import torch as t
8
- from torch import nn
9
- from torchvision import datasets, transforms
10
- from torch.utils.data import DataLoader
11
- from typing import Union, Optional, Callable, Tuple
12
- import numpy as np
13
- from einops import rearrange
14
- import time
15
  # %%
16
  tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
17
  if __name__ == "__main__":
@@ -90,9 +84,6 @@ class LayerNorm(nn.Module):
90
  pass
91
 
92
  # %%
93
- from dataclasses import dataclass
94
-
95
- @dataclass(frozen=True)
96
  class TransformerConfig:
97
  '''Constants used throughout your decoder-only transformer model.'''
98
 
@@ -101,10 +92,22 @@ class TransformerConfig:
101
  vocab_size: int
102
  hidden_size: int
103
  max_seq_len: int
104
- dropout: float = 0.1
105
- layer_norm_epsilon: float = 1e-05
 
 
 
 
 
 
 
 
 
 
 
 
106
  # %%
107
- import attention_replication
108
 
109
  class BertMLP(nn.Module):
110
  def __init__(self, config: TransformerConfig):
 
5
  from typing import Union, List
6
  from fancy_einsum import einsum
7
  import torch as t
8
+ import attention_replication
 
 
 
 
 
 
9
  # %%
10
  tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
11
  if __name__ == "__main__":
 
84
  pass
85
 
86
  # %%
 
 
 
87
  class TransformerConfig:
88
  '''Constants used throughout your decoder-only transformer model.'''
89
 
 
92
  vocab_size: int
93
  hidden_size: int
94
  max_seq_len: int
95
+ dropout: float
96
+ layer_norm_epsilon: float
97
+
98
+ def __init__(
99
+ self, num_layers, num_heads, vocab_size, hidden_size, max_seq_len,
100
+ dropout=0.1, layer_norm_epsilon=1e-5,
101
+ ) -> None:
102
+ self.num_layers = num_layers
103
+ self.num_heads = num_heads
104
+ self.vocab_size = vocab_size
105
+ self.hidden_size = hidden_size
106
+ self.max_seq_len = max_seq_len
107
+ self.dropout = dropout
108
+ self.layer_norm_epsilon = layer_norm_epsilon
109
  # %%
110
+
111
 
112
  class BertMLP(nn.Module):
113
  def __init__(self, config: TransformerConfig):