File size: 2,841 Bytes
5a97508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import os
import time
import tarfile
from dotenv import load_dotenv
from huggingface_hub import CommitScheduler, HfApi
import logging
from pathlib import Path
# 加载环境变量
load_dotenv()
# 全局配置变量
REPO_ID = os.getenv('HF_REPO_ID')
SYNC_INTERVAL = int(os.getenv('SYNC_INTERVAL', 5))
DATA_PATH = "/data"
ARCHIVE_NAME = "data.tar.gz"
SYNC_PATH = "/sync" # CommitScheduler 监控的目录
ARCHIVE_PATH = f"{SYNC_PATH}/{ARCHIVE_NAME}"
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 环境变量检查
if not REPO_ID:
raise ValueError("HF_REPO_ID must be set in environment variables")
def tar_filter(tarinfo):
"""tar 文件过滤器"""
if tarinfo.name.startswith('data/'):
tarinfo.name = tarinfo.name[5:]
return tarinfo
def download_and_extract():
"""下载并解压数据"""
api = HfApi()
try:
# 下载压缩包
logger.info("Downloading data archive...")
api.hf_hub_download(
repo_id=REPO_ID,
filename=ARCHIVE_NAME,
repo_type="dataset",
local_dir=SYNC_PATH
)
# 解压到 data 目录
logger.info("Extracting archive...")
with tarfile.open(ARCHIVE_PATH, "r:gz") as tar:
tar.extractall(path=DATA_PATH, filter=tar_filter)
except Exception as e:
logger.info(f"No existing archive found or download failed: {e}")
# 确保目录存在
Path(DATA_PATH).mkdir(parents=True, exist_ok=True)
def create_archive():
"""创建压缩包"""
logger.info("Creating new archive...")
with tarfile.open(ARCHIVE_PATH, "w:gz") as tar:
tar.add(DATA_PATH, arcname="data")
logger.info("Archive created")
def main():
logger.info(f"Starting sync process for repo: {REPO_ID}")
logger.info(f"Sync interval: {SYNC_INTERVAL} minutes")
# 创建同步目录
Path(SYNC_PATH).mkdir(parents=True, exist_ok=True)
# 初始下载并解压
download_and_extract()
# 创建调度器
scheduler = CommitScheduler(
repo_id=REPO_ID,
repo_type="dataset",
folder_path=SYNC_PATH,
path_in_repo="", # 直接将压缩包放在根目录
every=SYNC_INTERVAL,
squash_history=True,
private=True
)
# 主循环:定期创建新的压缩包
try:
while True:
create_archive() # 创建新的压缩包
# 等待下一次同步
logger.info(f"Waiting {SYNC_INTERVAL} minutes until next sync...")
time.sleep(SYNC_INTERVAL * 60)
except KeyboardInterrupt:
logger.info("Stopping sync process...")
scheduler.stop()
if __name__ == "__main__":
main()
|