Chenhao commited on
Commit
fca1eb6
·
1 Parent(s): 5a97508

一个可以正常同步的版本 但是在网络异常时 同步可能有些问题。

Browse files
Files changed (3) hide show
  1. build.sh +1 -1
  2. start.sh +3 -0
  3. sync.py +165 -80
build.sh CHANGED
@@ -2,5 +2,5 @@
2
  set -ex
3
 
4
  docker build -t one-api .
5
- docker run --rm -it -e PORT=7860 -p 7860:7860 one-api
6
 
 
2
  set -ex
3
 
4
  docker build -t one-api .
5
+ docker run --rm -it -e PORT=7860 -p 7860:7860 --name test-one-api one-api
6
 
start.sh CHANGED
@@ -6,5 +6,8 @@ python3 /app/sync.py &
6
 
7
  sleep 3
8
 
 
 
 
9
  /one-api
10
 
 
6
 
7
  sleep 3
8
 
9
+ ls /data
10
+
11
+
12
  /one-api
13
 
sync.py CHANGED
@@ -1,100 +1,185 @@
1
  import os
2
  import time
3
  import tarfile
 
 
 
 
 
 
 
4
  from dotenv import load_dotenv
5
  from huggingface_hub import CommitScheduler, HfApi
6
- import logging
7
- from pathlib import Path
8
 
9
- # 加载环境变量
10
- load_dotenv()
 
 
 
 
 
 
11
 
12
- # 全局配置变量
13
- REPO_ID = os.getenv('HF_REPO_ID')
14
- SYNC_INTERVAL = int(os.getenv('SYNC_INTERVAL', 5))
15
- DATA_PATH = "/data"
16
- ARCHIVE_NAME = "data.tar.gz"
17
- SYNC_PATH = "/sync" # CommitScheduler 监控的目录
18
- ARCHIVE_PATH = f"{SYNC_PATH}/{ARCHIVE_NAME}"
 
 
 
 
 
 
 
 
19
 
20
- # 配置日志
21
- logging.basicConfig(
22
- level=logging.INFO,
23
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
24
- )
25
- logger = logging.getLogger(__name__)
 
26
 
27
- # 环境变量检查
28
- if not REPO_ID:
29
- raise ValueError("HF_REPO_ID must be set in environment variables")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- def tar_filter(tarinfo):
32
- """tar 文件过滤器"""
33
- if tarinfo.name.startswith('data/'):
34
- tarinfo.name = tarinfo.name[5:]
35
- return tarinfo
 
36
 
37
- def download_and_extract():
38
- """下载并解压数据"""
39
- api = HfApi()
40
- try:
41
- # 下载压缩包
42
- logger.info("Downloading data archive...")
43
- api.hf_hub_download(
44
- repo_id=REPO_ID,
45
- filename=ARCHIVE_NAME,
46
- repo_type="dataset",
47
- local_dir=SYNC_PATH
48
- )
49
 
50
- # 解压到 data 目录
51
- logger.info("Extracting archive...")
52
- with tarfile.open(ARCHIVE_PATH, "r:gz") as tar:
53
- tar.extractall(path=DATA_PATH, filter=tar_filter)
 
 
 
 
 
 
 
 
 
54
 
55
- except Exception as e:
56
- logger.info(f"No existing archive found or download failed: {e}")
57
- # 确保目录存在
58
- Path(DATA_PATH).mkdir(parents=True, exist_ok=True)
59
 
60
- def create_archive():
61
- """创建压缩包"""
62
- logger.info("Creating new archive...")
63
- with tarfile.open(ARCHIVE_PATH, "w:gz") as tar:
64
- tar.add(DATA_PATH, arcname="data")
65
- logger.info("Archive created")
66
 
