Spaces:
Running
Running
from __future__ import annotations | |
from pathlib import Path | |
import yaml | |
from mattersim.forcefield import MatterSimCalculator | |
from ase import Atoms | |
from mlip_arena.models.utils import get_freer_device | |
# from pymatgen.io.ase import AseAtomsAdaptor, MSONAtoms | |
with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f: | |
REGISTRY = yaml.safe_load(f) | |
class MatterSim(MatterSimCalculator): | |
def __init__( | |
self, | |
checkpoint=REGISTRY["MatterSim"]["checkpoint"], | |
device=None, | |
**kwargs, | |
): | |
super().__init__( | |
load_path=checkpoint, device=str(device or get_freer_device()), **kwargs | |
) | |
def __getstate__(self): | |
state = self.__dict__.copy() | |
# BUG: remove unpicklizable potential | |
state.pop("potential", None) | |
return state | |
# def calculate( | |
# self, | |
# atoms: Atoms | None = None, | |
# properties: list | None = None, | |
# system_changes: list | None = None, | |
# ): | |
# super().calculate(atoms, properties, system_changes) | |
# # convert unpicklizable atoms back to picklizable atoms to avoid prefect pickling error | |
# if isinstance(self.atoms, MSONAtoms): | |
# atoms = self.atoms.copy() | |
# strucutre = AseAtomsAdaptor().get_structure(atoms) | |
# self.atoms = AseAtomsAdaptor().get_atoms(strucutre, msonable=False) | |