LPX55 commited on
Commit
130c21c
·
verified ·
1 Parent(s): 792c5ff

Update civitai_api.py

Browse files
Files changed (1) hide show
  1. civitai_api.py +450 -446
civitai_api.py CHANGED
@@ -1,446 +1,450 @@
1
- import gradio as gr
2
- from huggingface_hub import HfApi, hf_hub_url
3
- import os
4
- from pathlib import Path
5
- import gc
6
- import requests
7
- from requests.adapters import HTTPAdapter
8
- from urllib3.util import Retry
9
- from civitai_constants import PERIOD, SORT
10
- from utils import (get_token, set_token, is_repo_exists, get_user_agent, get_download_file,
11
- list_uniq, list_sub, duplicate_hf_repo, HF_SUBFOLDER_NAME, get_state, set_state)
12
- import re
13
- from PIL import Image
14
- import json
15
- import pandas as pd
16
- import tempfile
17
- import hashlib
18
-
19
- # Huge shoutout to @John6666, saved me many hours.
20
-
21
- TEMP_DIR = tempfile.mkdtemp()
22
-
23
-
24
- def parse_urls(s):
25
- url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+"
26
- try:
27
- urls = re.findall(url_pattern, s)
28
- return list(urls)
29
- except Exception:
30
- return []
31
-
32
-
33
- def parse_repos(s):
34
- repo_pattern = r'[^\w_\-\.]?([\w_\-\.]+/[\w_\-\.]+)[^\w_\-\.]?'
35
- try:
36
- s = re.sub("https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+", "", s)
37
- repos = re.findall(repo_pattern, s)
38
- return list(repos)
39
- except Exception:
40
- return []
41
-
42
-
43
- def to_urls(l: list[str]):
44
- return "\n".join(l)
45
-
46
-
47
- def uniq_urls(s):
48
- return to_urls(list_uniq(parse_urls(s) + parse_repos(s)))
49
-
50
-
51
- def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
52
- output_filename = Path(filename).name
53
- hf_token = get_token()
54
- api = HfApi(token=hf_token)
55
- try:
56
- if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
57
- progress(0, desc=f"Start uploading... {filename} to {repo_id}")
58
- api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
59
- progress(1, desc="Uploaded.")
60
- url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
61
- except Exception as e:
62
- print(f"Error: Failed to upload to {repo_id}. {e}")
63
- gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
64
- return None
65
- finally:
66
- if Path(filename).exists(): Path(filename).unlink()
67
- return url
68
-
69
-
70
- def is_same_file(filename: str, cmp_sha256: str, cmp_size: int):
71
- if cmp_sha256:
72
- sha256_hash = hashlib.sha256()
73
- with open(filename, "rb") as f:
74
- for byte_block in iter(lambda: f.read(4096), b""):
75
- sha256_hash.update(byte_block)
76
- sha256 = sha256_hash.hexdigest()
77
- else: sha256 = ""
78
- size = os.path.getsize(filename)
79
- if size == cmp_size and sha256 == cmp_sha256: return True
80
- else: return False
81
-
82
-
83
- def get_safe_filename(filename, repo_id, repo_type):
84
- hf_token = get_token()
85
- api = HfApi(token=hf_token)
86
- new_filename = filename
87
- try:
88
- i = 1
89
- while api.file_exists(repo_id=repo_id, filename=Path(new_filename).name, repo_type=repo_type, token=hf_token):
90
- infos = api.get_paths_info(repo_id=repo_id, paths=[Path(new_filename).name], repo_type=repo_type, token=hf_token)
91
- if infos and len(infos) == 1:
92
- repo_fs = infos[0].size
93
- repo_sha256 = infos[0].lfs.sha256 if infos[0].lfs is not None else ""
94
- if is_same_file(filename, repo_sha256, repo_fs): break
95
- new_filename = str(Path(Path(filename).parent, f"{Path(filename).stem}_{i}{Path(filename).suffix}"))
96
- i += 1
97
- if filename != new_filename:
98
- print(f"{Path(filename).name} is already exists but file content is different. renaming to {Path(new_filename).name}.")
99
- Path(filename).rename(new_filename)
100
- except Exception as e:
101
- print(f"Error occured when renaming {filename}. {e}")
102
- finally:
103
- return new_filename
104
-
105
-
106
- def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
107
- download_dir = TEMP_DIR
108
- progress(0, desc=f"Start downloading... {dl_url}")
109
- output_filename = get_download_file(download_dir, dl_url, civitai_key)
110
- return output_filename
111
-
112
-
113
- def save_civitai_info(dl_url, filename, civitai_key="", progress=gr.Progress(track_tqdm=True)):
114
- json_str, html_str, image_path = get_civitai_json(dl_url, True, filename, civitai_key)
115
- if not json_str: return "", "", ""
116
- json_path = str(Path(TEMP_DIR, Path(filename).stem + ".json"))
117
- html_path = str(Path(TEMP_DIR, Path(filename).stem + ".html"))
118
- try:
119
- with open(json_path, 'w') as f:
120
- json.dump(json_str, f, indent=2)
121
- with open(html_path, mode='w', encoding="utf-8") as f:
122
- f.write(html_str)
123
- return json_path, html_path, image_path
124
- except Exception as e:
125
- print(f"Error: Failed to save info file {json_path}, {html_path} {e}")
126
- return "", "", ""
127
-
128
-
129
- def upload_info_to_repo(dl_url, filename, repo_id, repo_type, is_private, civitai_key="", progress=gr.Progress(track_tqdm=True)):
130
- def upload_file(api, filename, repo_id, repo_type, hf_token):
131
- if not Path(filename).exists(): return
132
- api.upload_file(path_or_fileobj=filename, path_in_repo=Path(filename).name, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
133
- Path(filename).unlink()
134
-
135
- hf_token = get_token()
136
- api = HfApi(token=hf_token)
137
- try:
138
- if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
139
- progress(0, desc=f"Downloading info... {filename}")
140
- json_path, html_path, image_path = save_civitai_info(dl_url, filename, civitai_key)
141
- progress(0, desc=f"Start uploading info... {filename} to {repo_id}")
142
- if not json_path: return
143
- else: upload_file(api, json_path, repo_id, repo_type, hf_token)
144
- if html_path: upload_file(api, html_path, repo_id, repo_type, hf_token)
145
- if image_path: upload_file(api, image_path, repo_id, repo_type, hf_token)
146
- progress(1, desc="Info uploaded.")
147
- return
148
- except Exception as e:
149
- print(f"Error: Failed to upload info to {repo_id}. {e}")
150
- gr.Warning(f"Error: Failed to upload info to {repo_id}. {e}")
151
- return
152
-
153
-
154
- def download_civitai(dl_url, civitai_key, hf_token, urls,
155
- newrepo_id, repo_type="model", is_private=True, is_info=False, is_rename=True, progress=gr.Progress(track_tqdm=True)):
156
- if hf_token: set_token(hf_token)
157
- else: set_token(os.environ.get("HF_TOKEN")) # default huggingface write token
158
- if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
159
- if not newrepo_id: newrepo_id = os.environ.get("HF_REPO") # default repo to upload
160
- if not get_token() or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.")
161
- if not urls: urls = []
162
- dl_urls = parse_urls(dl_url)
163
- remain_urls = dl_urls.copy()
164
- try:
165
- md = f'### Your repo: [{newrepo_id}]({"https://huggingface.co/datasets/" if repo_type == "dataset" else "https://huggingface.co/"}{newrepo_id})\n'
166
- for u in dl_urls:
167
- file = download_file(u, civitai_key)
168
- if not Path(file).exists() or not Path(file).is_file(): continue
169
- if is_rename: file = get_safe_filename(file, newrepo_id, repo_type)
170
- url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
171
- if url:
172
- if is_info: upload_info_to_repo(u, file, newrepo_id, repo_type, is_private, civitai_key)
173
- urls.append(url)
174
- remain_urls.remove(u)
175
- md += f"- Uploaded [{str(u)}]({str(u)})\n"
176
- dp_repos = parse_repos(dl_url)
177
- for r in dp_repos:
178
- url = duplicate_hf_repo(r, newrepo_id, "model", repo_type, is_private, HF_SUBFOLDER_NAME[1])
179
- if url: urls.append(url)
180
- return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=False)
181
- except Exception as e:
182
- gr.Info(f"Error occured: {e}")
183
- return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=True)
184
- finally:
185
- gc.collect()
186
-
187
-
188
- def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100,
189
- sort: str = "Highest Rated", period: str = "AllTime", tag: str = "", user: str = "", page: int = 1,
190
- filetype: list[str] = [], api_key: str = "", progress=gr.Progress(track_tqdm=True)):
191
- user_agent = get_user_agent()
192
- headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
193
- if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
194
- base_url = 'https://civitai.com/api/v1/models'
195
- params = {'sort': sort, 'period': period, 'limit': int(limit), 'nsfw': 'true'}
196
- if len(types) != 0: params["types"] = types
197
- if query: params["query"] = query
198
- if tag: params["tag"] = tag
199
- if user: params["username"] = user
200
- if page != 0: params["page"] = int(page)
201
- session = requests.Session()
202
- retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
203
- session.mount("https://", HTTPAdapter(max_retries=retries))
204
- rs = []
205
- try:
206
- if page == 0:
207
- progress(0, desc="Searching page 1...")
208
- print("Searching page 1...")
209
- r = session.get(base_url, params=params | {'page': 1}, headers=headers, stream=True, timeout=(7.0, 30))
210
- rs.append(r)
211
- if r.ok:
212
- json = r.json()
213
- next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
214
- i = 2
215
- while(next_url is not None):
216
- progress(0, desc=f"Searching page {i}...")
217
- print(f"Searching page {i}...")
218
- r = session.get(next_url, headers=headers, stream=True, timeout=(7.0, 30))
219
- rs.append(r)
220
- if r.ok:
221
- json = r.json()
222
- next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
223
- else: next_url = None
224
- i += 1
225
- else:
226
- progress(0, desc="Searching page 1...")
227
- print("Searching page 1...")
228
- r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(7.0, 30))
229
- rs.append(r)
230
- except requests.exceptions.ConnectTimeout:
231
- print("Request timed out.")
232
- except Exception as e:
233
- print(e)
234
- items = []
235
- for r in rs:
236
- if not r.ok: continue
237
- json = r.json()
238
- if 'items' not in json: continue
239
- for j in json['items']:
240
- for model in j['modelVersions']:
241
- item = {}
242
- if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue
243
- item['name'] = j['name']
244
- item['creator'] = j['creator']['username'] if 'creator' in j.keys() and 'username' in j['creator'].keys() else ""
245
- item['tags'] = j['tags'] if 'tags' in j.keys() else []
246
- item['model_name'] = model['name'] if 'name' in model.keys() else ""
247
- item['base_model'] = model['baseModel'] if 'baseModel' in model.keys() else ""
248
- item['description'] = model['description'] if 'description' in model.keys() else ""
249
- item['md'] = ""
250
-
251
- # Handle both images and videos
252
- if 'images' in model.keys() and len(model["images"]) != 0:
253
- first_media = model["images"][0]
254
- item['img_url'] = first_media["url"]
255
- item['is_video'] = first_media.get("type", "image") == "video"
256
- item['video_url'] = first_media.get("meta", {}).get("video", "") if item['is_video'] else ""
257
-
258
- if item['is_video']:
259
- item['md'] += f'<video src="{item["img_url"]}" poster="{item["img_url"]}" muted loop autoplay width="300" height="480" style="float:right;padding:16px;"></video><br>'
260
- else:
261
- item['md'] += f'<img src="{item["img_url"]}#float" alt="thumbnail" width="150" height="240"><br>'
262
- else:
263
- item['img_url'] = "/home/user/app/null.png"
264
- item['is_video'] = False
265
- item['video_url'] = ""
266
-
267
- item['md'] += f'''Model URL: [https://civitai.com/models/{j["id"]}](https://civitai.com/models/{j["id"]})<br>Model Name: {item["name"]}<br>
268
- Creator: {item["creator"]}<br>Tags: {", ".join(item["tags"])}<br>Base Model: {item["base_model"]}<br>Description: {item["description"]}'''
269
- if 'files' in model.keys():
270
- for f in model['files']:
271
- i = item.copy()
272
- i['dl_url'] = f['downloadUrl']
273
- if len(filetype) != 0 and f['type'] not in set(filetype): continue
274
- items.append(i)
275
- else:
276
- item['dl_url'] = model['downloadUrl']
277
- items.append(item)
278
- return items if len(items) > 0 else None
279
-
280
-
281
- def search_civitai(query, types, base_model=[], sort=SORT[0], period=PERIOD[0], tag="", user="", limit=100, page=1,
282
- filetype=[], api_key="", gallery=[], state={}, progress=gr.Progress(track_tqdm=True)):
283
- civitai_last_results = {}
284
- set_state(state, "civitai_last_choices", [("", "")])
285
- set_state(state, "civitai_last_gallery", [])
286
- set_state(state, "civitai_last_results", civitai_last_results)
287
- results_info = "No item found."
288
- items = search_on_civitai(query, types, base_model, int(limit), sort, period, tag, user, int(page), filetype, api_key)
289
- if not items: return gr.update(choices=[("", "")], value=[], visible=True),\
290
- gr.update(value="", visible=False), gr.update(), gr.update(), gr.update(), gr.update(), results_info, state
291
- choices = []
292
- gallery = []
293
- for item in items:
294
- base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
295
- name = f"{item['name']} (for {base_model_name} / By: {item['creator']})"
296
- value = item['dl_url']
297
- choices.append((name, value))
298
-
299
- # For gallery, use tuples with HTML that includes both image and video
300
- if item.get('is_video') and item.get('video_url'):
301
- # Create an HTML element that contains both image and video
302
- media_html = f"""
303
- <div class="media-container">
304
- <img src="{item['img_url']}" alt="{name}">
305
- <video src="{item['video_url']}" muted loop poster="{item['img_url']}"></video>
306
- </div>
307
- """
308
- gallery.append((item['img_url'], name)) # Keep using image as thumbnail
309
- else:
310
- gallery.append((item['img_url'], name))
311
-
312
- civitai_last_results[value] = item
313
- if len(choices) >= 1:
314
- results_info = f"{int(len(choices))} items found."
315
- else:
316
- choices = [("", "")]
317
-
318
- md = ""
319
- set_state(state, "civitai_last_choices", choices)
320
- set_state(state, "civitai_last_gallery", gallery)
321
- set_state(state, "civitai_last_results", civitai_last_results)
322
-
323
- return gr.update(choices=choices, value=[], visible=True),\
324
- gr.update(value=md, visible=True),\
325
- gr.update(),\
326
- gr.update(),\
327
- gr.update(value=gallery),\
328
- gr.update(choices=choices, value=[]),\
329
- results_info,\
330
- state
331
-
332
-
333
- def get_civitai_json(dl_url: str, is_html: bool=False, image_baseurl: str="", api_key=""):
334
- if not image_baseurl: image_baseurl = dl_url
335
- default = ("", "", "") if is_html else ""
336
- if "https://civitai.com/api/download/models/" not in dl_url: return default
337
- user_agent = get_user_agent()
338
- headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
339
- if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
340
- base_url = 'https://civitai.com/api/v1/model-versions/'
341
- params = {}
342
- session = requests.Session()
343
- retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
344
- session.mount("https://", HTTPAdapter(max_retries=retries))
345
- model_id = re.sub('https://civitai.com/api/download/models/(\\d+)(?:.+)?', '\\1', dl_url)
346
- url = base_url + model_id
347
- #url = base_url + str(dl_url.split("/")[-1])
348
- try:
349
- r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
350
- if not r.ok: return default
351
- json = dict(r.json()).copy()
352
- html = ""
353
- image = ""
354
- if "modelId" in json.keys():
355
- url = f"https://civitai.com/models/{json['modelId']}"
356
- r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
357
- if not r.ok: return json, html, image
358
- html = r.text
359
- if 'images' in json.keys() and len(json["images"]) != 0:
360
- url = json["images"][0]["url"]
361
- r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
362
- if not r.ok: return json, html, image
363
- image_temp = str(Path(TEMP_DIR, "image" + Path(url.split("/")[-1]).suffix))
364
- image = str(Path(TEMP_DIR, Path(image_baseurl.split("/")[-1]).stem + ".png"))
365
- with open(image_temp, 'wb') as f:
366
- f.write(r.content)
367
- Image.open(image_temp).convert('RGBA').save(image)
368
- return json, html, image
369
- except Exception as e:
370
- print(e)
371
- return default
372
-
373
-
374
- def get_civitai_tag():
375
- default = [""]
376
- user_agent = get_user_agent()
377
- headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
378
- base_url = 'https://civitai.com/api/v1/tags'
379
- params = {'limit': 200}
380
- session = requests.Session()
381
- retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
382
- session.mount("https://", HTTPAdapter(max_retries=retries))
383
- url = base_url
384
- try:
385
- r = session.get(url, params=params, headers=headers, stream=True, timeout=(7.0, 15))
386
- if not r.ok: return default
387
- j = dict(r.json()).copy()
388
- if "items" not in j.keys(): return default
389
- items = []
390
- for item in j["items"]:
391
- items.append([str(item.get("name", "")), int(item.get("modelCount", 0))])
392
- df = pd.DataFrame(items)
393
- df.sort_values(1, ascending=False)
394
- tags = df.values.tolist()
395
- tags = [""] + [l[0] for l in tags]
396
- return tags
397
- except Exception as e:
398
- print(e)
399
- return default
400
-
401
-
402
- def select_civitai_item(results: list[str], state: dict):
403
- json = {}
404
- if "http" not in "".join(results) or len(results) == 0: return gr.update(value="", visible=True), gr.update(value=json, visible=False), state
405
- result = get_state(state, "civitai_last_results")
406
- last_selects = get_state(state, "civitai_last_selects")
407
- selects = list_sub(results, last_selects if last_selects else [])
408
- md = result.get(selects[-1]).get('md', "") if result and isinstance(result, dict) and len(selects) > 0 else ""
409
- set_state(state, "civitai_last_selects", results)
410
- return gr.update(value=md, visible=True), gr.update(value=json, visible=False), state
411
-
412
-
413
- def add_civitai_item(results: list[str], dl_url: str):
414
- if "http" not in "".join(results): return gr.update(value=dl_url)
415
- new_url = dl_url if dl_url else ""
416
- for result in results:
417
- if "http" not in result: continue
418
- new_url += f"\n{result}" if new_url else f"{result}"
419
- new_url = uniq_urls(new_url)
420
- return gr.update(value=new_url)
421
-
422
-
423
- def select_civitai_all_item(button_name: str, state: dict):
424
- if button_name not in ["Select All", "Deselect All"]: return gr.update(value=button_name), gr.Update(visible=True)
425
- civitai_last_choices = get_state(state, "civitai_last_choices")
426
- selected = [t[1] for t in civitai_last_choices if t[1] != ""] if button_name == "Select All" else []
427
- new_button_name = "Select All" if button_name == "Deselect All" else "Deselect All"
428
- return gr.update(value=new_button_name), gr.update(value=selected, choices=civitai_last_choices)
429
-
430
-
431
- def update_civitai_selection(evt: gr.SelectData, value: list[str], state: dict):
432
- try:
433
- civitai_last_choices = get_state(state, "civitai_last_choices")
434
- selected_index = evt.index
435
- selected = list_uniq([v for v in value if v != ""] + [civitai_last_choices[selected_index][1]])
436
- return gr.update(value=selected)
437
- except Exception:
438
- return gr.update()
439
-
440
-
441
- def update_civitai_checkbox(selected: list[str]):
442
- return gr.update(value=selected)
443
-
444
-
445
- def from_civitai_checkbox(selected: list[str]):
446
- return gr.update(value=selected)
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, hf_hub_url
3
+ import os
4
+ from pathlib import Path
5
+ import gc
6
+ import requests
7
+ from requests.adapters import HTTPAdapter
8
+ from urllib3.util import Retry
9
+ from civitai_constants import PERIOD, SORT
10
+ from utils import (get_token, set_token, is_repo_exists, get_user_agent, get_download_file,
11
+ list_uniq, list_sub, duplicate_hf_repo, HF_SUBFOLDER_NAME, get_state, set_state)
12
+ import re
13
+ from PIL import Image
14
+ import json
15
+ import pandas as pd
16
+ import tempfile
17
+ import hashlib
18
+ import logging
19
+ # Huge shoutout to @John6666, saved me many hours.
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ TEMP_DIR = tempfile.mkdtemp()
26
+
27
+
28
+ def parse_urls(s):
29
+ url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+"
30
+ try:
31
+ urls = re.findall(url_pattern, s)
32
+ return list(urls)
33
+ except Exception:
34
+ return []
35
+
36
+
37
+ def parse_repos(s):
38
+ repo_pattern = r'[^\w_\-\.]?([\w_\-\.]+/[\w_\-\.]+)[^\w_\-\.]?'
39
+ try:
40
+ s = re.sub("https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+", "", s)
41
+ repos = re.findall(repo_pattern, s)
42
+ return list(repos)
43
+ except Exception:
44
+ return []
45
+
46
+
47
+ def to_urls(l: list[str]):
48
+ return "\n".join(l)
49
+
50
+
51
+ def uniq_urls(s):
52
+ return to_urls(list_uniq(parse_urls(s) + parse_repos(s)))
53
+
54
+
55
+ def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
56
+ output_filename = Path(filename).name
57
+ hf_token = get_token()
58
+ api = HfApi(token=hf_token)
59
+ try:
60
+ if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
61
+ progress(0, desc=f"Start uploading... {filename} to {repo_id}")
62
+ api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
63
+ progress(1, desc="Uploaded.")
64
+ url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
65
+ except Exception as e:
66
+ print(f"Error: Failed to upload to {repo_id}. {e}")
67
+ gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
68
+ return None
69
+ finally:
70
+ if Path(filename).exists(): Path(filename).unlink()
71
+ return url
72
+
73
+
74
+ def is_same_file(filename: str, cmp_sha256: str, cmp_size: int):
75
+ if cmp_sha256:
76
+ sha256_hash = hashlib.sha256()
77
+ with open(filename, "rb") as f:
78
+ for byte_block in iter(lambda: f.read(4096), b""):
79
+ sha256_hash.update(byte_block)
80
+ sha256 = sha256_hash.hexdigest()
81
+ else: sha256 = ""
82
+ size = os.path.getsize(filename)
83
+ if size == cmp_size and sha256 == cmp_sha256: return True
84
+ else: return False
85
+
86
+
87
+ def get_safe_filename(filename, repo_id, repo_type):
88
+ hf_token = get_token()
89
+ api = HfApi(token=hf_token)
90
+ new_filename = filename
91
+ try:
92
+ i = 1
93
+ while api.file_exists(repo_id=repo_id, filename=Path(new_filename).name, repo_type=repo_type, token=hf_token):
94
+ infos = api.get_paths_info(repo_id=repo_id, paths=[Path(new_filename).name], repo_type=repo_type, token=hf_token)
95
+ if infos and len(infos) == 1:
96
+ repo_fs = infos[0].size
97
+ repo_sha256 = infos[0].lfs.sha256 if infos[0].lfs is not None else ""
98
+ if is_same_file(filename, repo_sha256, repo_fs): break
99
+ new_filename = str(Path(Path(filename).parent, f"{Path(filename).stem}_{i}{Path(filename).suffix}"))
100
+ i += 1
101
+ if filename != new_filename:
102
+ print(f"{Path(filename).name} is already exists but file content is different. renaming to {Path(new_filename).name}.")
103
+ Path(filename).rename(new_filename)
104
+ except Exception as e:
105
+ print(f"Error occured when renaming {filename}. {e}")
106
+ finally:
107
+ return new_filename
108
+
109
+
110
+ def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
111
+ download_dir = TEMP_DIR
112
+ progress(0, desc=f"Start downloading... {dl_url}")
113
+ output_filename = get_download_file(download_dir, dl_url, civitai_key)
114
+ return output_filename
115
+
116
+
117
+ def save_civitai_info(dl_url, filename, civitai_key="", progress=gr.Progress(track_tqdm=True)):
118
+ json_str, html_str, image_path = get_civitai_json(dl_url, True, filename, civitai_key)
119
+ if not json_str: return "", "", ""
120
+ json_path = str(Path(TEMP_DIR, Path(filename).stem + ".json"))
121
+ html_path = str(Path(TEMP_DIR, Path(filename).stem + ".html"))
122
+ try:
123
+ with open(json_path, 'w') as f:
124
+ json.dump(json_str, f, indent=2)
125
+ with open(html_path, mode='w', encoding="utf-8") as f:
126
+ f.write(html_str)
127
+ return json_path, html_path, image_path
128
+ except Exception as e:
129
+ print(f"Error: Failed to save info file {json_path}, {html_path} {e}")
130
+ return "", "", ""
131
+
132
+
133
+ def upload_info_to_repo(dl_url, filename, repo_id, repo_type, is_private, civitai_key="", progress=gr.Progress(track_tqdm=True)):
134
+ def upload_file(api, filename, repo_id, repo_type, hf_token):
135
+ if not Path(filename).exists(): return
136
+ api.upload_file(path_or_fileobj=filename, path_in_repo=Path(filename).name, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
137
+ Path(filename).unlink()
138
+
139
+ hf_token = get_token()
140
+ api = HfApi(token=hf_token)
141
+ try:
142
+ if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
143
+ progress(0, desc=f"Downloading info... {filename}")
144
+ json_path, html_path, image_path = save_civitai_info(dl_url, filename, civitai_key)
145
+ progress(0, desc=f"Start uploading info... {filename} to {repo_id}")
146
+ if not json_path: return
147
+ else: upload_file(api, json_path, repo_id, repo_type, hf_token)
148
+ if html_path: upload_file(api, html_path, repo_id, repo_type, hf_token)
149
+ if image_path: upload_file(api, image_path, repo_id, repo_type, hf_token)
150
+ progress(1, desc="Info uploaded.")
151
+ return
152
+ except Exception as e:
153
+ print(f"Error: Failed to upload info to {repo_id}. {e}")
154
+ gr.Warning(f"Error: Failed to upload info to {repo_id}. {e}")
155
+ return
156
+
157
+
158
+ def download_civitai(dl_url, civitai_key, hf_token, urls,
159
+ newrepo_id, repo_type="model", is_private=True, is_info=False, is_rename=True, progress=gr.Progress(track_tqdm=True)):
160
+ if hf_token: set_token(hf_token)
161
+ else: set_token(os.environ.get("HF_TOKEN")) # default huggingface write token
162
+ if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
163
+ if not newrepo_id: newrepo_id = os.environ.get("HF_REPO") # default repo to upload
164
+ if not get_token() or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.")
165
+ if not urls: urls = []
166
+ dl_urls = parse_urls(dl_url)
167
+ remain_urls = dl_urls.copy()
168
+ try:
169
+ md = f'### Your repo: [{newrepo_id}]({"https://huggingface.co/datasets/" if repo_type == "dataset" else "https://huggingface.co/"}{newrepo_id})\n'
170
+ for u in dl_urls:
171
+ file = download_file(u, civitai_key)
172
+ if not Path(file).exists() or not Path(file).is_file(): continue
173
+ if is_rename: file = get_safe_filename(file, newrepo_id, repo_type)
174
+ url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
175
+ if url:
176
+ if is_info: upload_info_to_repo(u, file, newrepo_id, repo_type, is_private, civitai_key)
177
+ urls.append(url)
178
+ remain_urls.remove(u)
179
+ md += f"- Uploaded [{str(u)}]({str(u)})\n"
180
+ dp_repos = parse_repos(dl_url)
181
+ for r in dp_repos:
182
+ url = duplicate_hf_repo(r, newrepo_id, "model", repo_type, is_private, HF_SUBFOLDER_NAME[1])
183
+ if url: urls.append(url)
184
+ return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=False)
185
+ except Exception as e:
186
+ gr.Info(f"Error occured: {e}")
187
+ return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=True)
188
+ finally:
189
+ gc.collect()
190
+
191
+
192
+ def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100,
193
+ sort: str = "Highest Rated", period: str = "AllTime", tag: str = "", user: str = "", page: int = 1,
194
+ filetype: list[str] = [], api_key: str = "", progress=gr.Progress(track_tqdm=True)):
195
+ user_agent = get_user_agent()
196
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
197
+ if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
198
+ base_url = 'https://civitai.com/api/v1/models'
199
+ params = {'sort': sort, 'period': period, 'limit': int(limit), 'nsfw': 'true'}
200
+ if len(types) != 0: params["types"] = types
201
+ if query: params["query"] = query
202
+ if tag: params["tag"] = tag
203
+ if user: params["username"] = user
204
+ if page != 0: params["page"] = int(page)
205
+ session = requests.Session()
206
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
207
+ session.mount("https://", HTTPAdapter(max_retries=retries))
208
+ rs = []
209
+ try:
210
+ if page == 0:
211
+ progress(0, desc="Searching page 1...")
212
+ print("Searching page 1...")
213
+ r = session.get(base_url, params=params | {'page': 1}, headers=headers, stream=True, timeout=(7.0, 30))
214
+ rs.append(r)
215
+ if r.ok:
216
+ json = r.json()
217
+ next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
218
+ i = 2
219
+ while(next_url is not None):
220
+ progress(0, desc=f"Searching page {i}...")
221
+ print(f"Searching page {i}...")
222
+ r = session.get(next_url, headers=headers, stream=True, timeout=(7.0, 30))
223
+ rs.append(r)
224
+ if r.ok:
225
+ json = r.json()
226
+ next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
227
+ else: next_url = None
228
+ i += 1
229
+ else:
230
+ progress(0, desc="Searching page 1...")
231
+ print("Searching page 1...")
232
+ r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(7.0, 30))
233
+ rs.append(r)
234
+ except requests.exceptions.ConnectTimeout:
235
+ print("Request timed out.")
236
+ except Exception as e:
237
+ print(e)
238
+ items = []
239
+ for r in rs:
240
+ if not r.ok: continue
241
+ json = r.json()
242
+ if 'items' not in json: continue
243
+ for j in json['items']:
244
+ for model in j['modelVersions']:
245
+ item = {}
246
+ if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue
247
+ item['name'] = j['name']
248
+ item['creator'] = j['creator']['username'] if 'creator' in j.keys() and 'username' in j['creator'].keys() else ""
249
+ item['tags'] = j['tags'] if 'tags' in j.keys() else []
250
+ item['model_name'] = model['name'] if 'name' in model.keys() else ""
251
+ item['base_model'] = model['baseModel'] if 'baseModel' in model.keys() else ""
252
+ item['description'] = model['description'] if 'description' in model.keys() else ""
253
+ item['md'] = ""
254
+
255
+ # Handle both images and videos
256
+ if 'images' in model.keys() and len(model["images"]) != 0:
257
+ first_media = model["images"][0]
258
+ item['img_url'] = first_media["url"]
259
+ item['is_video'] = first_media.get("type", "image") == "video"
260
+ item['video_url'] = first_media.get("meta", {}).get("video", "") if item['is_video'] else ""
261
+
262
+ if item['is_video']:
263
+ item['md'] += f'<video src="{item["img_url"]}" poster="{item["img_url"]}" muted loop autoplay width="300" height="480" style="float:right;padding:16px;"></video><br>'
264
+ else:
265
+ item['md'] += f'<img src="{item["img_url"]}#float" alt="thumbnail" width="150" height="240"><br>'
266
+ else:
267
+ item['img_url'] = "/home/user/app/null.png"
268
+ item['is_video'] = False
269
+ item['video_url'] = ""
270
+
271
+ item['md'] += f'''Model URL: [https://civitai.com/models/{j["id"]}](https://civitai.com/models/{j["id"]})<br>Model Name: {item["name"]}<br>
272
+ Creator: {item["creator"]}<br>Tags: {", ".join(item["tags"])}<br>Base Model: {item["base_model"]}<br>Description: {item["description"]}'''
273
+ if 'files' in model.keys():
274
+ for f in model['files']:
275
+ i = item.copy()
276
+ i['dl_url'] = f['downloadUrl']
277
+ if len(filetype) != 0 and f['type'] not in set(filetype): continue
278
+ items.append(i)
279
+ else:
280
+ item['dl_url'] = model['downloadUrl']
281
+ items.append(item)
282
+ return items if len(items) > 0 else None
283
+
284
+
285
+ def search_civitai(query, types, base_model=[], sort=SORT[0], period=PERIOD[0], tag="", user="", limit=100, page=1,
286
+ filetype=[], api_key="", gallery=[], state={}, progress=gr.Progress(track_tqdm=True)):
287
+ civitai_last_results = {}
288
+ set_state(state, "civitai_last_choices", [("", "")])
289
+ set_state(state, "civitai_last_gallery", [])
290
+ set_state(state, "civitai_last_results", civitai_last_results)
291
+ results_info = "No item found."
292
+ items = search_on_civitai(query, types, base_model, int(limit), sort, period, tag, user, int(page), filetype, api_key)
293
+ if not items: return gr.update(choices=[("", "")], value=[], visible=True),\
294
+ gr.update(value="", visible=False), gr.update(), gr.update(), gr.update(), gr.update(), results_info, state
295
+ choices = []
296
+ gallery = []
297
+ for item in items:
298
+ base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
299
+ name = f"{item['name']} (for {base_model_name} / By: {item['creator']})"
300
+ value = item['dl_url']
301
+ choices.append((name, value))
302
+
303
+ # For gallery, use tuples with HTML that includes both image and video
304
+ if item.get('is_video') and item.get('video_url'):
305
+ # Create an HTML element that contains both image and video
306
+ media_html = f"""
307
+ <div class="media-container">
308
+ <img src="{item['img_url']}" alt="{name}">
309
+ <video src="{item['video_url']}" muted loop poster="{item['img_url']}"></video>
310
+ </div>
311
+ """
312
+ gallery.append((item['img_url'], name)) # Keep using image as thumbnail
313
+ else:
314
+ gallery.append((item['img_url'], name))
315
+
316
+ civitai_last_results[value] = item
317
+ if len(choices) >= 1:
318
+ results_info = f"{int(len(choices))} items found."
319
+ else:
320
+ choices = [("", "")]
321
+
322
+ md = ""
323
+ set_state(state, "civitai_last_choices", choices)
324
+ set_state(state, "civitai_last_gallery", gallery)
325
+ set_state(state, "civitai_last_results", civitai_last_results)
326
+
327
+ return gr.update(choices=choices, value=[], visible=True),\
328
+ gr.update(value=md, visible=True),\
329
+ gr.update(),\
330
+ gr.update(),\
331
+ gr.update(value=gallery),\
332
+ gr.update(choices=choices, value=[]),\
333
+ results_info,\
334
+ state
335
+
336
+
337
+ def get_civitai_json(dl_url: str, is_html: bool=False, image_baseurl: str="", api_key=""):
338
+ if not image_baseurl: image_baseurl = dl_url
339
+ default = ("", "", "") if is_html else ""
340
+ if "https://civitai.com/api/download/models/" not in dl_url: return default
341
+ user_agent = get_user_agent()
342
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
343
+ if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
344
+ base_url = 'https://civitai.com/api/v1/model-versions/'
345
+ params = {}
346
+ session = requests.Session()
347
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
348
+ session.mount("https://", HTTPAdapter(max_retries=retries))
349
+ model_id = re.sub('https://civitai.com/api/download/models/(\\d+)(?:.+)?', '\\1', dl_url)
350
+ url = base_url + model_id
351
+ #url = base_url + str(dl_url.split("/")[-1])
352
+ try:
353
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
354
+ if not r.ok: return default
355
+ json = dict(r.json()).copy()
356
+ html = ""
357
+ image = ""
358
+ if "modelId" in json.keys():
359
+ url = f"https://civitai.com/models/{json['modelId']}"
360
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
361
+ if not r.ok: return json, html, image
362
+ html = r.text
363
+ if 'images' in json.keys() and len(json["images"]) != 0:
364
+ url = json["images"][0]["url"]
365
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
366
+ if not r.ok: return json, html, image
367
+ image_temp = str(Path(TEMP_DIR, "image" + Path(url.split("/")[-1]).suffix))
368
+ image = str(Path(TEMP_DIR, Path(image_baseurl.split("/")[-1]).stem + ".png"))
369
+ with open(image_temp, 'wb') as f:
370
+ f.write(r.content)
371
+ Image.open(image_temp).convert('RGBA').save(image)
372
+ return json, html, image
373
+ except Exception as e:
374
+ print(e)
375
+ return default
376
+
377
+
378
+ def get_civitai_tag():
379
+ default = [""]
380
+ user_agent = get_user_agent()
381
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
382
+ base_url = 'https://civitai.com/api/v1/tags'
383
+ params = {'limit': 200}
384
+ session = requests.Session()
385
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
386
+ session.mount("https://", HTTPAdapter(max_retries=retries))
387
+ url = base_url
388
+ try:
389
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(7.0, 15))
390
+ if not r.ok: return default
391
+ j = dict(r.json()).copy()
392
+ if "items" not in j.keys(): return default
393
+ items = []
394
+ for item in j["items"]:
395
+ items.append([str(item.get("name", "")), int(item.get("modelCount", 0))])
396
+ df = pd.DataFrame(items)
397
+ df.sort_values(1, ascending=False)
398
+ tags = df.values.tolist()
399
+ tags = [""] + [l[0] for l in tags]
400
+ return tags
401
+ except Exception as e:
402
+ print(e)
403
+ return default
404
+
405
+
406
+ def select_civitai_item(results: list[str], state: dict):
407
+ json = {}
408
+ if "http" not in "".join(results) or len(results) == 0: return gr.update(value="", visible=True), gr.update(value=json, visible=False), state
409
+ result = get_state(state, "civitai_last_results")
410
+ last_selects = get_state(state, "civitai_last_selects")
411
+ selects = list_sub(results, last_selects if last_selects else [])
412
+ md = result.get(selects[-1]).get('md', "") if result and isinstance(result, dict) and len(selects) > 0 else ""
413
+ set_state(state, "civitai_last_selects", results)
414
+ return gr.update(value=md, visible=True), gr.update(value=json, visible=False), state
415
+
416
+
417
+ def add_civitai_item(results: list[str], dl_url: str):
418
+ if "http" not in "".join(results): return gr.update(value=dl_url)
419
+ new_url = dl_url if dl_url else ""
420
+ for result in results:
421
+ if "http" not in result: continue
422
+ new_url += f"\n{result}" if new_url else f"{result}"
423
+ new_url = uniq_urls(new_url)
424
+ return gr.update(value=new_url)
425
+
426
+
427
+ def select_civitai_all_item(button_name: str, state: dict):
428
+ if button_name not in ["Select All", "Deselect All"]: return gr.update(value=button_name), gr.Update(visible=True)
429
+ civitai_last_choices = get_state(state, "civitai_last_choices")
430
+ selected = [t[1] for t in civitai_last_choices if t[1] != ""] if button_name == "Select All" else []
431
+ new_button_name = "Select All" if button_name == "Deselect All" else "Deselect All"
432
+ return gr.update(value=new_button_name), gr.update(value=selected, choices=civitai_last_choices)
433
+
434
+
435
+ def update_civitai_selection(evt: gr.SelectData, value: list[str], state: dict):
436
+ try:
437
+ civitai_last_choices = get_state(state, "civitai_last_choices")
438
+ selected_index = evt.index
439
+ selected = list_uniq([v for v in value if v != ""] + [civitai_last_choices[selected_index][1]])
440
+ return gr.update(value=selected)
441
+ except Exception:
442
+ return gr.update()
443
+
444
+
445
+ def update_civitai_checkbox(selected: list[str]):
446
+ return gr.update(value=selected)
447
+
448
+
449
+ def from_civitai_checkbox(selected: list[str]):
450
+ return gr.update(value=selected)