File size: 3,530 Bytes
51638da
1effaf5
 
 
 
 
51638da
1d1ee87
51638da
 
1effaf5
1d1ee87
1effaf5
 
 
 
52c1bfb
e80e29d
 
1effaf5
 
 
 
 
 
 
51638da
1effaf5
 
 
 
1d1ee87
1effaf5
 
 
 
 
 
 
 
51638da
 
 
 
 
 
 
 
 
1effaf5
51638da
1effaf5
51638da
 
a787930
51638da
1effaf5
 
 
1d1ee87
a787930
1effaf5
 
 
 
 
 
 
1d1ee87
1effaf5
a787930
 
 
 
 
 
 
1effaf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1ee87
 
 
1effaf5
 
e80e29d
 
1effaf5
e80e29d
 
 
 
1effaf5
 
 
 
e80e29d
 
 
1effaf5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Define structure optimization tasks.
"""

from __future__ import annotations

from prefect import task
from prefect.cache_policies import INPUTS, TASK_SOURCE
from prefect.runtime import task_run

from ase import Atoms
from ase.constraints import FixSymmetry
from ase.filters import *  # type: ignore
from ase.filters import Filter
from ase.optimize import *  # type: ignore
from ase.optimize.optimize import Optimizer
from mlip_arena.models import MLIPEnum
from mlip_arena.tasks.utils import get_calculator, logger, pformat


_valid_filters: dict[str, Filter] = {
    "Filter": Filter,
    "UnitCell": UnitCellFilter,
    "ExpCell": ExpCellFilter,
    "Strain": StrainFilter,
    "FrechetCell": FrechetCellFilter,
}  # type: ignore

_valid_optimizers: dict[str, Optimizer] = {
    "MDMin": MDMin,
    "FIRE": FIRE,
    "FIRE2": FIRE2,
    "LBFGS": LBFGS,
    "LBFGSLineSearch": LBFGSLineSearch,
    "BFGS": BFGS,
    "BFGSLineSearch": BFGSLineSearch,
    "QuasiNewton": QuasiNewton,
    "GPMin": GPMin,
    "CellAwareBFGS": CellAwareBFGS,
    "ODE12r": ODE12r,
}  # type: ignore


def _generate_task_run_name():
    task_name = task_run.task_name
    parameters = task_run.parameters

    atoms = parameters["atoms"]
    calculator_name = parameters["calculator_name"]

    return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"


@task(
    name="OPT", task_run_name=_generate_task_run_name, cache_policy=TASK_SOURCE + INPUTS
)
def run(
    atoms: Atoms,
    calculator_name: str | MLIPEnum,
    calculator_kwargs: dict | None = None,
    dispersion: bool = False,
    dispersion_kwargs: dict | None = None,
    device: str | None = None,
    optimizer: Optimizer | str = BFGSLineSearch,
    optimizer_kwargs: dict | None = None,
    filter: Filter | str | None = None,
    filter_kwargs: dict | None = None,
    criterion: dict | None = None,
    symmetry: bool = False,
):
    atoms.calc = get_calculator(
        calculator_name=calculator_name,
        calculator_kwargs=calculator_kwargs,
        dispersion=dispersion,
        dispersion_kwargs=dispersion_kwargs,
        device=device,
    )

    if isinstance(filter, str):
        if filter not in _valid_filters:
            raise ValueError(f"Invalid filter: {filter}")
        filter = _valid_filters[filter]

    if isinstance(optimizer, str):
        if optimizer not in _valid_optimizers:
            raise ValueError(f"Invalid optimizer: {optimizer}")
        optimizer = _valid_optimizers[optimizer]

    filter_kwargs = filter_kwargs or {}
    optimizer_kwargs = optimizer_kwargs or {}
    criterion = criterion or {}

    if symmetry:
        atoms.set_constraint(FixSymmetry(atoms))

    if isinstance(filter, type) and issubclass(filter, Filter):
        filter_instance = filter(atoms, **filter_kwargs)
        logger.info(f"Using filter: {filter_instance}")
        logger.info(pformat(filter_kwargs))

        optimizer_instance = optimizer(filter_instance, **optimizer_kwargs)
        logger.info(f"Using optimizer: {optimizer_instance}")
        logger.info(pformat(optimizer_kwargs))
        logger.info(f"Criterion: {pformat(criterion)}")

        optimizer_instance.run(**criterion)
    elif filter is None:
        optimizer_instance = optimizer(atoms, **optimizer_kwargs)
        logger.info(f"Using optimizer: {optimizer_instance}")
        logger.info(pformat(optimizer_kwargs))
        logger.info(f"Criterion: {pformat(criterion)}")
        optimizer_instance.run(**criterion)

    return {
        "atoms": atoms,
    }