67
- def main():
68
- logger.info(f"Starting sync process for repo: {REPO_ID}")
69
- logger.info(f"Sync interval: {SYNC_INTERVAL} minutes")
70
-
71
- # 创建同步目录
72
- Path(SYNC_PATH).mkdir(parents=True, exist_ok=True)
73
-
74
- # 初始下载并解压
75
- download_and_extract()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # 创建调度器
78
- scheduler = CommitScheduler(
79
- repo_id=REPO_ID,
80
- repo_type="dataset",
81
- folder_path=SYNC_PATH,
82
- path_in_repo="", # 直接将压缩包放在根目录
83
- every=SYNC_INTERVAL,
84
- squash_history=True,
85
- private=True
86
- )
 
 
 
 
87
 
88
- # 主循环:定期创建新的压缩包
89
- try:
90
- while True:
91
- create_archive() # 创建新的压缩包
92
- # 等待下一次同步
93
- logger.info(f"Waiting {SYNC_INTERVAL} minutes until next sync...")
94
- time.sleep(SYNC_INTERVAL * 60)
95
- except KeyboardInterrupt:
96
- logger.info("Stopping sync process...")
97
- scheduler.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  if __name__ == "__main__":
100
  main()
 
1
  import os
2
  import time
3
  import tarfile
4
+ import hashlib
5
+ import shutil
6
+ from pathlib import Path
7
+ from typing import Optional
8
+ from dataclasses import dataclass
9
+ from contextlib import contextmanager
10
+ import logging
11
  from dotenv import load_dotenv
12
  from huggingface_hub import CommitScheduler, HfApi
 
 
13
 
14
+ @dataclass
15
+ class Config:
16
+ repo_id: str
17
+ sync_interval: int
18
+ data_path: Path
19
+ sync_path: Path
20
+ tmp_path: Path
21
+ archive_name: str
22
 
23
+ @classmethod
24
+ def from_env(cls):
25
+ load_dotenv()
26
+ repo_id = os.getenv('HF_REPO_ID')
27
+ if not repo_id:
28
+ raise ValueError("HF_REPO_ID must be set")
29
+
30
+ return cls(
31
+ repo_id=repo_id,
32
+ sync_interval=int(os.getenv('SYNC_INTERVAL', '5')),
33
+ data_path=Path("/data"),
34
+ sync_path=Path("/sync"),
35
+ tmp_path=Path("/tmp/sync"),
36
+ archive_name="data.tar.gz"
37
+ )
38
 
39
+ class Logger:
40
+ def __init__(self):
41
+ logging.basicConfig(
42
+ level=logging.INFO,
43
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
44
+ )
45
+ self.logger = logging.getLogger(__name__)
46
 
47
+ class DirectoryMonitor:
48
+ def __init__(self, path: Path):
49
+ self.path = path
50
+ self.last_hash: Optional[str] = None
51
+
52
+ def get_directory_hash(self) -> str:
53
+ sha256_hash = hashlib.sha256()
54
+
55
+ all_files = sorted(
56
+ str(p) for p in self.path.rglob('*') if p.is_file()
57
+ )
58
+
59
+ for file_path in all_files:
60
+ rel_path = os.path.relpath(file_path, self.path)
61
+ sha256_hash.update(rel_path.encode())
62
+
63
+ with open(file_path, 'rb') as f:
64
+ for chunk in iter(lambda: f.read(4096), b''):
65
+ sha256_hash.update(chunk)
66
+
67
+ return sha256_hash.hexdigest()
68
 
69
+ def has_changes(self) -> bool:
70
+ current_hash = self.get_directory_hash()
71
+ if current_hash != self.last_hash:
72
+ self.last_hash = current_hash
73
+ return True
74
+ return False
75
 
76
+ class ArchiveManager:
77
+ def __init__(self, config: Config, logger: Logger):
78
+ self.config = config
79
+ self.logger = logger.logger
 
 
 
 
 
 
 
 
80
 
