Chenhao commited on
Commit
be8a1bb
·
1 Parent(s): fca1eb6

添加了重启时的强制同步拉取功能

Browse files
Files changed (2) hide show
  1. start.sh +8 -6
  2. sync.py +61 -10
start.sh CHANGED
@@ -1,13 +1,15 @@
1
 
2
-
3
  set -ex
4
 
5
- python3 /app/sync.py &
6
-
7
- sleep 3
8
 
9
- ls /data
 
10
 
 
 
11
 
 
 
12
  /one-api
13
-
 
1
 
 
2
  set -ex
3
 
4
+ # 首先执行初始化并等待完成
5
+ python3 /app/sync.py --mode init
 
6
 
7
+ # 如果初始化成功,启动同步服务
8
+ python3 /app/sync.py --mode sync &
9
 
10
+ # 等待几秒确保同步服务正常启动
11
+ sleep 1
12
 
13
+ # 启动 one-api 服务
14
+ ls /data
15
  /one-api
 
sync.py CHANGED
@@ -3,6 +3,9 @@ import time
3
  import tarfile
4
  import hashlib
5
  import shutil
 
 
 
6
  from pathlib import Path
7
  from typing import Optional
8
  from dataclasses import dataclass
@@ -11,6 +14,11 @@ import logging
11
  from dotenv import load_dotenv
12
  from huggingface_hub import CommitScheduler, HfApi
13
 
 
 
 
 
 
14
  @dataclass
15
  class Config:
16
  repo_id: str
@@ -123,9 +131,11 @@ class ArchiveManager:
123
  path=self.config.data_path,
124
  filter=self._tar_filter
125
  )
 
126
  except Exception as e:
127
- self.logger.info(f"No existing archive found or download failed: {e}")
128
  self.config.data_path.mkdir(parents=True, exist_ok=True)
 
129
 
130
  @staticmethod
131
  def _tar_filter(tarinfo, path):
@@ -133,8 +143,7 @@ class ArchiveManager:
133
  if tarinfo.name.startswith('data/'):
134
  tarinfo.name = tarinfo.name[5:]
135
  return tarinfo
136
- else:
137
- return None
138
 
139
  class SyncService:
140
  def __init__(self, config: Config, logger: Logger):
@@ -142,14 +151,29 @@ class SyncService:
142
  self.logger = logger.logger
143
  self.monitor = DirectoryMonitor(config.data_path)
144
  self.archive_manager = ArchiveManager(config, logger)
145
-
146
- def run(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  self.logger.info(f"Starting sync process for repo: {self.config.repo_id}")
148
  self.logger.info(f"Sync interval: {self.config.sync_interval} minutes")
149
-
150
- # 初始化目录和下载数据
151
- self.config.sync_path.mkdir(parents=True, exist_ok=True)
152
- self.archive_manager.extract_archive()
153
 
154
  scheduler = CommitScheduler(
155
  repo_id=self.config.repo_id,
@@ -175,11 +199,38 @@ class SyncService:
175
  self.logger.info("Stopping sync process...")
176
  scheduler.stop()
177
 
 
 
 
 
 
 
 
 
 
 
 
178
  def main():
 
179
  config = Config.from_env()
180
  logger = Logger()
181
  service = SyncService(config, logger)
182
- service.run()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  if __name__ == "__main__":
185
  main()
 
3
  import tarfile
4
  import hashlib
5
  import shutil
6
+ import argparse
7
+ import sys
8
+ from enum import Enum, auto
9
  from pathlib import Path
10
  from typing import Optional
11
  from dataclasses import dataclass
 
14
  from dotenv import load_dotenv
15
  from huggingface_hub import CommitScheduler, HfApi
16
 
17
+ class SyncMode(Enum):
18
+ INIT_ONLY = auto() # 只执行初始化
19
+ SYNC_ONLY = auto() # 只执行同步
20
+ BOTH = auto() # 执行初始化和同步
21
+
22
  @dataclass
23
  class Config:
24
  repo_id: str
 
131
  path=self.config.data_path,
132
  filter=self._tar_filter
133
  )
134
+ return True
135
  except Exception as e:
136
+ self.logger.error(f"No existing archive found or download failed: {e}")
137
  self.config.data_path.mkdir(parents=True, exist_ok=True)
138
+ return False
139
 
140
  @staticmethod
141
  def _tar_filter(tarinfo, path):
 
143
  if tarinfo.name.startswith('data/'):
144
  tarinfo.name = tarinfo.name[5:]
145
  return tarinfo
146
+ return None
 
147
 
148
  class SyncService:
149
  def __init__(self, config: Config, logger: Logger):
 
151
  self.logger = logger.logger
152
  self.monitor = DirectoryMonitor(config.data_path)
153
  self.archive_manager = ArchiveManager(config, logger)
154
+
155
+ def init(self) -> bool:
156
+ """
157
+ 执行初始化操作
158
+ 返回: 是否成功初始化
159
+ """
160
+ try:
161
+ self.logger.info("Starting initialization...")
162
+ self.config.sync_path.mkdir(parents=True, exist_ok=True)
163
+ success = self.archive_manager.extract_archive()
164
+ if success:
165
+ self.logger.info("Initialization completed successfully")
166
+ else:
167
+ self.logger.warning("Initialization completed with warnings")
168
+ return success
169
+ except Exception as e:
170
+ self.logger.error(f"Initialization failed: {e}")
171
+ return False
172
+
173
+ def sync(self):
174
+ """执行持续同步操作"""
175
  self.logger.info(f"Starting sync process for repo: {self.config.repo_id}")
176
  self.logger.info(f"Sync interval: {self.config.sync_interval} minutes")
 
 
 
 
177
 
178
  scheduler = CommitScheduler(
179
  repo_id=self.config.repo_id,
 
199
  self.logger.info("Stopping sync process...")
200
  scheduler.stop()
201
 
202
+ def parse_args():
203
+ parser = argparse.ArgumentParser(description='Data synchronization service')
204
+ parser.add_argument(
205
+ '--mode',
206
+ type=str,
207
+ choices=['init', 'sync', 'both'],
208
+ default='both',
209
+ help='Operation mode: init (initialization only), sync (synchronization only), both (default)'
210
+ )
211
+ return parser.parse_args()
212
+
213
  def main():
214
+ args = parse_args()
215
  config = Config.from_env()
216
  logger = Logger()
217
  service = SyncService(config, logger)
218
+
219
+ mode = {
220
+ 'init': SyncMode.INIT_ONLY,
221
+ 'sync': SyncMode.SYNC_ONLY,
222
+ 'both': SyncMode.BOTH
223
+ }[args.mode]
224
+
225
+ if mode in (SyncMode.INIT_ONLY, SyncMode.BOTH):
226
+ success = service.init()
227
+ if not success:
228
+ sys.exit(1)
229
+ if mode == SyncMode.INIT_ONLY:
230
+ return
231
+
232
+ if mode in (SyncMode.SYNC_ONLY, SyncMode.BOTH):
233
+ service.sync()
234
 
235
  if __name__ == "__main__":
236
  main()