from __future__ import annotations

from pathlib import Path

import yaml
from ase import Atoms
from fairchem.core import OCPCalculator

with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
    REGISTRY = yaml.safe_load(f)


class EquiformerV2(OCPCalculator):
    def __init__(
        self,
        checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"],
        # TODO: cannot assign device
        local_cache="/tmp/ocp/",
        cpu=False,
        seed=0,
        **kwargs,
    ) -> None:
        super().__init__(
            model_name=checkpoint,
            local_cache=local_cache,
            cpu=cpu,
            seed=seed,
            **kwargs,
        )

    def calculate(self, atoms: Atoms, properties, system_changes) -> None:
        super().calculate(atoms, properties, system_changes)

        self.results.update(
            force=atoms.get_forces(),
        )


class EquiformerV2OC20(OCPCalculator):
    def __init__(
        self,
        checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"],
        # TODO: cannot assign device
        local_cache="/tmp/ocp/",
        cpu=False,
        seed=0,
        **kwargs,
    ) -> None:
        super().__init__(
            model_name=checkpoint,
            local_cache=local_cache,
            cpu=cpu,
            seed=seed,
            **kwargs,
        )