cyrusyc commited on
Commit
9d1a2a5
1 Parent(s): 68aa5f5

reform scaffold

Browse files
.gitignore CHANGED
@@ -1,5 +1,7 @@
1
  tests/
2
  *.out
 
 
3
  mlip_arena/tasks/*/*/*/
4
 
5
  # Byte-compiled / optimized / DLL files
 
1
  tests/
2
  *.out
3
+ *.ipynb
4
+ *.extxyz
5
  mlip_arena/tasks/*/*/*/
6
 
7
  # Byte-compiled / optimized / DLL files
environment.yml ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ channels:
2
+ - defaults
3
+ - conda-forge
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - bzip2=1.0.8=h5eee18b_5
8
+ - ca-certificates=2024.3.11=h06a4308_0
9
+ - ld_impl_linux-64=2.38=h1181459_1
10
+ - libffi=3.4.4=h6a678d5_0
11
+ - libgcc-ng=11.2.0=h1234567_1
12
+ - libgomp=11.2.0=h1234567_1
13
+ - libstdcxx-ng=11.2.0=h1234567_1
14
+ - libuuid=1.41.5=h5eee18b_0
15
+ - ncurses=6.4=h6a678d5_0
16
+ - openssl=3.0.13=h7f8727e_0
17
+ - pip=23.3.1=py311h06a4308_0
18
+ - python=3.11.8=h955ad1f_0
19
+ - readline=8.2=h5eee18b_0
20
+ - setuptools=68.2.2=py311h06a4308_0
21
+ - sqlite=3.41.2=h5eee18b_0
22
+ - tk=8.6.12=h1ccaba5_0
23
+ - wheel=0.41.2=py311h06a4308_0
24
+ - xz=5.4.6=h5eee18b_0
25
+ - zlib=1.2.13=h5eee18b_0
26
+ - pip:
27
+ - absl-py==2.1.0
28
+ - aiofiles==23.2.1
29
+ - aiohttp==3.9.3
30
+ - aioitertools==0.11.0
31
+ - aiosignal==1.3.1
32
+ - aiosqlite==0.20.0
33
+ - alembic==1.13.1
34
+ - alignn==2024.5.27
35
+ - altair==5.3.0
36
+ - annotated-types==0.6.0
37
+ - anyio==3.7.1
38
+ - appdirs==1.4.4
39
+ - apprise==1.7.5
40
+ - ase==3.23.0
41
+ - asgi-lifespan==2.1.0
42
+ - asttokens==2.4.1
43
+ - async-timeout==4.0.3
44
+ - asyncpg==0.29.0
45
+ - atomate2==0.0.14.post30+g256b39a1
46
+ - attrs==23.2.0
47
+ - autograd==1.5
48
+ - autoray==0.6.9
49
+ - bcrypt==4.1.2
50
+ - bidict==0.23.1
51
+ - blinker==1.7.0
52
+ - blosc2==2.7.0
53
+ - boto3==1.34.74
54
+ - botocore==1.34.74
55
+ - cachetools==5.3.3
56
+ - certifi==2024.7.4
57
+ - cffi==1.16.0
58
+ - charset-normalizer==3.3.2
59
+ - chgnet==0.3.5
60
+ - click==8.1.7
61
+ - cloudpickle==3.0.0
62
+ - colorama==0.4.6
63
+ - comm==0.2.2
64
+ - contourpy==1.2.0
65
+ - coolname==2.2.0
66
+ - covalent==0.232.0.post1
67
+ - croniter==2.0.3
68
+ - cryptography==42.0.5
69
+ - custodian==2024.3.12
70
+ - cycler==0.12.1
71
+ - cython==3.0.10
72
+ - dask==2024.3.1
73
+ - dask-jobqueue==0.8.5
74
+ - dateparser==1.2.0
75
+ - debugpy==1.8.1
76
+ - decorator==5.1.1
77
+ - dgl==2.3.0+cu121
78
+ - distributed==2024.3.1
79
+ - dnspython==2.6.1
80
+ - docker==6.1.3
81
+ - docker-pycreds==0.4.0
82
+ - e3nn==0.5.1
83
+ - email-validator==2.1.1
84
+ - emmet-core==0.82.1
85
+ - executing==2.0.1
86
+ - fairchem-core==1.0.0
87
+ - fastapi==0.110.0
88
+ - fastjsonschema==2.20.0
89
+ - filelock==3.15.4
90
+ - fireworks==2.0.3
91
+ - flake8==7.1.0
92
+ - flask==3.0.2
93
+ - flask-paginate==2024.3.28
94
+ - fonttools==4.50.0
95
+ - frozenlist==1.4.1
96
+ - fsspec==2024.6.1
97
+ - furl==2.1.3
98
+ - future==1.0.0
99
+ - gitdb==4.0.11
100
+ - gitpython==3.1.43
101
+ - google-auth==2.29.0
102
+ - gpaw==24.7.0b1
103
+ - greenlet==3.0.3
104
+ - griffe==0.42.1
105
+ - grpcio==1.64.1
106
+ - gunicorn==21.2.0
107
+ - h11==0.14.0
108
+ - h2==4.1.0
109
+ - h5py==3.10.0
110
+ - hpack==4.0.0
111
+ - httpcore==1.0.5
112
+ - httptools==0.6.1
113
+ - httpx==0.27.0
114
+ - huggingface-hub==0.22.2
115
+ - hyperframe==6.0.1
116
+ - idna==3.7
117
+ - importlib-metadata==7.1.0
118
+ - importlib-resources==6.1.3
119
+ - inflect==7.3.1
120
+ - ipykernel==6.29.4
121
+ - ipython==8.22.2
122
+ - ipywidgets==8.1.2
123
+ - itsdangerous==2.1.2
124
+ - jarvis-tools==2024.5.10
125
+ - jedi==0.19.1
126
+ - jinja2==3.1.4
127
+ - jmespath==1.0.1
128
+ - jobflow==0.1.17
129
+ - joblib==1.3.2
130
+ - jsonpatch==1.33
131
+ - jsonpointer==2.4
132
+ - jsonschema==4.21.1
133
+ - jsonschema-specifications==2023.12.1
134
+ - jupyter-client==8.6.1
135
+ - jupyter-core==5.7.2
136
+ - jupyterlab-widgets==3.0.10
137
+ - kiwisolver==1.4.5
138
+ - kubernetes==29.0.0
139
+ - latexcodec==3.0.0
140
+ - lightning-utilities==0.11.2
141
+ - llvmlite==0.42.0
142
+ - lmdb==1.4.1
143
+ - lmdbm==0.0.5
144
+ - locket==1.0.0
145
+ - looseversion==1.3.0
146
+ - mace-torch==0.3.4
147
+ - maggma==0.64.0
148
+ - mako==1.3.2
149
+ - markdown==3.6
150
+ - markdown-it-py==2.2.0
151
+ - markupsafe==2.1.5
152
+ - matgl==1.0.0
153
+ - matplotlib==3.8.3
154
+ - matplotlib-inline==0.1.6
155
+ - matscipy==1.0.0
156
+ - mccabe==0.7.0
157
+ - mdurl==0.1.2
158
+ - mlip-arena==0.0.1
159
+ - mongogrant==0.3.3
160
+ - mongomock==4.1.2
161
+ - monty==2024.2.26
162
+ - more-itertools==10.3.0
163
+ - mp-api==0.41.2
164
+ - mpire==2.10.1
165
+ - mpmath==1.3.0
166
+ - msgpack==1.0.8
167
+ - multidict==6.0.5
168
+ - natsort==8.4.0
169
+ - nbformat==5.10.4
170
+ - ndindex==1.8
171
+ - nest-asyncio==1.6.0
172
+ - networkx==3.3
173
+ - numba==0.59.1
174
+ - numexpr==2.10.1
175
+ - numpy==1.26.4
176
+ - nvidia-cublas-cu12==12.1.3.1
177
+ - nvidia-cuda-cupti-cu12==12.1.105
178
+ - nvidia-cuda-nvrtc-cu12==12.1.105
179
+ - nvidia-cuda-runtime-cu12==12.1.105
180
+ - nvidia-cudnn-cu12==8.9.2.26
181
+ - nvidia-cufft-cu12==11.0.2.54
182
+ - nvidia-curand-cu12==10.3.2.106
183
+ - nvidia-cusolver-cu12==11.4.5.107
184
+ - nvidia-cusparse-cu12==12.1.0.106
185
+ - nvidia-ml-py3==7.352.0
186
+ - nvidia-nccl-cu12==2.19.3
187
+ - nvidia-nvjitlink-cu12==12.5.82
188
+ - nvidia-nvtx-cu12==12.1.105
189
+ - oauthlib==3.2.2
190
+ - opt-einsum==3.3.0
191
+ - opt-einsum-fx==0.1.4
192
+ - orderedmultidict==1.0.1
193
+ - orjson==3.10.0
194
+ - packaging==24.0
195
+ - palettable==3.3.3
196
+ - pandas==2.2.2
197
+ - paramiko==3.4.0
198
+ - parso==0.8.3
199
+ - partd==1.4.1
200
+ - pathspec==0.12.1
201
+ - pendulum==2.1.2
202
+ - pennylane==0.32.0
203
+ - pennylane-lightning==0.33.1
204
+ - pexpect==4.9.0
205
+ - pillow==10.2.0
206
+ - platformdirs==4.2.0
207
+ - plotly==5.20.0
208
+ - plumed==2.9.0
209
+ - prefect==2.16.8
210
+ - prefect-dask==0.2.6
211
+ - prettytable==3.10.0
212
+ - prompt-toolkit==3.0.43
213
+ - protobuf==4.25.3
214
+ - psutil==6.0.0
215
+ - ptyprocess==0.7.0
216
+ - pure-eval==0.2.2
217
+ - py-cpuinfo==9.0.0
218
+ - pyarrow==16.1.0
219
+ - pyasn1==0.6.0
220
+ - pyasn1-modules==0.4.0
221
+ - pybtex==0.24.0
222
+ - pycodestyle==2.12.0
223
+ - pycparser==2.22
224
+ - pydantic==2.6.4
225
+ - pydantic-core==2.16.3
226
+ - pydantic-settings==2.2.1
227
+ - pydash==8.0.0
228
+ - pydeck==0.9.1
229
+ - pydocstyle==6.3.0
230
+ - pyflakes==3.2.0
231
+ - pygments==2.17.2
232
+ - pymatgen==2024.4.13
233
+ - pymongo==4.6.3
234
+ - pynacl==1.5.0
235
+ - pyparsing==2.4.7
236
+ - python-dateutil==2.9.0.post0
237
+ - python-dotenv==1.0.1
238
+ - python-engineio==4.9.0
239
+ - python-graphviz==0.20.3
240
+ - python-hostlist==1.23.0
241
+ - python-multipart==0.0.9
242
+ - python-slugify==8.0.4
243
+ - python-socketio==5.11.2
244
+ - pytorch-lightning==2.2.1
245
+ - pytz==2024.1
246
+ - pytzdata==2020.1
247
+ - pyyaml==6.0.1
248
+ - pyzmq==25.1.2
249
+ - readchar==4.0.6
250
+ - referencing==0.34.0
251
+ - regex==2023.12.25
252
+ - requests==2.32.3
253
+ - requests-oauthlib==2.0.0
254
+ - rfc3339-validator==0.1.4
255
+ - rich==13.3.5
256
+ - rpds-py==0.18.0
257
+ - rsa==4.9
258
+ - ruamel-yaml==0.17.40
259
+ - ruamel-yaml-clib==0.2.8
260
+ - rustworkx==0.14.2
261
+ - s3transfer==0.10.1
262
+ - safetensors==0.4.2
263
+ - scikit-learn==1.4.1.post1
264
+ - scipy==1.14.0
265
+ - semantic-version==2.10.0
266
+ - sentinels==1.0.0
267
+ - sentry-sdk==2.7.1
268
+ - setproctitle==1.3.3
269
+ - shellingham==1.5.4
270
+ - simple-websocket==1.0.0
271
+ - simplejson==3.19.2
272
+ - six==1.16.0
273
+ - smart-open==7.0.4
274
+ - smmap==5.0.1
275
+ - sniffio==1.3.1
276
+ - snowballstemmer==2.2.0
277
+ - sortedcontainers==2.4.0
278
+ - spglib==2.3.1
279
+ - sqlalchemy==1.4.52
280
+ - sqlalchemy-utils==0.41.2
281
+ - sshtunnel==0.4.0
282
+ - stack-data==0.6.3
283
+ - starlette==0.36.3
284
+ - streamlit==1.36.0
285
+ - submitit==1.5.1
286
+ - sympy==1.12.1
287
+ - tables==3.9.2
288
+ - tabulate==0.9.0
289
+ - tblib==3.0.0
290
+ - tenacity==8.2.3
291
+ - tensorboard==2.17.0
292
+ - tensorboard-data-server==0.7.2
293
+ - text-unidecode==1.3
294
+ - threadpoolctl==3.4.0
295
+ - toml==0.10.2
296
+ - toolz==0.12.1
297
+ - torch==2.2.1
298
+ - torch-dftd==0.4.0
299
+ - torch-ema==0.3
300
+ - torch-geometric==2.5.2
301
+ - torch-scatter==2.1.2+pt22cu121
302
+ - torch-sparse==0.6.18+pt22cu121
303
+ - torchdata==0.7.1
304
+ - torchmetrics==1.3.2
305
+ - tornado==6.4
306
+ - tqdm==4.66.4
307
+ - trainstation==1.0
308
+ - traitlets==5.14.2
309
+ - triton==2.2.0
310
+ - typeguard==4.3.0
311
+ - typer==0.12.0
312
+ - typer-cli==0.12.0
313
+ - typer-slim==0.12.0
314
+ - typing-extensions==4.12.2
315
+ - tzdata==2024.1
316
+ - tzlocal==5.2
317
+ - ujson==5.9.0
318
+ - uncertainties==3.1.7
319
+ - urllib3==2.2.2
320
+ - uvicorn==0.18.3
321
+ - uvloop==0.19.0
322
+ - wandb==0.17.4
323
+ - watchdog==4.0.0
324
+ - watchfiles==0.21.0
325
+ - wcwidth==0.2.13
326
+ - websocket-client==1.7.0
327
+ - websockets==12.0
328
+ - werkzeug==3.0.1
329
+ - widgetsnbextension==4.0.10
330
+ - wrapt==1.16.0
331
+ - wsproto==1.2.0
332
+ - xmltodict==0.13.0
333
+ - yarl==1.9.4
334
+ - zict==3.0.0
335
+ - zipp==3.18.1
mlip_arena/models/chgnet.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from ase import Atoms
6
+ from ase.calculators.calculator import all_changes
7
+ from huggingface_hub import hf_hub_download
8
+ from torch_geometric.data import Data
9
+
10
+ from mlip_arena.models import MLIP, MLIPCalculator, ModuleMLIP
11
+
12
+
13
+ class CHGNetCalculator(MLIPCalculator):
14
+ def __init__(
15
+ self,
16
+ device: torch.device | None = None,
17
+ restart=None,
18
+ atoms=None,
19
+ directory=".",
20
+ **kwargs,
21
+ ):
22
+ super().__init__(restart=restart, atoms=atoms, directory=directory, **kwargs)
23
+
24
+ self.name: str = self.__class__.__name__
25
+
26
+ fpath = hf_hub_download(
27
+ repo_id="cyrusyc/mace-universal",
28
+ subfolder="pretrained",
29
+ filename="2023-12-12-mace-128-L1_epoch-199.model",
30
+ revision="main",
31
+ )
32
+
33
+ self.device = device or torch.device(
34
+ "cuda" if torch.cuda.is_available() else "cpu"
35
+ )
36
+
37
+ self.model = torch.load(fpath, map_location=self.device)
38
+
39
+ self.implemented_properties = ["energy", "forces", "stress"]
40
+
41
+ def calculate(
42
+ self, atoms: Atoms, properties: list[str], system_changes: list = all_changes
43
+ ):
44
+ """Calculate energies and forces for the given Atoms object"""
45
+ super().calculate(atoms, properties, system_changes)
46
+
47
+ output = self.forward(atoms)
48
+
49
+ self.results = {}
50
+ if "energy" in properties:
51
+ self.results["energy"] = output["energy"].item()
52
+ if "forces" in properties:
53
+ self.results["forces"] = output["forces"].cpu().detach().numpy()
54
+ if "stress" in properties:
55
+ self.results["stress"] = output["stress"].cpu().detach().numpy()
56
+
57
+ def forward(self, x: Data | Atoms) -> dict[str, torch.Tensor]:
58
+ """Implement data conversion, graph creation, and model forward pass"""
59
+ # TODO
60
+ raise NotImplementedError
mlip_arena/tasks/diatomics/mlip.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import timedelta
2
+ from typing import Union
3
+
4
+ # import covalent as ct
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from ase import Atoms
9
+ from ase.calculators.calculator import Calculator
10
+ from ase.data import chemical_symbols
11
+ from dask.distributed import Client
12
+ from dask_jobqueue import SLURMCluster
13
+ from prefect import flow, task
14
+ from prefect.tasks import task_input_hash
15
+ from prefect_dask import DaskTaskRunner
16
+
17
+ from mlip_arena.models import MLIPCalculator
18
+ from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
19
+
20
+ cluster_kwargs = {
21
+ "cores": 4,
22
+ "memory": "64 GB",
23
+ "shebang": "#!/bin/bash",
24
+ "account": "m3828",
25
+ "walltime": "00:10:00",
26
+ "job_mem": "0",
27
+ "job_script_prologue": ["source ~/.bashrc"],
28
+ "job_directives_skip": ["-n", "--cpus-per-task"],
29
+ "job_extra_directives": ["-q debug", "-C gpu"],
30
+ }
31
+
32
+ cluster = SLURMCluster(**cluster_kwargs)
33
+ cluster.scale(jobs=10)
34
+ client = Client(cluster)
35
+
36
+
37
+ @task(cache_key_fn=task_input_hash, cache_expiration=timedelta(hours=1))
38
+ def calculate_single_diatomic(
39
+ calculator_name: str | EXTMLIPEnum,
40
+ calculator_kwargs: dict | None,
41
+ atom1: str,
42
+ atom2: str,
43
+ rmin: float = 0.1,
44
+ rmax: float = 6.5,
45
+ npts: int = int(1e3),
46
+ ):
47
+
48
+ calculator_kwargs = calculator_kwargs or {}
49
+
50
+ if isinstance(calculator_name, EXTMLIPEnum) and calculator_name in EXTMLIPEnum:
51
+ calc = external_ase_calculator(calculator_name, **calculator_kwargs)
52
+ elif calculator_name in MLIPMap:
53
+ calc = MLIPMap[calculator_name](**calculator_kwargs)
54
+
55
+ a = 2 * rmax
56
+
57
+ rs = np.linspace(rmin, rmax, npts)
58
+ e = np.zeros_like(rs)
59
+ f = np.zeros_like(rs)
60
+
61
+ da = atom1 + atom2
62
+
63
+ for i, r in enumerate(rs):
64
+
65
+ positions = [
66
+ [0, 0, 0],
67
+ [r, 0, 0],
68
+ ]
69
+
70
+ # Create the unit cell with two atoms
71
+ atoms = Atoms(da, positions=positions, cell=[a, a, a])
72
+
73
+ atoms.calc = calc
74
+
75
+ e[i] = atoms.get_potential_energy()
76
+ f[i] = np.inner(np.array([1, 0, 0]), atoms.get_forces()[1])
77
+
78
+ return {"r": rs, "E": e, "F": f, "da": da}
79
+
80
+
81
+ @flow
82
+ def calculate_multiple_diatomics(calculator_name, calculator_kwargs):
83
+
84
+ futures = []
85
+ for symbol in chemical_symbols:
86
+ if symbol == "X":
87
+ continue
88
+ futures.append(
89
+ calculate_single_diatomic.submit(
90
+ calculator_name, calculator_kwargs, symbol, symbol
91
+ )
92
+ )
93
+
94
+ return [i for future in futures for i in future.result()]
95
+
96
+
97
+ @flow(task_runner=DaskTaskRunner(address=client.scheduler.address), log_prints=True)
98
+ def calculate_homonuclear_diatomics(calculator_name, calculator_kwargs):
99
+
100
+ curves = calculate_multiple_diatomics(calculator_name, calculator_kwargs)
101
+
102
+ pd.DataFrame(curves).to_csv(f"homonuclear-diatomics-{calculator_name}.csv")
103
+
104
+
105
+ # with plt.style.context("default"):
106
+
107
+ # SMALL_SIZE = 6
108
+ # MEDIUM_SIZE = 8
109
+ # LARGE_SIZE = 10
110
+
111
+ # LINE_WIDTH = 1
112
+
113
+ # plt.rcParams.update(
114
+ # {
115
+ # "pgf.texsystem": "pdflatex",
116
+ # "font.family": "sans-serif",
117
+ # "text.usetex": True,
118
+ # "pgf.rcfonts": True,
119
+ # "figure.constrained_layout.use": True,
120
+ # "axes.labelsize": MEDIUM_SIZE,
121
+ # "axes.titlesize": MEDIUM_SIZE,
122
+ # "legend.frameon": False,
123
+ # "legend.fontsize": MEDIUM_SIZE,
124
+ # "legend.loc": "best",
125
+ # "lines.linewidth": LINE_WIDTH,
126
+ # "xtick.labelsize": SMALL_SIZE,
127
+ # "ytick.labelsize": SMALL_SIZE,
128
+ # }
129
+ # )
130
+
131
+ # fig, ax = plt.subplots(layout="constrained", figsize=(3, 2), dpi=300)
132
+
133
+ # color = "tab:red"
134
+ # ax.plot(rs, e, color=color, zorder=1)
135
+
136
+ # ax.axhline(ls="--", color=color, alpha=0.5, lw=0.5 * LINE_WIDTH)
137
+
138
+ # ylo, yhi = ax.get_ylim()
139
+ # ax.set(xlabel=r"r [$\AA]$", ylim=(max(-7, ylo), min(5, yhi)))
140
+ # ax.set_ylabel(ylabel="E [eV]", color=color)
141
+ # ax.tick_params(axis="y", labelcolor=color)
142
+ # ax.text(0.8, 0.85, da, fontsize=LARGE_SIZE, transform=ax.transAxes)
143
+
144
+ # color = "tab:blue"
145
+
146
+ # at = ax.twinx()
147
+ # at.plot(rs, f, color=color, zorder=0, lw=0.5 * LINE_WIDTH)
148
+
149
+ # at.axhline(ls="--", color=color, alpha=0.5, lw=0.5 * LINE_WIDTH)
150
+
151
+ # ylo, yhi = at.get_ylim()
152
+ # at.set(
153
+ # xlabel=r"r [$\AA]$",
154
+ # ylim=(max(-20, ylo), min(20, yhi)),
155
+ # )
156
+ # at.set_ylabel(ylabel="F [eV/$\AA$]", color=color)
157
+ # at.tick_params(axis="y", labelcolor=color)
158
+
159
+ # plt.show()
160
+
161
+
162
+ if __name__ == "__main__":
163
+ calculate_homonuclear_diatomics(
164
+ EXTMLIPEnum.MACE, dict(model="medium", device="cuda")
165
+ )
mlip_arena/tasks/diatomics/vasp/run.ipynb DELETED
@@ -1,509 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {
7
- "tags": []
8
- },
9
- "outputs": [],
10
- "source": [
11
- "from datetime import timedelta\n",
12
- "from typing import Union\n",
13
- "import os\n",
14
- "from pathlib import Path\n",
15
- "from itertools import combinations, combinations_with_replacement\n",
16
- "import psutil\n",
17
- "import subprocess\n",
18
- "\n",
19
- "import numpy as np\n",
20
- "import pandas as pd\n",
21
- "import torch\n",
22
- "from ase import Atoms, Atom\n",
23
- "from ase.calculators.calculator import Calculator\n",
24
- "from ase.calculators.singlepoint import SinglePointCalculator\n",
25
- "from ase.calculators.vasp import Vasp\n",
26
- "from ase.data import chemical_symbols, covalent_radii, vdw_alvarez\n",
27
- "from ase.io import read, write\n",
28
- "from dask.distributed import Client\n",
29
- "from dask_jobqueue import SLURMCluster\n",
30
- "from prefect import flow, task\n",
31
- "from prefect.tasks import task_input_hash\n",
32
- "from prefect_dask import DaskTaskRunner\n",
33
- "\n",
34
- "from pymatgen.core import Element\n",
35
- "from pymatgen.io.ase import AseAtomsAdaptor\n",
36
- "from pymatgen.io.vasp.inputs import Kpoints\n",
37
- "from pymatgen.command_line.chargemol_caller import ChargemolAnalysis\n",
38
- "\n",
39
- "from mlip_arena.models import MLIPCalculator\n",
40
- "from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator\n",
41
- "from jobflow import run_locally\n",
42
- "\n",
43
- "from atomate2.vasp.jobs.mp import MPGGAStaticMaker\n",
44
- "from atomate2.vasp.sets.mp import MPGGAStaticSetGenerator"
45
- ]
46
- },
47
- {
48
- "cell_type": "code",
49
- "execution_count": null,
50
- "metadata": {
51
- "scrolled": true,
52
- "tags": []
53
- },
54
- "outputs": [],
55
- "source": [
56
- "\n",
57
- "nodes_per_alloc = 1\n",
58
- "cpus_per_task = 16\n",
59
- "gpus_per_node = 1\n",
60
- "# tasks_per_node = int(128/4)\n",
61
- "ntasks = 1\n",
62
- "\n",
63
- "\n",
64
- "cluster_kwargs = dict(\n",
65
- " cores=1,\n",
66
- " memory=\"64 GB\",\n",
67
- " shebang=\"#!/bin/bash\",\n",
68
- " account=\"m3828\",\n",
69
- " walltime=\"01:00:00\",\n",
70
- " # processes=16,\n",
71
- " # nanny=True,\n",
72
- " job_mem=\"0\",\n",
73
- " job_script_prologue=[\n",
74
- " \"source ~/.bashrc\",\n",
75
- " \"module load python\",\n",
76
- " \"source activate /pscratch/sd/c/cyrisinstance.conda/mlip-arena\",\n",
77
- " \"module load vasp/6.4.1-gpu\",\n",
78
- " \n",
79
- " \"export DDEC6_ATOMIC_DENSITIES_DIR='/global/homes/c/cyrusyc/chargemol/atomic_densities/'\",\n",
80
- " f\"export OMP_NUM_THREADS={gpus_per_node}\",\n",
81
- " \"export OMP_PLACES=threads\",\n",
82
- " \"export OMP_PROC_BIND=spread\",\n",
83
- " f\"export ASE_VASP_COMMAND='srun -n {ntasks} -c {4*gpus_per_node} --cpu-bind=cores --gpus-per-node {gpus_per_node} vasp_ncl'\"\n",
84
- " # f\"export ASE_VASP_COMMAND='srun -N {nodes_per_alloc} --ntasks-per-node={tasks_per_node} -c {cpus_per_task} --cpu-bind=cores vasp_std'\"\n",
85
- " # \"export ATOMATE2_CONFIG_FILE='/global/homes/c/cyrusyc/atomate2/config/atomate2-prefect-cpu-node.yaml'\"\n",
86
- " ],\n",
87
- " job_directives_skip=[\"-n\", \"--cpus-per-task\", \"-J\"],\n",
88
- " job_extra_directives=[\n",
89
- " \"-J diatomics\",\n",
90
- " \"-q regular\",\n",
91
- " f\"-N {nodes_per_alloc}\",\n",
92
- " \"-C gpu\",\n",
93
- " # \"-n 1\",\n",
94
- " # \"-c 16\",\n",
95
- " # \"--gpus-per-task=1\",\n",
96
- " # \"--threads-per-task=1\",\n",
97
- " # \"--gpu-bind=single:1\",\n",
98
- " # \"--comment=00:20:00\",\n",
99
- " # \"--time-min=00:05:00\",\n",
100
- " # \"--signal=B:USR1@60\",\n",
101
- " # \"--requeue\",\n",
102
- " # \"--open-mode=append\",\n",
103
- " # \"--mail-type=end,requeue\",\n",
104
- " # \"[email protected]\",\n",
105
- " ],\n",
106
- " # python=\"srun python\",\n",
107
- " death_timeout=86400, #float('inf')\n",
108
- ")\n",
109
- "\n",
110
- "\n",
111
- "cluster = SLURMCluster(**cluster_kwargs)\n",
112
- "print(cluster.job_script())\n",
113
- "# cluster.scale(3)\n",
114
- "cluster.adapt(minimum_jobs=50, maximum_jobs=100)\n",
115
- "client = Client(cluster)"
116
- ]
117
- },
118
- {
119
- "cell_type": "code",
120
- "execution_count": null,
121
- "metadata": {
122
- "tags": []
123
- },
124
- "outputs": [],
125
- "source": [
126
- "\n",
127
- "\n",
128
- "@task(cache_key_fn=task_input_hash, cache_expiration=timedelta(hours=24), log_prints=True)\n",
129
- "def calculate_single_diatomic(\n",
130
- " calculator: str | EXTMLIPEnum | Calculator,\n",
131
- " calculator_kwargs: dict | None,\n",
132
- " atom1: str,\n",
133
- " atom2: str,\n",
134
- " rmin: float = 1.25,\n",
135
- " rmax: float = 6.25,\n",
136
- " rstep: float = 0.2,\n",
137
- " magnetism: str = \"FM\"\n",
138
- "):\n",
139
- "\n",
140
- " calculator_kwargs = calculator_kwargs or {}\n",
141
- "\n",
142
- " if isinstance(calculator, str) and calculator.lower() == 'vasp-mp-gga':\n",
143
- " calc = Vasp(**calculator_kwargs)\n",
144
- " calc.name = 'vasp-mp-gga'\n",
145
- " # calc.name='atomate2'\n",
146
- " elif isinstance(calculator, EXTMLIPEnum) and calculator in EXTMLIPEnum:\n",
147
- " calc = external_ase_calculator(calculator, **calculator_kwargs)\n",
148
- " elif calculator in MLIPMap:\n",
149
- " calc = MLIPMap[calculator](**calculator_kwargs)\n",
150
- " elif issubclass(calculator, Calculator):\n",
151
- " calc = calculator(**calculator_kwargs)\n",
152
- "\n",
153
- " a = 2 * rmax\n",
154
- "\n",
155
- " npts = int((rmax - rmin)/rstep)\n",
156
- "\n",
157
- " rs = np.linspace(rmin, rmax, npts)\n",
158
- " e = np.zeros_like(rs)\n",
159
- " f = np.zeros_like(rs)\n",
160
- "\n",
161
- " da = atom1 + atom2\n",
162
- " \n",
163
- " assert isinstance(calc, Calculator)\n",
164
- " \n",
165
- " out_dir = Path(str(da + f\"_{magnetism}\"))\n",
166
- " os.makedirs(out_dir, exist_ok=True)\n",
167
- " \n",
168
- " calc.directory = out_dir\n",
169
- " \n",
170
- " print(f\"write output to {calc.directory}\")\n",
171
- " \n",
172
- " element = Element(atom1)\n",
173
- " \n",
174
- " try:\n",
175
- " m = element.valence[1]\n",
176
- " if element.valence == (0, 2):\n",
177
- " m = 0\n",
178
- " except:\n",
179
- " m = 0\n",
180
- " \n",
181
- " r = rs[0]\n",
182
- " \n",
183
- " positions = [\n",
184
- " [a/2-r/2, a/2, a/2],\n",
185
- " [a/2+r/2, a/2, a/2],\n",
186
- " ]\n",
187
- " \n",
188
- " if magnetism == 'FM':\n",
189
- " if m == 0:\n",
190
- " return {}\n",
191
- " magmoms = [m, m]\n",
192
- " elif magnetism == 'AFM':\n",
193
- " if m == 0:\n",
194
- " return {}\n",
195
- " magmoms = [m, -m]\n",
196
- " elif magnetism == 'NM':\n",
197
- " magmoms = [0, 0]\n",
198
- " \n",
199
- " traj_fpath = out_dir / \"traj.extxyz\"\n",
200
- " \n",
201
- " skip = 0\n",
202
- " if traj_fpath.exists():\n",
203
- " traj = read(traj_fpath, index=\":\")\n",
204
- " skip = len(traj)\n",
205
- " atoms = traj[-1]\n",
206
- " else:\n",
207
- " atoms = Atoms(\n",
208
- " da, \n",
209
- " positions=positions,\n",
210
- " magmoms=magmoms,\n",
211
- " cell=[a, a+0.001, a+0.002], \n",
212
- " pbc=True\n",
213
- " )\n",
214
- " \n",
215
- " # \n",
216
- " \n",
217
- " structure = AseAtomsAdaptor().get_structure(atoms)\n",
218
- " \n",
219
- " if magnetism == 'FM':\n",
220
- " I_CONSTRAINED_M = 2\n",
221
- " LAMBDA = 10\n",
222
- " M_CONSTR = [0, 0, 1, 0, 0, 1] # \" \".join(map(str, [0, 0, 1])) + \" \" + \" \".join(map(str, [0, 0, 1]))\n",
223
- " elif magnetism == 'AFM':\n",
224
- " I_CONSTRAINED_M = 2\n",
225
- " LAMBDA = 10\n",
226
- " M_CONSTR = [0, 0, 1, 0, 0, -1] # \" \".join(map(str, [0, 0, 1])) + \" \" + \" \".join(map(str, [0, 0, -1]))\n",
227
- " elif magnetism == 'NM':\n",
228
- " I_CONSTRAINED_M = 1\n",
229
- " LAMBDA = 10\n",
230
- " M_CONSTR = [0, 0, 0, 0, 0, 0] #\" \".join(map(str, [0, 0, 0])) + \" \" + \" \".join(map(str, [0, 0, 0]))\n",
231
- "\n",
232
- " input_set_generator = MPGGAStaticSetGenerator(\n",
233
- " user_incar_settings=dict(\n",
234
- " ISYM = 0, # symmetry is off\n",
235
- " ISPIN = 2,\n",
236
- " ISMEAR = 0, # Gaussian smearing, otherwise negative occupancies might come up\n",
237
- " SIGMA = 0.002, # tiny smearing width to safely break symmetry\n",
238
- " AMIX = 0.2, # mixing set manually\n",
239
- " BMIX = 0.0001,\n",
240
- " LSUBROT= True, # spin orbit coupling (non collinear)\n",
241
- " ALGO = \"Accurate\",\n",
242
- " PREC = \"High\",\n",
243
- " ENCUT = 1000,\n",
244
- " ENAUG = 2000,\n",
245
- " ISTART = 1,\n",
246
- " ICHARG = 1,\n",
247
- " NELM = 200,\n",
248
- " TIME = 0.2,\n",
249
- " LELF = False,\n",
250
- " LMAXMIX=max(max(map(lambda a: \"spdf\".index(Element(a.symbol).block) * 2, atoms)), 2),\n",
251
- " LMIXTAU=False,\n",
252
- " VOSKOWN = 1,\n",
253
- " I_CONSTRAINED_M = I_CONSTRAINED_M,\n",
254
- " M_CONSTR = M_CONSTR,\n",
255
- " LAMBDA = LAMBDA,\n",
256
- " # performance\n",
257
- " # lplane=False,\n",
258
- " # npar=int(sqrt(ncpus)),\n",
259
- " # nsim=1,\n",
260
- " # LPLANE = True,\n",
261
- " # # NCORE = 128,\n",
262
- " # LSCALU = False,\n",
263
- " # NSIM = 4,\n",
264
- " # LPLANE = False,\n",
265
- " # NPAR = 16,\n",
266
- " # NSIM = 1,\n",
267
- " # LSCALU = False,\n",
268
- " # GPU\n",
269
- " KPAR = gpus_per_node,\n",
270
- " NSIM = 64,\n",
271
- "\n",
272
- " LVTOT = False,\n",
273
- " LAECHG = True, # AECCARs\n",
274
- " LASPH = True, # aspherical charge density\n",
275
- " LCHARG = True, # CHGCAR\n",
276
- " LWAVE = True\n",
277
- " ),\n",
278
- " user_kpoints_settings=Kpoints(), # Gamma point only\n",
279
- " user_potcar_settings={\n",
280
- " \"Yb\": \"Yb_3\"\n",
281
- " },\n",
282
- " sort_structure=False\n",
283
- " )\n",
284
- "\n",
285
- " vis = input_set_generator.get_input_set(structure=structure)\n",
286
- " vis.incar.pop(\"MAGMOM\")\n",
287
- "\n",
288
- " incar = {key.lower(): value for key, value in vis.incar.items()} \n",
289
- " calc.set(kpts=1, gamma=True, **incar)\n",
290
- " \n",
291
- " atoms.calc = calc\n",
292
- "\n",
293
- " for i, r in enumerate(np.flip(rs)):\n",
294
- "\n",
295
- " \n",
296
- " if i < skip:\n",
297
- " continue\n",
298
- "\n",
299
- " positions = [\n",
300
- " [a/2-r/2, a/2, a/2],\n",
301
- " [a/2+r/2, a/2, a/2],\n",
302
- " ]\n",
303
- " \n",
304
- " if i > 0: \n",
305
- " magmoms = atoms.get_magnetic_moments()\n",
306
- " \n",
307
- " atoms.set_initial_magnetic_moments(magmoms)\n",
308
- " atoms.set_positions(positions)\n",
309
- "\n",
310
- " print(f\"{atoms} separated by {r} A ({i+1}/{len(rs)})\")\n",
311
- " \n",
312
- "\n",
313
- " e[i] = atoms.get_potential_energy()\n",
314
- " f[i] = np.inner(np.array([1, 0, 0]), atoms.get_forces()[1])\n",
315
- " \n",
316
- " atoms.calc.results.update(dict(\n",
317
- " magmoms=atoms.get_magnetic_moments()\n",
318
- " ))\n",
319
- " \n",
320
- " write(out_dir / \"traj.extxyz\", atoms, append=\"a\")\n",
321
- " \n",
322
- "# additional_results = {}\n",
323
- " \n",
324
- "# try:\n",
325
- "# ncpus = psutil.cpu_count(logical=True)\n",
326
- "# nthreads = os.environ[\"OMP_NUM_THREADS\"]\n",
327
- "# subprocess.run([\"export\", f\"OMP_NUM_THREADS={ncpus}\"], shell=True)\n",
328
- " \n",
329
- "# ca = ChargemolAnalysis(path=out_dir)\n",
330
- " \n",
331
- "# if charges := ca.ddec_charges:\n",
332
- "# additional_results[\"charges\"] = np.array(charges)\n",
333
- "# if dipoles := ca.dipoles:\n",
334
- "# additional_results[\"dipoles\"] = np.array(dipoles)\n",
335
- "# if magmoms := ca.ddec_spin_moments:\n",
336
- "# additional_results[\"magmoms\"] = np.array(magmoms)\n",
337
- " \n",
338
- "# subprocess.run([\"export\", f\"OMP_NUM_THREADS={nthreads}\"], shell=True)\n",
339
- "# except:\n",
340
- "# print(\"DDEC failed\")\n",
341
- " \n",
342
- " \n",
343
- "# atoms.calc.results.update(additional_results)\n",
344
- " \n",
345
- " return {\"r\": rs, \"E\": e, \"F\": f, \"da\": da}\n",
346
- "\n"
347
- ]
348
- },
349
- {
350
- "cell_type": "code",
351
- "execution_count": null,
352
- "metadata": {
353
- "tags": []
354
- },
355
- "outputs": [],
356
- "source": [
357
- "@flow(task_runner=DaskTaskRunner(address=client.scheduler.address), log_prints=True)\n",
358
- "def calculate_multiple_diatomics(calculator_name, calculator_kwargs):\n",
359
- "\n",
360
- " futures = []\n",
361
- " for sa in chemical_symbols:\n",
362
- " \n",
363
- " s = set([sa])\n",
364
- " \n",
365
- " if 'X' in s:\n",
366
- " continue\n",
367
- " \n",
368
- " atom = Atom(sa)\n",
369
- " rmin = covalent_radii[atom.number] * 2 * 0.6\n",
370
- " rvdw = vdw_alvarez.vdw_radii[atom.number] if atom.number < len(vdw_alvarez.vdw_radii) else np.nan \n",
371
- " rmax = 3.1 * rvdw if not np.isnan(rvdw) else 6\n",
372
- " rstep = 0.2 #if rmin < 1 else 0.4\n",
373
- " \n",
374
- " futures.append(\n",
375
- " calculate_single_diatomic.submit(\n",
376
- " calculator_name, calculator_kwargs, sa, sa,\n",
377
- " rmin=rmin, rmax=rmax,\n",
378
- " rstep=rstep,\n",
379
- " magnetism=\"FM\"\n",
380
- " # npts=16 if 'H' in s else 21\n",
381
- " )\n",
382
- " )\n",
383
- " futures.append(\n",
384
- " calculate_single_diatomic.submit(\n",
385
- " calculator_name, calculator_kwargs, sa, sa,\n",
386
- " rmin=rmin, rmax=rmax,\n",
387
- " rstep=rstep, #0.1 if rmin < 1 else 0.25,\n",
388
- " magnetism=\"AFM\"\n",
389
- " # npts=16 if 'H' in s else 21\n",
390
- " )\n",
391
- " )\n",
392
- " futures.append(\n",
393
- " calculate_single_diatomic.submit(\n",
394
- " calculator_name, calculator_kwargs, sa, sa,\n",
395
- " rmin=rmin, rmax=rmax,\n",
396
- " rstep=rstep, #0.1 if rmin < 1 else 0.25,\n",
397
- " magnetism=\"NM\"\n",
398
- " # npts=16 if 'H' in s else 21\n",
399
- " )\n",
400
- " )\n",
401
- "# for sa, sb in combinations_with_replacement(chemical_symbols, 2):\n",
402
- " \n",
403
- "# if 'X' in set([sa, sb]):\n",
404
- "# continue\n",
405
- " \n",
406
- "# futures.append(\n",
407
- "# calculate_single_diatomic.submit(\n",
408
- "# calculator_name, calculator_kwargs, sa, sb,\n",
409
- "# rmin=0.5, rmax=6.5,\n",
410
- "# npts=16\n",
411
- "# )\n",
412
- "# )\n",
413
- "\n",
414
- " return [i for future in futures for i in future.result()]\n",
415
- "\n"
416
- ]
417
- },
418
- {
419
- "cell_type": "code",
420
- "execution_count": null,
421
- "metadata": {
422
- "scrolled": true,
423
- "tags": []
424
- },
425
- "outputs": [],
426
- "source": [
427
- "calculate_multiple_diatomics(\n",
428
- " \"vasp-mp-gga\", \n",
429
- " dict(\n",
430
- " xc=\"pbe\",\n",
431
- " kpts=1,\n",
432
- " # Massively parallel machines (Cray)\n",
433
- " # lplane=False,\n",
434
- " # npar=int(sqrt(ncpus)),\n",
435
- " # nsim=1\n",
436
- " # Multicore modern linux machines\n",
437
- " # lplane=True,\n",
438
- " # npar=2,\n",
439
- " # lscalu=False,\n",
440
- " # nsim=4\n",
441
- " )\n",
442
- ")\n"
443
- ]
444
- },
445
- {
446
- "cell_type": "code",
447
- "execution_count": null,
448
- "metadata": {
449
- "scrolled": true,
450
- "tags": []
451
- },
452
- "outputs": [],
453
- "source": [
454
- "\n",
455
- "calculate_homonuclear_diatomics(\n",
456
- " \"vasp-mp-gga\", \n",
457
- " dict(\n",
458
- " xc=\"pbe\",\n",
459
- " kpts=1,\n",
460
- " # Massively parallel machines (Cray)\n",
461
- " # lplane=False,\n",
462
- " # npar=int(sqrt(ncpus)),\n",
463
- " # nsim=1\n",
464
- " # Multicore modern linux machines\n",
465
- " # lplane=True,\n",
466
- " # npar=2,\n",
467
- " # lscalu=False,\n",
468
- " # nsim=4\n",
469
- " )\n",
470
- ")\n"
471
- ]
472
- },
473
- {
474
- "cell_type": "code",
475
- "execution_count": null,
476
- "metadata": {},
477
- "outputs": [],
478
- "source": []
479
- }
480
- ],
481
- "metadata": {
482
- "kernelspec": {
483
- "display_name": "mlip-arena",
484
- "language": "python",
485
- "name": "mlip-arena"
486
- },
487
- "language_info": {
488
- "codemirror_mode": {
489
- "name": "ipython",
490
- "version": 3
491
- },
492
- "file_extension": ".py",
493
- "mimetype": "text/x-python",
494
- "name": "python",
495
- "nbconvert_exporter": "python",
496
- "pygments_lexer": "ipython3",
497
- "version": "3.11.8"
498
- },
499
- "widgets": {
500
- "application/vnd.jupyter.widget-state+json": {
501
- "state": {},
502
- "version_major": 2,
503
- "version_minor": 0
504
- }
505
- }
506
- },
507
- "nbformat": 4,
508
- "nbformat_minor": 4
509
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mlip_arena/tasks/stability/__init__.py ADDED
File without changes
mlip_arena/tasks/stability/run.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ from mlip_arena.tasks.utils import _valid_dynamics, _preset_dynamics
3
+
mlip_arena/tasks/utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+ from pathlib import Path
3
+ from ase.io import read, write
4
+ from ase import units
5
+ from ase import Atoms, units
6
+ from ase.calculators.calculator import Calculator
7
+ from ase.data import chemical_symbols
8
+ from ase.md.andersen import Andersen
9
+ from ase.md.langevin import Langevin
10
+ from ase.md.md import MolecularDynamics
11
+ from ase.md.npt import NPT
12
+ from ase.md.nptberendsen import NPTBerendsen
13
+ from ase.md.nvtberendsen import NVTBerendsen
14
+ from ase.md.velocitydistribution import (
15
+ MaxwellBoltzmannDistribution,
16
+ Stationary,
17
+ ZeroRotation,
18
+ )
19
+ from ase.md.verlet import VelocityVerlet
20
+ from dask.distributed import Client
21
+ from dask_jobqueue import SLURMCluster
22
+ from jobflow import Maker
23
+ from prefect import flow, task
24
+ from prefect.tasks import task_input_hash
25
+ from prefect_dask import DaskTaskRunner
26
+ from pymatgen.io.ase import AseAtomsAdaptor
27
+ from scipy.interpolate import interp1d
28
+ from scipy.linalg import schur
29
+
30
+ from mlip_arena.models import MLIPCalculator
31
+ from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
32
+ from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
33
+ from mp_api.client import MPRester
34
+
35
+ from fireworks import LaunchPad
36
+ from atomate2.vasp.flows.core import RelaxBandStructureMaker
37
+ from atomate2.vasp.flows.mp import MPGGADoubleRelaxStaticMaker
38
+ from atomate2.vasp.powerups import add_metadata_to_flow
39
+ from atomate2.forcefields.md import (
40
+ CHGNetMDMaker,
41
+ GAPMDMaker,
42
+ M3GNetMDMaker,
43
+ MACEMDMaker,
44
+ NequipMDMaker,
45
+ )
46
+ from atomate2.forcefields.utils import MLFF
47
+ from pymatgen.io.ase import AseAtomsAdaptor
48
+ from pymatgen.transformations.advanced_transformations import CubicSupercellTransformation
49
+ from jobflow.managers.fireworks import flow_to_workflow
50
+ from jobflow import run_locally, SETTINGS
51
+ from tqdm.auto import tqdm
52
+
53
+ from datetime import timedelta, datetime
54
+ from typing import Literal, Sequence, Tuple
55
+
56
+ import numpy as np
57
+ import torch
58
+ from pymatgen.core.structure import Structure
59
+
60
+ from ase.calculators.mixing import SumCalculator
61
+ from scipy.interpolate import interp1d
62
+
63
+ from ase.io.trajectory import Trajectory
64
+
65
+
66
+ _valid_dynamics: dict[str, tuple[str, ...]] = {
67
+ "nve": ("velocityverlet",),
68
+ "nvt": ("nose-hoover", "langevin", "andersen", "berendsen"),
69
+ "npt": ("nose-hoover", "berendsen"),
70
+ }
71
+
72
+ _preset_dynamics: dict = {
73
+ "nve_velocityverlet": VelocityVerlet,
74
+ "nvt_andersen": Andersen,
75
+ "nvt_berendsen": NVTBerendsen,
76
+ "nvt_langevin": Langevin,
77
+ "nvt_nose-hoover": NPT,
78
+ "npt_berendsen": NPTBerendsen,
79
+ "npt_nose-hoover": NPT,
80
+ }
81
+
82
+ def _interpolate_quantity(values: Sequence | np.ndarray, n_pts: int) -> np.ndarray:
83
+ """Interpolate temperature / pressure on a schedule."""
84
+ n_vals = len(values)
85
+ return np.interp(
86
+ np.linspace(0, n_vals - 1, n_pts + 1),
87
+ np.linspace(0, n_vals - 1, n_vals),
88
+ values,
89
+ )
90
+
91
+ def _get_ensemble_schedule(
92
+ ensemble: Literal["nve", "nvt", "npt"] = "nvt",
93
+ n_steps: int = 1000,
94
+ temperature: float | Sequence | np.ndarray | None = 300.0,
95
+ pressure: float | Sequence | np.ndarray | None = None
96
+ ) -> Tuple[np.ndarray, np.ndarray]:
97
+ if ensemble == "nve":
98
+ # Disable thermostat and barostat
99
+ temperature = np.nan
100
+ pressure = np.nan
101
+ t_schedule = np.full(n_steps + 1, temperature)
102
+ p_schedule = np.full(n_steps + 1, pressure)
103
+ return t_schedule, p_schedule
104
+
105
+ if isinstance(temperature, Sequence) or (
106
+ isinstance(temperature, np.ndarray) and temperature.ndim == 1
107
+ ):
108
+ t_schedule = _interpolate_quantity(temperature, n_steps)
109
+ # NOTE: In ASE Langevin dynamics, the temperature are normally
110
+ # scalars, but in principle one quantity per atom could be specified by giving
111
+ # an array. This is not implemented yet here.
112
+ else:
113
+ t_schedule = np.full(n_steps + 1, temperature)
114
+
115
+ if ensemble == "nvt":
116
+ pressure = np.nan
117
+ p_schedule = np.full(n_steps + 1, pressure)
118
+ return t_schedule, p_schedule
119
+
120
+ if isinstance(pressure, Sequence) or (
121
+ isinstance(pressure, np.ndarray) and pressure.ndim == 1
122
+ ):
123
+ p_schedule = _interpolate_quantity(pressure, n_steps)
124
+ elif isinstance(pressure, np.ndarray) and pressure.ndim == 4:
125
+ p_schedule = interp1d(
126
+ np.arange(n_steps + 1), pressure, kind="linear"
127
+ )
128
+ assert isinstance(p_schedule, np.ndarray)
129
+ else:
130
+ p_schedule = np.full(n_steps + 1, pressure)
131
+
132
+ return t_schedule, p_schedule
133
+
134
+ def _get_ensemble_defaults(
135
+ ensemble: Literal["nve", "nvt", "npt"],
136
+ dynamics: str | MolecularDynamics,
137
+ t_schedule: np.ndarray,
138
+ p_schedule: np.ndarray,
139
+ ase_md_kwargs: dict | None = None) -> dict:
140
+ """Update ASE MD kwargs"""
141
+ ase_md_kwargs = ase_md_kwargs or {}
142
+
143
+ if ensemble == "nve":
144
+ ase_md_kwargs.pop("temperature", None)
145
+ ase_md_kwargs.pop("temperature_K", None)
146
+ ase_md_kwargs.pop("externalstress", None)
147
+ elif ensemble == "nvt":
148
+ ase_md_kwargs["temperature_K"] = t_schedule[0]
149
+ ase_md_kwargs.pop("externalstress", None)
150
+ elif ensemble == "npt":
151
+ ase_md_kwargs["temperature_K"] = t_schedule[0]
152
+ ase_md_kwargs["externalstress"] = p_schedule[0] * 1e3 * units.bar
153
+
154
+ if isinstance(dynamics, str) and dynamics.lower() == "langevin":
155
+ ase_md_kwargs["friction"] = ase_md_kwargs.get(
156
+ "friction",
157
+ 10.0 * 1e-3 / units.fs, # Same default as in VASP: 10 ps^-1
158
+ )
159
+
160
+ return ase_md_kwargs
161
+
pyproject.toml CHANGED
@@ -46,6 +46,7 @@ Issues = "https://github.com/atomind-ai/mlip-arena/issues"
46
 
