import os from pathlib import Path from typing import Generator, Iterable from huggingface_hub import HfApi, hf_hub_download from prefect import task from ase import Atoms from ase.db import connect from mlip_arena.tasks.utils import logger def save_to_db( atoms_list: list[Atoms] | Iterable[Atoms] | Atoms, db_path: Path | str, upload: bool = True, hf_token: str | None = os.getenv("HF_TOKEN", None), repo_id: str = "atomind/mlip-arena", repo_type: str = "dataset", subfolder: str = Path(__file__).parent.name, ): """Save ASE Atoms objects to an ASE database and optionally upload to Hugging Face Hub.""" if upload and hf_token is None: raise ValueError("HF_TOKEN is required to upload the database.") db_path = Path(db_path) if isinstance(atoms_list, Atoms): atoms_list = [atoms_list] with connect(db_path) as db: for atoms in atoms_list: if not isinstance(atoms, Atoms): raise ValueError("atoms_list must contain ASE Atoms objects.") db.write(atoms) if upload: api = HfApi(token=hf_token) api.upload_file( path_or_fileobj=db_path, path_in_repo=f"{subfolder}/{db_path.name}", repo_id=repo_id, repo_type=repo_type, ) logger.info(f"{db_path.name} uploaded to {repo_id}/{subfolder}") return db_path @task def get_atoms_from_db( db_path: Path | str, hf_token: str | None = os.getenv("HF_TOKEN", None), repo_id: str = "atomind/mlip-arena", repo_type: str = "dataset", subfolder: str = Path(__file__).parent.name, force_download: bool = False, ) -> Generator[Atoms, None, None]: """Retrieve ASE Atoms objects from an ASE database.""" db_path = Path(db_path) if not db_path.exists(): db_path = hf_hub_download( repo_id=repo_id, repo_type=repo_type, subfolder=subfolder, # local_dir=db_path.parent, filename=db_path.name, token=hf_token, force_download=force_download, ) with connect(db_path) as db: for row in db.select(): yield row.toatoms()