Yuan (Cyrus) Chiang commited on
Commit
587c7e5
·
unverified ·
1 Parent(s): ef47233

Add elasticity task (#37)

Browse files

* add elasticity

* add summary to pytest

.github/workflows/test.yaml CHANGED
@@ -63,4 +63,4 @@ jobs:
63
  PREFECT_API_KEY: ${{ secrets.PREFECT_API_KEY }}
64
  PREFECT_API_URL: ${{ secrets.PREFECT_API_URL }}
65
  run: |
66
- pytest -v tests
 
63
  PREFECT_API_KEY: ${{ secrets.PREFECT_API_KEY }}
64
  PREFECT_API_URL: ${{ secrets.PREFECT_API_URL }}
65
  run: |
66
+ pytest -vvra tests
mlip_arena/tasks/elasticity.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines the tasks for computing the elastic tensor.
3
+
4
+ This module has been modified from MatCalc
5
+ https://github.com/materialsvirtuallab/matcalc/blob/main/src/matcalc/elasticity.py
6
+
7
+ https://github.com/materialsvirtuallab/matcalc/blob/main/LICENSE
8
+
9
+ BSD 3-Clause License
10
+
11
+ Copyright (c) 2023, Materials Virtual Lab
12
+
13
+ Redistribution and use in source and binary forms, with or without
14
+ modification, are permitted provided that the following conditions are met:
15
+
16
+ 1. Redistributions of source code must retain the above copyright notice, this
17
+ list of conditions and the following disclaimer.
18
+
19
+ 2. Redistributions in binary form must reproduce the above copyright notice,
20
+ this list of conditions and the following disclaimer in the documentation
21
+ and/or other materials provided with the distribution.
22
+
23
+ 3. Neither the name of the copyright holder nor the names of its
24
+ contributors may be used to endorse or promote products derived from
25
+ this software without specific prior written permission.
26
+
27
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
28
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
29
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
31
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
32
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
33
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
34
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
35
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
36
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37
+ """
38
+
39
+ from __future__ import annotations
40
+
41
+ from typing import TYPE_CHECKING, Any
42
+
43
+ import numpy as np
44
+ from numpy.typing import ArrayLike
45
+ from prefect import task
46
+ from prefect.cache_policies import INPUTS, TASK_SOURCE
47
+ from prefect.runtime import task_run
48
+ from prefect.states import State
49
+
50
+ from ase import Atoms
51
+ from ase.filters import * # type: ignore
52
+ from ase.optimize import * # type: ignore
53
+ from ase.optimize.optimize import Optimizer
54
+ from mlip_arena.models import MLIPEnum
55
+ from mlip_arena.tasks.optimize import run as OPT
56
+ from pymatgen.analysis.elasticity import DeformedStructureSet, ElasticTensor, Strain
57
+ from pymatgen.analysis.elasticity.elastic import get_strain_state_dict
58
+ from pymatgen.io.ase import AseAtomsAdaptor
59
+
60
+ if TYPE_CHECKING:
61
+ from ase.filters import Filter
62
+
63
+
64
+ def _generate_task_run_name():
65
+ task_name = task_run.task_name
66
+ parameters = task_run.parameters
67
+
68
+ atoms = parameters["atoms"]
69
+ calculator_name = parameters["calculator_name"]
70
+
71
+ return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
72
+
73
+
74
+ @task(
75
+ name="Elasticity",
76
+ task_run_name=_generate_task_run_name,
77
+ cache_policy=TASK_SOURCE + INPUTS,
78
+ # cache_key_fn=task_input_hash,
79
+ )
80
+ def run(
81
+ atoms: Atoms,
82
+ calculator_name: str | MLIPEnum,
83
+ calculator_kwargs: dict | None = None,
84
+ device: str | None = None,
85
+ optimizer: Optimizer | str = "BFGSLineSearch", # type: ignore
86
+ optimizer_kwargs: dict | None = None,
87
+ filter: Filter | str | None = "FrechetCell", # type: ignore
88
+ filter_kwargs: dict | None = None,
89
+ criterion: dict | None = None,
90
+ normal_strains: list[float] | np.ndarray | None = np.linspace(-0.01, 0.01, 4),
91
+ shear_strains: list[float] | np.ndarray | None = np.linspace(-0.06, 0.06, 4),
92
+ persist_opt: bool = True,
93
+ cache_opt: bool = True,
94
+ ) -> dict[str, Any] | State:
95
+ """
96
+ Compute the elastic tensor for the given structure and calculator.
97
+
98
+ Args:
99
+ atoms (Atoms): The input structure.
100
+ calculator_name (str | MLIPEnum): The calculator name.
101
+ calculator_kwargs (dict, optional): The calculator kwargs. Defaults to None.
102
+ device (str, optional): The device. Defaults to None.
103
+ optimizer (Optimizer | str, optional): The optimizer. Defaults to "BFGSLineSearch".
104
+ optimizer_kwargs (dict, optional): The optimizer kwargs. Defaults to None.
105
+ filter (Filter | str, optional): The filter. Defaults to "FrechetCell".
106
+ filter_kwargs (dict, optional): The filter kwargs. Defaults to None.
107
+ criterion (dict, optional): The criterion. Defaults to None.
108
+ normal_strains (list[float] | np.ndarray, optional): The normal strains. Defaults to np.linspace(-0.01, 0.01, 4).
109
+ shear_strains (list[float] | np.ndarray, optional): The shear strains. Defaults to np.linspace(-0.06, 0.06, 4).
110
+ concurrent (bool, optional): Whether to run concurrently. Defaults to True.
111
+ persist_opt (bool, optional): Whether to persist the optimizer results. Defaults to True.
112
+ cache_opt (bool, optional): Whether to cache the optimizer results. Defaults to True.
113
+
114
+ Returns:
115
+ dict[str, Any] | State: The elastic tensor.
116
+ """
117
+
118
+ OPT_ = OPT.with_options(
119
+ refresh_cache=not cache_opt,
120
+ persist_result=persist_opt,
121
+ )
122
+
123
+ first_relax = OPT_(
124
+ atoms=atoms,
125
+ calculator_name=calculator_name,
126
+ calculator_kwargs=calculator_kwargs,
127
+ device=device,
128
+ optimizer=optimizer,
129
+ optimizer_kwargs=optimizer_kwargs,
130
+ filter=filter,
131
+ filter_kwargs=filter_kwargs,
132
+ criterion=criterion,
133
+ return_state=True,
134
+ )
135
+
136
+ if first_relax.is_failed():
137
+ return first_relax
138
+
139
+ result = first_relax.result(raise_on_failure=False)
140
+
141
+ assert isinstance(result, dict)
142
+ relaxed = result["atoms"]
143
+
144
+ if isinstance(normal_strains, np.ndarray):
145
+ normal_strains = normal_strains.tolist()
146
+ if isinstance(shear_strains, np.ndarray):
147
+ shear_strains = shear_strains.tolist()
148
+
149
+ assert isinstance(relaxed, Atoms)
150
+ assert isinstance(normal_strains, list)
151
+ assert isinstance(shear_strains, list)
152
+
153
+ structure = AseAtomsAdaptor.get_structure(relaxed) # type: ignore
154
+
155
+ deformed_structure_set = DeformedStructureSet(
156
+ structure,
157
+ normal_strains,
158
+ shear_strains,
159
+ )
160
+
161
+ stresses = []
162
+ for deformed_structure in deformed_structure_set:
163
+ atoms = deformed_structure.to_ase_atoms()
164
+ atoms.calc = relaxed.calc
165
+ stresses.append(atoms.get_stress(voigt=False))
166
+
167
+ strains = [
168
+ Strain.from_deformation(deformation)
169
+ for deformation in deformed_structure_set.deformations
170
+ ]
171
+
172
+ fit = fit_elastic_tensor(
173
+ strains,
174
+ stresses,
175
+ eq_stress=relaxed.get_stress(voigt=False)
176
+ )
177
+
178
+ return {
179
+ "elastic_tensor": fit["elastic_tensor"],
180
+ "residuals_sum": fit["residuals_sum"],
181
+ }
182
+
183
+
184
+ @task
185
+ def fit_elastic_tensor(
186
+ strains: ArrayLike,
187
+ stresses: ArrayLike,
188
+ eq_stress: ArrayLike | None = None,
189
+ tolerance: float = 1e-7,
190
+ ):
191
+ """
192
+ Compute the elastic tensor from the given strains and stresses.
193
+
194
+ Args:
195
+ strains (ArrayLike): The strains.
196
+ stresses (ArrayLike): The stresses.
197
+ tolerance (float, optional): The tolerance. Defaults to 1e-7.
198
+
199
+ Returns:
200
+ ElasticTensor: The elastic tensor.
201
+ """
202
+
203
+ strain_states = [tuple(ss) for ss in np.eye(6)]
204
+ ss_dict = get_strain_state_dict(
205
+ strains,
206
+ stresses,
207
+ eq_stress=eq_stress,
208
+ add_eq=True if eq_stress is not None else False,
209
+ )
210
+ c_ij = np.zeros((6, 6))
211
+ residuals_sum = 0.0
212
+ for ii in range(6):
213
+ strain = ss_dict[strain_states[ii]]["strains"]
214
+ stress = ss_dict[strain_states[ii]]["stresses"]
215
+ for jj in range(6):
216
+ fit = np.polyfit(strain[:, ii], stress[:, jj], 1, full=True)
217
+ c_ij[ii, jj] = fit[0][0]
218
+ residuals_sum += fit[1][0] if len(fit[1]) > 0 else 0.0
219
+ elastic_tensor = ElasticTensor.from_voigt(c_ij)
220
+
221
+ return {
222
+ "elastic_tensor": elastic_tensor.zeroed(tolerance),
223
+ "residuals_sum": residuals_sum,
224
+ }
tests/test_elasticity.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import numpy as np
4
+ import pytest
5
+ from mlip_arena.models import MLIPEnum
6
+ from mlip_arena.tasks.elasticity import run as ELASTICITY
7
+ from prefect.testing.utilities import prefect_test_harness
8
+
9
+ from ase.build import bulk
10
+
11
+
12
+ @pytest.mark.skipif(
13
+ sys.version_info[:2] != (3, 11),
14
+ reason="avoid prefect race condition on concurrent tasks",
15
+ )
16
+ @pytest.mark.parametrize("model", [MLIPEnum["MACE-MP(M)"]])
17
+ def test_elasticity(model: MLIPEnum):
18
+ """
19
+ Test elasticity prefect workflow with a simple cubic lattice.
20
+ """
21
+
22
+ with prefect_test_harness():
23
+ result = ELASTICITY(
24
+ atoms=bulk("Cu", "fcc", a=3.6),
25
+ calculator_name=model.name,
26
+ calculator_kwargs={},
27
+ device=None,
28
+ optimizer="BFGSLineSearch",
29
+ optimizer_kwargs=None,
30
+ filter="FrechetCell",
31
+ filter_kwargs=None,
32
+ criterion=None,
33
+ persist_opt=False,
34
+ cache_opt=False,
35
+ )
36
+ assert isinstance(result, dict)
37
+ assert isinstance(result["elastic_tensor"], np.ndarray)
38
+ assert result["elastic_tensor"].shape == (3, 3, 3, 3)
39
+ assert isinstance(result["elastic_tensor"][0, 0, 0, 0], float)