File size: 2,841 Bytes
5a97508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import time
import tarfile
from dotenv import load_dotenv
from huggingface_hub import CommitScheduler, HfApi
import logging
from pathlib import Path

# 加载环境变量
load_dotenv()

# 全局配置变量
REPO_ID = os.getenv('HF_REPO_ID')
SYNC_INTERVAL = int(os.getenv('SYNC_INTERVAL', 5))
DATA_PATH = "/data"
ARCHIVE_NAME = "data.tar.gz"
SYNC_PATH = "/sync"  # CommitScheduler 监控的目录
ARCHIVE_PATH = f"{SYNC_PATH}/{ARCHIVE_NAME}"

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# 环境变量检查
if not REPO_ID:
    raise ValueError("HF_REPO_ID must be set in environment variables")

def tar_filter(tarinfo):
    """tar 文件过滤器"""
    if tarinfo.name.startswith('data/'):
        tarinfo.name = tarinfo.name[5:]
    return tarinfo

def download_and_extract():
    """下载并解压数据"""
    api = HfApi()
    try:
        # 下载压缩包
        logger.info("Downloading data archive...")
        api.hf_hub_download(
            repo_id=REPO_ID,
            filename=ARCHIVE_NAME,
            repo_type="dataset",
            local_dir=SYNC_PATH
        )
        
        # 解压到 data 目录
        logger.info("Extracting archive...")
        with tarfile.open(ARCHIVE_PATH, "r:gz") as tar:
            tar.extractall(path=DATA_PATH, filter=tar_filter)
            
    except Exception as e:
        logger.info(f"No existing archive found or download failed: {e}")
        # 确保目录存在
        Path(DATA_PATH).mkdir(parents=True, exist_ok=True)

def create_archive():
    """创建压缩包"""
    logger.info("Creating new archive...")
    with tarfile.open(ARCHIVE_PATH, "w:gz") as tar:
        tar.add(DATA_PATH, arcname="data")
    logger.info("Archive created")

def main():
    logger.info(f"Starting sync process for repo: {REPO_ID}")
    logger.info(f"Sync interval: {SYNC_INTERVAL} minutes")
    
    # 创建同步目录
    Path(SYNC_PATH).mkdir(parents=True, exist_ok=True)
    
    # 初始下载并解压
    download_and_extract()

    # 创建调度器
    scheduler = CommitScheduler(
        repo_id=REPO_ID,
        repo_type="dataset",
        folder_path=SYNC_PATH,
        path_in_repo="",  # 直接将压缩包放在根目录
        every=SYNC_INTERVAL,
        squash_history=True,
        private=True
    )

    # 主循环:定期创建新的压缩包
    try:
        while True:
            create_archive()  # 创建新的压缩包
            # 等待下一次同步
            logger.info(f"Waiting {SYNC_INTERVAL} minutes until next sync...")
            time.sleep(SYNC_INTERVAL * 60)
    except KeyboardInterrupt:
        logger.info("Stopping sync process...")
        scheduler.stop()

if __name__ == "__main__":
    main()