|
import os |
|
import time |
|
import tarfile |
|
import hashlib |
|
import shutil |
|
import argparse |
|
import sys |
|
from enum import Enum, auto |
|
from pathlib import Path |
|
from typing import Optional |
|
from dataclasses import dataclass |
|
from contextlib import contextmanager |
|
import logging |
|
from dotenv import load_dotenv |
|
from huggingface_hub import CommitScheduler, HfApi |
|
|
|
class SyncMode(Enum): |
|
INIT_ONLY = auto() |
|
SYNC_ONLY = auto() |
|
BOTH = auto() |
|
|
|
@dataclass |
|
class Config: |
|
repo_id: str |
|
sync_interval: int |
|
data_path: Path |
|
sync_path: Path |
|
tmp_path: Path |
|
archive_name: str |
|
|
|
@classmethod |
|
def from_env(cls): |
|
load_dotenv() |
|
repo_id = os.getenv('HF_DATASET_REPO_ID') |
|
if not repo_id: |
|
raise ValueError("HF_DATASET_REPO_ID must be set") |
|
|
|
return cls( |
|
repo_id=repo_id, |
|
sync_interval=int(os.getenv('SYNC_INTERVAL', '5')), |
|
data_path=Path("/data"), |
|
sync_path=Path("/sync"), |
|
tmp_path=Path("/tmp/sync"), |
|
archive_name="data.tar.gz" |
|
) |
|
|
|
class Logger: |
|
def __init__(self): |
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
self.logger = logging.getLogger(__name__) |
|
|
|
class DirectoryMonitor: |
|
def __init__(self, path: Path): |
|
self.path = path |
|
self.last_hash: Optional[str] = None |
|
|
|
def get_directory_hash(self) -> str: |
|
sha256_hash = hashlib.sha256() |
|
|
|
all_files = sorted( |
|
str(p) for p in self.path.rglob('*') if p.is_file() |
|
) |
|
|
|
for file_path in all_files: |
|
rel_path = os.path.relpath(file_path, self.path) |
|
sha256_hash.update(rel_path.encode()) |
|
|
|
with open(file_path, 'rb') as f: |
|
for chunk in iter(lambda: f.read(4096), b''): |
|
sha256_hash.update(chunk) |
|
|
|
return sha256_hash.hexdigest() |
|
|
|
def has_changes(self) -> bool: |
|
current_hash = self.get_directory_hash() |
|
if current_hash != self.last_hash: |
|
self.last_hash = current_hash |
|
return True |
|
return False |
|
|
|
class ArchiveManager: |
|
def __init__(self, config: Config, logger: Logger): |
|
self.config = config |
|
self.logger = logger.logger |
|
|
|
@contextmanager |
|
def safe_archive(self): |
|
"""安全地创建归档文件的上下文管理器""" |
|
self.config.tmp_path.mkdir(parents=True, exist_ok=True) |
|
tmp_archive = self.config.tmp_path / self.config.archive_name |
|
|
|
try: |
|
with tarfile.open(tmp_archive, "w:gz") as tar: |
|
yield tar |
|
|
|
|
|
self.config.sync_path.mkdir(parents=True, exist_ok=True) |
|
shutil.move(tmp_archive, self.config.sync_path / self.config.archive_name) |
|
|
|
finally: |
|
|
|
if tmp_archive.exists(): |
|
tmp_archive.unlink() |
|
|
|
def create_archive(self): |
|
"""创建压缩包""" |
|
self.logger.info("Creating new archive...") |
|
with self.safe_archive() as tar: |
|
tar.add(self.config.data_path, arcname="data") |
|
self.logger.info("Archive created") |
|
|
|
def extract_archive(self): |
|
"""解压现有数据""" |
|
api = HfApi() |
|
try: |
|
self.logger.info("Downloading data archive...") |
|
api.hf_hub_download( |
|
repo_id=self.config.repo_id, |
|
filename=self.config.archive_name, |
|
repo_type="dataset", |
|
local_dir=self.config.sync_path |
|
) |
|
|
|
self.logger.info("Extracting archive...") |
|
archive_path = self.config.sync_path / self.config.archive_name |
|
with tarfile.open(archive_path, "r:gz") as tar: |
|
tar.extractall( |
|
path=self.config.data_path, |
|
filter=self._tar_filter |
|
) |
|
return True |
|
except Exception as e: |
|
self.logger.error(f"No existing archive found or download failed: {e}") |
|
self.config.data_path.mkdir(parents=True, exist_ok=True) |
|
return False |
|
|
|
@staticmethod |
|
def _tar_filter(tarinfo, path): |
|
"""tar 文件过滤器""" |
|
if tarinfo.name.startswith('data/'): |
|
tarinfo.name = tarinfo.name[5:] |
|
return tarinfo |
|
return None |
|
|
|
class SyncService: |
|
def __init__(self, config: Config, logger: Logger): |
|
self.config = config |
|
self.logger = logger.logger |
|
self.monitor = DirectoryMonitor(config.data_path) |
|
self.archive_manager = ArchiveManager(config, logger) |
|
|
|
def init(self) -> bool: |
|
""" |
|
执行初始化操作 |
|
返回: 是否成功初始化 |
|
""" |
|
try: |
|
self.logger.info("Starting initialization...") |
|
self.config.sync_path.mkdir(parents=True, exist_ok=True) |
|
success = self.archive_manager.extract_archive() |
|
if success: |
|
self.logger.info("Initialization completed successfully") |
|
else: |
|
self.logger.warning("Initialization completed with warnings") |
|
return success |
|
except Exception as e: |
|
self.logger.error(f"Initialization failed: {e}") |
|
return False |
|
|
|
def sync(self): |
|
"""执行持续同步操作""" |
|
self.logger.info(f"Starting sync process for repo: {self.config.repo_id}") |
|
self.logger.info(f"Sync interval: {self.config.sync_interval} minutes") |
|
|
|
scheduler = CommitScheduler( |
|
repo_id=self.config.repo_id, |
|
repo_type="dataset", |
|
folder_path=str(self.config.sync_path), |
|
path_in_repo="", |
|
every=self.config.sync_interval, |
|
squash_history=True, |
|
private=True |
|
) |
|
|
|
try: |
|
while True: |
|
if self.monitor.has_changes(): |
|
self.logger.info("Directory changes detected, creating new archive...") |
|
self.archive_manager.create_archive() |
|
else: |
|
self.logger.info("No changes detected") |
|
|
|
self.logger.info(f"Waiting {self.config.sync_interval} minutes until next check...") |
|
time.sleep(self.config.sync_interval * 60) |
|
except KeyboardInterrupt: |
|
self.logger.info("Stopping sync process...") |
|
scheduler.stop() |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Data synchronization service') |
|
parser.add_argument( |
|
'--mode', |
|
type=str, |
|
choices=['init', 'sync', 'both'], |
|
default='both', |
|
help='Operation mode: init (initialization only), sync (synchronization only), both (default)' |
|
) |
|
return parser.parse_args() |
|
|
|
def main(): |
|
args = parse_args() |
|
config = Config.from_env() |
|
logger = Logger() |
|
service = SyncService(config, logger) |
|
|
|
mode = { |
|
'init': SyncMode.INIT_ONLY, |
|
'sync': SyncMode.SYNC_ONLY, |
|
'both': SyncMode.BOTH |
|
}[args.mode] |
|
|
|
if mode in (SyncMode.INIT_ONLY, SyncMode.BOTH): |
|
success = service.init() |
|
if not success: |
|
sys.exit(1) |
|
if mode == SyncMode.INIT_ONLY: |
|
return |
|
|
|
if mode in (SyncMode.SYNC_ONLY, SyncMode.BOTH): |
|
service.sync() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|