File size: 6,031 Bytes
5a97508
 
 
fca1eb6
 
 
 
 
 
 
5a97508
 
 
fca1eb6
 
 
 
 
 
 
 
5a97508
fca1eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a97508
fca1eb6
 
 
 
 
 
 
5a97508
fca1eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a97508
fca1eb6
 
 
 
 
 
5a97508
fca1eb6
 
 
 
5a97508
fca1eb6
 
 
 
 
 
 
 
 
 
 
 
 
5a97508
fca1eb6
 
 
 
5a97508
fca1eb6
 
 
 
 
 
5a97508
fca1eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a97508
fca1eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
5a97508
fca1eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a97508
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import time
import tarfile
import hashlib
import shutil
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
from contextlib import contextmanager
import logging
from dotenv import load_dotenv
from huggingface_hub import CommitScheduler, HfApi

@dataclass
class Config:
    repo_id: str
    sync_interval: int
    data_path: Path
    sync_path: Path
    tmp_path: Path
    archive_name: str

    @classmethod
    def from_env(cls):
        load_dotenv()
        repo_id = os.getenv('HF_REPO_ID')
        if not repo_id:
            raise ValueError("HF_REPO_ID must be set")
            
        return cls(
            repo_id=repo_id,
            sync_interval=int(os.getenv('SYNC_INTERVAL', '5')),
            data_path=Path("/data"),
            sync_path=Path("/sync"),
            tmp_path=Path("/tmp/sync"),
            archive_name="data.tar.gz"
        )

class Logger:
    def __init__(self):
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger(__name__)

class DirectoryMonitor:
    def __init__(self, path: Path):
        self.path = path
        self.last_hash: Optional[str] = None
    
    def get_directory_hash(self) -> str:
        sha256_hash = hashlib.sha256()
        
        all_files = sorted(
            str(p) for p in self.path.rglob('*') if p.is_file()
        )
        
        for file_path in all_files:
            rel_path = os.path.relpath(file_path, self.path)
            sha256_hash.update(rel_path.encode())
            
            with open(file_path, 'rb') as f:
                for chunk in iter(lambda: f.read(4096), b''):
                    sha256_hash.update(chunk)
        
        return sha256_hash.hexdigest()

    def has_changes(self) -> bool:
        current_hash = self.get_directory_hash()
        if current_hash != self.last_hash:
            self.last_hash = current_hash
            return True
        return False

class ArchiveManager:
    def __init__(self, config: Config, logger: Logger):
        self.config = config
        self.logger = logger.logger
        
    @contextmanager
    def safe_archive(self):
        """安全地创建归档文件的上下文管理器"""
        self.config.tmp_path.mkdir(parents=True, exist_ok=True)
        tmp_archive = self.config.tmp_path / self.config.archive_name
        
        try:
            with tarfile.open(tmp_archive, "w:gz") as tar:
                yield tar
            
            # 成功创建后移动到最终位置
            self.config.sync_path.mkdir(parents=True, exist_ok=True)
            shutil.move(tmp_archive, self.config.sync_path / self.config.archive_name)
            
        finally:
            # 清理临时文件
            if tmp_archive.exists():
                tmp_archive.unlink()

    def create_archive(self):
        """创建压缩包"""
        self.logger.info("Creating new archive...")
        with self.safe_archive() as tar:
            tar.add(self.config.data_path, arcname="data")
        self.logger.info("Archive created")

    def extract_archive(self):
        """解压现有数据"""
        api = HfApi()
        try:
            self.logger.info("Downloading data archive...")
            api.hf_hub_download(
                repo_id=self.config.repo_id,
                filename=self.config.archive_name,
                repo_type="dataset",
                local_dir=self.config.sync_path
            )
            
            self.logger.info("Extracting archive...")
            archive_path = self.config.sync_path / self.config.archive_name
            with tarfile.open(archive_path, "r:gz") as tar:
                tar.extractall(
                    path=self.config.data_path,
                    filter=self._tar_filter
                )
        except Exception as e:
            self.logger.info(f"No existing archive found or download failed: {e}")
            self.config.data_path.mkdir(parents=True, exist_ok=True)

    @staticmethod
    def _tar_filter(tarinfo, path):
        """tar 文件过滤器"""
        if tarinfo.name.startswith('data/'):
            tarinfo.name = tarinfo.name[5:]
            return tarinfo
        else:
            return None

class SyncService:
    def __init__(self, config: Config, logger: Logger):
        self.config = config
        self.logger = logger.logger
        self.monitor = DirectoryMonitor(config.data_path)
        self.archive_manager = ArchiveManager(config, logger)
        
    def run(self):
        self.logger.info(f"Starting sync process for repo: {self.config.repo_id}")
        self.logger.info(f"Sync interval: {self.config.sync_interval} minutes")
        
        # 初始化目录和下载数据
        self.config.sync_path.mkdir(parents=True, exist_ok=True)
        self.archive_manager.extract_archive()

        scheduler = CommitScheduler(
            repo_id=self.config.repo_id,
            repo_type="dataset",
            folder_path=str(self.config.sync_path),
            path_in_repo="",
            every=self.config.sync_interval,
            squash_history=True,
            private=True
        )

        try:
            while True:
                if self.monitor.has_changes():
                    self.logger.info("Directory changes detected, creating new archive...")
                    self.archive_manager.create_archive()
                else:
                    self.logger.info("No changes detected")
                
                self.logger.info(f"Waiting {self.config.sync_interval} minutes until next check...")
                time.sleep(self.config.sync_interval * 60)
        except KeyboardInterrupt:
            self.logger.info("Stopping sync process...")
            scheduler.stop()

def main():
    config = Config.from_env()
    logger = Logger()
    service = SyncService(config, logger)
    service.run()

if __name__ == "__main__":
    main()