Spaces:
Running
Running
File size: 2,210 Bytes
b5053f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import os
from pathlib import Path
from typing import Generator, Iterable
from huggingface_hub import HfApi, hf_hub_download
from prefect import task
from ase import Atoms
from ase.db import connect
from mlip_arena.tasks.utils import logger
def save_to_db(
atoms_list: list[Atoms] | Iterable[Atoms] | Atoms,
db_path: Path | str,
upload: bool = True,
hf_token: str | None = os.getenv("HF_TOKEN", None),
repo_id: str = "atomind/mlip-arena",
repo_type: str = "dataset",
subfolder: str = Path(__file__).parent.name,
):
"""Save ASE Atoms objects to an ASE database and optionally upload to Hugging Face Hub."""
if upload and hf_token is None:
raise ValueError("HF_TOKEN is required to upload the database.")
db_path = Path(db_path)
if isinstance(atoms_list, Atoms):
atoms_list = [atoms_list]
with connect(db_path) as db:
for atoms in atoms_list:
if not isinstance(atoms, Atoms):
raise ValueError("atoms_list must contain ASE Atoms objects.")
db.write(atoms)
if upload:
api = HfApi(token=hf_token)
api.upload_file(
path_or_fileobj=db_path,
path_in_repo=f"{subfolder}/{db_path.name}",
repo_id=repo_id,
repo_type=repo_type,
)
logger.info(f"{db_path.name} uploaded to {repo_id}/{subfolder}")
return db_path
@task
def get_atoms_from_db(
db_path: Path | str,
hf_token: str | None = os.getenv("HF_TOKEN", None),
repo_id: str = "atomind/mlip-arena",
repo_type: str = "dataset",
subfolder: str = Path(__file__).parent.name,
force_download: bool = False,
) -> Generator[Atoms, None, None]:
"""Retrieve ASE Atoms objects from an ASE database."""
db_path = Path(db_path)
if not db_path.exists():
db_path = hf_hub_download(
repo_id=repo_id,
repo_type=repo_type,
subfolder=subfolder,
# local_dir=db_path.parent,
filename=db_path.name,
token=hf_token,
force_download=force_download,
)
with connect(db_path) as db:
for row in db.select():
yield row.toatoms()
|