Spaces:
Running
Running
File size: 5,295 Bytes
1485b15 a133fcb d98d701 072f65e fa55745 1547ed2 fa55745 d98d701 1547ed2 df95987 c7442c5 072f65e c7442c5 1485b15 c162771 d98d701 a133fcb c7442c5 a133fcb 80f7283 c7442c5 a133fcb e517f23 c7442c5 d98d701 5d9e01e c7442c5 5d9e01e 242e83d 072f65e 242e83d 5d9e01e 242e83d e517f23 5d9e01e 242e83d d98d701 c7442c5 242e83d 5d9e01e d98d701 5d9e01e e517f23 5d9e01e 242e83d dad15cc c7442c5 242e83d 1547ed2 c7442c5 d98d701 e517f23 d98d701 1547ed2 c7442c5 1547ed2 d98d701 1547ed2 d98d701 1547ed2 c7442c5 d98d701 c7442c5 1547ed2 c7442c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from __future__ import annotations
import importlib
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Type, TypeVar, Union
T = TypeVar("T", bound="MLIP")
import torch
import yaml
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from mlip_arena.data.collate import collate_fn
try:
from prefect.logging import get_run_logger
logger = get_run_logger()
except (ImportError, RuntimeError):
from loguru import logger
with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
REGISTRY = yaml.safe_load(f)
MLIPMap = {}
for model, metadata in REGISTRY.items():
try:
module = importlib.import_module(
f"{__package__}.{metadata['module']}.{metadata['family']}"
)
MLIPMap[model] = getattr(module, metadata["class"])
except (ModuleNotFoundError, AttributeError, ValueError) as e:
logger.warning(e)
continue
MLIPEnum = Enum("MLIPEnum", MLIPMap)
class MLIP(
nn.Module,
PyTorchModelHubMixin,
tags=["atomistic-simulation", "MLIP"],
):
def __init__(self, model: nn.Module) -> None:
super().__init__()
# https://github.com/pytorch/pytorch/blob/3cbc8c54fd37eb590e2a9206aecf3ab568b3e63c/torch/_dynamo/config.py#L534
# torch._dynamo.config.compiled_autograd = True
# self.model = torch.compile(model)
self.model = model
def _save_pretrained(self, save_directory: Path) -> None:
return super()._save_pretrained(save_directory)
@classmethod
def from_pretrained(
cls: Type[T],
pretrained_model_name_or_path: Union[str, Path],
*,
force_download: bool = False,
resume_download: Optional[bool] = None,
proxies: Optional[Dict] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: Optional[Union[str, Path]] = None,
local_files_only: bool = False,
revision: Optional[str] = None,
**model_kwargs,
) -> T:
return super().from_pretrained(
pretrained_model_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**model_kwargs,
)
def forward(self, x):
return self.model(x)
class MLIPCalculator(MLIP, Calculator):
name: str
implemented_properties: list[str] = ["energy", "forces", "stress"]
def __init__(
self,
model: nn.Module,
device: torch.device | None = None,
cutoff: float = 6.0,
# ASE Calculator
restart=None,
atoms=None,
directory=".",
calculator_kwargs: dict = {},
):
MLIP.__init__(self, model=model) # Initialize MLIP part
Calculator.__init__(
self, restart=restart, atoms=atoms, directory=directory, **calculator_kwargs
) # Initialize ASE Calculator part
# Additional initialization if needed
# self.name: str = self.__class__.__name__
from mlip_arena.models.utils import get_freer_device
self.device = device or get_freer_device()
self.cutoff = cutoff
self.model.to(self.device)
# self.device = device or torch.device(
# "cuda" if torch.cuda.is_available() else "cpu"
# )
# self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
# self.implemented_properties = ["energy", "forces", "stress"]
# def __getstate__(self):
# state = self.__dict__.copy()
# state["_modules"]["model"] = state["_modules"]["model"]._orig_mod
# return state
# def __setstate__(self, state):
# self.__dict__.update(state)
# self.model = torch.compile(state["_modules"]["model"])
def calculate(
self,
atoms: Atoms,
properties: list[str],
system_changes: list = all_changes,
):
"""Calculate energies and forces for the given Atoms object"""
super().calculate(atoms, properties, system_changes)
# TODO: move collate_fn to here in MLIPCalculator
data = collate_fn([atoms], cutoff=self.cutoff).to(self.device)
output = self.forward(data)
# TODO: decollate_fn
self.results = {}
if "energy" in properties:
self.results["energy"] = output["energy"].squeeze().item()
if "forces" in properties:
self.results["forces"] = output["forces"].squeeze().cpu().detach().numpy()
if "stress" in properties:
self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()
# def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
# """Implement data conversion, graph creation, and model forward pass
# Example implementation:
# 1. Use `ase.neighborlist.NeighborList` to get neighbor list
# 2. Create `torch_geometric.data.Data` object and copy the data
# 3. Pass the `Data` object to the model and return the output
# """
# raise NotImplementedError
|