|
import gradio as gr |
|
from huggingface_hub import HfApi, hf_hub_url |
|
import os |
|
from pathlib import Path |
|
import gc |
|
import requests |
|
from requests.adapters import HTTPAdapter |
|
from urllib3.util import Retry |
|
from civitai_constants import PERIOD, SORT |
|
from utils import (get_token, set_token, is_repo_exists, get_user_agent, get_download_file, |
|
list_uniq, list_sub, duplicate_hf_repo, HF_SUBFOLDER_NAME, get_state, set_state) |
|
import re |
|
from PIL import Image |
|
import json |
|
import pandas as pd |
|
import tempfile |
|
import hashlib |
|
import logging |
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
TEMP_DIR = tempfile.mkdtemp() |
|
|
|
|
|
def parse_urls(s): |
|
url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+" |
|
try: |
|
urls = re.findall(url_pattern, s) |
|
return list(urls) |
|
except Exception: |
|
return [] |
|
|
|
|
|
def parse_repos(s): |
|
repo_pattern = r'[^\w_\-\.]?([\w_\-\.]+/[\w_\-\.]+)[^\w_\-\.]?' |
|
try: |
|
s = re.sub("https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+", "", s) |
|
repos = re.findall(repo_pattern, s) |
|
return list(repos) |
|
except Exception: |
|
return [] |
|
|
|
|
|
def to_urls(l: list[str]): |
|
return "\n".join(l) |
|
|
|
|
|
def uniq_urls(s): |
|
return to_urls(list_uniq(parse_urls(s) + parse_repos(s))) |
|
|
|
|
|
def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)): |
|
output_filename = Path(filename).name |
|
hf_token = get_token() |
|
api = HfApi(token=hf_token) |
|
try: |
|
if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private) |
|
progress(0, desc=f"Start uploading... {filename} to {repo_id}") |
|
api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id) |
|
progress(1, desc="Uploaded.") |
|
url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename) |
|
except Exception as e: |
|
print(f"Error: Failed to upload to {repo_id}. {e}") |
|
gr.Warning(f"Error: Failed to upload to {repo_id}. {e}") |
|
return None |
|
finally: |
|
if Path(filename).exists(): Path(filename).unlink() |
|
return url |
|
|
|
|
|
def is_same_file(filename: str, cmp_sha256: str, cmp_size: int): |
|
if cmp_sha256: |
|
sha256_hash = hashlib.sha256() |
|
with open(filename, "rb") as f: |
|
for byte_block in iter(lambda: f.read(4096), b""): |
|
sha256_hash.update(byte_block) |
|
sha256 = sha256_hash.hexdigest() |
|
else: sha256 = "" |
|
size = os.path.getsize(filename) |
|
if size == cmp_size and sha256 == cmp_sha256: return True |
|
else: return False |
|
|
|
|
|
def get_safe_filename(filename, repo_id, repo_type): |
|
hf_token = get_token() |
|
api = HfApi(token=hf_token) |
|
new_filename = filename |
|
try: |
|
i = 1 |
|
while api.file_exists(repo_id=repo_id, filename=Path(new_filename).name, repo_type=repo_type, token=hf_token): |
|
infos = api.get_paths_info(repo_id=repo_id, paths=[Path(new_filename).name], repo_type=repo_type, token=hf_token) |
|
if infos and len(infos) == 1: |
|
repo_fs = infos[0].size |
|
repo_sha256 = infos[0].lfs.sha256 if infos[0].lfs is not None else "" |
|
if is_same_file(filename, repo_sha256, repo_fs): break |
|
new_filename = str(Path(Path(filename).parent, f"{Path(filename).stem}_{i}{Path(filename).suffix}")) |
|
i += 1 |
|
if filename != new_filename: |
|
print(f"{Path(filename).name} is already exists but file content is different. renaming to {Path(new_filename).name}.") |
|
Path(filename).rename(new_filename) |
|
except Exception as e: |
|
print(f"Error occured when renaming {filename}. {e}") |
|
finally: |
|
return new_filename |
|
|
|
|
|
def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)): |
|
download_dir = TEMP_DIR |
|
progress(0, desc=f"Start downloading... {dl_url}") |
|
output_filename = get_download_file(download_dir, dl_url, civitai_key) |
|
return output_filename |
|
|
|
|
|
def save_civitai_info(dl_url, filename, civitai_key="", progress=gr.Progress(track_tqdm=True)): |
|
json_str, html_str, image_path = get_civitai_json(dl_url, True, filename, civitai_key) |
|
if not json_str: return "", "", "" |
|
json_path = str(Path(TEMP_DIR, Path(filename).stem + ".json")) |
|
html_path = str(Path(TEMP_DIR, Path(filename).stem + ".html")) |
|
try: |
|
with open(json_path, 'w') as f: |
|
json.dump(json_str, f, indent=2) |
|
with open(html_path, mode='w', encoding="utf-8") as f: |
|
f.write(html_str) |
|
return json_path, html_path, image_path |
|
except Exception as e: |
|
print(f"Error: Failed to save info file {json_path}, {html_path} {e}") |
|
return "", "", "" |
|
|
|
|
|
def upload_info_to_repo(dl_url, filename, repo_id, repo_type, is_private, civitai_key="", progress=gr.Progress(track_tqdm=True)): |
|
def upload_file(api, filename, repo_id, repo_type, hf_token): |
|
if not Path(filename).exists(): return |
|
api.upload_file(path_or_fileobj=filename, path_in_repo=Path(filename).name, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id) |
|
Path(filename).unlink() |
|
|
|
hf_token = get_token() |
|
api = HfApi(token=hf_token) |
|
try: |
|
if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private) |
|
progress(0, desc=f"Downloading info... {filename}") |
|
json_path, html_path, image_path = save_civitai_info(dl_url, filename, civitai_key) |
|
progress(0, desc=f"Start uploading info... {filename} to {repo_id}") |
|
if not json_path: return |
|
else: upload_file(api, json_path, repo_id, repo_type, hf_token) |
|
if html_path: upload_file(api, html_path, repo_id, repo_type, hf_token) |
|
if image_path: upload_file(api, image_path, repo_id, repo_type, hf_token) |
|
progress(1, desc="Info uploaded.") |
|
return |
|
except Exception as e: |
|
print(f"Error: Failed to upload info to {repo_id}. {e}") |
|
gr.Warning(f"Error: Failed to upload info to {repo_id}. {e}") |
|
return |
|
|
|
|
|
def download_civitai(dl_url, civitai_key, hf_token, urls, |
|
newrepo_id, repo_type="model", is_private=True, is_info=False, is_rename=True, progress=gr.Progress(track_tqdm=True)): |
|
if hf_token: set_token(hf_token) |
|
else: set_token(os.environ.get("HF_TOKEN")) |
|
if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") |
|
if not newrepo_id: newrepo_id = os.environ.get("HF_REPO") |
|
if not get_token() or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.") |
|
if not urls: urls = [] |
|
dl_urls = parse_urls(dl_url) |
|
remain_urls = dl_urls.copy() |
|
try: |
|
md = f'### Your repo: [{newrepo_id}]({"https://huggingface.co/datasets/" if repo_type == "dataset" else "https://huggingface.co/"}{newrepo_id})\n' |
|
for u in dl_urls: |
|
file = download_file(u, civitai_key) |
|
if not Path(file).exists() or not Path(file).is_file(): continue |
|
if is_rename: file = get_safe_filename(file, newrepo_id, repo_type) |
|
url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private) |
|
if url: |
|
if is_info: upload_info_to_repo(u, file, newrepo_id, repo_type, is_private, civitai_key) |
|
urls.append(url) |
|
remain_urls.remove(u) |
|
md += f"- Uploaded [{str(u)}]({str(u)})\n" |
|
dp_repos = parse_repos(dl_url) |
|
for r in dp_repos: |
|
url = duplicate_hf_repo(r, newrepo_id, "model", repo_type, is_private, HF_SUBFOLDER_NAME[1]) |
|
if url: urls.append(url) |
|
return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=False) |
|
except Exception as e: |
|
gr.Info(f"Error occured: {e}") |
|
return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=True) |
|
finally: |
|
gc.collect() |
|
|
|
|
|
def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100, |
|
sort: str = "Highest Rated", period: str = "AllTime", tag: str = "", user: str = "", page: int = 1, |
|
filetype: list[str] = [], api_key: str = "", progress=gr.Progress(track_tqdm=True)): |
|
user_agent = get_user_agent() |
|
headers = {'User-Agent': user_agent, 'content-type': 'application/json'} |
|
if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}' |
|
base_url = 'https://civitai.com/api/v1/models' |
|
params = {'sort': sort, 'period': period, 'limit': int(limit), 'nsfw': 'true'} |
|
if len(types) != 0: params["types"] = types |
|
if query: params["query"] = query |
|
if tag: params["tag"] = tag |
|
if user: params["username"] = user |
|
if page != 0: params["page"] = int(page) |
|
session = requests.Session() |
|
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) |
|
session.mount("https://", HTTPAdapter(max_retries=retries)) |
|
rs = [] |
|
try: |
|
if page == 0: |
|
progress(0, desc="Searching page 1...") |
|
print("Searching page 1...") |
|
r = session.get(base_url, params=params | {'page': 1}, headers=headers, stream=True, timeout=(7.0, 30)) |
|
rs.append(r) |
|
if r.ok: |
|
json = r.json() |
|
next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None |
|
i = 2 |
|
while(next_url is not None): |
|
progress(0, desc=f"Searching page {i}...") |
|
print(f"Searching page {i}...") |
|
r = session.get(next_url, headers=headers, stream=True, timeout=(7.0, 30)) |
|
rs.append(r) |
|
if r.ok: |
|
json = r.json() |
|
next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None |
|
else: next_url = None |
|
i += 1 |
|
else: |
|
progress(0, desc="Searching page 1...") |
|
print("Searching page 1...") |
|
r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(7.0, 30)) |
|
rs.append(r) |
|
except requests.exceptions.ConnectTimeout: |
|
print("Request timed out.") |
|
except Exception as e: |
|
print(e) |
|
items = [] |
|
for r in rs: |
|
if not r.ok: continue |
|
json = r.json() |
|
if 'items' not in json: continue |
|
for j in json['items']: |
|
for model in j['modelVersions']: |
|
item = {} |
|
if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue |
|
item['name'] = j['name'] |
|
item['creator'] = j['creator']['username'] if 'creator' in j.keys() and 'username' in j['creator'].keys() else "" |
|
item['tags'] = j['tags'] if 'tags' in j.keys() else [] |
|
item['model_name'] = model['name'] if 'name' in model.keys() else "" |
|
item['base_model'] = model['baseModel'] if 'baseModel' in model.keys() else "" |
|
item['description'] = model['description'] if 'description' in model.keys() else "" |
|
item['md'] = "" |
|
|
|
|
|
if 'images' in model.keys() and len(model["images"]) != 0: |
|
first_media = model["images"][0] |
|
item['img_url'] = first_media["url"] |
|
item['is_video'] = first_media.get("type", "image") == "video" |
|
item['video_url'] = first_media.get("meta", {}).get("video", "") if item['is_video'] else "" |
|
|
|
if item['is_video']: |
|
item['md'] += f'<video src="{item["img_url"]}" poster="{item["img_url"]}" muted loop autoplay width="300" height="480" style="float:right;padding:16px;"></video><br>' |
|
else: |
|
item['md'] += f'<img src="{item["img_url"]}#float" alt="thumbnail" width="150" height="240"><br>' |
|
else: |
|
item['img_url'] = "/home/user/app/null.png" |
|
item['is_video'] = False |
|
item['video_url'] = "" |
|
|
|
item['md'] += f'''Model URL: [https://civitai.com/models/{j["id"]}](https://civitai.com/models/{j["id"]})<br>Model Name: {item["name"]}<br> |
|
Creator: {item["creator"]}<br>Tags: {", ".join(item["tags"])}<br>Base Model: {item["base_model"]}<br>Description: {item["description"]}''' |
|
if 'files' in model.keys(): |
|
for f in model['files']: |
|
i = item.copy() |
|
i['dl_url'] = f['downloadUrl'] |
|
if len(filetype) != 0 and f['type'] not in set(filetype): continue |
|
items.append(i) |
|
else: |
|
item['dl_url'] = model['downloadUrl'] |
|
items.append(item) |
|
return items if len(items) > 0 else None |
|
|
|
|
|
def search_civitai(query, types, base_model=[], sort=SORT[0], period=PERIOD[0], tag="", user="", limit=100, page=1, |
|
filetype=[], api_key="", gallery=[], state={}, progress=gr.Progress(track_tqdm=True)): |
|
civitai_last_results = {} |
|
set_state(state, "civitai_last_choices", [("", "")]) |
|
set_state(state, "civitai_last_gallery", []) |
|
set_state(state, "civitai_last_results", civitai_last_results) |
|
results_info = "No item found." |
|
items = search_on_civitai(query, types, base_model, int(limit), sort, period, tag, user, int(page), filetype, api_key) |
|
if not items: return gr.update(choices=[("", "")], value=[], visible=True),\ |
|
gr.update(value="", visible=False), gr.update(), gr.update(), gr.update(), gr.update(), results_info, state |
|
choices = [] |
|
gallery = [] |
|
for item in items: |
|
base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model'] |
|
name = f"{item['name']} (for {base_model_name} / By: {item['creator']})" |
|
value = item['dl_url'] |
|
choices.append((name, value)) |
|
|
|
|
|
if item.get('is_video') and item.get('video_url'): |
|
|
|
media_html = f""" |
|
<div class="media-container"> |
|
<img src="{item['img_url']}" alt="{name}"> |
|
<video src="{item['video_url']}" muted loop poster="{item['img_url']}"></video> |
|
</div> |
|
""" |
|
gallery.append((item['img_url'], name)) |
|
else: |
|
gallery.append((item['img_url'], name)) |
|
|
|
civitai_last_results[value] = item |
|
if len(choices) >= 1: |
|
results_info = f"{int(len(choices))} items found." |
|
else: |
|
choices = [("", "")] |
|
|
|
md = "" |
|
set_state(state, "civitai_last_choices", choices) |
|
set_state(state, "civitai_last_gallery", gallery) |
|
set_state(state, "civitai_last_results", civitai_last_results) |
|
|
|
return gr.update(choices=choices, value=[], visible=True),\ |
|
gr.update(value=md, visible=True),\ |
|
gr.update(),\ |
|
gr.update(),\ |
|
gr.update(value=gallery),\ |
|
gr.update(choices=choices, value=[]),\ |
|
results_info,\ |
|
state |
|
|
|
|
|
def get_civitai_json(dl_url: str, is_html: bool=False, image_baseurl: str="", api_key=""): |
|
if not image_baseurl: image_baseurl = dl_url |
|
default = ("", "", "") if is_html else "" |
|
if "https://civitai.com/api/download/models/" not in dl_url: return default |
|
user_agent = get_user_agent() |
|
headers = {'User-Agent': user_agent, 'content-type': 'application/json'} |
|
if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}' |
|
base_url = 'https://civitai.com/api/v1/model-versions/' |
|
params = {} |
|
session = requests.Session() |
|
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) |
|
session.mount("https://", HTTPAdapter(max_retries=retries)) |
|
model_id = re.sub('https://civitai.com/api/download/models/(\\d+)(?:.+)?', '\\1', dl_url) |
|
url = base_url + model_id |
|
|
|
try: |
|
r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15)) |
|
if not r.ok: return default |
|
json = dict(r.json()).copy() |
|
html = "" |
|
image = "" |
|
if "modelId" in json.keys(): |
|
url = f"https://civitai.com/models/{json['modelId']}" |
|
r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15)) |
|
if not r.ok: return json, html, image |
|
html = r.text |
|
if 'images' in json.keys() and len(json["images"]) != 0: |
|
url = json["images"][0]["url"] |
|
r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15)) |
|
if not r.ok: return json, html, image |
|
image_temp = str(Path(TEMP_DIR, "image" + Path(url.split("/")[-1]).suffix)) |
|
image = str(Path(TEMP_DIR, Path(image_baseurl.split("/")[-1]).stem + ".png")) |
|
with open(image_temp, 'wb') as f: |
|
f.write(r.content) |
|
Image.open(image_temp).convert('RGBA').save(image) |
|
return json, html, image |
|
except Exception as e: |
|
print(e) |
|
return default |
|
|
|
|
|
def get_civitai_tag(): |
|
default = [""] |
|
user_agent = get_user_agent() |
|
headers = {'User-Agent': user_agent, 'content-type': 'application/json'} |
|
base_url = 'https://civitai.com/api/v1/tags' |
|
params = {'limit': 200} |
|
session = requests.Session() |
|
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) |
|
session.mount("https://", HTTPAdapter(max_retries=retries)) |
|
url = base_url |
|
try: |
|
r = session.get(url, params=params, headers=headers, stream=True, timeout=(7.0, 15)) |
|
if not r.ok: return default |
|
j = dict(r.json()).copy() |
|
if "items" not in j.keys(): return default |
|
items = [] |
|
for item in j["items"]: |
|
items.append([str(item.get("name", "")), int(item.get("modelCount", 0))]) |
|
df = pd.DataFrame(items) |
|
df.sort_values(1, ascending=False) |
|
tags = df.values.tolist() |
|
tags = [""] + [l[0] for l in tags] |
|
return tags |
|
except Exception as e: |
|
print(e) |
|
return default |
|
|
|
|
|
def select_civitai_item(results: list[str], state: dict): |
|
json = {} |
|
if "http" not in "".join(results) or len(results) == 0: return gr.update(value="", visible=True), gr.update(value=json, visible=False), state |
|
result = get_state(state, "civitai_last_results") |
|
last_selects = get_state(state, "civitai_last_selects") |
|
selects = list_sub(results, last_selects if last_selects else []) |
|
md = result.get(selects[-1]).get('md', "") if result and isinstance(result, dict) and len(selects) > 0 else "" |
|
set_state(state, "civitai_last_selects", results) |
|
return gr.update(value=md, visible=True), gr.update(value=json, visible=False), state |
|
|
|
|
|
def add_civitai_item(results: list[str], dl_url: str): |
|
if "http" not in "".join(results): return gr.update(value=dl_url) |
|
new_url = dl_url if dl_url else "" |
|
for result in results: |
|
if "http" not in result: continue |
|
new_url += f"\n{result}" if new_url else f"{result}" |
|
new_url = uniq_urls(new_url) |
|
return gr.update(value=new_url) |
|
|
|
|
|
def select_civitai_all_item(button_name: str, state: dict): |
|
if button_name not in ["Select All", "Deselect All"]: return gr.update(value=button_name), gr.Update(visible=True) |
|
civitai_last_choices = get_state(state, "civitai_last_choices") |
|
selected = [t[1] for t in civitai_last_choices if t[1] != ""] if button_name == "Select All" else [] |
|
new_button_name = "Select All" if button_name == "Deselect All" else "Deselect All" |
|
return gr.update(value=new_button_name), gr.update(value=selected, choices=civitai_last_choices) |
|
|
|
|
|
def update_civitai_selection(evt: gr.SelectData, value: list[str], state: dict): |
|
try: |
|
civitai_last_choices = get_state(state, "civitai_last_choices") |
|
selected_index = evt.index |
|
selected = list_uniq([v for v in value if v != ""] + [civitai_last_choices[selected_index][1]]) |
|
return gr.update(value=selected) |
|
except Exception: |
|
return gr.update() |
|
|
|
|
|
def update_civitai_checkbox(selected: list[str]): |
|
return gr.update(value=selected) |
|
|
|
|
|
def from_civitai_checkbox(selected: list[str]): |
|
return gr.update(value=selected) |
|
|