cyrusyc's picture
move imports into calculators
8d4e26f
from __future__ import annotations
import importlib
from enum import Enum
from pathlib import Path
import torch
import yaml
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
try:
from prefect.logging import get_run_logger
logger = get_run_logger()
except (ImportError, RuntimeError):
from loguru import logger
with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
REGISTRY = yaml.safe_load(f)
MLIPMap = {}
for model, metadata in REGISTRY.items():
try:
module = importlib.import_module(
f"{__package__}.{metadata['module']}.{metadata['family']}"
)
MLIPMap[model] = getattr(module, metadata["class"])
except (ModuleNotFoundError, AttributeError, ValueError) as e:
logger.warning(e)
continue
MLIPEnum = Enum("MLIPEnum", MLIPMap)
class MLIP(
nn.Module,
PyTorchModelHubMixin,
tags=["atomistic-simulation", "MLIP"],
):
def __init__(self, model: nn.Module) -> None:
super().__init__()
# https://github.com/pytorch/pytorch/blob/3cbc8c54fd37eb590e2a9206aecf3ab568b3e63c/torch/_dynamo/config.py#L534
# torch._dynamo.config.compiled_autograd = True
# self.model = torch.compile(model)
self.model = model
def forward(self, x):
return self.model(x)
class MLIPCalculator(MLIP, Calculator):
name: str
implemented_properties: list[str] = ["energy", "forces", "stress"]
def __init__(
self,
model: nn.Module,
device: torch.device | None = None,
cutoff: float = 6.0,
# ASE Calculator
restart=None,
atoms=None,
directory=".",
calculator_kwargs: dict = {},
):
MLIP.__init__(self, model=model) # Initialize MLIP part
Calculator.__init__(
self, restart=restart, atoms=atoms, directory=directory, **calculator_kwargs
) # Initialize ASE Calculator part
# Additional initialization if needed
# self.name: str = self.__class__.__name__
from mlip_arena.models.utils import get_freer_device
self.device = device or get_freer_device()
self.cutoff = cutoff
self.model.to(self.device)
# self.device = device or torch.device(
# "cuda" if torch.cuda.is_available() else "cpu"
# )
# self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
# self.implemented_properties = ["energy", "forces", "stress"]
# def __getstate__(self):
# state = self.__dict__.copy()
# state["_modules"]["model"] = state["_modules"]["model"]._orig_mod
# return state
# def __setstate__(self, state):
# self.__dict__.update(state)
# self.model = torch.compile(state["_modules"]["model"])
def calculate(
self,
atoms: Atoms,
properties: list[str],
system_changes: list = all_changes,
):
"""Calculate energies and forces for the given Atoms object"""
super().calculate(atoms, properties, system_changes)
from mlip_arena.data.collate import collate_fn
# TODO: move collate_fn to here in MLIPCalculator
data = collate_fn([atoms], cutoff=self.cutoff).to(self.device)
output = self.forward(data)
# TODO: decollate_fn
self.results = {}
if "energy" in properties:
self.results["energy"] = output["energy"].squeeze().item()
if "forces" in properties:
self.results["forces"] = output["forces"].squeeze().cpu().detach().numpy()
if "stress" in properties:
self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()
# def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
# """Implement data conversion, graph creation, and model forward pass
# Example implementation:
# 1. Use `ase.neighborlist.NeighborList` to get neighbor list
# 2. Create `torch_geometric.data.Data` object and copy the data
# 3. Pass the `Data` object to the model and return the output
# """
# raise NotImplementedError