cyrusyc commited on
Commit
2f7e23a
·
1 Parent(s): 6bc1d71

clean up and add copyright notice

Browse files
mlip_arena/tasks/alexandria.py DELETED
@@ -1,7 +0,0 @@
1
-
2
-
3
- URL = "https://alexandria.icams.rub.de/"
4
-
5
-
6
- def whoami():
7
- print(f'TEST: {__file__}')
 
 
 
 
 
 
 
 
mlip_arena/tasks/md.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script has been adapted from Atomate2 MLFF MD workflow written by Aaron Kaplan and Yuan Chiang
3
+ https://github.com/materialsproject/atomate2/blob/main/src/atomate2/forcefields/md.py
4
+
5
+ atomate2 Copyright (c) 2015, The Regents of the University of
6
+ California, through Lawrence Berkeley National Laboratory (subject
7
+ to receipt of any required approvals from the U.S. Dept. of Energy).
8
+ All rights reserved.
9
+
10
+ Redistribution and use in source and binary forms, with or without
11
+ modification, are permitted provided that the following conditions
12
+ are met:
13
+
14
+ (1) Redistributions of source code must retain the above copyright
15
+ notice, this list of conditions and the following disclaimer.
16
+
17
+ (2) Redistributions in binary form must reproduce the above
18
+ copyright notice, this list of conditions and the following
19
+ disclaimer in the documentation and/or other materials provided with
20
+ the distribution.
21
+
22
+ (3) Neither the name of the University of California, Lawrence
23
+ Berkeley National Laboratory, U.S. Dept. of Energy nor the names of
24
+ its contributors may be used to endorse or promote products derived
25
+ from this software without specific prior written permission.
26
+
27
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
28
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
29
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
30
+ FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
31
+ COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
32
+ INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
33
+ BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
34
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
35
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
36
+ LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
37
+ ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38
+ POSSIBILITY OF SUCH DAMAGE.
39
+
40
+ You are under no obligation whatsoever to provide any bug fixes,
41
+ patches, or upgrades to the features, functionality or performance
42
+ of the source code ("Enhancements") to anyone; however, if you
43
+ choose to make your Enhancements available either publicly, or
44
+ directly to Lawrence Berkeley National Laboratory or its
45
+ contributors, without imposing a separate written license agreement
46
+ for such Enhancements, then you hereby grant the following license:
47
+ a non-exclusive, royalty-free perpetual license to install, use,
48
+ modify, prepare derivative works, incorporate into other computer
49
+ software, distribute, and sublicense such enhancements or derivative
50
+ works thereof, in binary and source code form.
51
+ """
52
+
53
+ from __future__ import annotations
54
+
55
+ from datetime import datetime, timedelta
56
+ from pathlib import Path
57
+ from typing import Literal, Sequence, Tuple
58
+
59
+ import numpy as np
60
+ from ase import Atoms, units
61
+ from ase.calculators.calculator import Calculator
62
+ from ase.calculators.mixing import SumCalculator
63
+ from ase.io import read
64
+ from ase.io.trajectory import Trajectory
65
+ from ase.md.andersen import Andersen
66
+ from ase.md.langevin import Langevin
67
+ from ase.md.md import MolecularDynamics
68
+ from ase.md.npt import NPT
69
+ from ase.md.nptberendsen import NPTBerendsen
70
+ from ase.md.nvtberendsen import NVTBerendsen
71
+ from ase.md.velocitydistribution import (
72
+ MaxwellBoltzmannDistribution,
73
+ Stationary,
74
+ ZeroRotation,
75
+ )
76
+ from ase.md.verlet import VelocityVerlet
77
+ from prefect import task
78
+ from prefect.tasks import task_input_hash
79
+ from scipy.interpolate import interp1d
80
+ from scipy.linalg import schur
81
+ from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
82
+ from tqdm.auto import tqdm
83
+
84
+ from mlip_arena.models.utils import MLIPEnum, get_freer_device
85
+
86
+ # from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
87
+
88
+ _valid_dynamics: dict[str, tuple[str, ...]] = {
89
+ "nve": ("velocityverlet",),
90
+ "nvt": ("nose-hoover", "langevin", "andersen", "berendsen"),
91
+ "npt": ("nose-hoover", "berendsen"),
92
+ }
93
+
94
+ _preset_dynamics: dict = {
95
+ "nve_velocityverlet": VelocityVerlet,
96
+ "nvt_andersen": Andersen,
97
+ "nvt_berendsen": NVTBerendsen,
98
+ "nvt_langevin": Langevin,
99
+ "nvt_nose-hoover": NPT,
100
+ "npt_berendsen": NPTBerendsen,
101
+ "npt_nose-hoover": NPT,
102
+ }
103
+
104
+
105
+ def _interpolate_quantity(values: Sequence | np.ndarray, n_pts: int) -> np.ndarray:
106
+ """Interpolate temperature / pressure on a schedule."""
107
+ n_vals = len(values)
108
+ return np.interp(
109
+ np.linspace(0, n_vals - 1, n_pts + 1),
110
+ np.linspace(0, n_vals - 1, n_vals),
111
+ values,
112
+ )
113
+
114
+
115
+ def _get_ensemble_schedule(
116
+ ensemble: Literal["nve", "nvt", "npt"] = "nvt",
117
+ n_steps: int = 1000,
118
+ temperature: float | Sequence | np.ndarray | None = 300.0,
119
+ pressure: float | Sequence | np.ndarray | None = None,
120
+ ) -> Tuple[np.ndarray, np.ndarray]:
121
+ if ensemble == "nve":
122
+ # Disable thermostat and barostat
123
+ temperature = np.nan
124
+ pressure = np.nan
125
+ t_schedule = np.full(n_steps + 1, temperature)
126
+ p_schedule = np.full(n_steps + 1, pressure)
127
+ return t_schedule, p_schedule
128
+
129
+ if isinstance(temperature, Sequence) or (
130
+ isinstance(temperature, np.ndarray) and temperature.ndim == 1
131
+ ):
132
+ t_schedule = _interpolate_quantity(temperature, n_steps)
133
+ # NOTE: In ASE Langevin dynamics, the temperature are normally
134
+ # scalars, but in principle one quantity per atom could be specified by giving
135
+ # an array. This is not implemented yet here.
136
+ else:
137
+ t_schedule = np.full(n_steps + 1, temperature)
138
+
139
+ if ensemble == "nvt":
140
+ pressure = np.nan
141
+ p_schedule = np.full(n_steps + 1, pressure)
142
+ return t_schedule, p_schedule
143
+
144
+ if isinstance(pressure, Sequence) or (
145
+ isinstance(pressure, np.ndarray) and pressure.ndim == 1
146
+ ):
147
+ p_schedule = _interpolate_quantity(pressure, n_steps)
148
+ elif isinstance(pressure, np.ndarray) and pressure.ndim == 4:
149
+ p_schedule = interp1d(np.arange(n_steps + 1), pressure, kind="linear")
150
+ assert isinstance(p_schedule, np.ndarray)
151
+ else:
152
+ p_schedule = np.full(n_steps + 1, pressure)
153
+
154
+ return t_schedule, p_schedule
155
+
156
+
157
+ def _get_ensemble_defaults(
158
+ ensemble: Literal["nve", "nvt", "npt"],
159
+ dynamics: str | MolecularDynamics,
160
+ t_schedule: np.ndarray,
161
+ p_schedule: np.ndarray,
162
+ ase_md_kwargs: dict | None = None,
163
+ ) -> dict:
164
+ """Update ASE MD kwargs"""
165
+ ase_md_kwargs = ase_md_kwargs or {}
166
+
167
+ if ensemble == "nve":
168
+ ase_md_kwargs.pop("temperature", None)
169
+ ase_md_kwargs.pop("temperature_K", None)
170
+ ase_md_kwargs.pop("externalstress", None)
171
+ elif ensemble == "nvt":
172
+ ase_md_kwargs["temperature_K"] = t_schedule[0]
173
+ ase_md_kwargs.pop("externalstress", None)
174
+ elif ensemble == "npt":
175
+ ase_md_kwargs["temperature_K"] = t_schedule[0]
176
+ ase_md_kwargs["externalstress"] = p_schedule[0] # * 1e3 * units.bar
177
+
178
+ if isinstance(dynamics, str) and dynamics.lower() == "langevin":
179
+ ase_md_kwargs["friction"] = ase_md_kwargs.get(
180
+ "friction",
181
+ 10.0 * 1e-3 / units.fs, # Same default as in VASP: 10 ps^-1
182
+ )
183
+
184
+ return ase_md_kwargs
185
+
186
+
187
+ @task(cache_key_fn=task_input_hash, cache_expiration=timedelta(days=1))
188
+ def run(
189
+ atoms: Atoms,
190
+ calculator_name: str | MLIPEnum,
191
+ calculator_kwargs: dict | None,
192
+ dispersion: str | None = None,
193
+ dispersion_kwargs: dict | None = None,
194
+ device: str | None = None,
195
+ ensemble: Literal["nve", "nvt", "npt"] = "nvt",
196
+ dynamics: str | MolecularDynamics = "langevin",
197
+ time_step: float | None = None,
198
+ total_time: float = 1000,
199
+ temperature: float | Sequence | np.ndarray | None = 300.0,
200
+ pressure: float | Sequence | np.ndarray | None = None,
201
+ ase_md_kwargs: dict | None = None,
202
+ md_velocity_seed: int | None = None,
203
+ zero_linear_momentum: bool = True,
204
+ zero_angular_momentum: bool = True,
205
+ traj_file: str | Path | None = None,
206
+ traj_interval: int = 1,
207
+ restart: bool = True,
208
+ ):
209
+ device = device or str(get_freer_device())
210
+
211
+ print(f"Using device: {device}")
212
+
213
+ calculator_kwargs = calculator_kwargs or {}
214
+
215
+ if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
216
+ assert issubclass(calculator_name.value, Calculator)
217
+ calc = calculator_name.value(**calculator_kwargs)
218
+ elif (
219
+ isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_
220
+ ):
221
+ calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
222
+ else:
223
+ raise ValueError(f"Invalid calculator: {calculator_name}")
224
+
225
+ print(f"Using calculator: {calc}")
226
+
227
+ dispersion_kwargs = dispersion_kwargs or {}
228
+
229
+ dispersion_kwargs.update({"device": device})
230
+
231
+ if dispersion is not None:
232
+ disp_calc = TorchDFTD3Calculator(
233
+ **dispersion_kwargs,
234
+ )
235
+ calc = SumCalculator([calc, disp_calc])
236
+
237
+ print(f"Using dispersion: {dispersion}")
238
+
239
+ atoms.calc = calc
240
+
241
+ if time_step is None:
242
+ # If a structure contains an isotope of hydrogen, set default `time_step`
243
+ # to 0.5 fs, and 2 fs otherwise.
244
+ has_h_isotope = "H" in atoms.get_chemical_symbols()
245
+ time_step = 0.5 if has_h_isotope else 2.0
246
+
247
+ n_steps = int(total_time / time_step)
248
+ target_steps = n_steps
249
+
250
+ t_schedule, p_schedule = _get_ensemble_schedule(
251
+ ensemble=ensemble,
252
+ n_steps=n_steps,
253
+ temperature=temperature,
254
+ pressure=pressure,
255
+ )
256
+
257
+ ase_md_kwargs = _get_ensemble_defaults(
258
+ ensemble=ensemble,
259
+ dynamics=dynamics,
260
+ t_schedule=t_schedule,
261
+ p_schedule=p_schedule,
262
+ ase_md_kwargs=ase_md_kwargs,
263
+ )
264
+
265
+ if isinstance(dynamics, str):
266
+ # Use known dynamics if `self.dynamics` is a str
267
+ dynamics = dynamics.lower()
268
+ if dynamics not in _valid_dynamics[ensemble]:
269
+ raise ValueError(
270
+ f"{dynamics} thermostat not available for {ensemble}."
271
+ f"Available {ensemble} thermostats are:"
272
+ " ".join(_valid_dynamics[ensemble])
273
+ )
274
+ if ensemble == "nve":
275
+ dynamics = "velocityverlet"
276
+ md_class = _preset_dynamics[f"{ensemble}_{dynamics}"]
277
+ elif dynamics is MolecularDynamics:
278
+ md_class = dynamics
279
+ else:
280
+ raise ValueError(f"Invalid dynamics: {dynamics}")
281
+
282
+ if md_class is NPT:
283
+ # Note that until md_func is instantiated, isinstance(md_func,NPT) is False
284
+ # ASE NPT implementation requires upper triangular cell
285
+ u, _ = schur(atoms.get_cell(complete=True), output="complex")
286
+ atoms.set_cell(u.real, scale_atoms=True)
287
+
288
+ last_step = 0
289
+
290
+ if traj_file is not None:
291
+ traj_file = Path(traj_file)
292
+ traj_file.parent.mkdir(parents=True, exist_ok=True)
293
+
294
+ if restart and traj_file.exists():
295
+ try:
296
+ traj = read(traj_file, index=":")
297
+ last_atoms = traj[-1]
298
+ assert isinstance(last_atoms, Atoms)
299
+ last_step = last_atoms.info.get("step", len(traj) * traj_interval)
300
+ n_steps -= last_step
301
+ traj = Trajectory(traj_file, "a", atoms)
302
+ atoms.set_positions(last_atoms.get_positions())
303
+ atoms.set_momenta(last_atoms.get_momenta())
304
+ except Exception:
305
+ traj = Trajectory(traj_file, "w", atoms)
306
+
307
+ if not np.isnan(t_schedule).any():
308
+ MaxwellBoltzmannDistribution(
309
+ atoms=atoms,
310
+ temperature_K=t_schedule[last_step],
311
+ rng=np.random.default_rng(seed=md_velocity_seed),
312
+ )
313
+
314
+ if zero_linear_momentum:
315
+ Stationary(atoms)
316
+ if zero_angular_momentum:
317
+ ZeroRotation(atoms)
318
+ else:
319
+ traj = Trajectory(traj_file, "w", atoms)
320
+
321
+ if not np.isnan(t_schedule).any():
322
+ MaxwellBoltzmannDistribution(
323
+ atoms=atoms,
324
+ temperature_K=t_schedule[last_step],
325
+ rng=np.random.default_rng(seed=md_velocity_seed),
326
+ )
327
+
328
+ if zero_linear_momentum:
329
+ Stationary(atoms)
330
+ if zero_angular_momentum:
331
+ ZeroRotation(atoms)
332
+
333
+ md_runner = md_class(
334
+ atoms=atoms,
335
+ timestep=time_step * units.fs,
336
+ **ase_md_kwargs,
337
+ )
338
+
339
+ if traj_file is not None:
340
+ md_runner.attach(traj.write, interval=traj_interval)
341
+
342
+ with tqdm(total=n_steps) as pbar:
343
+
344
+ def _callback(dyn: MolecularDynamics = md_runner) -> None:
345
+ step = last_step + dyn.nsteps
346
+ dyn.atoms.info["restart"] = last_step
347
+ dyn.atoms.info["datetime"] = datetime.now()
348
+ dyn.atoms.info["step"] = step
349
+ dyn.atoms.info["target_steps"] = target_steps
350
+ if ensemble == "nve":
351
+ return
352
+ dyn.set_temperature(temperature_K=t_schedule[step])
353
+ if ensemble == "nvt":
354
+ return
355
+ dyn.set_stress(p_schedule[step] * 1e3 * units.bar)
356
+ pbar.update()
357
+
358
+ md_runner.attach(_callback, interval=1)
359
+
360
+ start_time = datetime.now()
361
+ md_runner.run(steps=n_steps)
362
+ end_time = datetime.now()
363
+
364
+ traj.close()
365
+
366
+ return {
367
+ "atoms": atoms,
368
+ "runtime": end_time - start_time,
369
+ "n_steps": n_steps,
370
+ }
mlip_arena/tasks/nacl.py DELETED
@@ -1,11 +0,0 @@
1
-
2
- from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
3
- from mlip_arena.models import MLIP
4
-
5
-
6
- def whoami():
7
- print(f'TEST: {__file__}')
8
-
9
-
10
- if __name__ == "__main__":
11
-
 
 
 
 
 
 
 
 
 
 
 
 
mlip_arena/tasks/qmof.py DELETED
@@ -1,4 +0,0 @@
1
-
2
-
3
- def whoami():
4
- print(f'TEST: {__file__}')