File size: 5,332 Bytes
1e50f35
51638da
1e50f35
 
 
 
 
 
1d1ee87
1e50f35
 
51638da
1d1ee87
51638da
 
1d1ee87
51638da
1e50f35
 
 
 
52c1bfb
1e50f35
51638da
1e50f35
 
 
 
 
51638da
c7922c2
 
 
 
 
 
 
 
 
51638da
 
 
1d1ee87
 
51638da
 
1e50f35
 
1d1ee87
1e50f35
c7922c2
1e50f35
1d1ee87
1e50f35
 
 
 
1d1ee87
 
1e50f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1ee87
1e50f35
 
1d1ee87
1e50f35
1d1ee87
1e50f35
 
 
 
 
 
 
 
 
1d1ee87
1e50f35
 
1d1ee87
 
 
 
 
c7922c2
1e50f35
 
 
 
 
 
1d1ee87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7922c2
 
 
 
 
 
 
 
1e50f35
 
 
 
 
1d1ee87
 
1e50f35
 
08a88d8
 
 
 
1e50f35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
Define equation of state task.

https://github.com/materialsvirtuallab/matcalc/blob/main/matcalc/eos.py
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
from prefect import task
from prefect.cache_policies import INPUTS, TASK_SOURCE
from prefect.futures import wait
from prefect.runtime import task_run
from prefect.states import State

from ase import Atoms
from ase.filters import *  # type: ignore
from ase.optimize import *  # type: ignore
from ase.optimize.optimize import Optimizer
from mlip_arena.models import MLIPEnum
from mlip_arena.tasks.optimize import run as OPT
from pymatgen.analysis.eos import BirchMurnaghan

if TYPE_CHECKING:
    from ase.filters import Filter


def _generate_task_run_name():
    task_name = task_run.task_name
    parameters = task_run.parameters

    atoms = parameters["atoms"]
    calculator_name = parameters["calculator_name"]

    return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"


@task(
    name="EOS",
    task_run_name=_generate_task_run_name,
    cache_policy=TASK_SOURCE + INPUTS
    # cache_key_fn=task_input_hash,
)
def run(
    atoms: Atoms,
    calculator_name: str | MLIPEnum,
    calculator_kwargs: dict | None = None,
    device: str | None = None,
    optimizer: Optimizer | str = "BFGSLineSearch",  # type: ignore
    optimizer_kwargs: dict | None = None,
    filter: Filter | str | None = "FrechetCell",  # type: ignore
    filter_kwargs: dict | None = None,
    criterion: dict | None = None,
    max_abs_strain: float = 0.1,
    npoints: int = 11,
    concurrent: bool = True,
) -> dict[str, Any] | State:
    """
    Compute the equation of state (EOS) for the given atoms and calculator.

    Args:
        atoms: The input atoms.
        calculator_name: The name of the calculator to use.
        calculator_kwargs: Additional kwargs to pass to the calculator.
        device: The device to use.
        optimizer: The optimizer to use.
        optimizer_kwargs: Additional kwargs to pass to the optimizer.
        filter: The filter to use.
        filter_kwargs: Additional kwargs to pass to the filter.
        criterion: The criterion to use.
        max_abs_strain: The maximum absolute strain to use.
        npoints: The number of points to sample.
        concurrent: Whether to relax multiple structures concurrently.

    Returns:
        A dictionary containing the EOS data, bulk modulus, equilibrium volume, and equilibrium energy if successful. Otherwise, a prefect state object.
    """
    state = OPT(
        atoms=atoms,
        calculator_name=calculator_name,
        calculator_kwargs=calculator_kwargs,
        device=device,
        optimizer=optimizer,
        optimizer_kwargs=optimizer_kwargs,
        filter=filter,
        filter_kwargs=filter_kwargs,
        criterion=criterion,
        return_state=True,
    )

    if state.is_failed():
        return state

    first_relax = state.result(raise_on_failure=False)
    assert isinstance(first_relax, dict)
    relaxed = first_relax["atoms"]

    # p0 = relaxed.get_positions()
    c0 = relaxed.get_cell()

    factors = np.linspace(1 - max_abs_strain, 1 + max_abs_strain, npoints) ** (1 / 3)

    if concurrent:
        futures = []
        for f in factors:
            atoms = relaxed.copy()
            atoms.set_cell(c0 * f, scale_atoms=True)

            future = OPT.submit(
                atoms=atoms,
                calculator_name=calculator_name,
                calculator_kwargs=calculator_kwargs,
                device=device,
                optimizer=optimizer,
                optimizer_kwargs=optimizer_kwargs,
                filter=None,
                filter_kwargs=None,
                criterion=criterion,
            )
            futures.append(future)

        wait(futures)

        results = [
            f.result(raise_on_failure=False)
            for f in futures
            if future.state.is_completed()
        ]
    else:
        states = []
        for f in factors:
            atoms = relaxed.copy()
            atoms.set_cell(c0 * f, scale_atoms=True)

            state = OPT(
                atoms=atoms,
                calculator_name=calculator_name,
                calculator_kwargs=calculator_kwargs,
                device=device,
                optimizer=optimizer,
                optimizer_kwargs=optimizer_kwargs,
                filter=None,
                filter_kwargs=None,
                criterion=criterion,
                return_state=True,
            )
            states.append(state)
        results = [
            s.result(raise_on_failure=False) for s in states if state.is_completed()
        ]

    volumes = [f["atoms"].get_volume() for f in results]
    energies = [f["atoms"].get_potential_energy() for f in results]

    volumes, energies = map(
        list,
        zip(
            *sorted(zip(volumes, energies, strict=True), key=lambda i: i[0]),
            strict=True,
        ),
    )

    bm = BirchMurnaghan(volumes=volumes, energies=energies)
    bm.fit()

    return {
        "atoms": relaxed,
        "calculator_name": calculator_name,
        "eos": {"volumes": volumes, "energies": energies},
        "K": bm.b0_GPa,
        "b0": bm.b0,
        "b1": bm.b1,
        "e0": bm.e0,
        "v0": bm.v0,
    }