import os import time import tarfile from dotenv import load_dotenv from huggingface_hub import CommitScheduler, HfApi import logging from pathlib import Path # 加载环境变量 load_dotenv() # 全局配置变量 REPO_ID = os.getenv('HF_REPO_ID') SYNC_INTERVAL = int(os.getenv('SYNC_INTERVAL', 5)) DATA_PATH = "/data" ARCHIVE_NAME = "data.tar.gz" SYNC_PATH = "/sync" # CommitScheduler 监控的目录 ARCHIVE_PATH = f"{SYNC_PATH}/{ARCHIVE_NAME}" # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 环境变量检查 if not REPO_ID: raise ValueError("HF_REPO_ID must be set in environment variables") def tar_filter(tarinfo): """tar 文件过滤器""" if tarinfo.name.startswith('data/'): tarinfo.name = tarinfo.name[5:] return tarinfo def download_and_extract(): """下载并解压数据""" api = HfApi() try: # 下载压缩包 logger.info("Downloading data archive...") api.hf_hub_download( repo_id=REPO_ID, filename=ARCHIVE_NAME, repo_type="dataset", local_dir=SYNC_PATH ) # 解压到 data 目录 logger.info("Extracting archive...") with tarfile.open(ARCHIVE_PATH, "r:gz") as tar: tar.extractall(path=DATA_PATH, filter=tar_filter) except Exception as e: logger.info(f"No existing archive found or download failed: {e}") # 确保目录存在 Path(DATA_PATH).mkdir(parents=True, exist_ok=True) def create_archive(): """创建压缩包""" logger.info("Creating new archive...") with tarfile.open(ARCHIVE_PATH, "w:gz") as tar: tar.add(DATA_PATH, arcname="data") logger.info("Archive created") def main(): logger.info(f"Starting sync process for repo: {REPO_ID}") logger.info(f"Sync interval: {SYNC_INTERVAL} minutes") # 创建同步目录 Path(SYNC_PATH).mkdir(parents=True, exist_ok=True) # 初始下载并解压 download_and_extract() # 创建调度器 scheduler = CommitScheduler( repo_id=REPO_ID, repo_type="dataset", folder_path=SYNC_PATH, path_in_repo="", # 直接将压缩包放在根目录 every=SYNC_INTERVAL, squash_history=True, private=True ) # 主循环:定期创建新的压缩包 try: while True: create_archive() # 创建新的压缩包 # 等待下一次同步 logger.info(f"Waiting {SYNC_INTERVAL} minutes until next sync...") time.sleep(SYNC_INTERVAL * 60) except KeyboardInterrupt: logger.info("Stopping sync process...") scheduler.stop() if __name__ == "__main__": main()