cyrusyc commited on
Commit
e80e29d
·
1 Parent(s): cb1fb61

more logger; fix relaxation filter

Browse files
mlip_arena/models/utils.py CHANGED
@@ -2,6 +2,13 @@
2
 
3
  import torch
4
 
 
 
 
 
 
 
 
5
 
6
  def get_freer_device() -> torch.device:
7
  """Get the GPU with the most free memory, or use MPS if available.
@@ -22,16 +29,16 @@ def get_freer_device() -> torch.device:
22
  ]
23
  free_gpu_index = mem_free.index(max(mem_free))
24
  device = torch.device(f"cuda:{free_gpu_index}")
25
- print(
26
  f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
27
  )
28
  elif torch.backends.mps.is_available():
29
  # If no CUDA GPUs are available but MPS is, use MPS
30
- print("No GPU available. Using MPS.")
31
  device = torch.device("mps")
32
  else:
33
  # Fallback to CPU if neither CUDA GPUs nor MPS are available
34
- print("No GPU or MPS available. Using CPU.")
35
  device = torch.device("cpu")
36
 
37
  return device
 
2
 
3
  import torch
4
 
5
+ try:
6
+ from prefect.logging import get_run_logger
7
+
8
+ logger = get_run_logger()
9
+ except (ImportError, RuntimeError):
10
+ from loguru import logger
11
+
12
 
13
  def get_freer_device() -> torch.device:
14
  """Get the GPU with the most free memory, or use MPS if available.
 
29
  ]
30
  free_gpu_index = mem_free.index(max(mem_free))
31
  device = torch.device(f"cuda:{free_gpu_index}")
32
+ logger.info(
33
  f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
34
  )
35
  elif torch.backends.mps.is_available():
36
  # If no CUDA GPUs are available but MPS is, use MPS
37
+ logger.info("No GPU available. Using MPS.")
38
  device = torch.device("mps")
39
  else:
40
  # Fallback to CPU if neither CUDA GPUs nor MPS are available
41
+ logger.info("No GPU or MPS available. Using CPU.")
42
  device = torch.device("cpu")
43
 
44
  return device
mlip_arena/tasks/elasticity.py CHANGED
@@ -48,8 +48,6 @@ from prefect.runtime import task_run
48
  from prefect.states import State
49
 
50
  from ase import Atoms
51
- from ase.filters import * # type: ignore
52
- from ase.optimize import * # type: ignore
53
  from ase.optimize.optimize import Optimizer
54
  from mlip_arena.models import MLIPEnum
55
  from mlip_arena.tasks.optimize import run as OPT
