Spaces:
Running
Running
import os | |
from pathlib import Path | |
from typing import Generator, Iterable | |
from huggingface_hub import HfApi, hf_hub_download | |
from prefect import task | |
from ase import Atoms | |
from ase.db import connect | |
from mlip_arena.tasks.utils import logger | |
def save_to_db( | |
atoms_list: list[Atoms] | Iterable[Atoms] | Atoms, | |
db_path: Path | str, | |
upload: bool = True, | |
hf_token: str | None = os.getenv("HF_TOKEN", None), | |
repo_id: str = "atomind/mlip-arena", | |
repo_type: str = "dataset", | |
subfolder: str = Path(__file__).parent.name, | |
): | |
"""Save ASE Atoms objects to an ASE database and optionally upload to Hugging Face Hub.""" | |
if upload and hf_token is None: | |
raise ValueError("HF_TOKEN is required to upload the database.") | |
db_path = Path(db_path) | |
if isinstance(atoms_list, Atoms): | |
atoms_list = [atoms_list] | |
with connect(db_path) as db: | |
for atoms in atoms_list: | |
if not isinstance(atoms, Atoms): | |
raise ValueError("atoms_list must contain ASE Atoms objects.") | |
db.write(atoms) | |
if upload: | |
api = HfApi(token=hf_token) | |
api.upload_file( | |
path_or_fileobj=db_path, | |
path_in_repo=f"{subfolder}/{db_path.name}", | |
repo_id=repo_id, | |
repo_type=repo_type, | |
) | |
logger.info(f"{db_path.name} uploaded to {repo_id}/{subfolder}") | |
return db_path | |
def get_atoms_from_db( | |
db_path: Path | str, | |
hf_token: str | None = os.getenv("HF_TOKEN", None), | |
repo_id: str = "atomind/mlip-arena", | |
repo_type: str = "dataset", | |
subfolder: str = Path(__file__).parent.name, | |
force_download: bool = False, | |
) -> Generator[Atoms, None, None]: | |
"""Retrieve ASE Atoms objects from an ASE database.""" | |
db_path = Path(db_path) | |
if not db_path.exists(): | |
db_path = hf_hub_download( | |
repo_id=repo_id, | |
repo_type=repo_type, | |
subfolder=subfolder, | |
# local_dir=db_path.parent, | |
filename=db_path.name, | |
token=hf_token, | |
force_download=force_download, | |
) | |
with connect(db_path) as db: | |
for row in db.select(): | |
yield row.toatoms() | |