Spaces:
Running
Running
File size: 3,867 Bytes
e59bc30 5716d3b a787930 5716d3b e59bc30 419b35b e59bc30 a787930 5716d3b a787930 e59bc30 5716d3b a787930 e59bc30 5716d3b a787930 e59bc30 8cb1d3b e59bc30 8cb1d3b e59bc30 419b35b e59bc30 5716d3b e80e29d 419b35b 5716d3b cb1fb61 419b35b e59bc30 e80e29d a787930 e59bc30 a787930 e59bc30 a787930 e59bc30 e80e29d a787930 e59bc30 419b35b e59bc30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
"""Utility functions for MLIP models."""
from __future__ import annotations
from pprint import pformat
import torch
from ase import units
from ase.calculators.calculator import BaseCalculator
from ase.calculators.mixing import SumCalculator
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
from mlip_arena.models import MLIPEnum
try:
from prefect.logging import get_run_logger
logger = get_run_logger()
except (ImportError, RuntimeError):
from loguru import logger
def get_freer_device() -> torch.device:
"""Get the GPU with the most free memory, or use MPS if available.
s
Returns:
torch.device: The selected GPU device or MPS.
Raises:
ValueError: If no GPU or MPS is available.
"""
device_count = torch.cuda.device_count()
if device_count > 0:
# If CUDA GPUs are available, select the one with the most free memory
mem_free = [
torch.cuda.get_device_properties(i).total_memory
- torch.cuda.memory_allocated(i)
for i in range(device_count)
]
free_gpu_index = mem_free.index(max(mem_free))
device = torch.device(f"cuda:{free_gpu_index}")
logger.info(
f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
)
elif torch.backends.mps.is_available():
# If no CUDA GPUs are available but MPS is, use MPS
logger.info("No GPU available. Using MPS.")
device = torch.device("mps")
else:
# Fallback to CPU if neither CUDA GPUs nor MPS are available
logger.info("No GPU or MPS available. Using CPU.")
device = torch.device("cpu")
return device
def get_calculator(
calculator_name: str | MLIPEnum | BaseCalculator,
calculator_kwargs: dict | None = None,
dispersion: bool = False,
dispersion_kwargs: dict | None = None,
device: str | None = None,
) -> BaseCalculator:
"""Get a calculator with optional dispersion correction."""
device = device or str(get_freer_device())
calculator_kwargs = calculator_kwargs or {}
calculator_kwargs.update({"device": device})
logger.info(f"Using device: {device}")
if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
calc = calculator_name.value(**calculator_kwargs)
calc.__str__ = lambda: calculator_name.name
elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
elif isinstance(calculator_name, type) and issubclass(
calculator_name, BaseCalculator
):
logger.warning(f"Using custom calculator class: {calculator_name}")
calc = calculator_name(**calculator_kwargs)
calc.__str__ = lambda: f"{calc.__class__.__name__}"
elif isinstance(calculator_name, BaseCalculator):
logger.warning(
f"Using custom calculator object (kwargs are ignored): {calculator_name}"
)
calc = calculator_name
calc.__str__ = lambda: f"{calc.__class__.__name__}"
else:
raise ValueError(f"Invalid calculator: {calculator_name}")
logger.info(f"Using calculator: {calc}")
if calculator_kwargs:
logger.info(pformat(calculator_kwargs))
dispersion_kwargs = dispersion_kwargs or dict(
damping="bj", xc="pbe", cutoff=40.0 * units.Bohr
)
dispersion_kwargs.update({"device": device})
if dispersion:
disp_calc = TorchDFTD3Calculator(
**dispersion_kwargs,
)
calc = SumCalculator([calc, disp_calc])
logger.info(f"Using dispersion: {disp_calc}")
if dispersion_kwargs:
logger.info(pformat(dispersion_kwargs))
assert isinstance(calc, BaseCalculator)
return calc
|