File size: 3,867 Bytes
e59bc30
 
 
 
5716d3b
 
 
a787930
5716d3b
e59bc30
419b35b
 
e59bc30
 
a787930
 
 
 
 
 
 
5716d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a787930
e59bc30
 
5716d3b
 
a787930
e59bc30
 
5716d3b
a787930
e59bc30
8cb1d3b
e59bc30
 
8cb1d3b
 
 
e59bc30
 
 
419b35b
e59bc30
 
5716d3b
 
 
e80e29d
 
419b35b
5716d3b
 
 
 
cb1fb61
419b35b
e59bc30
 
 
e80e29d
a787930
 
e59bc30
a787930
 
 
e59bc30
 
 
a787930
e59bc30
 
 
 
 
e80e29d
a787930
 
e59bc30
419b35b
e59bc30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Utility functions for MLIP models."""

from __future__ import annotations

from pprint import pformat

import torch
from ase import units
from ase.calculators.calculator import BaseCalculator
from ase.calculators.mixing import SumCalculator
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator

from mlip_arena.models import MLIPEnum

try:
    from prefect.logging import get_run_logger

    logger = get_run_logger()
except (ImportError, RuntimeError):
    from loguru import logger


def get_freer_device() -> torch.device:
    """Get the GPU with the most free memory, or use MPS if available.
    s
        Returns:
            torch.device: The selected GPU device or MPS.

        Raises:
            ValueError: If no GPU or MPS is available.
    """
    device_count = torch.cuda.device_count()
    if device_count > 0:
        # If CUDA GPUs are available, select the one with the most free memory
        mem_free = [
            torch.cuda.get_device_properties(i).total_memory
            - torch.cuda.memory_allocated(i)
            for i in range(device_count)
        ]
        free_gpu_index = mem_free.index(max(mem_free))
        device = torch.device(f"cuda:{free_gpu_index}")
        logger.info(
            f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
        )
    elif torch.backends.mps.is_available():
        # If no CUDA GPUs are available but MPS is, use MPS
        logger.info("No GPU available. Using MPS.")
        device = torch.device("mps")
    else:
        # Fallback to CPU if neither CUDA GPUs nor MPS are available
        logger.info("No GPU or MPS available. Using CPU.")
        device = torch.device("cpu")

    return device


def get_calculator(
    calculator_name: str | MLIPEnum | BaseCalculator,
    calculator_kwargs: dict | None = None,
    dispersion: bool = False,
    dispersion_kwargs: dict | None = None,
    device: str | None = None,
) -> BaseCalculator:
    """Get a calculator with optional dispersion correction."""

    device = device or str(get_freer_device())

    calculator_kwargs = calculator_kwargs or {}
    calculator_kwargs.update({"device": device})

    logger.info(f"Using device: {device}")

    if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
        calc = calculator_name.value(**calculator_kwargs)
        calc.__str__ = lambda: calculator_name.name
    elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
        calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
    elif isinstance(calculator_name, type) and issubclass(
        calculator_name, BaseCalculator
    ):
        logger.warning(f"Using custom calculator class: {calculator_name}")
        calc = calculator_name(**calculator_kwargs)
        calc.__str__ = lambda: f"{calc.__class__.__name__}"
    elif isinstance(calculator_name, BaseCalculator):
        logger.warning(
            f"Using custom calculator object (kwargs are ignored): {calculator_name}"
        )
        calc = calculator_name
        calc.__str__ = lambda: f"{calc.__class__.__name__}"
    else:
        raise ValueError(f"Invalid calculator: {calculator_name}")

    logger.info(f"Using calculator: {calc}")
    if calculator_kwargs:
        logger.info(pformat(calculator_kwargs))

    dispersion_kwargs = dispersion_kwargs or dict(
        damping="bj", xc="pbe", cutoff=40.0 * units.Bohr
    )

    dispersion_kwargs.update({"device": device})

    if dispersion:
        disp_calc = TorchDFTD3Calculator(
            **dispersion_kwargs,
        )
        calc = SumCalculator([calc, disp_calc])

        logger.info(f"Using dispersion: {disp_calc}")
        if dispersion_kwargs:
            logger.info(pformat(dispersion_kwargs))

    assert isinstance(calc, BaseCalculator)    
    return calc