File size: 4,243 Bytes
3b3aaa9
 
52c1bfb
 
49d0cfc
b3722a8
d390139
b3722a8
49d0cfc
d390139
d72faca
7cc6c4a
 
 
 
 
 
 
 
 
 
3b3aaa9
056d8d3
49d0cfc
52c1bfb
 
 
 
7cc6c4a
 
 
52c1bfb
1d1ee87
7cc6c4a
52c1bfb
 
 
0d1ce35
7cc6c4a
49d0cfc
 
 
 
 
7cbf186
 
7cc6c4a
 
 
7cbf186
2d8bda8
 
7cbf186
2d8bda8
0d1ce35
7cbf186
2d8bda8
 
 
49d0cfc
 
7cc6c4a
 
 
2d8bda8
 
 
 
7cbf186
49d0cfc
7cbf186
0d1ce35
 
 
7cbf186
2d8bda8
8d4e26f
7cc6c4a
 
 
2d8bda8
 
 
 
 
d390139
7cc6c4a
 
 
 
 
 
 
 
 
49d0cfc
0d1ce35
 
 
 
49d0cfc
d390139
 
d0bf60f
d390139
7cc6c4a
 
 
 
 
d390139
 
 
49d0cfc
d390139
49d0cfc
 
 
d390139
7cc6c4a
 
 
 
 
 
 
49d0cfc
7cc6c4a
d390139
7cc6c4a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations

import importlib
from enum import Enum
from pathlib import Path

import torch
import yaml
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes

try:
    from prefect.logging import get_run_logger

    logger = get_run_logger()
except (ImportError, RuntimeError):
    from loguru import logger

with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
    REGISTRY = yaml.safe_load(f)

MLIPMap = {}

for model, metadata in REGISTRY.items():
    try:
        module = importlib.import_module(
            f"{__package__}.{metadata['module']}.{metadata['family']}"
        )
        MLIPMap[model] = getattr(module, metadata["class"])
    except (ModuleNotFoundError, AttributeError, ValueError) as e:
        logger.warning(e)
        continue

MLIPEnum = Enum("MLIPEnum", MLIPMap)


class MLIP(
    nn.Module,
    PyTorchModelHubMixin,
    tags=["atomistic-simulation", "MLIP"],
):
    def __init__(self, model: nn.Module) -> None:
        super().__init__()
        # https://github.com/pytorch/pytorch/blob/3cbc8c54fd37eb590e2a9206aecf3ab568b3e63c/torch/_dynamo/config.py#L534
        # torch._dynamo.config.compiled_autograd = True
        # self.model = torch.compile(model)
        self.model = model

    def forward(self, x):
        return self.model(x)


class MLIPCalculator(MLIP, Calculator):
    name: str
    implemented_properties: list[str] = ["energy", "forces", "stress"]

    def __init__(
        self,
        model: nn.Module,
        device: torch.device | None = None,
        cutoff: float = 6.0,
        # ASE Calculator
        restart=None,
        atoms=None,
        directory=".",
        calculator_kwargs: dict = {},
    ):
        MLIP.__init__(self, model=model)  # Initialize MLIP part
        Calculator.__init__(
            self, restart=restart, atoms=atoms, directory=directory, **calculator_kwargs
        )  # Initialize ASE Calculator part
        # Additional initialization if needed
        # self.name: str = self.__class__.__name__
        from mlip_arena.models.utils import get_freer_device
        self.device = device or get_freer_device()
        self.cutoff = cutoff
        self.model.to(self.device)
        # self.device = device or torch.device(
        #     "cuda" if torch.cuda.is_available() else "cpu"
        # )
        # self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
        # self.implemented_properties = ["energy", "forces", "stress"]

    # def __getstate__(self):
    #     state = self.__dict__.copy()
    #     state["_modules"]["model"] = state["_modules"]["model"]._orig_mod
    #     return state

    # def __setstate__(self, state):
    #     self.__dict__.update(state)
    #     self.model = torch.compile(state["_modules"]["model"])

    def calculate(
        self,
        atoms: Atoms,
        properties: list[str],
        system_changes: list = all_changes,
    ):
        """Calculate energies and forces for the given Atoms object"""
        super().calculate(atoms, properties, system_changes)
        from mlip_arena.data.collate import collate_fn

        # TODO: move collate_fn to here in MLIPCalculator
        data = collate_fn([atoms], cutoff=self.cutoff).to(self.device)
        output = self.forward(data)

        # TODO: decollate_fn

        self.results = {}
        if "energy" in properties:
            self.results["energy"] = output["energy"].squeeze().item()
        if "forces" in properties:
            self.results["forces"] = output["forces"].squeeze().cpu().detach().numpy()
        if "stress" in properties:
            self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()

    # def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
    #     """Implement data conversion, graph creation, and model forward pass

    #     Example implementation:
    #     1. Use `ase.neighborlist.NeighborList` to get neighbor list
    #     2. Create `torch_geometric.data.Data` object and copy the data
    #     3. Pass the `Data` object to the model and return the output

    #     """

    #     raise NotImplementedError