Yuan (Cyrus) Chiang
Add `eqV2_86M_omat_mp_salex` model (#14)
52c1bfb unverified
raw
history blame
1.17 kB
from __future__ import annotations
from typing import Literal
from ase import Atoms
from chgnet.model.dynamics import CHGNetCalculator
from chgnet.model.model import CHGNet as CHGNetModel
from mlip_arena.models.utils import get_freer_device
class CHGNet(CHGNetCalculator):
def __init__(
self,
checkpoint: CHGNetModel | None = None, # TODO: specifiy version
device: str | None = None,
stress_weight: float | None = 1 / 160.21766208,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
**kwargs,
) -> None:
use_device = device or str(get_freer_device())
super().__init__(
model=checkpoint,
use_device=use_device,
stress_weight=stress_weight,
on_isolated_atoms=on_isolated_atoms,
**kwargs,
)
def calculate(
self,
atoms: Atoms | None = None,
properties: list | None = None,
system_changes: list | None = None,
) -> None:
super().calculate(atoms, properties, system_changes)
# for ase.io.write compatibility
self.results.pop("crystal_fea", None)