Spaces:
Running
Running
reform scaffold
Browse files- .gitignore +2 -0
- environment.yml +335 -0
- mlip_arena/models/chgnet.py +60 -0
- mlip_arena/tasks/diatomics/mlip.py +165 -0
- mlip_arena/tasks/diatomics/vasp/run.ipynb +0 -509
- mlip_arena/tasks/stability/__init__.py +0 -0
- mlip_arena/tasks/stability/run.py +3 -0
- mlip_arena/tasks/utils.py +161 -0
- pyproject.toml +1 -0
- scripts/install-pyg.sh +15 -0
- serve/models/alerts.py +4 -0
- serve/models/bugs.py +4 -0
- serve/tasks/homonuclear-diatomics.py +1 -1
.gitignore
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
tests/
|
2 |
*.out
|
|
|
|
|
3 |
mlip_arena/tasks/*/*/*/
|
4 |
|
5 |
# Byte-compiled / optimized / DLL files
|
|
|
1 |
tests/
|
2 |
*.out
|
3 |
+
*.ipynb
|
4 |
+
*.extxyz
|
5 |
mlip_arena/tasks/*/*/*/
|
6 |
|
7 |
# Byte-compiled / optimized / DLL files
|
environment.yml
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
channels:
|
2 |
+
- defaults
|
3 |
+
- conda-forge
|
4 |
+
dependencies:
|
5 |
+
- _libgcc_mutex=0.1=main
|
6 |
+
- _openmp_mutex=5.1=1_gnu
|
7 |
+
- bzip2=1.0.8=h5eee18b_5
|
8 |
+
- ca-certificates=2024.3.11=h06a4308_0
|
9 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
10 |
+
- libffi=3.4.4=h6a678d5_0
|
11 |
+
- libgcc-ng=11.2.0=h1234567_1
|
12 |
+
- libgomp=11.2.0=h1234567_1
|
13 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
14 |
+
- libuuid=1.41.5=h5eee18b_0
|
15 |
+
- ncurses=6.4=h6a678d5_0
|
16 |
+
- openssl=3.0.13=h7f8727e_0
|
17 |
+
- pip=23.3.1=py311h06a4308_0
|
18 |
+
- python=3.11.8=h955ad1f_0
|
19 |
+
- readline=8.2=h5eee18b_0
|
20 |
+
- setuptools=68.2.2=py311h06a4308_0
|
21 |
+
- sqlite=3.41.2=h5eee18b_0
|
22 |
+
- tk=8.6.12=h1ccaba5_0
|
23 |
+
- wheel=0.41.2=py311h06a4308_0
|
24 |
+
- xz=5.4.6=h5eee18b_0
|
25 |
+
- zlib=1.2.13=h5eee18b_0
|
26 |
+
- pip:
|
27 |
+
- absl-py==2.1.0
|
28 |
+
- aiofiles==23.2.1
|
29 |
+
- aiohttp==3.9.3
|
30 |
+
- aioitertools==0.11.0
|
31 |
+
- aiosignal==1.3.1
|
32 |
+
- aiosqlite==0.20.0
|
33 |
+
- alembic==1.13.1
|
34 |
+
- alignn==2024.5.27
|
35 |
+
- altair==5.3.0
|
36 |
+
- annotated-types==0.6.0
|
37 |
+
- anyio==3.7.1
|
38 |
+
- appdirs==1.4.4
|
39 |
+
- apprise==1.7.5
|
40 |
+
- ase==3.23.0
|
41 |
+
- asgi-lifespan==2.1.0
|
42 |
+
- asttokens==2.4.1
|
43 |
+
- async-timeout==4.0.3
|
44 |
+
- asyncpg==0.29.0
|
45 |
+
- atomate2==0.0.14.post30+g256b39a1
|
46 |
+
- attrs==23.2.0
|
47 |
+
- autograd==1.5
|
48 |
+
- autoray==0.6.9
|
49 |
+
- bcrypt==4.1.2
|
50 |
+
- bidict==0.23.1
|
51 |
+
- blinker==1.7.0
|
52 |
+
- blosc2==2.7.0
|
53 |
+
- boto3==1.34.74
|
54 |
+
- botocore==1.34.74
|
55 |
+
- cachetools==5.3.3
|
56 |
+
- certifi==2024.7.4
|
57 |
+
- cffi==1.16.0
|
58 |
+
- charset-normalizer==3.3.2
|
59 |
+
- chgnet==0.3.5
|
60 |
+
- click==8.1.7
|
61 |
+
- cloudpickle==3.0.0
|
62 |
+
- colorama==0.4.6
|
63 |
+
- comm==0.2.2
|
64 |
+
- contourpy==1.2.0
|
65 |
+
- coolname==2.2.0
|
66 |
+
- covalent==0.232.0.post1
|
67 |
+
- croniter==2.0.3
|
68 |
+
- cryptography==42.0.5
|
69 |
+
- custodian==2024.3.12
|
70 |
+
- cycler==0.12.1
|
71 |
+
- cython==3.0.10
|
72 |
+
- dask==2024.3.1
|
73 |
+
- dask-jobqueue==0.8.5
|
74 |
+
- dateparser==1.2.0
|
75 |
+
- debugpy==1.8.1
|
76 |
+
- decorator==5.1.1
|
77 |
+
- dgl==2.3.0+cu121
|
78 |
+
- distributed==2024.3.1
|
79 |
+
- dnspython==2.6.1
|
80 |
+
- docker==6.1.3
|
81 |
+
- docker-pycreds==0.4.0
|
82 |
+
- e3nn==0.5.1
|
83 |
+
- email-validator==2.1.1
|
84 |
+
- emmet-core==0.82.1
|
85 |
+
- executing==2.0.1
|
86 |
+
- fairchem-core==1.0.0
|
87 |
+
- fastapi==0.110.0
|
88 |
+
- fastjsonschema==2.20.0
|
89 |
+
- filelock==3.15.4
|
90 |
+
- fireworks==2.0.3
|
91 |
+
- flake8==7.1.0
|
92 |
+
- flask==3.0.2
|
93 |
+
- flask-paginate==2024.3.28
|
94 |
+
- fonttools==4.50.0
|
95 |
+
- frozenlist==1.4.1
|
96 |
+
- fsspec==2024.6.1
|
97 |
+
- furl==2.1.3
|
98 |
+
- future==1.0.0
|
99 |
+
- gitdb==4.0.11
|
100 |
+
- gitpython==3.1.43
|
101 |
+
- google-auth==2.29.0
|
102 |
+
- gpaw==24.7.0b1
|
103 |
+
- greenlet==3.0.3
|
104 |
+
- griffe==0.42.1
|
105 |
+
- grpcio==1.64.1
|
106 |
+
- gunicorn==21.2.0
|
107 |
+
- h11==0.14.0
|
108 |
+
- h2==4.1.0
|
109 |
+
- h5py==3.10.0
|
110 |
+
- hpack==4.0.0
|
111 |
+
- httpcore==1.0.5
|
112 |
+
- httptools==0.6.1
|
113 |
+
- httpx==0.27.0
|
114 |
+
- huggingface-hub==0.22.2
|
115 |
+
- hyperframe==6.0.1
|
116 |
+
- idna==3.7
|
117 |
+
- importlib-metadata==7.1.0
|
118 |
+
- importlib-resources==6.1.3
|
119 |
+
- inflect==7.3.1
|
120 |
+
- ipykernel==6.29.4
|
121 |
+
- ipython==8.22.2
|
122 |
+
- ipywidgets==8.1.2
|
123 |
+
- itsdangerous==2.1.2
|
124 |
+
- jarvis-tools==2024.5.10
|
125 |
+
- jedi==0.19.1
|
126 |
+
- jinja2==3.1.4
|
127 |
+
- jmespath==1.0.1
|
128 |
+
- jobflow==0.1.17
|
129 |
+
- joblib==1.3.2
|
130 |
+
- jsonpatch==1.33
|
131 |
+
- jsonpointer==2.4
|
132 |
+
- jsonschema==4.21.1
|
133 |
+
- jsonschema-specifications==2023.12.1
|
134 |
+
- jupyter-client==8.6.1
|
135 |
+
- jupyter-core==5.7.2
|
136 |
+
- jupyterlab-widgets==3.0.10
|
137 |
+
- kiwisolver==1.4.5
|
138 |
+
- kubernetes==29.0.0
|
139 |
+
- latexcodec==3.0.0
|
140 |
+
- lightning-utilities==0.11.2
|
141 |
+
- llvmlite==0.42.0
|
142 |
+
- lmdb==1.4.1
|
143 |
+
- lmdbm==0.0.5
|
144 |
+
- locket==1.0.0
|
145 |
+
- looseversion==1.3.0
|
146 |
+
- mace-torch==0.3.4
|
147 |
+
- maggma==0.64.0
|
148 |
+
- mako==1.3.2
|
149 |
+
- markdown==3.6
|
150 |
+
- markdown-it-py==2.2.0
|
151 |
+
- markupsafe==2.1.5
|
152 |
+
- matgl==1.0.0
|
153 |
+
- matplotlib==3.8.3
|
154 |
+
- matplotlib-inline==0.1.6
|
155 |
+
- matscipy==1.0.0
|
156 |
+
- mccabe==0.7.0
|
157 |
+
- mdurl==0.1.2
|
158 |
+
- mlip-arena==0.0.1
|
159 |
+
- mongogrant==0.3.3
|
160 |
+
- mongomock==4.1.2
|
161 |
+
- monty==2024.2.26
|
162 |
+
- more-itertools==10.3.0
|
163 |
+
- mp-api==0.41.2
|
164 |
+
- mpire==2.10.1
|
165 |
+
- mpmath==1.3.0
|
166 |
+
- msgpack==1.0.8
|
167 |
+
- multidict==6.0.5
|
168 |
+
- natsort==8.4.0
|
169 |
+
- nbformat==5.10.4
|
170 |
+
- ndindex==1.8
|
171 |
+
- nest-asyncio==1.6.0
|
172 |
+
- networkx==3.3
|
173 |
+
- numba==0.59.1
|
174 |
+
- numexpr==2.10.1
|
175 |
+
- numpy==1.26.4
|
176 |
+
- nvidia-cublas-cu12==12.1.3.1
|
177 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
178 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
179 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
180 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
181 |
+
- nvidia-cufft-cu12==11.0.2.54
|
182 |
+
- nvidia-curand-cu12==10.3.2.106
|
183 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
184 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
185 |
+
- nvidia-ml-py3==7.352.0
|
186 |
+
- nvidia-nccl-cu12==2.19.3
|
187 |
+
- nvidia-nvjitlink-cu12==12.5.82
|
188 |
+
- nvidia-nvtx-cu12==12.1.105
|
189 |
+
- oauthlib==3.2.2
|
190 |
+
- opt-einsum==3.3.0
|
191 |
+
- opt-einsum-fx==0.1.4
|
192 |
+
- orderedmultidict==1.0.1
|
193 |
+
- orjson==3.10.0
|
194 |
+
- packaging==24.0
|
195 |
+
- palettable==3.3.3
|
196 |
+
- pandas==2.2.2
|
197 |
+
- paramiko==3.4.0
|
198 |
+
- parso==0.8.3
|
199 |
+
- partd==1.4.1
|
200 |
+
- pathspec==0.12.1
|
201 |
+
- pendulum==2.1.2
|
202 |
+
- pennylane==0.32.0
|
203 |
+
- pennylane-lightning==0.33.1
|
204 |
+
- pexpect==4.9.0
|
205 |
+
- pillow==10.2.0
|
206 |
+
- platformdirs==4.2.0
|
207 |
+
- plotly==5.20.0
|
208 |
+
- plumed==2.9.0
|
209 |
+
- prefect==2.16.8
|
210 |
+
- prefect-dask==0.2.6
|
211 |
+
- prettytable==3.10.0
|
212 |
+
- prompt-toolkit==3.0.43
|
213 |
+
- protobuf==4.25.3
|
214 |
+
- psutil==6.0.0
|
215 |
+
- ptyprocess==0.7.0
|
216 |
+
- pure-eval==0.2.2
|
217 |
+
- py-cpuinfo==9.0.0
|
218 |
+
- pyarrow==16.1.0
|
219 |
+
- pyasn1==0.6.0
|
220 |
+
- pyasn1-modules==0.4.0
|
221 |
+
- pybtex==0.24.0
|
222 |
+
- pycodestyle==2.12.0
|
223 |
+
- pycparser==2.22
|
224 |
+
- pydantic==2.6.4
|
225 |
+
- pydantic-core==2.16.3
|
226 |
+
- pydantic-settings==2.2.1
|
227 |
+
- pydash==8.0.0
|
228 |
+
- pydeck==0.9.1
|
229 |
+
- pydocstyle==6.3.0
|
230 |
+
- pyflakes==3.2.0
|
231 |
+
- pygments==2.17.2
|
232 |
+
- pymatgen==2024.4.13
|
233 |
+
- pymongo==4.6.3
|
234 |
+
- pynacl==1.5.0
|
235 |
+
- pyparsing==2.4.7
|
236 |
+
- python-dateutil==2.9.0.post0
|
237 |
+
- python-dotenv==1.0.1
|
238 |
+
- python-engineio==4.9.0
|
239 |
+
- python-graphviz==0.20.3
|
240 |
+
- python-hostlist==1.23.0
|
241 |
+
- python-multipart==0.0.9
|
242 |
+
- python-slugify==8.0.4
|
243 |
+
- python-socketio==5.11.2
|
244 |
+
- pytorch-lightning==2.2.1
|
245 |
+
- pytz==2024.1
|
246 |
+
- pytzdata==2020.1
|
247 |
+
- pyyaml==6.0.1
|
248 |
+
- pyzmq==25.1.2
|
249 |
+
- readchar==4.0.6
|
250 |
+
- referencing==0.34.0
|
251 |
+
- regex==2023.12.25
|
252 |
+
- requests==2.32.3
|
253 |
+
- requests-oauthlib==2.0.0
|
254 |
+
- rfc3339-validator==0.1.4
|
255 |
+
- rich==13.3.5
|
256 |
+
- rpds-py==0.18.0
|
257 |
+
- rsa==4.9
|
258 |
+
- ruamel-yaml==0.17.40
|
259 |
+
- ruamel-yaml-clib==0.2.8
|
260 |
+
- rustworkx==0.14.2
|
261 |
+
- s3transfer==0.10.1
|
262 |
+
- safetensors==0.4.2
|
263 |
+
- scikit-learn==1.4.1.post1
|
264 |
+
- scipy==1.14.0
|
265 |
+
- semantic-version==2.10.0
|
266 |
+
- sentinels==1.0.0
|
267 |
+
- sentry-sdk==2.7.1
|
268 |
+
- setproctitle==1.3.3
|
269 |
+
- shellingham==1.5.4
|
270 |
+
- simple-websocket==1.0.0
|
271 |
+
- simplejson==3.19.2
|
272 |
+
- six==1.16.0
|
273 |
+
- smart-open==7.0.4
|
274 |
+
- smmap==5.0.1
|
275 |
+
- sniffio==1.3.1
|
276 |
+
- snowballstemmer==2.2.0
|
277 |
+
- sortedcontainers==2.4.0
|
278 |
+
- spglib==2.3.1
|
279 |
+
- sqlalchemy==1.4.52
|
280 |
+
- sqlalchemy-utils==0.41.2
|
281 |
+
- sshtunnel==0.4.0
|
282 |
+
- stack-data==0.6.3
|
283 |
+
- starlette==0.36.3
|
284 |
+
- streamlit==1.36.0
|
285 |
+
- submitit==1.5.1
|
286 |
+
- sympy==1.12.1
|
287 |
+
- tables==3.9.2
|
288 |
+
- tabulate==0.9.0
|
289 |
+
- tblib==3.0.0
|
290 |
+
- tenacity==8.2.3
|
291 |
+
- tensorboard==2.17.0
|
292 |
+
- tensorboard-data-server==0.7.2
|
293 |
+
- text-unidecode==1.3
|
294 |
+
- threadpoolctl==3.4.0
|
295 |
+
- toml==0.10.2
|
296 |
+
- toolz==0.12.1
|
297 |
+
- torch==2.2.1
|
298 |
+
- torch-dftd==0.4.0
|
299 |
+
- torch-ema==0.3
|
300 |
+
- torch-geometric==2.5.2
|
301 |
+
- torch-scatter==2.1.2+pt22cu121
|
302 |
+
- torch-sparse==0.6.18+pt22cu121
|
303 |
+
- torchdata==0.7.1
|
304 |
+
- torchmetrics==1.3.2
|
305 |
+
- tornado==6.4
|
306 |
+
- tqdm==4.66.4
|
307 |
+
- trainstation==1.0
|
308 |
+
- traitlets==5.14.2
|
309 |
+
- triton==2.2.0
|
310 |
+
- typeguard==4.3.0
|
311 |
+
- typer==0.12.0
|
312 |
+
- typer-cli==0.12.0
|
313 |
+
- typer-slim==0.12.0
|
314 |
+
- typing-extensions==4.12.2
|
315 |
+
- tzdata==2024.1
|
316 |
+
- tzlocal==5.2
|
317 |
+
- ujson==5.9.0
|
318 |
+
- uncertainties==3.1.7
|
319 |
+
- urllib3==2.2.2
|
320 |
+
- uvicorn==0.18.3
|
321 |
+
- uvloop==0.19.0
|
322 |
+
- wandb==0.17.4
|
323 |
+
- watchdog==4.0.0
|
324 |
+
- watchfiles==0.21.0
|
325 |
+
- wcwidth==0.2.13
|
326 |
+
- websocket-client==1.7.0
|
327 |
+
- websockets==12.0
|
328 |
+
- werkzeug==3.0.1
|
329 |
+
- widgetsnbextension==4.0.10
|
330 |
+
- wrapt==1.16.0
|
331 |
+
- wsproto==1.2.0
|
332 |
+
- xmltodict==0.13.0
|
333 |
+
- yarl==1.9.4
|
334 |
+
- zict==3.0.0
|
335 |
+
- zipp==3.18.1
|
mlip_arena/models/chgnet.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from ase import Atoms
|
6 |
+
from ase.calculators.calculator import all_changes
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from torch_geometric.data import Data
|
9 |
+
|
10 |
+
from mlip_arena.models import MLIP, MLIPCalculator, ModuleMLIP
|
11 |
+
|
12 |
+
|
13 |
+
class CHGNetCalculator(MLIPCalculator):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
device: torch.device | None = None,
|
17 |
+
restart=None,
|
18 |
+
atoms=None,
|
19 |
+
directory=".",
|
20 |
+
**kwargs,
|
21 |
+
):
|
22 |
+
super().__init__(restart=restart, atoms=atoms, directory=directory, **kwargs)
|
23 |
+
|
24 |
+
self.name: str = self.__class__.__name__
|
25 |
+
|
26 |
+
fpath = hf_hub_download(
|
27 |
+
repo_id="cyrusyc/mace-universal",
|
28 |
+
subfolder="pretrained",
|
29 |
+
filename="2023-12-12-mace-128-L1_epoch-199.model",
|
30 |
+
revision="main",
|
31 |
+
)
|
32 |
+
|
33 |
+
self.device = device or torch.device(
|
34 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
35 |
+
)
|
36 |
+
|
37 |
+
self.model = torch.load(fpath, map_location=self.device)
|
38 |
+
|
39 |
+
self.implemented_properties = ["energy", "forces", "stress"]
|
40 |
+
|
41 |
+
def calculate(
|
42 |
+
self, atoms: Atoms, properties: list[str], system_changes: list = all_changes
|
43 |
+
):
|
44 |
+
"""Calculate energies and forces for the given Atoms object"""
|
45 |
+
super().calculate(atoms, properties, system_changes)
|
46 |
+
|
47 |
+
output = self.forward(atoms)
|
48 |
+
|
49 |
+
self.results = {}
|
50 |
+
if "energy" in properties:
|
51 |
+
self.results["energy"] = output["energy"].item()
|
52 |
+
if "forces" in properties:
|
53 |
+
self.results["forces"] = output["forces"].cpu().detach().numpy()
|
54 |
+
if "stress" in properties:
|
55 |
+
self.results["stress"] = output["stress"].cpu().detach().numpy()
|
56 |
+
|
57 |
+
def forward(self, x: Data | Atoms) -> dict[str, torch.Tensor]:
|
58 |
+
"""Implement data conversion, graph creation, and model forward pass"""
|
59 |
+
# TODO
|
60 |
+
raise NotImplementedError
|
mlip_arena/tasks/diatomics/mlip.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import timedelta
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
# import covalent as ct
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
from ase import Atoms
|
9 |
+
from ase.calculators.calculator import Calculator
|
10 |
+
from ase.data import chemical_symbols
|
11 |
+
from dask.distributed import Client
|
12 |
+
from dask_jobqueue import SLURMCluster
|
13 |
+
from prefect import flow, task
|
14 |
+
from prefect.tasks import task_input_hash
|
15 |
+
from prefect_dask import DaskTaskRunner
|
16 |
+
|
17 |
+
from mlip_arena.models import MLIPCalculator
|
18 |
+
from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
|
19 |
+
|
20 |
+
cluster_kwargs = {
|
21 |
+
"cores": 4,
|
22 |
+
"memory": "64 GB",
|
23 |
+
"shebang": "#!/bin/bash",
|
24 |
+
"account": "m3828",
|
25 |
+
"walltime": "00:10:00",
|
26 |
+
"job_mem": "0",
|
27 |
+
"job_script_prologue": ["source ~/.bashrc"],
|
28 |
+
"job_directives_skip": ["-n", "--cpus-per-task"],
|
29 |
+
"job_extra_directives": ["-q debug", "-C gpu"],
|
30 |
+
}
|
31 |
+
|
32 |
+
cluster = SLURMCluster(**cluster_kwargs)
|
33 |
+
cluster.scale(jobs=10)
|
34 |
+
client = Client(cluster)
|
35 |
+
|
36 |
+
|
37 |
+
@task(cache_key_fn=task_input_hash, cache_expiration=timedelta(hours=1))
|
38 |
+
def calculate_single_diatomic(
|
39 |
+
calculator_name: str | EXTMLIPEnum,
|
40 |
+
calculator_kwargs: dict | None,
|
41 |
+
atom1: str,
|
42 |
+
atom2: str,
|
43 |
+
rmin: float = 0.1,
|
44 |
+
rmax: float = 6.5,
|
45 |
+
npts: int = int(1e3),
|
46 |
+
):
|
47 |
+
|
48 |
+
calculator_kwargs = calculator_kwargs or {}
|
49 |
+
|
50 |
+
if isinstance(calculator_name, EXTMLIPEnum) and calculator_name in EXTMLIPEnum:
|
51 |
+
calc = external_ase_calculator(calculator_name, **calculator_kwargs)
|
52 |
+
elif calculator_name in MLIPMap:
|
53 |
+
calc = MLIPMap[calculator_name](**calculator_kwargs)
|
54 |
+
|
55 |
+
a = 2 * rmax
|
56 |
+
|
57 |
+
rs = np.linspace(rmin, rmax, npts)
|
58 |
+
e = np.zeros_like(rs)
|
59 |
+
f = np.zeros_like(rs)
|
60 |
+
|
61 |
+
da = atom1 + atom2
|
62 |
+
|
63 |
+
for i, r in enumerate(rs):
|
64 |
+
|
65 |
+
positions = [
|
66 |
+
[0, 0, 0],
|
67 |
+
[r, 0, 0],
|
68 |
+
]
|
69 |
+
|
70 |
+
# Create the unit cell with two atoms
|
71 |
+
atoms = Atoms(da, positions=positions, cell=[a, a, a])
|
72 |
+
|
73 |
+
atoms.calc = calc
|
74 |
+
|
75 |
+
e[i] = atoms.get_potential_energy()
|
76 |
+
f[i] = np.inner(np.array([1, 0, 0]), atoms.get_forces()[1])
|
77 |
+
|
78 |
+
return {"r": rs, "E": e, "F": f, "da": da}
|
79 |
+
|
80 |
+
|
81 |
+
@flow
|
82 |
+
def calculate_multiple_diatomics(calculator_name, calculator_kwargs):
|
83 |
+
|
84 |
+
futures = []
|
85 |
+
for symbol in chemical_symbols:
|
86 |
+
if symbol == "X":
|
87 |
+
continue
|
88 |
+
futures.append(
|
89 |
+
calculate_single_diatomic.submit(
|
90 |
+
calculator_name, calculator_kwargs, symbol, symbol
|
91 |
+
)
|
92 |
+
)
|
93 |
+
|
94 |
+
return [i for future in futures for i in future.result()]
|
95 |
+
|
96 |
+
|
97 |
+
@flow(task_runner=DaskTaskRunner(address=client.scheduler.address), log_prints=True)
|
98 |
+
def calculate_homonuclear_diatomics(calculator_name, calculator_kwargs):
|
99 |
+
|
100 |
+
curves = calculate_multiple_diatomics(calculator_name, calculator_kwargs)
|
101 |
+
|
102 |
+
pd.DataFrame(curves).to_csv(f"homonuclear-diatomics-{calculator_name}.csv")
|
103 |
+
|
104 |
+
|
105 |
+
# with plt.style.context("default"):
|
106 |
+
|
107 |
+
# SMALL_SIZE = 6
|
108 |
+
# MEDIUM_SIZE = 8
|
109 |
+
# LARGE_SIZE = 10
|
110 |
+
|
111 |
+
# LINE_WIDTH = 1
|
112 |
+
|
113 |
+
# plt.rcParams.update(
|
114 |
+
# {
|
115 |
+
# "pgf.texsystem": "pdflatex",
|
116 |
+
# "font.family": "sans-serif",
|
117 |
+
# "text.usetex": True,
|
118 |
+
# "pgf.rcfonts": True,
|
119 |
+
# "figure.constrained_layout.use": True,
|
120 |
+
# "axes.labelsize": MEDIUM_SIZE,
|
121 |
+
# "axes.titlesize": MEDIUM_SIZE,
|
122 |
+
# "legend.frameon": False,
|
123 |
+
# "legend.fontsize": MEDIUM_SIZE,
|
124 |
+
# "legend.loc": "best",
|
125 |
+
# "lines.linewidth": LINE_WIDTH,
|
126 |
+
# "xtick.labelsize": SMALL_SIZE,
|
127 |
+
# "ytick.labelsize": SMALL_SIZE,
|
128 |
+
# }
|
129 |
+
# )
|
130 |
+
|
131 |
+
# fig, ax = plt.subplots(layout="constrained", figsize=(3, 2), dpi=300)
|
132 |
+
|
133 |
+
# color = "tab:red"
|
134 |
+
# ax.plot(rs, e, color=color, zorder=1)
|
135 |
+
|
136 |
+
# ax.axhline(ls="--", color=color, alpha=0.5, lw=0.5 * LINE_WIDTH)
|
137 |
+
|
138 |
+
# ylo, yhi = ax.get_ylim()
|
139 |
+
# ax.set(xlabel=r"r [$\AA]$", ylim=(max(-7, ylo), min(5, yhi)))
|
140 |
+
# ax.set_ylabel(ylabel="E [eV]", color=color)
|
141 |
+
# ax.tick_params(axis="y", labelcolor=color)
|
142 |
+
# ax.text(0.8, 0.85, da, fontsize=LARGE_SIZE, transform=ax.transAxes)
|
143 |
+
|
144 |
+
# color = "tab:blue"
|
145 |
+
|
146 |
+
# at = ax.twinx()
|
147 |
+
# at.plot(rs, f, color=color, zorder=0, lw=0.5 * LINE_WIDTH)
|
148 |
+
|
149 |
+
# at.axhline(ls="--", color=color, alpha=0.5, lw=0.5 * LINE_WIDTH)
|
150 |
+
|
151 |
+
# ylo, yhi = at.get_ylim()
|
152 |
+
# at.set(
|
153 |
+
# xlabel=r"r [$\AA]$",
|
154 |
+
# ylim=(max(-20, ylo), min(20, yhi)),
|
155 |
+
# )
|
156 |
+
# at.set_ylabel(ylabel="F [eV/$\AA$]", color=color)
|
157 |
+
# at.tick_params(axis="y", labelcolor=color)
|
158 |
+
|
159 |
+
# plt.show()
|
160 |
+
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
calculate_homonuclear_diatomics(
|
164 |
+
EXTMLIPEnum.MACE, dict(model="medium", device="cuda")
|
165 |
+
)
|
mlip_arena/tasks/diatomics/vasp/run.ipynb
DELETED
@@ -1,509 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {
|
7 |
-
"tags": []
|
8 |
-
},
|
9 |
-
"outputs": [],
|
10 |
-
"source": [
|
11 |
-
"from datetime import timedelta\n",
|
12 |
-
"from typing import Union\n",
|
13 |
-
"import os\n",
|
14 |
-
"from pathlib import Path\n",
|
15 |
-
"from itertools import combinations, combinations_with_replacement\n",
|
16 |
-
"import psutil\n",
|
17 |
-
"import subprocess\n",
|
18 |
-
"\n",
|
19 |
-
"import numpy as np\n",
|
20 |
-
"import pandas as pd\n",
|
21 |
-
"import torch\n",
|
22 |
-
"from ase import Atoms, Atom\n",
|
23 |
-
"from ase.calculators.calculator import Calculator\n",
|
24 |
-
"from ase.calculators.singlepoint import SinglePointCalculator\n",
|
25 |
-
"from ase.calculators.vasp import Vasp\n",
|
26 |
-
"from ase.data import chemical_symbols, covalent_radii, vdw_alvarez\n",
|
27 |
-
"from ase.io import read, write\n",
|
28 |
-
"from dask.distributed import Client\n",
|
29 |
-
"from dask_jobqueue import SLURMCluster\n",
|
30 |
-
"from prefect import flow, task\n",
|
31 |
-
"from prefect.tasks import task_input_hash\n",
|
32 |
-
"from prefect_dask import DaskTaskRunner\n",
|
33 |
-
"\n",
|
34 |
-
"from pymatgen.core import Element\n",
|
35 |
-
"from pymatgen.io.ase import AseAtomsAdaptor\n",
|
36 |
-
"from pymatgen.io.vasp.inputs import Kpoints\n",
|
37 |
-
"from pymatgen.command_line.chargemol_caller import ChargemolAnalysis\n",
|
38 |
-
"\n",
|
39 |
-
"from mlip_arena.models import MLIPCalculator\n",
|
40 |
-
"from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator\n",
|
41 |
-
"from jobflow import run_locally\n",
|
42 |
-
"\n",
|
43 |
-
"from atomate2.vasp.jobs.mp import MPGGAStaticMaker\n",
|
44 |
-
"from atomate2.vasp.sets.mp import MPGGAStaticSetGenerator"
|
45 |
-
]
|
46 |
-
},
|
47 |
-
{
|
48 |
-
"cell_type": "code",
|
49 |
-
"execution_count": null,
|
50 |
-
"metadata": {
|
51 |
-
"scrolled": true,
|
52 |
-
"tags": []
|
53 |
-
},
|
54 |
-
"outputs": [],
|
55 |
-
"source": [
|
56 |
-
"\n",
|
57 |
-
"nodes_per_alloc = 1\n",
|
58 |
-
"cpus_per_task = 16\n",
|
59 |
-
"gpus_per_node = 1\n",
|
60 |
-
"# tasks_per_node = int(128/4)\n",
|
61 |
-
"ntasks = 1\n",
|
62 |
-
"\n",
|
63 |
-
"\n",
|
64 |
-
"cluster_kwargs = dict(\n",
|
65 |
-
" cores=1,\n",
|
66 |
-
" memory=\"64 GB\",\n",
|
67 |
-
" shebang=\"#!/bin/bash\",\n",
|
68 |
-
" account=\"m3828\",\n",
|
69 |
-
" walltime=\"01:00:00\",\n",
|
70 |
-
" # processes=16,\n",
|
71 |
-
" # nanny=True,\n",
|
72 |
-
" job_mem=\"0\",\n",
|
73 |
-
" job_script_prologue=[\n",
|
74 |
-
" \"source ~/.bashrc\",\n",
|
75 |
-
" \"module load python\",\n",
|
76 |
-
" \"source activate /pscratch/sd/c/cyrisinstance.conda/mlip-arena\",\n",
|
77 |
-
" \"module load vasp/6.4.1-gpu\",\n",
|
78 |
-
" \n",
|
79 |
-
" \"export DDEC6_ATOMIC_DENSITIES_DIR='/global/homes/c/cyrusyc/chargemol/atomic_densities/'\",\n",
|
80 |
-
" f\"export OMP_NUM_THREADS={gpus_per_node}\",\n",
|
81 |
-
" \"export OMP_PLACES=threads\",\n",
|
82 |
-
" \"export OMP_PROC_BIND=spread\",\n",
|
83 |
-
" f\"export ASE_VASP_COMMAND='srun -n {ntasks} -c {4*gpus_per_node} --cpu-bind=cores --gpus-per-node {gpus_per_node} vasp_ncl'\"\n",
|
84 |
-
" # f\"export ASE_VASP_COMMAND='srun -N {nodes_per_alloc} --ntasks-per-node={tasks_per_node} -c {cpus_per_task} --cpu-bind=cores vasp_std'\"\n",
|
85 |
-
" # \"export ATOMATE2_CONFIG_FILE='/global/homes/c/cyrusyc/atomate2/config/atomate2-prefect-cpu-node.yaml'\"\n",
|
86 |
-
" ],\n",
|
87 |
-
" job_directives_skip=[\"-n\", \"--cpus-per-task\", \"-J\"],\n",
|
88 |
-
" job_extra_directives=[\n",
|
89 |
-
" \"-J diatomics\",\n",
|
90 |
-
" \"-q regular\",\n",
|
91 |
-
" f\"-N {nodes_per_alloc}\",\n",
|
92 |
-
" \"-C gpu\",\n",
|
93 |
-
" # \"-n 1\",\n",
|
94 |
-
" # \"-c 16\",\n",
|
95 |
-
" # \"--gpus-per-task=1\",\n",
|
96 |
-
" # \"--threads-per-task=1\",\n",
|
97 |
-
" # \"--gpu-bind=single:1\",\n",
|
98 |
-
" # \"--comment=00:20:00\",\n",
|
99 |
-
" # \"--time-min=00:05:00\",\n",
|
100 |
-
" # \"--signal=B:USR1@60\",\n",
|
101 |
-
" # \"--requeue\",\n",
|
102 |
-
" # \"--open-mode=append\",\n",
|
103 |
-
" # \"--mail-type=end,requeue\",\n",
|
104 |
-
" # \"[email protected]\",\n",
|
105 |
-
" ],\n",
|
106 |
-
" # python=\"srun python\",\n",
|
107 |
-
" death_timeout=86400, #float('inf')\n",
|
108 |
-
")\n",
|
109 |
-
"\n",
|
110 |
-
"\n",
|
111 |
-
"cluster = SLURMCluster(**cluster_kwargs)\n",
|
112 |
-
"print(cluster.job_script())\n",
|
113 |
-
"# cluster.scale(3)\n",
|
114 |
-
"cluster.adapt(minimum_jobs=50, maximum_jobs=100)\n",
|
115 |
-
"client = Client(cluster)"
|
116 |
-
]
|
117 |
-
},
|
118 |
-
{
|
119 |
-
"cell_type": "code",
|
120 |
-
"execution_count": null,
|
121 |
-
"metadata": {
|
122 |
-
"tags": []
|
123 |
-
},
|
124 |
-
"outputs": [],
|
125 |
-
"source": [
|
126 |
-
"\n",
|
127 |
-
"\n",
|
128 |
-
"@task(cache_key_fn=task_input_hash, cache_expiration=timedelta(hours=24), log_prints=True)\n",
|
129 |
-
"def calculate_single_diatomic(\n",
|
130 |
-
" calculator: str | EXTMLIPEnum | Calculator,\n",
|
131 |
-
" calculator_kwargs: dict | None,\n",
|
132 |
-
" atom1: str,\n",
|
133 |
-
" atom2: str,\n",
|
134 |
-
" rmin: float = 1.25,\n",
|
135 |
-
" rmax: float = 6.25,\n",
|
136 |
-
" rstep: float = 0.2,\n",
|
137 |
-
" magnetism: str = \"FM\"\n",
|
138 |
-
"):\n",
|
139 |
-
"\n",
|
140 |
-
" calculator_kwargs = calculator_kwargs or {}\n",
|
141 |
-
"\n",
|
142 |
-
" if isinstance(calculator, str) and calculator.lower() == 'vasp-mp-gga':\n",
|
143 |
-
" calc = Vasp(**calculator_kwargs)\n",
|
144 |
-
" calc.name = 'vasp-mp-gga'\n",
|
145 |
-
" # calc.name='atomate2'\n",
|
146 |
-
" elif isinstance(calculator, EXTMLIPEnum) and calculator in EXTMLIPEnum:\n",
|
147 |
-
" calc = external_ase_calculator(calculator, **calculator_kwargs)\n",
|
148 |
-
" elif calculator in MLIPMap:\n",
|
149 |
-
" calc = MLIPMap[calculator](**calculator_kwargs)\n",
|
150 |
-
" elif issubclass(calculator, Calculator):\n",
|
151 |
-
" calc = calculator(**calculator_kwargs)\n",
|
152 |
-
"\n",
|
153 |
-
" a = 2 * rmax\n",
|
154 |
-
"\n",
|
155 |
-
" npts = int((rmax - rmin)/rstep)\n",
|
156 |
-
"\n",
|
157 |
-
" rs = np.linspace(rmin, rmax, npts)\n",
|
158 |
-
" e = np.zeros_like(rs)\n",
|
159 |
-
" f = np.zeros_like(rs)\n",
|
160 |
-
"\n",
|
161 |
-
" da = atom1 + atom2\n",
|
162 |
-
" \n",
|
163 |
-
" assert isinstance(calc, Calculator)\n",
|
164 |
-
" \n",
|
165 |
-
" out_dir = Path(str(da + f\"_{magnetism}\"))\n",
|
166 |
-
" os.makedirs(out_dir, exist_ok=True)\n",
|
167 |
-
" \n",
|
168 |
-
" calc.directory = out_dir\n",
|
169 |
-
" \n",
|
170 |
-
" print(f\"write output to {calc.directory}\")\n",
|
171 |
-
" \n",
|
172 |
-
" element = Element(atom1)\n",
|
173 |
-
" \n",
|
174 |
-
" try:\n",
|
175 |
-
" m = element.valence[1]\n",
|
176 |
-
" if element.valence == (0, 2):\n",
|
177 |
-
" m = 0\n",
|
178 |
-
" except:\n",
|
179 |
-
" m = 0\n",
|
180 |
-
" \n",
|
181 |
-
" r = rs[0]\n",
|
182 |
-
" \n",
|
183 |
-
" positions = [\n",
|
184 |
-
" [a/2-r/2, a/2, a/2],\n",
|
185 |
-
" [a/2+r/2, a/2, a/2],\n",
|
186 |
-
" ]\n",
|
187 |
-
" \n",
|
188 |
-
" if magnetism == 'FM':\n",
|
189 |
-
" if m == 0:\n",
|
190 |
-
" return {}\n",
|
191 |
-
" magmoms = [m, m]\n",
|
192 |
-
" elif magnetism == 'AFM':\n",
|
193 |
-
" if m == 0:\n",
|
194 |
-
" return {}\n",
|
195 |
-
" magmoms = [m, -m]\n",
|
196 |
-
" elif magnetism == 'NM':\n",
|
197 |
-
" magmoms = [0, 0]\n",
|
198 |
-
" \n",
|
199 |
-
" traj_fpath = out_dir / \"traj.extxyz\"\n",
|
200 |
-
" \n",
|
201 |
-
" skip = 0\n",
|
202 |
-
" if traj_fpath.exists():\n",
|
203 |
-
" traj = read(traj_fpath, index=\":\")\n",
|
204 |
-
" skip = len(traj)\n",
|
205 |
-
" atoms = traj[-1]\n",
|
206 |
-
" else:\n",
|
207 |
-
" atoms = Atoms(\n",
|
208 |
-
" da, \n",
|
209 |
-
" positions=positions,\n",
|
210 |
-
" magmoms=magmoms,\n",
|
211 |
-
" cell=[a, a+0.001, a+0.002], \n",
|
212 |
-
" pbc=True\n",
|
213 |
-
" )\n",
|
214 |
-
" \n",
|
215 |
-
" # \n",
|
216 |
-
" \n",
|
217 |
-
" structure = AseAtomsAdaptor().get_structure(atoms)\n",
|
218 |
-
" \n",
|
219 |
-
" if magnetism == 'FM':\n",
|
220 |
-
" I_CONSTRAINED_M = 2\n",
|
221 |
-
" LAMBDA = 10\n",
|
222 |
-
" M_CONSTR = [0, 0, 1, 0, 0, 1] # \" \".join(map(str, [0, 0, 1])) + \" \" + \" \".join(map(str, [0, 0, 1]))\n",
|
223 |
-
" elif magnetism == 'AFM':\n",
|
224 |
-
" I_CONSTRAINED_M = 2\n",
|
225 |
-
" LAMBDA = 10\n",
|
226 |
-
" M_CONSTR = [0, 0, 1, 0, 0, -1] # \" \".join(map(str, [0, 0, 1])) + \" \" + \" \".join(map(str, [0, 0, -1]))\n",
|
227 |
-
" elif magnetism == 'NM':\n",
|
228 |
-
" I_CONSTRAINED_M = 1\n",
|
229 |
-
" LAMBDA = 10\n",
|
230 |
-
" M_CONSTR = [0, 0, 0, 0, 0, 0] #\" \".join(map(str, [0, 0, 0])) + \" \" + \" \".join(map(str, [0, 0, 0]))\n",
|
231 |
-
"\n",
|
232 |
-
" input_set_generator = MPGGAStaticSetGenerator(\n",
|
233 |
-
" user_incar_settings=dict(\n",
|
234 |
-
" ISYM = 0, # symmetry is off\n",
|
235 |
-
" ISPIN = 2,\n",
|
236 |
-
" ISMEAR = 0, # Gaussian smearing, otherwise negative occupancies might come up\n",
|
237 |
-
" SIGMA = 0.002, # tiny smearing width to safely break symmetry\n",
|
238 |
-
" AMIX = 0.2, # mixing set manually\n",
|
239 |
-
" BMIX = 0.0001,\n",
|
240 |
-
" LSUBROT= True, # spin orbit coupling (non collinear)\n",
|
241 |
-
" ALGO = \"Accurate\",\n",
|
242 |
-
" PREC = \"High\",\n",
|
243 |
-
" ENCUT = 1000,\n",
|
244 |
-
" ENAUG = 2000,\n",
|
245 |
-
" ISTART = 1,\n",
|
246 |
-
" ICHARG = 1,\n",
|
247 |
-
" NELM = 200,\n",
|
248 |
-
" TIME = 0.2,\n",
|
249 |
-
" LELF = False,\n",
|
250 |
-
" LMAXMIX=max(max(map(lambda a: \"spdf\".index(Element(a.symbol).block) * 2, atoms)), 2),\n",
|
251 |
-
" LMIXTAU=False,\n",
|
252 |
-
" VOSKOWN = 1,\n",
|
253 |
-
" I_CONSTRAINED_M = I_CONSTRAINED_M,\n",
|
254 |
-
" M_CONSTR = M_CONSTR,\n",
|
255 |
-
" LAMBDA = LAMBDA,\n",
|
256 |
-
" # performance\n",
|
257 |
-
" # lplane=False,\n",
|
258 |
-
" # npar=int(sqrt(ncpus)),\n",
|
259 |
-
" # nsim=1,\n",
|
260 |
-
" # LPLANE = True,\n",
|
261 |
-
" # # NCORE = 128,\n",
|
262 |
-
" # LSCALU = False,\n",
|
263 |
-
" # NSIM = 4,\n",
|
264 |
-
" # LPLANE = False,\n",
|
265 |
-
" # NPAR = 16,\n",
|
266 |
-
" # NSIM = 1,\n",
|
267 |
-
" # LSCALU = False,\n",
|
268 |
-
" # GPU\n",
|
269 |
-
" KPAR = gpus_per_node,\n",
|
270 |
-
" NSIM = 64,\n",
|
271 |
-
"\n",
|
272 |
-
" LVTOT = False,\n",
|
273 |
-
" LAECHG = True, # AECCARs\n",
|
274 |
-
" LASPH = True, # aspherical charge density\n",
|
275 |
-
" LCHARG = True, # CHGCAR\n",
|
276 |
-
" LWAVE = True\n",
|
277 |
-
" ),\n",
|
278 |
-
" user_kpoints_settings=Kpoints(), # Gamma point only\n",
|
279 |
-
" user_potcar_settings={\n",
|
280 |
-
" \"Yb\": \"Yb_3\"\n",
|
281 |
-
" },\n",
|
282 |
-
" sort_structure=False\n",
|
283 |
-
" )\n",
|
284 |
-
"\n",
|
285 |
-
" vis = input_set_generator.get_input_set(structure=structure)\n",
|
286 |
-
" vis.incar.pop(\"MAGMOM\")\n",
|
287 |
-
"\n",
|
288 |
-
" incar = {key.lower(): value for key, value in vis.incar.items()} \n",
|
289 |
-
" calc.set(kpts=1, gamma=True, **incar)\n",
|
290 |
-
" \n",
|
291 |
-
" atoms.calc = calc\n",
|
292 |
-
"\n",
|
293 |
-
" for i, r in enumerate(np.flip(rs)):\n",
|
294 |
-
"\n",
|
295 |
-
" \n",
|
296 |
-
" if i < skip:\n",
|
297 |
-
" continue\n",
|
298 |
-
"\n",
|
299 |
-
" positions = [\n",
|
300 |
-
" [a/2-r/2, a/2, a/2],\n",
|
301 |
-
" [a/2+r/2, a/2, a/2],\n",
|
302 |
-
" ]\n",
|
303 |
-
" \n",
|
304 |
-
" if i > 0: \n",
|
305 |
-
" magmoms = atoms.get_magnetic_moments()\n",
|
306 |
-
" \n",
|
307 |
-
" atoms.set_initial_magnetic_moments(magmoms)\n",
|
308 |
-
" atoms.set_positions(positions)\n",
|
309 |
-
"\n",
|
310 |
-
" print(f\"{atoms} separated by {r} A ({i+1}/{len(rs)})\")\n",
|
311 |
-
" \n",
|
312 |
-
"\n",
|
313 |
-
" e[i] = atoms.get_potential_energy()\n",
|
314 |
-
" f[i] = np.inner(np.array([1, 0, 0]), atoms.get_forces()[1])\n",
|
315 |
-
" \n",
|
316 |
-
" atoms.calc.results.update(dict(\n",
|
317 |
-
" magmoms=atoms.get_magnetic_moments()\n",
|
318 |
-
" ))\n",
|
319 |
-
" \n",
|
320 |
-
" write(out_dir / \"traj.extxyz\", atoms, append=\"a\")\n",
|
321 |
-
" \n",
|
322 |
-
"# additional_results = {}\n",
|
323 |
-
" \n",
|
324 |
-
"# try:\n",
|
325 |
-
"# ncpus = psutil.cpu_count(logical=True)\n",
|
326 |
-
"# nthreads = os.environ[\"OMP_NUM_THREADS\"]\n",
|
327 |
-
"# subprocess.run([\"export\", f\"OMP_NUM_THREADS={ncpus}\"], shell=True)\n",
|
328 |
-
" \n",
|
329 |
-
"# ca = ChargemolAnalysis(path=out_dir)\n",
|
330 |
-
" \n",
|
331 |
-
"# if charges := ca.ddec_charges:\n",
|
332 |
-
"# additional_results[\"charges\"] = np.array(charges)\n",
|
333 |
-
"# if dipoles := ca.dipoles:\n",
|
334 |
-
"# additional_results[\"dipoles\"] = np.array(dipoles)\n",
|
335 |
-
"# if magmoms := ca.ddec_spin_moments:\n",
|
336 |
-
"# additional_results[\"magmoms\"] = np.array(magmoms)\n",
|
337 |
-
" \n",
|
338 |
-
"# subprocess.run([\"export\", f\"OMP_NUM_THREADS={nthreads}\"], shell=True)\n",
|
339 |
-
"# except:\n",
|
340 |
-
"# print(\"DDEC failed\")\n",
|
341 |
-
" \n",
|
342 |
-
" \n",
|
343 |
-
"# atoms.calc.results.update(additional_results)\n",
|
344 |
-
" \n",
|
345 |
-
" return {\"r\": rs, \"E\": e, \"F\": f, \"da\": da}\n",
|
346 |
-
"\n"
|
347 |
-
]
|
348 |
-
},
|
349 |
-
{
|
350 |
-
"cell_type": "code",
|
351 |
-
"execution_count": null,
|
352 |
-
"metadata": {
|
353 |
-
"tags": []
|
354 |
-
},
|
355 |
-
"outputs": [],
|
356 |
-
"source": [
|
357 |
-
"@flow(task_runner=DaskTaskRunner(address=client.scheduler.address), log_prints=True)\n",
|
358 |
-
"def calculate_multiple_diatomics(calculator_name, calculator_kwargs):\n",
|
359 |
-
"\n",
|
360 |
-
" futures = []\n",
|
361 |
-
" for sa in chemical_symbols:\n",
|
362 |
-
" \n",
|
363 |
-
" s = set([sa])\n",
|
364 |
-
" \n",
|
365 |
-
" if 'X' in s:\n",
|
366 |
-
" continue\n",
|
367 |
-
" \n",
|
368 |
-
" atom = Atom(sa)\n",
|
369 |
-
" rmin = covalent_radii[atom.number] * 2 * 0.6\n",
|
370 |
-
" rvdw = vdw_alvarez.vdw_radii[atom.number] if atom.number < len(vdw_alvarez.vdw_radii) else np.nan \n",
|
371 |
-
" rmax = 3.1 * rvdw if not np.isnan(rvdw) else 6\n",
|
372 |
-
" rstep = 0.2 #if rmin < 1 else 0.4\n",
|
373 |
-
" \n",
|
374 |
-
" futures.append(\n",
|
375 |
-
" calculate_single_diatomic.submit(\n",
|
376 |
-
" calculator_name, calculator_kwargs, sa, sa,\n",
|
377 |
-
" rmin=rmin, rmax=rmax,\n",
|
378 |
-
" rstep=rstep,\n",
|
379 |
-
" magnetism=\"FM\"\n",
|
380 |
-
" # npts=16 if 'H' in s else 21\n",
|
381 |
-
" )\n",
|
382 |
-
" )\n",
|
383 |
-
" futures.append(\n",
|
384 |
-
" calculate_single_diatomic.submit(\n",
|
385 |
-
" calculator_name, calculator_kwargs, sa, sa,\n",
|
386 |
-
" rmin=rmin, rmax=rmax,\n",
|
387 |
-
" rstep=rstep, #0.1 if rmin < 1 else 0.25,\n",
|
388 |
-
" magnetism=\"AFM\"\n",
|
389 |
-
" # npts=16 if 'H' in s else 21\n",
|
390 |
-
" )\n",
|
391 |
-
" )\n",
|
392 |
-
" futures.append(\n",
|
393 |
-
" calculate_single_diatomic.submit(\n",
|
394 |
-
" calculator_name, calculator_kwargs, sa, sa,\n",
|
395 |
-
" rmin=rmin, rmax=rmax,\n",
|
396 |
-
" rstep=rstep, #0.1 if rmin < 1 else 0.25,\n",
|
397 |
-
" magnetism=\"NM\"\n",
|
398 |
-
" # npts=16 if 'H' in s else 21\n",
|
399 |
-
" )\n",
|
400 |
-
" )\n",
|
401 |
-
"# for sa, sb in combinations_with_replacement(chemical_symbols, 2):\n",
|
402 |
-
" \n",
|
403 |
-
"# if 'X' in set([sa, sb]):\n",
|
404 |
-
"# continue\n",
|
405 |
-
" \n",
|
406 |
-
"# futures.append(\n",
|
407 |
-
"# calculate_single_diatomic.submit(\n",
|
408 |
-
"# calculator_name, calculator_kwargs, sa, sb,\n",
|
409 |
-
"# rmin=0.5, rmax=6.5,\n",
|
410 |
-
"# npts=16\n",
|
411 |
-
"# )\n",
|
412 |
-
"# )\n",
|
413 |
-
"\n",
|
414 |
-
" return [i for future in futures for i in future.result()]\n",
|
415 |
-
"\n"
|
416 |
-
]
|
417 |
-
},
|
418 |
-
{
|
419 |
-
"cell_type": "code",
|
420 |
-
"execution_count": null,
|
421 |
-
"metadata": {
|
422 |
-
"scrolled": true,
|
423 |
-
"tags": []
|
424 |
-
},
|
425 |
-
"outputs": [],
|
426 |
-
"source": [
|
427 |
-
"calculate_multiple_diatomics(\n",
|
428 |
-
" \"vasp-mp-gga\", \n",
|
429 |
-
" dict(\n",
|
430 |
-
" xc=\"pbe\",\n",
|
431 |
-
" kpts=1,\n",
|
432 |
-
" # Massively parallel machines (Cray)\n",
|
433 |
-
" # lplane=False,\n",
|
434 |
-
" # npar=int(sqrt(ncpus)),\n",
|
435 |
-
" # nsim=1\n",
|
436 |
-
" # Multicore modern linux machines\n",
|
437 |
-
" # lplane=True,\n",
|
438 |
-
" # npar=2,\n",
|
439 |
-
" # lscalu=False,\n",
|
440 |
-
" # nsim=4\n",
|
441 |
-
" )\n",
|
442 |
-
")\n"
|
443 |
-
]
|
444 |
-
},
|
445 |
-
{
|
446 |
-
"cell_type": "code",
|
447 |
-
"execution_count": null,
|
448 |
-
"metadata": {
|
449 |
-
"scrolled": true,
|
450 |
-
"tags": []
|
451 |
-
},
|
452 |
-
"outputs": [],
|
453 |
-
"source": [
|
454 |
-
"\n",
|
455 |
-
"calculate_homonuclear_diatomics(\n",
|
456 |
-
" \"vasp-mp-gga\", \n",
|
457 |
-
" dict(\n",
|
458 |
-
" xc=\"pbe\",\n",
|
459 |
-
" kpts=1,\n",
|
460 |
-
" # Massively parallel machines (Cray)\n",
|
461 |
-
" # lplane=False,\n",
|
462 |
-
" # npar=int(sqrt(ncpus)),\n",
|
463 |
-
" # nsim=1\n",
|
464 |
-
" # Multicore modern linux machines\n",
|
465 |
-
" # lplane=True,\n",
|
466 |
-
" # npar=2,\n",
|
467 |
-
" # lscalu=False,\n",
|
468 |
-
" # nsim=4\n",
|
469 |
-
" )\n",
|
470 |
-
")\n"
|
471 |
-
]
|
472 |
-
},
|
473 |
-
{
|
474 |
-
"cell_type": "code",
|
475 |
-
"execution_count": null,
|
476 |
-
"metadata": {},
|
477 |
-
"outputs": [],
|
478 |
-
"source": []
|
479 |
-
}
|
480 |
-
],
|
481 |
-
"metadata": {
|
482 |
-
"kernelspec": {
|
483 |
-
"display_name": "mlip-arena",
|
484 |
-
"language": "python",
|
485 |
-
"name": "mlip-arena"
|
486 |
-
},
|
487 |
-
"language_info": {
|
488 |
-
"codemirror_mode": {
|
489 |
-
"name": "ipython",
|
490 |
-
"version": 3
|
491 |
-
},
|
492 |
-
"file_extension": ".py",
|
493 |
-
"mimetype": "text/x-python",
|
494 |
-
"name": "python",
|
495 |
-
"nbconvert_exporter": "python",
|
496 |
-
"pygments_lexer": "ipython3",
|
497 |
-
"version": "3.11.8"
|
498 |
-
},
|
499 |
-
"widgets": {
|
500 |
-
"application/vnd.jupyter.widget-state+json": {
|
501 |
-
"state": {},
|
502 |
-
"version_major": 2,
|
503 |
-
"version_minor": 0
|
504 |
-
}
|
505 |
-
}
|
506 |
-
},
|
507 |
-
"nbformat": 4,
|
508 |
-
"nbformat_minor": 4
|
509 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mlip_arena/tasks/stability/__init__.py
ADDED
File without changes
|
mlip_arena/tasks/stability/run.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from mlip_arena.tasks.utils import _valid_dynamics, _preset_dynamics
|
3 |
+
|
mlip_arena/tasks/utils.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob
|
2 |
+
from pathlib import Path
|
3 |
+
from ase.io import read, write
|
4 |
+
from ase import units
|
5 |
+
from ase import Atoms, units
|
6 |
+
from ase.calculators.calculator import Calculator
|
7 |
+
from ase.data import chemical_symbols
|
8 |
+
from ase.md.andersen import Andersen
|
9 |
+
from ase.md.langevin import Langevin
|
10 |
+
from ase.md.md import MolecularDynamics
|
11 |
+
from ase.md.npt import NPT
|
12 |
+
from ase.md.nptberendsen import NPTBerendsen
|
13 |
+
from ase.md.nvtberendsen import NVTBerendsen
|
14 |
+
from ase.md.velocitydistribution import (
|
15 |
+
MaxwellBoltzmannDistribution,
|
16 |
+
Stationary,
|
17 |
+
ZeroRotation,
|
18 |
+
)
|
19 |
+
from ase.md.verlet import VelocityVerlet
|
20 |
+
from dask.distributed import Client
|
21 |
+
from dask_jobqueue import SLURMCluster
|
22 |
+
from jobflow import Maker
|
23 |
+
from prefect import flow, task
|
24 |
+
from prefect.tasks import task_input_hash
|
25 |
+
from prefect_dask import DaskTaskRunner
|
26 |
+
from pymatgen.io.ase import AseAtomsAdaptor
|
27 |
+
from scipy.interpolate import interp1d
|
28 |
+
from scipy.linalg import schur
|
29 |
+
|
30 |
+
from mlip_arena.models import MLIPCalculator
|
31 |
+
from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
|
32 |
+
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
33 |
+
from mp_api.client import MPRester
|
34 |
+
|
35 |
+
from fireworks import LaunchPad
|
36 |
+
from atomate2.vasp.flows.core import RelaxBandStructureMaker
|
37 |
+
from atomate2.vasp.flows.mp import MPGGADoubleRelaxStaticMaker
|
38 |
+
from atomate2.vasp.powerups import add_metadata_to_flow
|
39 |
+
from atomate2.forcefields.md import (
|
40 |
+
CHGNetMDMaker,
|
41 |
+
GAPMDMaker,
|
42 |
+
M3GNetMDMaker,
|
43 |
+
MACEMDMaker,
|
44 |
+
NequipMDMaker,
|
45 |
+
)
|
46 |
+
from atomate2.forcefields.utils import MLFF
|
47 |
+
from pymatgen.io.ase import AseAtomsAdaptor
|
48 |
+
from pymatgen.transformations.advanced_transformations import CubicSupercellTransformation
|
49 |
+
from jobflow.managers.fireworks import flow_to_workflow
|
50 |
+
from jobflow import run_locally, SETTINGS
|
51 |
+
from tqdm.auto import tqdm
|
52 |
+
|
53 |
+
from datetime import timedelta, datetime
|
54 |
+
from typing import Literal, Sequence, Tuple
|
55 |
+
|
56 |
+
import numpy as np
|
57 |
+
import torch
|
58 |
+
from pymatgen.core.structure import Structure
|
59 |
+
|
60 |
+
from ase.calculators.mixing import SumCalculator
|
61 |
+
from scipy.interpolate import interp1d
|
62 |
+
|
63 |
+
from ase.io.trajectory import Trajectory
|
64 |
+
|
65 |
+
|
66 |
+
_valid_dynamics: dict[str, tuple[str, ...]] = {
|
67 |
+
"nve": ("velocityverlet",),
|
68 |
+
"nvt": ("nose-hoover", "langevin", "andersen", "berendsen"),
|
69 |
+
"npt": ("nose-hoover", "berendsen"),
|
70 |
+
}
|
71 |
+
|
72 |
+
_preset_dynamics: dict = {
|
73 |
+
"nve_velocityverlet": VelocityVerlet,
|
74 |
+
"nvt_andersen": Andersen,
|
75 |
+
"nvt_berendsen": NVTBerendsen,
|
76 |
+
"nvt_langevin": Langevin,
|
77 |
+
"nvt_nose-hoover": NPT,
|
78 |
+
"npt_berendsen": NPTBerendsen,
|
79 |
+
"npt_nose-hoover": NPT,
|
80 |
+
}
|
81 |
+
|
82 |
+
def _interpolate_quantity(values: Sequence | np.ndarray, n_pts: int) -> np.ndarray:
|
83 |
+
"""Interpolate temperature / pressure on a schedule."""
|
84 |
+
n_vals = len(values)
|
85 |
+
return np.interp(
|
86 |
+
np.linspace(0, n_vals - 1, n_pts + 1),
|
87 |
+
np.linspace(0, n_vals - 1, n_vals),
|
88 |
+
values,
|
89 |
+
)
|
90 |
+
|
91 |
+
def _get_ensemble_schedule(
|
92 |
+
ensemble: Literal["nve", "nvt", "npt"] = "nvt",
|
93 |
+
n_steps: int = 1000,
|
94 |
+
temperature: float | Sequence | np.ndarray | None = 300.0,
|
95 |
+
pressure: float | Sequence | np.ndarray | None = None
|
96 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
97 |
+
if ensemble == "nve":
|
98 |
+
# Disable thermostat and barostat
|
99 |
+
temperature = np.nan
|
100 |
+
pressure = np.nan
|
101 |
+
t_schedule = np.full(n_steps + 1, temperature)
|
102 |
+
p_schedule = np.full(n_steps + 1, pressure)
|
103 |
+
return t_schedule, p_schedule
|
104 |
+
|
105 |
+
if isinstance(temperature, Sequence) or (
|
106 |
+
isinstance(temperature, np.ndarray) and temperature.ndim == 1
|
107 |
+
):
|
108 |
+
t_schedule = _interpolate_quantity(temperature, n_steps)
|
109 |
+
# NOTE: In ASE Langevin dynamics, the temperature are normally
|
110 |
+
# scalars, but in principle one quantity per atom could be specified by giving
|
111 |
+
# an array. This is not implemented yet here.
|
112 |
+
else:
|
113 |
+
t_schedule = np.full(n_steps + 1, temperature)
|
114 |
+
|
115 |
+
if ensemble == "nvt":
|
116 |
+
pressure = np.nan
|
117 |
+
p_schedule = np.full(n_steps + 1, pressure)
|
118 |
+
return t_schedule, p_schedule
|
119 |
+
|
120 |
+
if isinstance(pressure, Sequence) or (
|
121 |
+
isinstance(pressure, np.ndarray) and pressure.ndim == 1
|
122 |
+
):
|
123 |
+
p_schedule = _interpolate_quantity(pressure, n_steps)
|
124 |
+
elif isinstance(pressure, np.ndarray) and pressure.ndim == 4:
|
125 |
+
p_schedule = interp1d(
|
126 |
+
np.arange(n_steps + 1), pressure, kind="linear"
|
127 |
+
)
|
128 |
+
assert isinstance(p_schedule, np.ndarray)
|
129 |
+
else:
|
130 |
+
p_schedule = np.full(n_steps + 1, pressure)
|
131 |
+
|
132 |
+
return t_schedule, p_schedule
|
133 |
+
|
134 |
+
def _get_ensemble_defaults(
|
135 |
+
ensemble: Literal["nve", "nvt", "npt"],
|
136 |
+
dynamics: str | MolecularDynamics,
|
137 |
+
t_schedule: np.ndarray,
|
138 |
+
p_schedule: np.ndarray,
|
139 |
+
ase_md_kwargs: dict | None = None) -> dict:
|
140 |
+
"""Update ASE MD kwargs"""
|
141 |
+
ase_md_kwargs = ase_md_kwargs or {}
|
142 |
+
|
143 |
+
if ensemble == "nve":
|
144 |
+
ase_md_kwargs.pop("temperature", None)
|
145 |
+
ase_md_kwargs.pop("temperature_K", None)
|
146 |
+
ase_md_kwargs.pop("externalstress", None)
|
147 |
+
elif ensemble == "nvt":
|
148 |
+
ase_md_kwargs["temperature_K"] = t_schedule[0]
|
149 |
+
ase_md_kwargs.pop("externalstress", None)
|
150 |
+
elif ensemble == "npt":
|
151 |
+
ase_md_kwargs["temperature_K"] = t_schedule[0]
|
152 |
+
ase_md_kwargs["externalstress"] = p_schedule[0] * 1e3 * units.bar
|
153 |
+
|
154 |
+
if isinstance(dynamics, str) and dynamics.lower() == "langevin":
|
155 |
+
ase_md_kwargs["friction"] = ase_md_kwargs.get(
|
156 |
+
"friction",
|
157 |
+
10.0 * 1e-3 / units.fs, # Same default as in VASP: 10 ps^-1
|
158 |
+
)
|
159 |
+
|
160 |
+
return ase_md_kwargs
|
161 |
+
|
pyproject.toml
CHANGED
@@ -46,6 +46,7 @@ Issues = "https://github.com/atomind-ai/mlip-arena/issues"
|
|
46 |
|
47 |
[tool.ruff]
|
48 |
# Exclude a variety of commonly ignored directories.
|
|
|
49 |
exclude = [
|
50 |
".bzr",
|
51 |
".direnv",
|
|
|
46 |
|
47 |
[tool.ruff]
|
48 |
# Exclude a variety of commonly ignored directories.
|
49 |
+
extend-include = ["*.ipynb"]
|
50 |
exclude = [
|
51 |
".bzr",
|
52 |
".direnv",
|
scripts/install-pyg.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
# PyTorch Geometric (OCP)
|
4 |
+
TORCH=2.2.0
|
5 |
+
CUDA=cu121
|
6 |
+
|
7 |
+
pip install --verbose --no-cache torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
|
8 |
+
pip install --verbose --no-cache torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
|
9 |
+
|
10 |
+
# DGL (M3GNet)
|
11 |
+
pip install --verbose --no-cache dgl -f https://data.dgl.ai/wheels/{CUDA}/repo.html
|
12 |
+
|
13 |
+
|
14 |
+
# DGL (ALIGNN)
|
15 |
+
# pip install --verbose --no-cache dgl -f https://data.dgl.ai/wheels/torch-2.2/cu122/repo.html
|
serve/models/alerts.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
st.markdown("# Alerts")
|
serve/models/bugs.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
st.markdown("# Bugs")
|
serve/tasks/homonuclear-diatomics.py
CHANGED
@@ -136,7 +136,7 @@ for i, symbol in enumerate(chemical_symbols[1:]):
|
|
136 |
)
|
137 |
|
138 |
# Set x-axis title
|
139 |
-
fig.update_xaxes(title_text="Bond length
|
140 |
|
141 |
# Set y-axes titles
|
142 |
if energy_plot:
|
|
|
136 |
)
|
137 |
|
138 |
# Set x-axis title
|
139 |
+
fig.update_xaxes(title_text="Bond length [Å]")
|
140 |
|
141 |
# Set y-axes titles
|
142 |
if energy_plot:
|