File size: 7,604 Bytes
5a97508
 
 
fca1eb6
 
be8a1bb
 
 
fca1eb6
 
 
 
 
5a97508
 
 
be8a1bb
 
 
 
 
fca1eb6
 
 
 
 
 
 
 
5a97508
fca1eb6
 
 
e6a7d9d
fca1eb6
e6a7d9d
fca1eb6
 
 
 
 
 
 
 
 
5a97508
fca1eb6
 
 
 
 
 
 
5a97508
fca1eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a97508
fca1eb6
 
 
 
 
 
5a97508
fca1eb6
 
 
 
5a97508
fca1eb6
 
 
 
 
 
 
 
 
 
 
 
 
5a97508
fca1eb6
 
 
 
5a97508
fca1eb6
 
 
 
 
 
5a97508
fca1eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8a1bb
fca1eb6
be8a1bb
fca1eb6
be8a1bb
fca1eb6
 
 
 
 
 
 
be8a1bb
5a97508
fca1eb6
 
 
 
 
 
be8a1bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fca1eb6
 
5a97508
fca1eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8a1bb
 
 
 
 
 
 
 
 
 
 
fca1eb6
be8a1bb
fca1eb6
 
 
be8a1bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a97508
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import time
import tarfile
import hashlib
import shutil
import argparse
import sys
from enum import Enum, auto
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
from contextlib import contextmanager
import logging
from dotenv import load_dotenv
from huggingface_hub import CommitScheduler, HfApi

class SyncMode(Enum):
    INIT_ONLY = auto()  # 只执行初始化
    SYNC_ONLY = auto()  # 只执行同步
    BOTH = auto()       # 执行初始化和同步

@dataclass
class Config:
    repo_id: str
    sync_interval: int
    data_path: Path
    sync_path: Path
    tmp_path: Path
    archive_name: str

    @classmethod
    def from_env(cls):
        load_dotenv()
        repo_id = os.getenv('HF_DATASET_REPO_ID')
        if not repo_id:
            raise ValueError("HF_DATASET_REPO_ID must be set")
            
        return cls(
            repo_id=repo_id,
            sync_interval=int(os.getenv('SYNC_INTERVAL', '5')),
            data_path=Path("/data"),
            sync_path=Path("/sync"),
            tmp_path=Path("/tmp/sync"),
            archive_name="data.tar.gz"
        )

class Logger:
    def __init__(self):
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger(__name__)

class DirectoryMonitor:
    def __init__(self, path: Path):
        self.path = path
        self.last_hash: Optional[str] = None
    
    def get_directory_hash(self) -> str:
        sha256_hash = hashlib.sha256()
        
        all_files = sorted(
            str(p) for p in self.path.rglob('*') if p.is_file()
        )
        
        for file_path in all_files:
            rel_path = os.path.relpath(file_path, self.path)
            sha256_hash.update(rel_path.encode())
            
            with open(file_path, 'rb') as f:
                for chunk in iter(lambda: f.read(4096), b''):
                    sha256_hash.update(chunk)
        
        return sha256_hash.hexdigest()

    def has_changes(self) -> bool:
        current_hash = self.get_directory_hash()
        if current_hash != self.last_hash:
            self.last_hash = current_hash
            return True
        return False

