Spaces:
Running
Running
Yuan (Cyrus) Chiang
commited on
hotfix mattersim pickling issue; add phonon task (#42)
Browse files- mlip_arena/models/externals/mattersim.py +15 -7
- mlip_arena/tasks/__init__.py +37 -36
- mlip_arena/tasks/phonon.py +162 -0
- mlip_arena/tasks/utils.py +48 -10
mlip_arena/models/externals/mattersim.py
CHANGED
@@ -24,13 +24,21 @@ class MatterSim(MatterSimCalculator):
|
|
24 |
load_path=checkpoint, device=str(device or get_freer_device()), **kwargs
|
25 |
)
|
26 |
|
27 |
-
def
|
28 |
-
self
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
# # convert unpicklizable atoms back to picklizable atoms to avoid prefect pickling error
|
36 |
# if isinstance(self.atoms, MSONAtoms):
|
|
|
24 |
load_path=checkpoint, device=str(device or get_freer_device()), **kwargs
|
25 |
)
|
26 |
|
27 |
+
def __getstate__(self):
|
28 |
+
state = self.__dict__.copy()
|
29 |
+
|
30 |
+
# BUG: remove unpicklizable potential
|
31 |
+
state.pop("potential", None)
|
32 |
+
|
33 |
+
return state
|
34 |
+
|
35 |
+
# def calculate(
|
36 |
+
# self,
|
37 |
+
# atoms: Atoms | None = None,
|
38 |
+
# properties: list | None = None,
|
39 |
+
# system_changes: list | None = None,
|
40 |
+
# ):
|
41 |
+
# super().calculate(atoms, properties, system_changes)
|
42 |
|
43 |
# # convert unpicklizable atoms back to picklizable atoms to avoid prefect pickling error
|
44 |
# if isinstance(self.atoms, MSONAtoms):
|
mlip_arena/tasks/__init__.py
CHANGED
@@ -3,8 +3,8 @@ from pathlib import Path
|
|
3 |
import yaml
|
4 |
from huggingface_hub import HfApi, HfFileSystem, hf_hub_download
|
5 |
|
6 |
-
from mlip_arena.models import MLIP
|
7 |
-
from mlip_arena.models import REGISTRY as MODEL_REGISTRY
|
8 |
|
9 |
try:
|
10 |
from .elasticity import run as ELASTICITY
|
@@ -13,8 +13,9 @@ try:
|
|
13 |
from .neb import run as NEB
|
14 |
from .neb import run_from_endpoints as NEB_FROM_ENDPOINTS
|
15 |
from .optimize import run as OPT
|
|
|
16 |
|
17 |
-
__all__ = ["OPT", "EOS", "MD", "NEB", "NEB_FROM_ENDPOINTS", "ELASTICITY"]
|
18 |
except ImportError:
|
19 |
pass
|
20 |
|
@@ -22,43 +23,43 @@ with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
|
|
22 |
REGISTRY = yaml.safe_load(f)
|
23 |
|
24 |
|
25 |
-
class Task:
|
26 |
-
|
27 |
-
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
|
64 |
-
|
|
|
3 |
import yaml
|
4 |
from huggingface_hub import HfApi, HfFileSystem, hf_hub_download
|
5 |
|
6 |
+
# from mlip_arena.models import MLIP
|
7 |
+
# from mlip_arena.models import REGISTRY as MODEL_REGISTRY
|
8 |
|
9 |
try:
|
10 |
from .elasticity import run as ELASTICITY
|
|
|
13 |
from .neb import run as NEB
|
14 |
from .neb import run_from_endpoints as NEB_FROM_ENDPOINTS
|
15 |
from .optimize import run as OPT
|
16 |
+
from .phonon import run as PHONON
|
17 |
|
18 |
+
__all__ = ["OPT", "EOS", "MD", "NEB", "NEB_FROM_ENDPOINTS", "ELASTICITY", "PHONON"]
|
19 |
except ImportError:
|
20 |
pass
|
21 |
|
|
|
23 |
REGISTRY = yaml.safe_load(f)
|
24 |
|
25 |
|
26 |
+
# class Task:
|
27 |
+
# def __init__(self):
|
28 |
+
# self.name: str = self.__class__.__name__ # display name on the leaderboard
|
29 |
|
30 |
+
# def run_local(self, model: MLIP):
|
31 |
+
# """Run the task using the given model and return the results."""
|
32 |
+
# raise NotImplementedError
|
33 |
|
34 |
+
# def run_hf(self, model: MLIP):
|
35 |
+
# """Run the task using the given model and return the results."""
|
36 |
+
# raise NotImplementedError
|
37 |
|
38 |
+
# # Calcualte evaluation metrics and postprocessed data
|
39 |
+
# api = HfApi()
|
40 |
+
# api.upload_file(
|
41 |
+
# path_or_fileobj="results.json",
|
42 |
+
# path_in_repo=f"{self.__class__.__name__}/{model.__class__.__name__}/results.json", # Upload to a specific folder
|
43 |
+
# repo_id="atomind/mlip-arena",
|
44 |
+
# repo_type="dataset",
|
45 |
+
# )
|
46 |
|
47 |
+
# def run_nersc(self, model: MLIP):
|
48 |
+
# """Run the task using the given model and return the results."""
|
49 |
+
# raise NotImplementedError
|
50 |
|
51 |
+
# def get_results(self):
|
52 |
+
# """Get the results from the task."""
|
53 |
+
# # fs = HfFileSystem()
|
54 |
+
# # files = fs.glob(f"datasets/atomind/mlip-arena/{self.__class__.__name__}/*/*.json")
|
55 |
|
56 |
+
# for model, metadata in MODEL_REGISTRY.items():
|
57 |
+
# results = hf_hub_download(
|
58 |
+
# repo_id="atomind/mlip-arena",
|
59 |
+
# filename="results.json",
|
60 |
+
# subfolder=f"{self.__class__.__name__}/{model}",
|
61 |
+
# repo_type="dataset",
|
62 |
+
# revision=None,
|
63 |
+
# )
|
64 |
|
65 |
+
# return results
|
mlip_arena/tasks/phonon.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module has been adapted from Quacc (https://github.com/Quantum-Accelerators/quacc). By using this software, you agree to the Quacc license agreement: https://github.com/Quantum-Accelerators/quacc/blob/main/LICENSE.md
|
3 |
+
|
4 |
+
|
5 |
+
BSD 3-Clause License
|
6 |
+
|
7 |
+
Copyright (c) 2025, Andrew S. Rosen.
|
8 |
+
All rights reserved.
|
9 |
+
|
10 |
+
Redistribution and use in source and binary forms, with or without
|
11 |
+
modification, are permitted provided that the following conditions are met:
|
12 |
+
|
13 |
+
- Redistributions of source code must retain the above copyright notice, this
|
14 |
+
list of conditions and the following disclaimer.
|
15 |
+
|
16 |
+
- Redistributions in binary form must reproduce the above copyright notice,
|
17 |
+
this list of conditions and the following disclaimer in the documentation
|
18 |
+
and/or other materials provided with the distribution.
|
19 |
+
|
20 |
+
- Neither the name of the copyright holder nor the names of its
|
21 |
+
contributors may be used to endorse or promote products derived from
|
22 |
+
this software without specific prior written permission.
|
23 |
+
|
24 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
25 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
26 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
27 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
28 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
29 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
30 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
31 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
32 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
33 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
34 |
+
"""
|
35 |
+
|
36 |
+
from pathlib import Path
|
37 |
+
|
38 |
+
import numpy as np
|
39 |
+
from phonopy import Phonopy
|
40 |
+
from phonopy.structure.atoms import PhonopyAtoms
|
41 |
+
from prefect import task
|
42 |
+
from prefect.cache_policies import INPUTS, TASK_SOURCE
|
43 |
+
from prefect.runtime import task_run
|
44 |
+
|
45 |
+
from ase import Atoms
|
46 |
+
from ase.calculators.calculator import BaseCalculator
|
47 |
+
|
48 |
+
|
49 |
+
@task(cache_policy=TASK_SOURCE + INPUTS)
|
50 |
+
def get_phonopy(
|
51 |
+
atoms: Atoms,
|
52 |
+
supercell_matrix: list[int] | None = None,
|
53 |
+
min_lengths: float | tuple[float, float, float] | None = None,
|
54 |
+
symprec: float = 1e-5,
|
55 |
+
distance: float = 0.01,
|
56 |
+
phonopy_kwargs: dict = {},
|
57 |
+
) -> Phonopy:
|
58 |
+
if supercell_matrix is None and min_lengths is not None:
|
59 |
+
supercell_matrix = np.diag(
|
60 |
+
np.round(np.ceil(min_lengths / atoms.cell.lengths()))
|
61 |
+
)
|
62 |
+
|
63 |
+
phonon = Phonopy(
|
64 |
+
PhonopyAtoms(
|
65 |
+
symbols=atoms.get_chemical_symbols(),
|
66 |
+
cell=atoms.get_cell(),
|
67 |
+
scaled_positions=atoms.get_scaled_positions(wrap=True),
|
68 |
+
masses=atoms.get_masses(),
|
69 |
+
),
|
70 |
+
symprec=symprec,
|
71 |
+
supercell_matrix=supercell_matrix,
|
72 |
+
**phonopy_kwargs,
|
73 |
+
)
|
74 |
+
phonon.generate_displacements(distance=distance)
|
75 |
+
|
76 |
+
return phonon
|
77 |
+
|
78 |
+
|
79 |
+
def _get_forces(
|
80 |
+
phononpy_atoms: PhonopyAtoms,
|
81 |
+
calculator: BaseCalculator,
|
82 |
+
) -> np.ndarray:
|
83 |
+
atoms = Atoms(
|
84 |
+
symbols=phononpy_atoms.symbols,
|
85 |
+
cell=phononpy_atoms.cell,
|
86 |
+
scaled_positions=phononpy_atoms.scaled_positions,
|
87 |
+
pbc=True,
|
88 |
+
)
|
89 |
+
|
90 |
+
atoms.calc = calculator
|
91 |
+
|
92 |
+
return atoms.get_forces()
|
93 |
+
|
94 |
+
|
95 |
+
def _generate_task_run_name():
|
96 |
+
task_name = task_run.task_name
|
97 |
+
parameters = task_run.parameters
|
98 |
+
|
99 |
+
atoms = parameters["atoms"]
|
100 |
+
calculator = parameters["calculator"]
|
101 |
+
|
102 |
+
return (
|
103 |
+
f"{task_name}: {atoms.get_chemical_formula()} - {calculator.__class__.__name__}"
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
@task(
|
108 |
+
name="PHONON",
|
109 |
+
task_run_name=_generate_task_run_name,
|
110 |
+
cache_policy=TASK_SOURCE + INPUTS,
|
111 |
+
)
|
112 |
+
def run(
|
113 |
+
atoms: Atoms,
|
114 |
+
calculator: BaseCalculator,
|
115 |
+
supercell_matrix: list[int] | None = None,
|
116 |
+
min_lengths: float | tuple[float, float, float] | None = None,
|
117 |
+
symprec: float = 1e-5,
|
118 |
+
distance: float = 0.01,
|
119 |
+
phonopy_kwargs: dict = {},
|
120 |
+
symmetry: bool = False,
|
121 |
+
t_min: float = 0.0,
|
122 |
+
t_max: float = 1000.0,
|
123 |
+
t_step: float = 10.0,
|
124 |
+
outdir: str | None = None,
|
125 |
+
):
|
126 |
+
phonon = get_phonopy(
|
127 |
+
atoms=atoms,
|
128 |
+
supercell_matrix=supercell_matrix,
|
129 |
+
min_lengths=min_lengths,
|
130 |
+
symprec=symprec,
|
131 |
+
distance=distance,
|
132 |
+
phonopy_kwargs=phonopy_kwargs,
|
133 |
+
)
|
134 |
+
|
135 |
+
supercells_with_displacements = phonon.supercells_with_displacements
|
136 |
+
|
137 |
+
phonon.forces = [
|
138 |
+
_get_forces(supercell, calculator)
|
139 |
+
for supercell in supercells_with_displacements
|
140 |
+
if supercell is not None
|
141 |
+
]
|
142 |
+
phonon.produce_force_constants()
|
143 |
+
|
144 |
+
if symmetry:
|
145 |
+
phonon.symmetrize_force_constants()
|
146 |
+
phonon.symmetrize_force_constants_by_space_group()
|
147 |
+
|
148 |
+
phonon.run_mesh(with_eigenvectors=True)
|
149 |
+
phonon.run_total_dos()
|
150 |
+
phonon.run_thermal_properties(t_step=t_step, t_max=t_max, t_min=t_min) # type: ignore
|
151 |
+
phonon.auto_band_structure(
|
152 |
+
write_yaml=True if outdir is not None else False,
|
153 |
+
filename=Path(outdir, "band.yaml") if outdir is not None else "band.yaml",
|
154 |
+
)
|
155 |
+
if outdir:
|
156 |
+
phonon.save(
|
157 |
+
Path(outdir, "phonopy.yaml"), settings={"force_constants": True}
|
158 |
+
)
|
159 |
+
|
160 |
+
return {
|
161 |
+
"phonon": phonon,
|
162 |
+
}
|
mlip_arena/tasks/utils.py
CHANGED
@@ -2,13 +2,15 @@
|
|
2 |
|
3 |
from __future__ import annotations
|
4 |
|
|
|
|
|
|
|
5 |
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
6 |
|
7 |
from ase import units
|
8 |
-
from ase.calculators.calculator import
|
9 |
from ase.calculators.mixing import SumCalculator
|
10 |
from mlip_arena.models import MLIPEnum
|
11 |
-
from mlip_arena.models.utils import get_freer_device
|
12 |
|
13 |
try:
|
14 |
from prefect.logging import get_run_logger
|
@@ -17,16 +19,48 @@ try:
|
|
17 |
except (ImportError, RuntimeError):
|
18 |
from loguru import logger
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
def get_calculator(
|
24 |
-
calculator_name: str | MLIPEnum |
|
25 |
-
calculator_kwargs: dict | None,
|
26 |
dispersion: bool = False,
|
27 |
dispersion_kwargs: dict | None = None,
|
28 |
device: str | None = None,
|
29 |
-
) ->
|
30 |
"""Get a calculator with optional dispersion correction."""
|
31 |
|
32 |
device = device or str(get_freer_device())
|
@@ -40,11 +74,15 @@ def get_calculator(
|
|
40 |
calc = calculator_name.value(**calculator_kwargs)
|
41 |
elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
|
42 |
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
|
43 |
-
elif isinstance(calculator_name, type) and issubclass(
|
|
|
|
|
44 |
logger.warning(f"Using custom calculator class: {calculator_name}")
|
45 |
calc = calculator_name(**calculator_kwargs)
|
46 |
-
elif isinstance(calculator_name,
|
47 |
-
logger.warning(
|
|
|
|
|
48 |
calc = calculator_name
|
49 |
else:
|
50 |
raise ValueError(f"Invalid calculator: {calculator_name}")
|
@@ -69,5 +107,5 @@ def get_calculator(
|
|
69 |
if dispersion_kwargs:
|
70 |
logger.info(pformat(dispersion_kwargs))
|
71 |
|
72 |
-
assert isinstance(calc,
|
73 |
return calc
|
|
|
2 |
|
3 |
from __future__ import annotations
|
4 |
|
5 |
+
from pprint import pformat
|
6 |
+
|
7 |
+
import torch
|
8 |
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
9 |
|
10 |
from ase import units
|
11 |
+
from ase.calculators.calculator import BaseCalculator
|
12 |
from ase.calculators.mixing import SumCalculator
|
13 |
from mlip_arena.models import MLIPEnum
|
|
|
14 |
|
15 |
try:
|
16 |
from prefect.logging import get_run_logger
|
|
|
19 |
except (ImportError, RuntimeError):
|
20 |
from loguru import logger
|
21 |
|
22 |
+
|
23 |
+
def get_freer_device() -> torch.device:
|
24 |
+
"""Get the GPU with the most free memory, or use MPS if available.
|
25 |
+
s
|
26 |
+
Returns:
|
27 |
+
torch.device: The selected GPU device or MPS.
|
28 |
+
|
29 |
+
Raises:
|
30 |
+
ValueError: If no GPU or MPS is available.
|
31 |
+
"""
|
32 |
+
device_count = torch.cuda.device_count()
|
33 |
+
if device_count > 0:
|
34 |
+
# If CUDA GPUs are available, select the one with the most free memory
|
35 |
+
mem_free = [
|
36 |
+
torch.cuda.get_device_properties(i).total_memory
|
37 |
+
- torch.cuda.memory_allocated(i)
|
38 |
+
for i in range(device_count)
|
39 |
+
]
|
40 |
+
free_gpu_index = mem_free.index(max(mem_free))
|
41 |
+
device = torch.device(f"cuda:{free_gpu_index}")
|
42 |
+
logger.info(
|
43 |
+
f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
|
44 |
+
)
|
45 |
+
elif torch.backends.mps.is_available():
|
46 |
+
# If no CUDA GPUs are available but MPS is, use MPS
|
47 |
+
logger.info("No GPU available. Using MPS.")
|
48 |
+
device = torch.device("mps")
|
49 |
+
else:
|
50 |
+
# Fallback to CPU if neither CUDA GPUs nor MPS are available
|
51 |
+
logger.info("No GPU or MPS available. Using CPU.")
|
52 |
+
device = torch.device("cpu")
|
53 |
+
|
54 |
+
return device
|
55 |
|
56 |
|
57 |
def get_calculator(
|
58 |
+
calculator_name: str | MLIPEnum | BaseCalculator,
|
59 |
+
calculator_kwargs: dict | None = None,
|
60 |
dispersion: bool = False,
|
61 |
dispersion_kwargs: dict | None = None,
|
62 |
device: str | None = None,
|
63 |
+
) -> BaseCalculator:
|
64 |
"""Get a calculator with optional dispersion correction."""
|
65 |
|
66 |
device = device or str(get_freer_device())
|
|
|
74 |
calc = calculator_name.value(**calculator_kwargs)
|
75 |
elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
|
76 |
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
|
77 |
+
elif isinstance(calculator_name, type) and issubclass(
|
78 |
+
calculator_name, BaseCalculator
|
79 |
+
):
|
80 |
logger.warning(f"Using custom calculator class: {calculator_name}")
|
81 |
calc = calculator_name(**calculator_kwargs)
|
82 |
+
elif isinstance(calculator_name, BaseCalculator):
|
83 |
+
logger.warning(
|
84 |
+
f"Using custom calculator object (kwargs are ignored): {calculator_name}"
|
85 |
+
)
|
86 |
calc = calculator_name
|
87 |
else:
|
88 |
raise ValueError(f"Invalid calculator: {calculator_name}")
|
|
|
107 |
if dispersion_kwargs:
|
108 |
logger.info(pformat(dispersion_kwargs))
|
109 |
|
110 |
+
assert isinstance(calc, BaseCalculator)
|
111 |
return calc
|