cyrusyc commited on
Commit
1e50f35
·
1 Parent(s): 699356f

add eos flow file

Browse files
Files changed (1) hide show
  1. mlip_arena/tasks/eos/run.py +113 -0
mlip_arena/tasks/eos/run.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Define equation of state flows.
3
+
4
+ https://github.com/materialsvirtuallab/matcalc/blob/main/matcalc/eos.py
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import asyncio
10
+ from typing import TYPE_CHECKING
11
+
12
+ import numpy as np
13
+ from ase import Atoms
14
+ from ase.filters import * # type: ignore
15
+ from ase.optimize import * # type: ignore
16
+ from ase.optimize.optimize import Optimizer
17
+ from prefect import flow
18
+ from pymatgen.analysis.eos import BirchMurnaghan
19
+
20
+ from prefect.futures import wait
21
+
22
+ from mlip_arena.models.utils import MLIPEnum
23
+ from mlip_arena.tasks.optimize import run as OPT
24
+
25
+ if TYPE_CHECKING:
26
+ from ase.filters import Filter
27
+
28
+
29
+ @flow
30
+ def fit(
31
+ atoms: Atoms,
32
+ calculator_name: str | MLIPEnum,
33
+ calculator_kwargs: dict | None,
34
+ device: str | None = None,
35
+ optimizer: Optimizer | str = BFGSLineSearch, # type: ignore
36
+ optimizer_kwargs: dict | None = None,
37
+ filter: Filter | str | None = None,
38
+ filter_kwargs: dict | None = None,
39
+ criterion: dict | None = None,
40
+ max_abs_strain: float = 0.1,
41
+ npoints: int = 11,
42
+ ):
43
+ """
44
+ Compute the equation of state (EOS) for the given atoms and calculator.
45
+
46
+ Args:
47
+ atoms: The input atoms.
48
+ calculator_name: The name of the calculator to use.
49
+ calculator_kwargs: Additional kwargs to pass to the calculator.
50
+ device: The device to use.
51
+ optimizer: The optimizer to use.
52
+ optimizer_kwargs: Additional kwargs to pass to the optimizer.
53
+ filter: The filter to use.
54
+ filter_kwargs: Additional kwargs to pass to the filter.
55
+ criterion: The criterion to use.
56
+ max_abs_strain: The maximum absolute strain to use.
57
+ npoints: The number of points to sample.
58
+
59
+ Returns:
60
+ A dictionary containing the EOS data and the bulk modulus.
61
+ """
62
+ result = OPT(
63
+ atoms=atoms,
64
+ calculator_name=calculator_name,
65
+ calculator_kwargs=calculator_kwargs,
66
+ device=device,
67
+ optimizer=optimizer,
68
+ optimizer_kwargs=optimizer_kwargs,
69
+ filter=filter,
70
+ filter_kwargs=filter_kwargs,
71
+ criterion=criterion,
72
+ )
73
+
74
+ relaxed = result["atoms"]
75
+
76
+ # p0 = relaxed.get_positions()
77
+ c0 = relaxed.get_cell()
78
+
79
+ factors = np.linspace(1 - max_abs_strain, 1 + max_abs_strain, npoints) ** (1 / 3)
80
+
81
+ futures = []
82
+ for f in factors:
83
+ atoms = relaxed.copy()
84
+ atoms.set_cell(c0 * f, scale_atoms=True)
85
+
86
+ future = OPT.submit(
87
+ atoms=atoms,
88
+ calculator_name=calculator_name,
89
+ calculator_kwargs=calculator_kwargs,
90
+ device=device,
91
+ optimizer=optimizer,
92
+ optimizer_kwargs=optimizer_kwargs,
93
+ filter=None,
94
+ filter_kwargs=None,
95
+ criterion=criterion,
96
+ )
97
+
98
+ futures.append(future)
99
+
100
+ wait(futures)
101
+
102
+ volumes = [f.result()["atoms"].get_volume() for f in futures]
103
+ energies = [f.result()["atoms"].get_potential_energy() for f in futures]
104
+
105
+ bm = BirchMurnaghan(volumes=volumes, energies=energies)
106
+ bm.fit()
107
+
108
+ volumes, energies = map(list, zip(*sorted(zip(volumes, energies, strict=False), key=lambda i: i[0]), strict=False))
109
+
110
+ return {
111
+ "eos": {"volumes": volumes, "energies": energies},
112
+ "K": bm.b0_GPa,
113
+ }