File size: 1,165 Bytes
52c1bfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9692e2
52c1bfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from __future__ import annotations

from typing import Literal

from ase import Atoms
from chgnet.model.dynamics import CHGNetCalculator
from chgnet.model.model import CHGNet as CHGNetModel

from mlip_arena.models.utils import get_freer_device


class CHGNet(CHGNetCalculator):
    def __init__(
        self,
        checkpoint: CHGNetModel | None = None,  # TODO: specifiy version
        device: str | None = None,
        stress_weight: float | None = 1 / 160.21766208,
        on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
        **kwargs,
    ) -> None:
        use_device = str(device or get_freer_device())
        super().__init__(
            model=checkpoint,
            use_device=use_device,
            stress_weight=stress_weight,
            on_isolated_atoms=on_isolated_atoms,
            **kwargs,
        )

    def calculate(
        self,
        atoms: Atoms | None = None,
        properties: list | None = None,
        system_changes: list | None = None,
    ) -> None:
        super().calculate(atoms, properties, system_changes)

        # for ase.io.write compatibility
        self.results.pop("crystal_fea", None)