Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import HfApi, hf_hub_url | |
import os | |
from pathlib import Path | |
import gc | |
import requests | |
from requests.adapters import HTTPAdapter | |
from urllib3.util import Retry | |
from utils import get_token, set_token, is_repo_exists, get_user_agent, get_download_file | |
def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)): | |
output_filename = Path(filename).name | |
hf_token = get_token() | |
api = HfApi(token=hf_token) | |
try: | |
if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private) | |
progress(0, desc="Start uploading...") | |
api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id) | |
progress(1, desc="Uploaded.") | |
url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename) | |
except Exception as e: | |
print(f"Error: Failed to upload to {repo_id}. {e}") | |
gr.Warning(f"Error: Failed to upload to {repo_id}. {e}") | |
return None | |
return url | |
def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)): | |
download_dir = "." | |
progress(0, desc="Start downloading...") | |
output_filename = get_download_file(download_dir, dl_url, civitai_key) | |
return output_filename | |
def download_civitai(dl_url, civitai_key, hf_token, urls, | |
newrepo_id, repo_type="model", is_private=True, progress=gr.Progress(track_tqdm=True)): | |
if hf_token: set_token(hf_token) | |
else: set_token(os.environ.get("HF_TOKEN")) | |
if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") | |
if not hf_token or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.") | |
file = download_file(dl_url, civitai_key) | |
if not urls: urls = [] | |
url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private) | |
progress(1, desc="Processing...") | |
if url: urls.append(url) | |
Path(file).unlink() | |
md = "" | |
for u in urls: | |
md += f"[Uploaded {str(u)}]({str(u)})<br>" | |
gc.collect() | |
return gr.update(value=urls, choices=urls), gr.update(value=md) | |
CIVITAI_TYPE = ["Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient", "LORA", "Controlnet", "Poses"] | |
CIVITAI_BASEMODEL = ["Pony", "SD 1.5", "SDXL 1.0", "Flux.1 D", "Flux.1 S"] | |
CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Newest"] | |
CIVITAI_PERIOD = ["AllTime", "Year", "Month", "Week", "Day"] | |
def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100, | |
sort: str = "Highest Rated", period: str = "AllTime", tag: str = ""): | |
user_agent = get_user_agent() | |
headers = {'User-Agent': user_agent, 'content-type': 'application/json'} | |
base_url = 'https://civitai.com/api/v1/models' | |
params = {'sort': sort, 'period': period, 'limit': limit, 'nsfw': 'true'} | |
if len(types) != 0: params["types"] = types | |
if query: params["query"] = query | |
if tag: params["tag"] = tag | |
session = requests.Session() | |
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) | |
session.mount("https://", HTTPAdapter(max_retries=retries)) | |
try: | |
r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(3.0, 30)) | |
except Exception as e: | |
print(e) | |
return None | |
else: | |
if not r.ok: return None | |
json = r.json() | |
if 'items' not in json: return None | |
items = [] | |
for j in json['items']: | |
for model in j['modelVersions']: | |
item = {} | |
if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue | |
item['name'] = j['name'] | |
item['creator'] = j['creator']['username'] | |
item['tags'] = j['tags'] | |
item['model_name'] = model['name'] | |
item['base_model'] = model['baseModel'] | |
item['dl_url'] = model['downloadUrl'] | |
item['md'] = f'<img src="{model["images"][0]["url"]}" alt="thumbnail" width="150" height="240"><br>[LoRA Model URL](https://civitai.com/models/{j["id"]})' | |
items.append(item) | |
return items | |
civitai_last_results = {} | |
def search_civitai(query, types, base_model=[], sort=CIVITAI_SORT[0], period=CIVITAI_PERIOD[0], tag=""): | |
global civitai_last_results | |
items = search_on_civitai(query, types, base_model, 100, sort, period, tag) | |
if not items: return gr.update(choices=[("", "")], value="", visible=False),\ | |
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True) | |
civitai_last_results = {} | |
choices = [] | |
for item in items: | |
base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model'] | |
name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})" | |
value = item['dl_url'] | |
choices.append((name, value)) | |
civitai_last_results[value] = item | |
if not choices: return gr.update(choices=[("", "")], value="", visible=False),\ | |
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True) | |
result = civitai_last_results.get(choices[0][1], "None") | |
md = result['md'] if result else "" | |
return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\ | |
gr.update(visible=True), gr.update(visible=True) | |
def select_civitai_item(search_result): | |
if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True) | |
result = civitai_last_results.get(search_result, "None") | |
md = result['md'] if result else "" | |
return gr.update(value=search_result), gr.update(value=md, visible=True) | |