class ArchiveManager:
    def __init__(self, config: Config, logger: Logger):
        self.config = config
        self.logger = logger.logger
        
    @contextmanager
    def safe_archive(self):
        """安全地创建归档文件的上下文管理器"""
        self.config.tmp_path.mkdir(parents=True, exist_ok=True)
        tmp_archive = self.config.tmp_path / self.config.archive_name
        
        try:
            with tarfile.open(tmp_archive, "w:gz") as tar:
                yield tar
            
            # 成功创建后移动到最终位置
            self.config.sync_path.mkdir(parents=True, exist_ok=True)
            shutil.move(tmp_archive, self.config.sync_path / self.config.archive_name)
            
        finally:
            # 清理临时文件
            if tmp_archive.exists():
                tmp_archive.unlink()

    def create_archive(self):
        """创建压缩包"""
        self.logger.info("Creating new archive...")
        with self.safe_archive() as tar:
            tar.add(self.config.data_path, arcname="data")
        self.logger.info("Archive created")

    def extract_archive(self):
        """解压现有数据"""
        api = HfApi()
        try:
            self.logger.info("Downloading data archive...")
            api.hf_hub_download(
                repo_id=self.config.repo_id,
                filename=self.config.archive_name,
                repo_type="dataset",
                local_dir=self.config.sync_path
            )
            
            self.logger.info("Extracting archive...")
            archive_path = self.config.sync_path / self.config.archive_name
            with tarfile.open(archive_path, "r:gz") as tar:
                tar.extractall(
                    path=self.config.data_path,
                    filter=self._tar_filter
                )
            return True
        except Exception as e:
            self.logger.error(f"No existing archive found or download failed: {e}")
            self.config.data_path.mkdir(parents=True, exist_ok=True)
            return False

    @staticmethod
    def _tar_filter(tarinfo, path):
        """tar 文件过滤器"""
        if tarinfo.name.startswith('data/'):
            tarinfo.name = tarinfo.name[5:]
            return tarinfo
        return None

class SyncService:
    def __init__(self, config: Config, logger: Logger):
        self.config = config
        self.logger = logger.logger
        self.monitor = DirectoryMonitor(config.data_path)
        self.archive_manager = ArchiveManager(config, logger)
    
    def init(self) -> bool:
        """
        执行初始化操作
        返回: 是否成功初始化
        """
        try:
            self.logger.info("Starting initialization...")
            self.config.sync_path.mkdir(parents=True, exist_ok=True)
            success = self.archive_manager.extract_archive()
            if success:
                self.logger.info("Initialization completed successfully")
            else:
                self.logger.warning("Initialization completed with warnings")
            return success
        except Exception as e:
            self.logger.error(f"Initialization failed: {e}")
            return False

    def sync(self):
        """执行持续同步操作"""
        self.logger.info(f"Starting sync process for repo: {self.config.repo_id}")
        self.logger.info(f"Sync interval: {self.config.sync_interval} minutes")

        scheduler = CommitScheduler(
            repo_id=self.config.repo_id,
            repo_type="dataset",
            folder_path=str(self.config.sync_path),
            path_in_repo="",
            every=self.config.sync_interval,
            squash_history=True,
            private=True
        )

        try:
            while True:
                if self.monitor.has_changes():
                    self.logger.info("Directory changes detected, creating new archive...")
                    self.archive_manager.create_archive()
                else:
                    self.logger.info("No changes detected")
                
                self.logger.info(f"Waiting {self.config.sync_interval} minutes until next check...")
                time.sleep(self.config.sync_interval * 60)
        except KeyboardInterrupt:
            self.logger.info("Stopping sync process...")
            scheduler.stop()

def parse_args():
    parser = argparse.ArgumentParser(description='Data synchronization service')
    parser.add_argument(
        '--mode',
        type=str,
        choices=['init', 'sync', 'both'],
        default='both',
        help='Operation mode: init (initialization only), sync (synchronization only), both (default)'
    )
    return parser.parse_args()

def main():
    args = parse_args()
    config = Config.from_env()
    logger = Logger()
    service = SyncService(config, logger)

    mode = {
        'init': SyncMode.INIT_ONLY,
        'sync': SyncMode.SYNC_ONLY,
        'both': SyncMode.BOTH
    }[args.mode]

    if mode in (SyncMode.INIT_ONLY, SyncMode.BOTH):
        success = service.init()
        if not success:
            sys.exit(1)
        if mode == SyncMode.INIT_ONLY:
            return

    if mode in (SyncMode.SYNC_ONLY, SyncMode.BOTH):
        service.sync()

if __name__ == "__main__":
    main()