|
import os |
|
import time |
|
import tarfile |
|
from dotenv import load_dotenv |
|
from huggingface_hub import CommitScheduler, HfApi |
|
import logging |
|
from pathlib import Path |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
REPO_ID = os.getenv('HF_REPO_ID') |
|
SYNC_INTERVAL = int(os.getenv('SYNC_INTERVAL', 5)) |
|
DATA_PATH = "/data" |
|
ARCHIVE_NAME = "data.tar.gz" |
|
SYNC_PATH = "/sync" |
|
ARCHIVE_PATH = f"{SYNC_PATH}/{ARCHIVE_NAME}" |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
if not REPO_ID: |
|
raise ValueError("HF_REPO_ID must be set in environment variables") |
|
|
|
def tar_filter(tarinfo): |
|
"""tar 文件过滤器""" |
|
if tarinfo.name.startswith('data/'): |
|
tarinfo.name = tarinfo.name[5:] |
|
return tarinfo |
|
|
|
def download_and_extract(): |
|
"""下载并解压数据""" |
|
api = HfApi() |
|
try: |
|
|
|
logger.info("Downloading data archive...") |
|
api.hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename=ARCHIVE_NAME, |
|
repo_type="dataset", |
|
local_dir=SYNC_PATH |
|
) |
|
|
|
|
|
logger.info("Extracting archive...") |
|
with tarfile.open(ARCHIVE_PATH, "r:gz") as tar: |
|
tar.extractall(path=DATA_PATH, filter=tar_filter) |
|
|
|
except Exception as e: |
|
logger.info(f"No existing archive found or download failed: {e}") |
|
|
|
Path(DATA_PATH).mkdir(parents=True, exist_ok=True) |
|
|
|
def create_archive(): |
|
"""创建压缩包""" |
|
logger.info("Creating new archive...") |
|
with tarfile.open(ARCHIVE_PATH, "w:gz") as tar: |
|
tar.add(DATA_PATH, arcname="data") |
|
logger.info("Archive created") |
|
|
|
def main(): |
|
logger.info(f"Starting sync process for repo: {REPO_ID}") |
|
logger.info(f"Sync interval: {SYNC_INTERVAL} minutes") |
|
|
|
|
|
Path(SYNC_PATH).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
download_and_extract() |
|
|
|
|
|
scheduler = CommitScheduler( |
|
repo_id=REPO_ID, |
|
repo_type="dataset", |
|
folder_path=SYNC_PATH, |
|
path_in_repo="", |
|
every=SYNC_INTERVAL, |
|
squash_history=True, |
|
private=True |
|
) |
|
|
|
|
|
try: |
|
while True: |
|
create_archive() |
|
|
|
logger.info(f"Waiting {SYNC_INTERVAL} minutes until next sync...") |
|
time.sleep(SYNC_INTERVAL * 60) |
|
except KeyboardInterrupt: |
|
logger.info("Stopping sync process...") |
|
scheduler.stop() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|