@@ -81,6 +79,8 @@ def run(
81
  atoms: Atoms,
82
  calculator_name: str | MLIPEnum,
83
  calculator_kwargs: dict | None = None,
 
 
84
  device: str | None = None,
85
  optimizer: Optimizer | str = "BFGSLineSearch", # type: ignore
86
  optimizer_kwargs: dict | None = None,
@@ -124,6 +124,8 @@ def run(
124
  atoms=atoms,
125
  calculator_name=calculator_name,
126
  calculator_kwargs=calculator_kwargs,
 
 
127
  device=device,
128
  optimizer=optimizer,
129
  optimizer_kwargs=optimizer_kwargs,
 
48
  from prefect.states import State
49
 
50
  from ase import Atoms
 
 
51
  from ase.optimize.optimize import Optimizer
52
  from mlip_arena.models import MLIPEnum
53
  from mlip_arena.tasks.optimize import run as OPT
 
79
  atoms: Atoms,
80
  calculator_name: str | MLIPEnum,
81
  calculator_kwargs: dict | None = None,
82
+ dispersion: bool = False,
83
+ dispersion_kwargs: dict | None = None,
84
  device: str | None = None,
85
  optimizer: Optimizer | str = "BFGSLineSearch", # type: ignore
86
  optimizer_kwargs: dict | None = None,
 
124
  atoms=atoms,
125
  calculator_name=calculator_name,
126
  calculator_kwargs=calculator_kwargs,
127
+ dispersion=dispersion,
128
+ dispersion_kwargs=dispersion_kwargs,
129
  device=device,
130
  optimizer=optimizer,
131
  optimizer_kwargs=optimizer_kwargs,
mlip_arena/tasks/neb.py CHANGED
@@ -39,7 +39,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
39
  from __future__ import annotations
40
 
41
  from pathlib import Path
42
- from typing import TYPE_CHECKING, Any, Literal
43
 
44
  from prefect import task
45
  from prefect.cache_policies import INPUTS, TASK_SOURCE
@@ -54,13 +54,9 @@ from ase.optimize.optimize import Optimizer
54
  from ase.utils.forcecurve import fit_images
55
  from mlip_arena.models import MLIPEnum
56
  from mlip_arena.tasks.optimize import run as OPT
57
- from mlip_arena.tasks.utils import get_calculator
58
  from pymatgen.io.ase import AseAtomsAdaptor
59
 
60
-
61
- if TYPE_CHECKING:
62
- pass
63
-
64
  _valid_optimizers: dict[str, Optimizer] = {
65
  "MDMin": MDMin,
66
  "FIRE": FIRE,
@@ -86,7 +82,7 @@ def _generate_task_run_name():
86
  atoms = parameters["start"]
87
  else:
88
  raise ValueError("No images or start atoms found in parameters")
89
-
90
  calculator_name = parameters["calculator_name"]
91
 
92
  return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
@@ -156,20 +152,17 @@ def run(
156
  criterion = criterion or {}
157
 
158
  optimizer_instance = optimizer(neb, trajectory=traj_file, **optimizer_kwargs) # type: ignore
159
-
 
 
160
  optimizer_instance.run(**criterion)
161
 
162
  neb_tool = NEBTools(neb.images)
163
- barrier = neb_tool.get_barrier()
164
-
165
- forcefit = fit_images(neb.images)
166
-
167
- images = neb.images
168
 
169
  return {
170
- "barrier": barrier,
171
- "images": images,
172
- "forcefit": forcefit,
173
  }
174
 
175
 
@@ -261,7 +254,7 @@ def run_from_end_points(
261
  )
262
  )
263
 
264
- images = [s.to_ase_atoms() for s in path]
265
 
266
  return run.with_options(
267
  refresh_cache=not cache_subtasks,
 
39
  from __future__ import annotations
40
 
41
  from pathlib import Path
42
+ from typing import Any, Literal
43
 
44
  from prefect import task
45
  from prefect.cache_policies import INPUTS, TASK_SOURCE
 
54
  from ase.utils.forcecurve import fit_images
55
  from mlip_arena.models import MLIPEnum
56
  from mlip_arena.tasks.optimize import run as OPT
57
+ from mlip_arena.tasks.utils import get_calculator, logger, pformat
58
  from pymatgen.io.ase import AseAtomsAdaptor
59
 
 
 
 
 
60
  _valid_optimizers: dict[str, Optimizer] = {
61
  "MDMin": MDMin,
62
  "FIRE": FIRE,
 
82
  atoms = parameters["start"]
83
  else:
84
  raise ValueError("No images or start atoms found in parameters")
85
+
86
  calculator_name = parameters["calculator_name"]
87
 
88
  return f"{task_name}: {atoms.get_chemical_formula()} - {calculator_name}"
 
152
  criterion = criterion or {}
153
 
154
  optimizer_instance = optimizer(neb, trajectory=traj_file, **optimizer_kwargs) # type: ignore
155
+ logger.info(f"Using optimizer: {optimizer_instance}")
156
+ logger.info(pformat(optimizer_kwargs))
157
+ logger.info(f"Criterion: {pformat(criterion)}")
158
  optimizer_instance.run(**criterion)
159
 
160
  neb_tool = NEBTools(neb.images)
 
 
 
 
 
161
 
162
  return {
163
+ "barrier": neb_tool.get_barrier(),
164
+ "images": neb.images,
165
+ "forcefit": fit_images(neb.images),
166
  }
167
 
168
 
 
254
  )
255
  )
256
 
257
+ images = [s.to_ase_atoms(msonable=False) for s in path]
258
 
259
  return run.with_options(
260
  refresh_cache=not cache_subtasks,
mlip_arena/tasks/optimize.py CHANGED
@@ -15,7 +15,8 @@ from ase.filters import Filter
15
  from ase.optimize import * # type: ignore
16
  from ase.optimize.optimize import Optimizer
17
  from mlip_arena.models import MLIPEnum
18
- from mlip_arena.tasks.utils import get_calculator
 
19
 
20
  _valid_filters: dict[str, Filter] = {
21
  "Filter": Filter,
@@ -94,16 +95,20 @@ def run(
94
 
95
  if isinstance(filter, type) and issubclass(filter, Filter):
96
  filter_instance = filter(atoms, **filter_kwargs)
97
- print(f"Using filter: {filter_instance}")
 
98
 
99
- optimizer_instance = optimizer(atoms, **optimizer_kwargs)
100
- print(f"Using optimizer: {optimizer_instance}")
 
 
101
 
102
  optimizer_instance.run(**criterion)
103
-
104
  elif filter is None:
105
  optimizer_instance = optimizer(atoms, **optimizer_kwargs)
106
- print(f"Using optimizer: {optimizer_instance}")
 
 
107
  optimizer_instance.run(**criterion)
108
 
109
  return {
 
15
  from ase.optimize import * # type: ignore
16
  from ase.optimize.optimize import Optimizer
17
  from mlip_arena.models import MLIPEnum
18
+ from mlip_arena.tasks.utils import get_calculator, logger, pformat
19
+
20
 
21
  _valid_filters: dict[str, Filter] = {
22
  "Filter": Filter,
 
95
 
96
  if isinstance(filter, type) and issubclass(filter, Filter):
97
  filter_instance = filter(atoms, **filter_kwargs)
98
+ logger.info(f"Using filter: {filter_instance}")
99
+ logger.info(pformat(filter_kwargs))
100
 
101
+ optimizer_instance = optimizer(filter_instance, **optimizer_kwargs)
102
+ logger.info(f"Using optimizer: {optimizer_instance}")
103
+ logger.info(pformat(optimizer_kwargs))
104
+ logger.info(f"Criterion: {pformat(criterion)}")
105
 
106
  optimizer_instance.run(**criterion)
 
107
  elif filter is None:
108
  optimizer_instance = optimizer(atoms, **optimizer_kwargs)
109
+ logger.info(f"Using optimizer: {optimizer_instance}")
110
+ logger.info(pformat(optimizer_kwargs))
111
+ logger.info(f"Criterion: {pformat(criterion)}")
112
  optimizer_instance.run(**criterion)
113
 
114
  return {
mlip_arena/tasks/utils.py CHANGED
@@ -5,7 +5,7 @@ from __future__ import annotations
5
  from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
6
 
7
  from ase import units
8
- from ase.calculators.calculator import Calculator
9
  from ase.calculators.mixing import SumCalculator
10
  from mlip_arena.models import MLIPEnum
11
  from mlip_arena.models.utils import get_freer_device
@@ -21,7 +21,7 @@ from pprint import pformat
21
 
22
 
23
  def get_calculator(
24
- calculator_name: str | MLIPEnum | Calculator,
25
  calculator_kwargs: dict | None,
26
  dispersion: bool = False,
27
  dispersion_kwargs: dict | None = None,
@@ -30,22 +30,24 @@ def get_calculator(
30
  """Get a calculator with optional dispersion correction."""
31
  device = device or str(get_freer_device())
32
 
33
- logger.info("Using device: %s", device)
34
 
35
  calculator_kwargs = calculator_kwargs or {}
36
 
37
  if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
38
- assert issubclass(calculator_name.value, Calculator)
39
  calc = calculator_name.value(**calculator_kwargs)
40
  elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
41
  calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
42
- elif isinstance(calculator_name, Calculator):
43
- logger.warning("Using custom calculator: {calculator_name}")
 
 
 
44
  calc = calculator_name
45
  else:
46
  raise ValueError(f"Invalid calculator: {calculator_name}")
47
 
48
- logger.info("Using calculator: %s", calc)
49
  if calculator_kwargs:
50
  logger.info(pformat(calculator_kwargs))
51
 
@@ -61,9 +63,9 @@ def get_calculator(
61
  )
62
  calc = SumCalculator([calc, disp_calc])
63
 
64
- logger.info("Using dispersion: %s", disp_calc)
65
  if dispersion_kwargs:
66
  logger.info(pformat(dispersion_kwargs))
67
 
68
- assert isinstance(calc, Calculator) or isinstance(calc, SumCalculator)
69
  return calc
 
5
  from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
6
 
7
  from ase import units
8
+ from ase.calculators.calculator import Calculator, BaseCalculator
9
  from ase.calculators.mixing import SumCalculator
10
  from mlip_arena.models import MLIPEnum
11
  from mlip_arena.models.utils import get_freer_device
 
21
 
22
 
23
  def get_calculator(
24
+ calculator_name: str | MLIPEnum | Calculator | SumCalculator,
25
  calculator_kwargs: dict | None,
26
  dispersion: bool = False,
27
  dispersion_kwargs: dict | None = None,
 
30
  """Get a calculator with optional dispersion correction."""
31
  device = device or str(get_freer_device())
32
 
33
+ logger.info(f"Using device: {device}")
34
 
35
  calculator_kwargs = calculator_kwargs or {}
36
 
37
  if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
 
38
  calc = calculator_name.value(**calculator_kwargs)
39
  elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
40
  calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
41
+ elif isinstance(calculator_name, type) and issubclass(calculator_name, BaseCalculator):
42
+ logger.warning(f"Using custom calculator class: {calculator_name}")
43
+ calc = calculator_name(**calculator_kwargs)
44
+ elif isinstance(calculator_name, Calculator | SumCalculator):
45
+ logger.warning(f"Using custom calculator object (kwargs are ignored): {calculator_name}")
46
  calc = calculator_name
47
  else:
48
  raise ValueError(f"Invalid calculator: {calculator_name}")
49
 
50
+ logger.info(f"Using calculator: {calc}")
51
  if calculator_kwargs:
52
  logger.info(pformat(calculator_kwargs))
53
 
 
63
  )
64
  calc = SumCalculator([calc, disp_calc])
65
 
66
+ logger.info(f"Using dispersion: {disp_calc}")
67
  if dispersion_kwargs:
68
  logger.info(pformat(dispersion_kwargs))
69
 
70
+ assert isinstance(calc, Calculator | SumCalculator)
71
  return calc