47
  [tool.ruff]
48
  # Exclude a variety of commonly ignored directories.
 
49
  exclude = [
50
  ".bzr",
51
  ".direnv",
 
46
 
47
  [tool.ruff]
48
  # Exclude a variety of commonly ignored directories.
49
+ extend-include = ["*.ipynb"]
50
  exclude = [
51
  ".bzr",
52
  ".direnv",
scripts/install-pyg.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # PyTorch Geometric (OCP)
4
+ TORCH=2.2.0
5
+ CUDA=cu121
6
+
7
+ pip install --verbose --no-cache torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
8
+ pip install --verbose --no-cache torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
9
+
10
+ # DGL (M3GNet)
11
+ pip install --verbose --no-cache dgl -f https://data.dgl.ai/wheels/{CUDA}/repo.html
12
+
13
+
14
+ # DGL (ALIGNN)
15
+ # pip install --verbose --no-cache dgl -f https://data.dgl.ai/wheels/torch-2.2/cu122/repo.html
serve/models/alerts.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ st.markdown("# Alerts")
serve/models/bugs.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ st.markdown("# Bugs")
serve/tasks/homonuclear-diatomics.py CHANGED
@@ -136,7 +136,7 @@ for i, symbol in enumerate(chemical_symbols[1:]):
136
  )
137
 
138
  # Set x-axis title
139
- fig.update_xaxes(title_text="Bond length (Å)")
140
 
141
  # Set y-axes titles
142
  if energy_plot:
 
136
  )
137
 
138
  # Set x-axis title
139
+ fig.update_xaxes(title_text="Bond length [Å]")
140
 
141
  # Set y-axes titles
142
  if energy_plot: