File size: 3,298 Bytes
7471e3d
9241861
 
 
 
 
 
56200e6
80f7283
9241861
 
 
 
56200e6
 
 
df7cf57
56200e6
9241861
 
 
 
 
 
 
7471e3d
9241861
 
 
 
80f7283
9241861
 
 
 
 
 
 
 
7471e3d
 
 
 
 
 
 
 
56200e6
9241861
7471e3d
9241861
7471e3d
 
50064b5
7471e3d
9241861
 
56200e6
9241861
 
 
 
 
80f7283
9241861
56200e6
 
9241861
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80f7283
 
 
9241861
 
df7cf57
 
9241861
df7cf57
 
 
 
9241861
 
 
 
df7cf57
 
 
9241861
 
c7442c5
9241861
 
c7442c5
 
9241861
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Define structure optimization tasks.
"""

from __future__ import annotations

from ase import Atoms
from ase.calculators.calculator import BaseCalculator
from ase.constraints import FixSymmetry
from ase.filters import *  # type: ignore
from ase.filters import Filter
from ase.optimize import *  # type: ignore
from ase.optimize.optimize import Optimizer
from prefect import task
from prefect.cache_policies import INPUTS, TASK_SOURCE
from prefect.runtime import task_run

from mlip_arena.tasks.utils import logger, pformat

_valid_filters: dict[str, Filter] = {
    "Filter": Filter,
    "UnitCell": UnitCellFilter,
    "ExpCell": ExpCellFilter,
    "Strain": StrainFilter,
    "FrechetCell": FrechetCellFilter,
}  # type: ignore

_valid_optimizers: dict[str, Optimizer] = {
    "MDMin": MDMin,
    "FIRE": FIRE,
    "FIRE2": FIRE2,
    "LBFGS": LBFGS,
    "LBFGSLineSearch": LBFGSLineSearch,
    "BFGS": BFGS,
    "BFGSLineSearch": BFGSLineSearch,
    "QuasiNewton": QuasiNewton,
    "GPMin": GPMin,
    "CellAwareBFGS": CellAwareBFGS,
    "ODE12r": ODE12r,
}  # type: ignore


def _generate_task_run_name():
    task_name = task_run.task_name
    parameters = task_run.parameters

    atoms = parameters["atoms"]
    calculator_name = parameters["calculator"]

    return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"


@task(
    name="OPT", task_run_name=_generate_task_run_name, cache_policy=TASK_SOURCE + INPUTS
)
def run(
    atoms: Atoms,
    calculator: BaseCalculator,
    optimizer: Optimizer | str = BFGSLineSearch,
    optimizer_kwargs: dict | None = None,
    filter: Filter | str | None = None,
    filter_kwargs: dict | None = None,
    criterion: dict | None = None,
    symmetry: bool = False,
):
    atoms = atoms.copy()
    atoms.calc = calculator

    if isinstance(filter, str):
        if filter not in _valid_filters:
            raise ValueError(f"Invalid filter: {filter}")
        filter = _valid_filters[filter]

    if isinstance(optimizer, str):
        if optimizer not in _valid_optimizers:
            raise ValueError(f"Invalid optimizer: {optimizer}")
        optimizer = _valid_optimizers[optimizer]

    filter_kwargs = filter_kwargs or {}
    optimizer_kwargs = optimizer_kwargs or {}
    criterion = criterion or {}

    if symmetry:
        atoms.set_constraint(FixSymmetry(atoms))

    if isinstance(filter, type) and issubclass(filter, Filter):
        filter_instance = filter(atoms, **filter_kwargs)
        logger.info(f"Using filter: {filter_instance}")
        logger.info(pformat(filter_kwargs))

        optimizer_instance = optimizer(filter_instance, **optimizer_kwargs)
        logger.info(f"Using optimizer: {optimizer_instance}")
        logger.info(pformat(optimizer_kwargs))
        logger.info(f"Criterion: {pformat(criterion)}")

        optimizer_instance.run(**criterion)
    elif filter is None:
        optimizer_instance = optimizer(atoms, **optimizer_kwargs)
        logger.info(f"Using optimizer: {optimizer_instance}")
        logger.info(pformat(optimizer_kwargs))
        logger.info(f"Criterion: {pformat(criterion)}")
        optimizer_instance.run(**criterion)


    return {
        "atoms": atoms,
        "steps": optimizer_instance.nsteps,
        "converged": optimizer_instance.converged(),
    }