Spaces:
Running
Running
File size: 3,298 Bytes
7471e3d 9241861 56200e6 80f7283 9241861 56200e6 df7cf57 56200e6 9241861 7471e3d 9241861 80f7283 9241861 7471e3d 56200e6 9241861 7471e3d 9241861 7471e3d 50064b5 7471e3d 9241861 56200e6 9241861 80f7283 9241861 56200e6 9241861 80f7283 9241861 df7cf57 9241861 df7cf57 9241861 df7cf57 9241861 c7442c5 9241861 c7442c5 9241861 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
"""
Define structure optimization tasks.
"""
from __future__ import annotations
from ase import Atoms
from ase.calculators.calculator import BaseCalculator
from ase.constraints import FixSymmetry
from ase.filters import * # type: ignore
from ase.filters import Filter
from ase.optimize import * # type: ignore
from ase.optimize.optimize import Optimizer
from prefect import task
from prefect.cache_policies import INPUTS, TASK_SOURCE
from prefect.runtime import task_run
from mlip_arena.tasks.utils import logger, pformat
_valid_filters: dict[str, Filter] = {
"Filter": Filter,
"UnitCell": UnitCellFilter,
"ExpCell": ExpCellFilter,
"Strain": StrainFilter,
"FrechetCell": FrechetCellFilter,
} # type: ignore
_valid_optimizers: dict[str, Optimizer] = {
"MDMin": MDMin,
"FIRE": FIRE,
"FIRE2": FIRE2,
"LBFGS": LBFGS,
"LBFGSLineSearch": LBFGSLineSearch,
"BFGS": BFGS,
"BFGSLineSearch": BFGSLineSearch,
"QuasiNewton": QuasiNewton,
"GPMin": GPMin,
"CellAwareBFGS": CellAwareBFGS,
"ODE12r": ODE12r,
} # type: ignore
def _generate_task_run_name():
task_name = task_run.task_name
parameters = task_run.parameters
atoms = parameters["atoms"]
calculator_name = parameters["calculator"]
return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
@task(
name="OPT", task_run_name=_generate_task_run_name, cache_policy=TASK_SOURCE + INPUTS
)
def run(
atoms: Atoms,
calculator: BaseCalculator,
optimizer: Optimizer | str = BFGSLineSearch,
optimizer_kwargs: dict | None = None,
filter: Filter | str | None = None,
filter_kwargs: dict | None = None,
criterion: dict | None = None,
symmetry: bool = False,
):
atoms = atoms.copy()
atoms.calc = calculator
if isinstance(filter, str):
if filter not in _valid_filters:
raise ValueError(f"Invalid filter: {filter}")
filter = _valid_filters[filter]
if isinstance(optimizer, str):
if optimizer not in _valid_optimizers:
raise ValueError(f"Invalid optimizer: {optimizer}")
optimizer = _valid_optimizers[optimizer]
filter_kwargs = filter_kwargs or {}
optimizer_kwargs = optimizer_kwargs or {}
criterion = criterion or {}
if symmetry:
atoms.set_constraint(FixSymmetry(atoms))
if isinstance(filter, type) and issubclass(filter, Filter):
filter_instance = filter(atoms, **filter_kwargs)
logger.info(f"Using filter: {filter_instance}")
logger.info(pformat(filter_kwargs))
optimizer_instance = optimizer(filter_instance, **optimizer_kwargs)
logger.info(f"Using optimizer: {optimizer_instance}")
logger.info(pformat(optimizer_kwargs))
logger.info(f"Criterion: {pformat(criterion)}")
optimizer_instance.run(**criterion)
elif filter is None:
optimizer_instance = optimizer(atoms, **optimizer_kwargs)
logger.info(f"Using optimizer: {optimizer_instance}")
logger.info(pformat(optimizer_kwargs))
logger.info(f"Criterion: {pformat(criterion)}")
optimizer_instance.run(**criterion)
return {
"atoms": atoms,
"steps": optimizer_instance.nsteps,
"converged": optimizer_instance.converged(),
}
|