File size: 4,140 Bytes
11ac28c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Utility functions for MLIP models."""

from __future__ import annotations

from pprint import pformat

import torch
from ase import units
from ase.calculators.calculator import BaseCalculator
from ase.calculators.mixing import SumCalculator

from mlip_arena.models import MLIPEnum

try:
    from prefect.logging import get_run_logger

    logger = get_run_logger()
except (ImportError, RuntimeError):
    from loguru import logger


def get_freer_device() -> torch.device:
    """Get the GPU with the most free memory, or use MPS if available.
    s
        Returns:
            torch.device: The selected GPU device or MPS.

        Raises:
            ValueError: If no GPU or MPS is available.
    """
    device_count = torch.cuda.device_count()
    if device_count > 0:
        # If CUDA GPUs are available, select the one with the most free memory
        mem_free = [
            torch.cuda.get_device_properties(i).total_memory
            - torch.cuda.memory_allocated(i)
            for i in range(device_count)
        ]
        free_gpu_index = mem_free.index(max(mem_free))
        device = torch.device(f"cuda:{free_gpu_index}")
        logger.info(
            f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
        )
    elif torch.backends.mps.is_available():
        # If no CUDA GPUs are available but MPS is, use MPS
        logger.info("No GPU available. Using MPS.")
        device = torch.device("mps")
    else:
        # Fallback to CPU if neither CUDA GPUs nor MPS are available
        logger.info("No GPU or MPS available. Using CPU.")
        device = torch.device("cpu")

    return device


def get_calculator(
    calculator_name: str | MLIPEnum | BaseCalculator,
    calculator_kwargs: dict | None = None,
    dispersion: bool = False,
    dispersion_kwargs: dict | None = None,
    device: str | None = None,
) -> BaseCalculator:
    """Get a calculator with optional dispersion correction."""

    device = device or str(get_freer_device())

    calculator_kwargs = calculator_kwargs or {}
    calculator_kwargs.update({"device": device})

    logger.info(f"Using device: {device}")

    if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
        calc = calculator_name.value(**calculator_kwargs)
        calc.__str__ = lambda: calculator_name.name
    elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
        calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
        calc.__str__ = lambda: calculator_name
    elif isinstance(calculator_name, type) and issubclass(
        calculator_name, BaseCalculator
    ):
        logger.warning(f"Using custom calculator class: {calculator_name}")
        calc = calculator_name(**calculator_kwargs)
        calc.__str__ = lambda: f"{calc.__class__.__name__}"
    elif isinstance(calculator_name, BaseCalculator):
        logger.warning(
            f"Using custom calculator object (kwargs are ignored): {calculator_name}"
        )
        calc = calculator_name
        calc.__str__ = lambda: f"{calc.__class__.__name__}"
    else:
        raise ValueError(f"Invalid calculator: {calculator_name}")

    logger.info(f"Using calculator: {calc}")
    if calculator_kwargs:
        logger.info(pformat(calculator_kwargs))

    dispersion_kwargs = dispersion_kwargs or dict(
        damping="bj", xc="pbe", cutoff=40.0 * units.Bohr
    )

    dispersion_kwargs.update({"device": device})

    if dispersion:
        try:
            from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
        except ImportError as e:
            raise ImportError(
                "torch_dftd is required for dispersion but is not installed."
            ) from e

        disp_calc = TorchDFTD3Calculator(
            **dispersion_kwargs,
        )
        calc = SumCalculator([calc, disp_calc])
        # TODO: rename the SumCalculator

        logger.info(f"Using dispersion: {disp_calc}")
        if dispersion_kwargs:
            logger.info(pformat(dispersion_kwargs))

    assert isinstance(calc, BaseCalculator)
    return calc