81
+ @contextmanager
82
+ def safe_archive(self):
83
+ """安全地创建归档文件的上下文管理器"""
84
+ self.config.tmp_path.mkdir(parents=True, exist_ok=True)
85
+ tmp_archive = self.config.tmp_path / self.config.archive_name
86
+
87
+ try:
88
+ with tarfile.open(tmp_archive, "w:gz") as tar:
89
+ yield tar
90
+
91
+ # 成功创建后移动到最终位置
92
+ self.config.sync_path.mkdir(parents=True, exist_ok=True)
93
+ shutil.move(tmp_archive, self.config.sync_path / self.config.archive_name)
94
 
95
+ finally:
96
+ # 清理临时文件
97
+ if tmp_archive.exists():
98
+ tmp_archive.unlink()
99
 
100
+ def create_archive(self):
101
+ """创建压缩包"""
102
+ self.logger.info("Creating new archive...")
103
+ with self.safe_archive() as tar:
104
+ tar.add(self.config.data_path, arcname="data")
105
+ self.logger.info("Archive created")
106
 
107
+ def extract_archive(self):
108
+ """解压现有数据"""
109
+ api = HfApi()
110
+ try:
111
+ self.logger.info("Downloading data archive...")
112
+ api.hf_hub_download(
113
+ repo_id=self.config.repo_id,
114
+ filename=self.config.archive_name,
115
+ repo_type="dataset",
116
+ local_dir=self.config.sync_path
117
+ )
118
+
119
+ self.logger.info("Extracting archive...")
120
+ archive_path = self.config.sync_path / self.config.archive_name
121
+ with tarfile.open(archive_path, "r:gz") as tar:
122
+ tar.extractall(
123
+ path=self.config.data_path,
124
+ filter=self._tar_filter
125
+ )
126
+ except Exception as e:
127
+ self.logger.info(f"No existing archive found or download failed: {e}")
128
+ self.config.data_path.mkdir(parents=True, exist_ok=True)
129
+
130
+ @staticmethod
131
+ def _tar_filter(tarinfo, path):
132
+ """tar 文件过滤器"""
133
+ if tarinfo.name.startswith('data/'):
134
+ tarinfo.name = tarinfo.name[5:]
135
+ return tarinfo
136
+ else:
137
+ return None
138
 
139
+ class SyncService:
140
+ def __init__(self, config: Config, logger: Logger):
141
+ self.config = config
142
+ self.logger = logger.logger
143
+ self.monitor = DirectoryMonitor(config.data_path)
144
+ self.archive_manager = ArchiveManager(config, logger)
145
+
146
+ def run(self):
147
+ self.logger.info(f"Starting sync process for repo: {self.config.repo_id}")
148
+ self.logger.info(f"Sync interval: {self.config.sync_interval} minutes")
149
+
150
+ # 初始化目录和下载数据
151
+ self.config.sync_path.mkdir(parents=True, exist_ok=True)
152
+ self.archive_manager.extract_archive()
153
 
154
+ scheduler = CommitScheduler(
155
+ repo_id=self.config.repo_id,
156
+ repo_type="dataset",
157
+ folder_path=str(self.config.sync_path),
158
+ path_in_repo="",
159
+ every=self.config.sync_interval,
160
+ squash_history=True,
161
+ private=True
162
+ )
163
+
164
+ try:
165
+ while True:
166
+ if self.monitor.has_changes():
167
+ self.logger.info("Directory changes detected, creating new archive...")
168
+ self.archive_manager.create_archive()
169
+ else:
170
+ self.logger.info("No changes detected")
171
+
172
+ self.logger.info(f"Waiting {self.config.sync_interval} minutes until next check...")
173
+ time.sleep(self.config.sync_interval * 60)
174
+ except KeyboardInterrupt:
175
+ self.logger.info("Stopping sync process...")
176
+ scheduler.stop()
177
+
178
+ def main():
179
+ config = Config.from_env()
180
+ logger = Logger()
181
+ service = SyncService(config, logger)
182
+ service.run()
183
 
184
  if __name__ == "__main__":
185
  main()