Spaces:
Running
Running
from __future__ import annotations | |
from datetime import datetime, timedelta | |
from pathlib import Path | |
from typing import Literal, Sequence, Tuple | |
import numpy as np | |
from ase import Atoms, units | |
from ase.calculators.calculator import Calculator | |
from ase.calculators.mixing import SumCalculator | |
from ase.io import read | |
from ase.io.trajectory import Trajectory | |
from ase.md.andersen import Andersen | |
from ase.md.langevin import Langevin | |
from ase.md.md import MolecularDynamics | |
from ase.md.npt import NPT | |
from ase.md.nptberendsen import NPTBerendsen | |
from ase.md.nvtberendsen import NVTBerendsen | |
from ase.md.velocitydistribution import ( | |
MaxwellBoltzmannDistribution, | |
Stationary, | |
ZeroRotation, | |
) | |
from ase.md.verlet import VelocityVerlet | |
from prefect import task | |
from prefect.tasks import task_input_hash | |
from scipy.interpolate import interp1d | |
from scipy.linalg import schur | |
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator | |
from tqdm.auto import tqdm | |
from mlip_arena.models.utils import MLIPEnum, get_freer_device | |
# from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator | |
_valid_dynamics: dict[str, tuple[str, ...]] = { | |
"nve": ("velocityverlet",), | |
"nvt": ("nose-hoover", "langevin", "andersen", "berendsen"), | |
"npt": ("nose-hoover", "berendsen"), | |
} | |
_preset_dynamics: dict = { | |
"nve_velocityverlet": VelocityVerlet, | |
"nvt_andersen": Andersen, | |
"nvt_berendsen": NVTBerendsen, | |
"nvt_langevin": Langevin, | |
"nvt_nose-hoover": NPT, | |
"npt_berendsen": NPTBerendsen, | |
"npt_nose-hoover": NPT, | |
} | |
def _interpolate_quantity(values: Sequence | np.ndarray, n_pts: int) -> np.ndarray: | |
"""Interpolate temperature / pressure on a schedule.""" | |
n_vals = len(values) | |
return np.interp( | |
np.linspace(0, n_vals - 1, n_pts + 1), | |
np.linspace(0, n_vals - 1, n_vals), | |
values, | |
) | |
def _get_ensemble_schedule( | |
ensemble: Literal["nve", "nvt", "npt"] = "nvt", | |
n_steps: int = 1000, | |
temperature: float | Sequence | np.ndarray | None = 300.0, | |
pressure: float | Sequence | np.ndarray | None = None, | |
) -> Tuple[np.ndarray, np.ndarray]: | |
if ensemble == "nve": | |
# Disable thermostat and barostat | |
temperature = np.nan | |
pressure = np.nan | |
t_schedule = np.full(n_steps + 1, temperature) | |
p_schedule = np.full(n_steps + 1, pressure) | |
return t_schedule, p_schedule | |
if isinstance(temperature, Sequence) or ( | |
isinstance(temperature, np.ndarray) and temperature.ndim == 1 | |
): | |
t_schedule = _interpolate_quantity(temperature, n_steps) | |
# NOTE: In ASE Langevin dynamics, the temperature are normally | |
# scalars, but in principle one quantity per atom could be specified by giving | |
# an array. This is not implemented yet here. | |
else: | |
t_schedule = np.full(n_steps + 1, temperature) | |
if ensemble == "nvt": | |
pressure = np.nan | |
p_schedule = np.full(n_steps + 1, pressure) | |
return t_schedule, p_schedule | |
if isinstance(pressure, Sequence) or ( | |
isinstance(pressure, np.ndarray) and pressure.ndim == 1 | |
): | |
p_schedule = _interpolate_quantity(pressure, n_steps) | |
elif isinstance(pressure, np.ndarray) and pressure.ndim == 4: | |
p_schedule = interp1d(np.arange(n_steps + 1), pressure, kind="linear") | |
assert isinstance(p_schedule, np.ndarray) | |
else: | |
p_schedule = np.full(n_steps + 1, pressure) | |
return t_schedule, p_schedule | |
def _get_ensemble_defaults( | |
ensemble: Literal["nve", "nvt", "npt"], | |
dynamics: str | MolecularDynamics, | |
t_schedule: np.ndarray, | |
p_schedule: np.ndarray, | |
ase_md_kwargs: dict | None = None, | |
) -> dict: | |
"""Update ASE MD kwargs""" | |
ase_md_kwargs = ase_md_kwargs or {} | |
if ensemble == "nve": | |
ase_md_kwargs.pop("temperature", None) | |
ase_md_kwargs.pop("temperature_K", None) | |
ase_md_kwargs.pop("externalstress", None) | |
elif ensemble == "nvt": | |
ase_md_kwargs["temperature_K"] = t_schedule[0] | |
ase_md_kwargs.pop("externalstress", None) | |
elif ensemble == "npt": | |
ase_md_kwargs["temperature_K"] = t_schedule[0] | |
ase_md_kwargs["externalstress"] = p_schedule[0] # * 1e3 * units.bar | |
if isinstance(dynamics, str) and dynamics.lower() == "langevin": | |
ase_md_kwargs["friction"] = ase_md_kwargs.get( | |
"friction", | |
10.0 * 1e-3 / units.fs, # Same default as in VASP: 10 ps^-1 | |
) | |
return ase_md_kwargs | |
def md( | |
atoms: Atoms, | |
calculator_name: str | MLIPEnum, | |
calculator_kwargs: dict | None, | |
dispersion: str | None = None, | |
dispersion_kwargs: dict | None = None, | |
device: str | None = None, | |
ensemble: Literal["nve", "nvt", "npt"] = "nvt", | |
dynamics: str | MolecularDynamics = "langevin", | |
time_step: float | None = None, | |
total_time: float = 1000, | |
temperature: float | Sequence | np.ndarray | None = 300.0, | |
pressure: float | Sequence | np.ndarray | None = None, | |
ase_md_kwargs: dict | None = None, | |
mb_velocity_seed: int | None = None, | |
zero_linear_momentum: bool = True, | |
zero_angular_momentum: bool = True, | |
traj_file: str | Path | None = None, | |
traj_interval: int = 1, | |
restart: bool = True, | |
): | |
device = device or str(get_freer_device()) | |
print(f"Using device: {device}") | |
calculator_kwargs = calculator_kwargs or {} | |
if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum: | |
assert issubclass(calculator_name.value, Calculator) | |
calc = calculator_name.value(**calculator_kwargs) | |
elif ( | |
isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_ | |
): | |
calc = MLIPEnum[calculator_name].value(**calculator_kwargs) | |
else: | |
raise ValueError(f"Invalid calculator: {calculator_name}") | |
print(f"Using calculator: {calc}") | |
dispersion_kwargs = dispersion_kwargs or {} | |
dispersion_kwargs.update({"device": device}) | |
if dispersion is not None: | |
disp_calc = TorchDFTD3Calculator( | |
**dispersion_kwargs, | |
) | |
calc = SumCalculator([calc, disp_calc]) | |
print(f"Using dispersion: {dispersion}") | |
atoms.calc = calc | |
if time_step is None: | |
# If a structure contains an isotope of hydrogen, set default `time_step` | |
# to 0.5 fs, and 2 fs otherwise. | |
has_h_isotope = "H" in atoms.get_chemical_symbols() | |
time_step = 0.5 if has_h_isotope else 2.0 | |
n_steps = int(total_time / time_step) | |
target_steps = n_steps | |
t_schedule, p_schedule = _get_ensemble_schedule( | |
ensemble=ensemble, | |
n_steps=n_steps, | |
temperature=temperature, | |
pressure=pressure, | |
) | |
ase_md_kwargs = _get_ensemble_defaults( | |
ensemble=ensemble, | |
dynamics=dynamics, | |
t_schedule=t_schedule, | |
p_schedule=p_schedule, | |
ase_md_kwargs=ase_md_kwargs, | |
) | |
if isinstance(dynamics, str): | |
# Use known dynamics if `self.dynamics` is a str | |
dynamics = dynamics.lower() | |
if dynamics not in _valid_dynamics[ensemble]: | |
raise ValueError( | |
f"{dynamics} thermostat not available for {ensemble}." | |
f"Available {ensemble} thermostats are:" | |
" ".join(_valid_dynamics[ensemble]) | |
) | |
if ensemble == "nve": | |
dynamics = "velocityverlet" | |
md_class = _preset_dynamics[f"{ensemble}_{dynamics}"] | |
elif dynamics is MolecularDynamics: | |
md_class = dynamics | |
else: | |
raise ValueError(f"Invalid dynamics: {dynamics}") | |
if md_class is NPT: | |
# Note that until md_func is instantiated, isinstance(md_func,NPT) is False | |
# ASE NPT implementation requires upper triangular cell | |
u, _ = schur(atoms.get_cell(complete=True), output="complex") | |
atoms.set_cell(u.real, scale_atoms=True) | |
last_step = 0 | |
if traj_file is not None: | |
traj_file = Path(traj_file) | |
traj_file.parent.mkdir(parents=True, exist_ok=True) | |
if restart and traj_file.exists(): | |
try: | |
traj = read(traj_file, index=":") | |
last_atoms = traj[-1] | |
assert isinstance(last_atoms, Atoms) | |
last_step = last_atoms.info.get("step", len(traj) * traj_interval) | |
n_steps -= last_step | |
traj = Trajectory(traj_file, "a", atoms) | |
atoms.set_positions(last_atoms.get_positions()) | |
atoms.set_momenta(last_atoms.get_momenta()) | |
except Exception: | |
traj = Trajectory(traj_file, "w", atoms) | |
if not np.isnan(t_schedule).any(): | |
MaxwellBoltzmannDistribution( | |
atoms=atoms, | |
temperature_K=t_schedule[last_step], | |
rng=np.random.default_rng(seed=mb_velocity_seed), | |
) | |
if zero_linear_momentum: | |
Stationary(atoms) | |
if zero_angular_momentum: | |
ZeroRotation(atoms) | |
else: | |
traj = Trajectory(traj_file, "w", atoms) | |
if not np.isnan(t_schedule).any(): | |
MaxwellBoltzmannDistribution( | |
atoms=atoms, | |
temperature_K=t_schedule[last_step], | |
rng=np.random.default_rng(seed=mb_velocity_seed), | |
) | |
if zero_linear_momentum: | |
Stationary(atoms) | |
if zero_angular_momentum: | |
ZeroRotation(atoms) | |
md_runner = md_class( | |
atoms=atoms, | |
timestep=time_step * units.fs, | |
**ase_md_kwargs, | |
) | |
if traj_file is not None: | |
md_runner.attach(traj.write, interval=traj_interval) | |
with tqdm(total=n_steps) as pbar: | |
def _callback(dyn: MolecularDynamics = md_runner) -> None: | |
step = last_step + dyn.nsteps | |
dyn.atoms.info["restart"] = last_step | |
dyn.atoms.info["datetime"] = datetime.now() | |
dyn.atoms.info["step"] = step | |
dyn.atoms.info["target_steps"] = target_steps | |
if ensemble == "nve": | |
return | |
dyn.set_temperature(temperature_K=t_schedule[step]) | |
if ensemble == "nvt": | |
return | |
dyn.set_stress(p_schedule[step] * 1e3 * units.bar) | |
pbar.update() | |
md_runner.attach(_callback, interval=1) | |
start_time = datetime.now() | |
md_runner.run(steps=n_steps) | |
end_time = datetime.now() | |
traj.close() | |
return {"runtime": end_time - start_time, "n_steps": n_steps} | |