Spaces:
Running
Running
File size: 4,243 Bytes
3b3aaa9 52c1bfb 49d0cfc b3722a8 d390139 b3722a8 49d0cfc d390139 d72faca 7cc6c4a 3b3aaa9 056d8d3 49d0cfc 52c1bfb 7cc6c4a 52c1bfb 1d1ee87 7cc6c4a 52c1bfb 0d1ce35 7cc6c4a 49d0cfc 7cbf186 7cc6c4a 7cbf186 2d8bda8 7cbf186 2d8bda8 0d1ce35 7cbf186 2d8bda8 49d0cfc 7cc6c4a 2d8bda8 7cbf186 49d0cfc 7cbf186 0d1ce35 7cbf186 2d8bda8 8d4e26f 7cc6c4a 2d8bda8 d390139 7cc6c4a 49d0cfc 0d1ce35 49d0cfc d390139 d0bf60f d390139 7cc6c4a d390139 49d0cfc d390139 49d0cfc d390139 7cc6c4a 49d0cfc 7cc6c4a d390139 7cc6c4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from __future__ import annotations
import importlib
from enum import Enum
from pathlib import Path
import torch
import yaml
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
try:
from prefect.logging import get_run_logger
logger = get_run_logger()
except (ImportError, RuntimeError):
from loguru import logger
with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
REGISTRY = yaml.safe_load(f)
MLIPMap = {}
for model, metadata in REGISTRY.items():
try:
module = importlib.import_module(
f"{__package__}.{metadata['module']}.{metadata['family']}"
)
MLIPMap[model] = getattr(module, metadata["class"])
except (ModuleNotFoundError, AttributeError, ValueError) as e:
logger.warning(e)
continue
MLIPEnum = Enum("MLIPEnum", MLIPMap)
class MLIP(
nn.Module,
PyTorchModelHubMixin,
tags=["atomistic-simulation", "MLIP"],
):
def __init__(self, model: nn.Module) -> None:
super().__init__()
# https://github.com/pytorch/pytorch/blob/3cbc8c54fd37eb590e2a9206aecf3ab568b3e63c/torch/_dynamo/config.py#L534
# torch._dynamo.config.compiled_autograd = True
# self.model = torch.compile(model)
self.model = model
def forward(self, x):
return self.model(x)
class MLIPCalculator(MLIP, Calculator):
name: str
implemented_properties: list[str] = ["energy", "forces", "stress"]
def __init__(
self,
model: nn.Module,
device: torch.device | None = None,
cutoff: float = 6.0,
# ASE Calculator
restart=None,
atoms=None,
directory=".",
calculator_kwargs: dict = {},
):
MLIP.__init__(self, model=model) # Initialize MLIP part
Calculator.__init__(
self, restart=restart, atoms=atoms, directory=directory, **calculator_kwargs
) # Initialize ASE Calculator part
# Additional initialization if needed
# self.name: str = self.__class__.__name__
from mlip_arena.models.utils import get_freer_device
self.device = device or get_freer_device()
self.cutoff = cutoff
self.model.to(self.device)
# self.device = device or torch.device(
# "cuda" if torch.cuda.is_available() else "cpu"
# )
# self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
# self.implemented_properties = ["energy", "forces", "stress"]
# def __getstate__(self):
# state = self.__dict__.copy()
# state["_modules"]["model"] = state["_modules"]["model"]._orig_mod
# return state
# def __setstate__(self, state):
# self.__dict__.update(state)
# self.model = torch.compile(state["_modules"]["model"])
def calculate(
self,
atoms: Atoms,
properties: list[str],
system_changes: list = all_changes,
):
"""Calculate energies and forces for the given Atoms object"""
super().calculate(atoms, properties, system_changes)
from mlip_arena.data.collate import collate_fn
# TODO: move collate_fn to here in MLIPCalculator
data = collate_fn([atoms], cutoff=self.cutoff).to(self.device)
output = self.forward(data)
# TODO: decollate_fn
self.results = {}
if "energy" in properties:
self.results["energy"] = output["energy"].squeeze().item()
if "forces" in properties:
self.results["forces"] = output["forces"].squeeze().cpu().detach().numpy()
if "stress" in properties:
self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()
# def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
# """Implement data conversion, graph creation, and model forward pass
# Example implementation:
# 1. Use `ase.neighborlist.NeighborList` to get neighbor list
# 2. Create `torch_geometric.data.Data` object and copy the data
# 3. Pass the `Data` object to the model and return the output
# """
# raise NotImplementedError
|