Peter-Young commited on
Commit
075ab90
·
verified ·
1 Parent(s): 4c3a879

Upload 2 files

Browse files
Files changed (2) hide show
  1. download_utils.py +341 -0
  2. start.sh +22 -0
download_utils.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import shutil
4
+ import tempfile
5
+ from contextlib import contextmanager
6
+ from pathlib import Path
7
+ from urllib.parse import unquote, urlparse
8
+
9
+ import fal
10
+ from fal.toolkit.utils.download_utils import (
11
+ FAL_MODEL_WEIGHTS_DIR,
12
+ DownloadError,
13
+ _hash_url,
14
+ )
15
+
16
+ FAL_VERSION = getattr(fal, "__version__", "<1.0.0")
17
+ _REQUEST_HEADERS = {"User-Agent": f"fal-client ({FAL_VERSION}/python)"}
18
+
19
+
20
+ def get_civitai_headers() -> dict[str, str]:
21
+ headers: dict[str, str] = {}
22
+
23
+ civitai_token = os.getenv("CIVITAI_TOKEN", None)
24
+
25
+ if not civitai_token:
26
+ print("CIVITAI_TOKEN is not set in the environment variables.")
27
+ return headers
28
+
29
+ headers["Authorization"] = f"Bearer {civitai_token}"
30
+
31
+ return headers
32
+
33
+
34
+ def get_huggingface_headers() -> dict[str, str]:
35
+ headers: dict[str, str] = {}
36
+
37
+ hf_token = os.getenv("HF_TOKEN", None)
38
+
39
+ if not hf_token:
40
+ print("HF_TOKEN is not set in the environment variables.")
41
+ return headers
42
+
43
+ headers["Authorization"] = f"Bearer {hf_token}"
44
+
45
+ return headers
46
+
47
+
48
+ def get_local_file_content_length(file_path: Path) -> int:
49
+ return file_path.stat().st_size
50
+
51
+
52
+ def download_url_to_file(
53
+ url: str,
54
+ dst: str | Path,
55
+ progress: bool = True,
56
+ headers: dict[str, str] = None,
57
+ chunk_size_in_mb=16,
58
+ file_integrity_check_callback=None,
59
+ ) -> Path:
60
+ """Download object at the given URL to a local path.
61
+
62
+ Args:
63
+ url (str): URL of the object to download
64
+ dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file``
65
+ progress (bool, optional): whether or not to display a progress bar to stderr
66
+ Default: True
67
+ headers (dict, optional): HTTP headers to include with the request
68
+ Default: None
69
+ chunk_size_in_mb (int, optional): size of each chunk in MB
70
+ Default: 16
71
+ file_integrity_check_callback (callable, optional): callback function to check file integrity
72
+ Default: None
73
+
74
+ """
75
+ from tqdm import tqdm
76
+
77
+ file_size = None
78
+
79
+ request_headers = {
80
+ **_REQUEST_HEADERS,
81
+ **(headers or {}),
82
+ }
83
+
84
+ url = url.strip()
85
+
86
+ if url.startswith("data:"):
87
+ return _download_data_url_to_file(url, dst)
88
+
89
+ import requests
90
+
91
+ req = requests.get(url, headers=request_headers, stream=True, allow_redirects=True)
92
+ req.raise_for_status()
93
+
94
+ headers = req.headers # type: ignore
95
+ content_length = headers.get("Content-Length", None) # type: ignore
96
+ if content_length is not None and len(content_length) > 0:
97
+ file_size = int(content_length[0])
98
+
99
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
100
+ file_path = temp_file.name
101
+ try:
102
+ with tqdm(
103
+ total=file_size,
104
+ disable=not progress,
105
+ unit="B",
106
+ unit_scale=True,
107
+ unit_divisor=1024,
108
+ ) as pbar, open(file_path, "wb") as f:
109
+ for chunk in req.iter_content(
110
+ chunk_size=chunk_size_in_mb * 1024 * 1024
111
+ ):
112
+ if chunk:
113
+ f.write(chunk)
114
+ pbar.update(len(chunk))
115
+
116
+ # Move the file when the file is downloaded completely. Since the
117
+ # file used is temporary, in a case of an interruption, the downloaded
118
+ # content will be lost. So, it is safe to redownload the file in such cases.
119
+ shutil.move(file_path, dst)
120
+ except Exception as error:
121
+ raise error
122
+ finally:
123
+ Path(temp_file.name).unlink(missing_ok=True)
124
+
125
+ if file_integrity_check_callback:
126
+ file_integrity_check_callback(dst)
127
+
128
+ return Path(dst)
129
+
130
+
131
+ def _download_data_url_to_file(url: str, dst: str | Path):
132
+ import base64
133
+
134
+ data = url.split(",")[1]
135
+ data = base64.b64decode(data)
136
+
137
+ with open(dst, "wb") as fp:
138
+ fp.write(data)
139
+
140
+ return Path(dst)
141
+
142
+
143
+ def download_model_weights(url: str, force: bool = False) -> Path:
144
+ parsed_url = urlparse(url)
145
+
146
+ headers = {}
147
+ if parsed_url.netloc == "civitai.com":
148
+ headers.update(get_civitai_headers())
149
+ elif parsed_url.netloc == "huggingface.co":
150
+ headers.update(get_huggingface_headers())
151
+
152
+ return download_model_weights_fal(url, request_headers=headers, force=force)
153
+
154
+ def url_without_query(url)->str:
155
+ # 找到 '?' 的位置
156
+ query_index = url.find('?')
157
+
158
+ # 如果找到 '?', 则截取到 '?' 之前的部分
159
+ if query_index != -1:
160
+ url_without_query = url[:query_index]
161
+ else:
162
+ url_without_query = url # 如果没有查询参数,保持原 URL
163
+ return url_without_query
164
+
165
+ def download_model_weights_fal(
166
+ url: str, force: bool = False, request_headers: dict[str, str] | None = None
167
+ ) -> Path:
168
+ without_query = os.environ.get("CKPT_DOWNLOAD_WITHOUT_QUERY", "false") == "true"
169
+ if without_query:
170
+ url = url_without_query(url)
171
+
172
+ weights_dir = Path(FAL_MODEL_WEIGHTS_DIR / _hash_url(url))
173
+
174
+ if weights_dir.exists() and not force:
175
+ try:
176
+ weights_path = next(weights_dir.glob("*"))
177
+ is_safetensors_file(weights_path)
178
+ return weights_path
179
+
180
+ # The model weights directory is empty, so we need to download the weights
181
+ except StopIteration:
182
+ pass
183
+
184
+ try:
185
+ file_name, file_content_length = _get_remote_file_properties(
186
+ url, request_headers=request_headers
187
+ )
188
+ except Exception as e:
189
+ print(e)
190
+ raise DownloadError(f"Failed to get remote file properties for {url}")
191
+
192
+ target_path = weights_dir / file_name
193
+
194
+ if (
195
+ target_path.exists()
196
+ and get_local_file_content_length(target_path) == file_content_length
197
+ and not force
198
+ ):
199
+ is_safetensors_file(target_path)
200
+ return target_path
201
+
202
+ # Make sure the parent directory exists
203
+ target_path.parent.mkdir(parents=True, exist_ok=True)
204
+
205
+ # download from network-volume
206
+ ckpt_download_dir = os.environ.get("CKPT_DOWNLOAD_DIR", None)
207
+ if ckpt_download_dir:
208
+ src_path = os.path.join(ckpt_download_dir, file_name)
209
+ if not os.path.exists(src_path):
210
+ src_path = os.path.join(ckpt_download_dir, _hash_url(url), file_name)
211
+ if not os.path.exists(src_path):
212
+ src_path = os.path.join(ckpt_download_dir, _hash_url(url_without_query(url)), file_name)
213
+
214
+ # 如果文件存在则复制到路径,并返回
215
+ if os.path.exists(src_path):
216
+ src_file_content_length = get_local_file_content_length(Path(src_path))
217
+ if src_file_content_length == file_content_length:
218
+ print(f"copy start from:{src_path} to:{target_path}")
219
+ shutil.copy(src_path, target_path)
220
+ print(f"copy done from:{src_path} to:{target_path}")
221
+ return target_path
222
+ else:
223
+ print(f"cannot copy file length not same src_path:{src_path} len:{src_file_content_length} target_path:{target_path} len:{file_content_length}")
224
+
225
+ try:
226
+ download_url_to_file(
227
+ url,
228
+ target_path,
229
+ progress=True,
230
+ headers=request_headers,
231
+ file_integrity_check_callback=is_safetensors_file,
232
+ )
233
+ except Exception as e:
234
+ print(e)
235
+ raise DownloadError(f"Failed to download {url}")
236
+
237
+ return target_path
238
+
239
+
240
+ def _get_filename_from_content_disposition(cd: str | None) -> str | None:
241
+ if not cd:
242
+ return None
243
+
244
+ filenames = re.findall('filename="(.+)"', cd)
245
+
246
+ if len(filenames) == 0:
247
+ filenames = re.findall("filename=(.+)", cd)
248
+
249
+ if len(filenames) == 0:
250
+ return None
251
+
252
+ return unquote(filenames[0])
253
+
254
+
255
+ def _parse_filename(url: str, cd: str | None) -> str:
256
+ url = url.strip()
257
+ file_name = _get_filename_from_content_disposition(cd)
258
+
259
+ if not file_name:
260
+ parsed_url = urlparse(url)
261
+
262
+ if parsed_url.scheme == "data":
263
+ file_name = _hash_url(url)
264
+ else:
265
+ url_path = parsed_url.path
266
+ file_name = Path(url_path).name or _hash_url(url)
267
+
268
+ if url.startswith("data:"):
269
+ import mimetypes
270
+
271
+ mime_type = url.split(",")[0].split(":")[1].split(";")[0]
272
+ extension = mimetypes.guess_extension(mime_type)
273
+ if extension:
274
+ file_name += extension
275
+
276
+ return file_name # type: ignore
277
+
278
+
279
+ def _get_remote_file_properties(
280
+ url: str, request_headers: dict[str, str] = None
281
+ ) -> tuple[str, int]:
282
+ import requests
283
+
284
+ headers = {
285
+ **_REQUEST_HEADERS,
286
+ **(request_headers or {}),
287
+ }
288
+
289
+ req = requests.get(
290
+ url, headers=headers, stream=True, allow_redirects=True, verify=False
291
+ )
292
+ req.raise_for_status()
293
+
294
+ headers = req.headers # type: ignore
295
+ content_disposition = headers.get("Content-Disposition", None)
296
+ file_name = _parse_filename(url, content_disposition)
297
+ content_length = int(headers.get("Content-Length", -1))
298
+
299
+ return file_name, content_length
300
+
301
+
302
+ def is_safetensors_file(path: str | Path):
303
+ from safetensors import safe_open
304
+
305
+ path = str(path)
306
+
307
+ if not path.endswith(".safetensors"):
308
+ raise ValueError(f"File {path} is not a .safetensors file")
309
+
310
+ try:
311
+ with safe_open(path, framework="pt"):
312
+ pass
313
+ except Exception as e:
314
+ print(e)
315
+ error_mesage = e.args[0]
316
+ if error_mesage == "Error while deserializing header: HeaderTooLarge":
317
+ raise ValueError(f"File {path} is not a .safetensors file")
318
+ else:
319
+ raise e
320
+
321
+
322
+ @contextmanager
323
+ def download_file_temp(
324
+ url: str,
325
+ progress: bool = True,
326
+ headers: dict[str, str] = None,
327
+ chunk_size_in_mb=16,
328
+ file_integrity_check_callback=None,
329
+ ):
330
+ file_name = _parse_filename(url, None)
331
+
332
+ with tempfile.TemporaryDirectory() as temp_dir:
333
+ file_path = download_url_to_file(
334
+ url,
335
+ f"{temp_dir}/{file_name}",
336
+ progress=progress,
337
+ headers=headers,
338
+ chunk_size_in_mb=chunk_size_in_mb,
339
+ file_integrity_check_callback=file_integrity_check_callback,
340
+ )
341
+ yield file_path
start.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Use libtcmalloc for better memory management
4
+ TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)"
5
+ export LD_PRELOAD="${TCMALLOC}"
6
+
7
+ echo "Worker Initiated"
8
+
9
+ # Serve the API and don't shutdown the container
10
+ if [ "$SERVE_API_LOCALLY" == "true" ]; then
11
+ echo "runpod-worker-comfy: Starting ComfyUI"
12
+ python3 /comfyui/main.py --disable-auto-launch --disable-metadata --listen &
13
+
14
+ echo "runpod-worker-comfy: Starting RunPod Handler"
15
+ python3 -u /rp_handler.py --rp_serve_api --rp_api_host=0.0.0.0
16
+ else
17
+ echo "runpod-worker-comfy: Starting ComfyUI"
18
+ python3 /comfyui/main.py --disable-auto-launch --disable-metadata &
19
+
20
+ echo "runpod-worker-comfy: Starting RunPod Handler"
21
+ python3 -u /rp_handler.py
22
+ fi