import gradio as gr
from huggingface_hub import HfApi, hf_hub_url
import os
from pathlib import Path
import gc
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
from utils import get_token, set_token, is_repo_exists, get_user_agent, get_download_file
def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
output_filename = Path(filename).name
hf_token = get_token()
api = HfApi(token=hf_token)
try:
if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
progress(0, desc="Start uploading...")
api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
progress(1, desc="Uploaded.")
url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
except Exception as e:
print(f"Error: Failed to upload to {repo_id}. {e}")
gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
return None
return url
def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
download_dir = "."
progress(0, desc="Start downloading...")
output_filename = get_download_file(download_dir, dl_url, civitai_key)
return output_filename
def download_civitai(dl_url, civitai_key, hf_token, urls,
newrepo_id, repo_type="model", is_private=True, progress=gr.Progress(track_tqdm=True)):
if hf_token: set_token(hf_token)
else: set_token(os.environ.get("HF_TOKEN"))
if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY")
if not hf_token or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.")
file = download_file(dl_url, civitai_key)
if not urls: urls = []
url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
progress(1, desc="Processing...")
if url: urls.append(url)
Path(file).unlink()
md = ""
for u in urls:
md += f"[Uploaded {str(u)}]({str(u)})
"
gc.collect()
return gr.update(value=urls, choices=urls), gr.update(value=md)
CIVITAI_TYPE = ["Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient", "LORA", "Controlnet", "Poses"]
CIVITAI_BASEMODEL = ["Pony", "SD 1.5", "SDXL 1.0", "Flux.1 D", "Flux.1 S"]
CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Newest"]
CIVITAI_PERIOD = ["AllTime", "Year", "Month", "Week", "Day"]
def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100,
sort: str = "Highest Rated", period: str = "AllTime", tag: str = ""):
user_agent = get_user_agent()
headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
base_url = 'https://civitai.com/api/v1/models'
params = {'sort': sort, 'period': period, 'limit': limit, 'nsfw': 'true'}
if len(types) != 0: params["types"] = types
if query: params["query"] = query
if tag: params["tag"] = tag
session = requests.Session()
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
session.mount("https://", HTTPAdapter(max_retries=retries))
try:
r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(3.0, 30))
except Exception as e:
print(e)
return None
else:
if not r.ok: return None
json = r.json()
if 'items' not in json: return None
items = []
for j in json['items']:
for model in j['modelVersions']:
item = {}
if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue
item['name'] = j['name']
item['creator'] = j['creator']['username'] if 'creator' in j.keys() and 'username' in j['creator'].keys() else ""
item['tags'] = j['tags'] if 'tags' in j.keys() else []
item['model_name'] = model['name'] if 'name' in model.keys() else ""
item['base_model'] = model['baseModel'] if 'baseModel' in model.keys() else ""
item['dl_url'] = model['downloadUrl']
if 'images' in model.keys() and len(model["images"]) != 0:
item['md'] = f'
[Model URL](https://civitai.com/models/{j["id"]})'
else: item['md'] = f'[Model URL](https://civitai.com/models/{j["id"]})'
items.append(item)
return items
civitai_last_results = {}
def search_civitai(query, types, base_model=[], sort=CIVITAI_SORT[0], period=CIVITAI_PERIOD[0], tag=""):
global civitai_last_results
items = search_on_civitai(query, types, base_model, 100, sort, period, tag)
if not items: return gr.update(choices=[("", "")], value="", visible=False),\
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
civitai_last_results = {}
choices = []
for item in items:
base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})"
value = item['dl_url']
choices.append((name, value))
civitai_last_results[value] = item
if not choices: return gr.update(choices=[("", "")], value="", visible=False),\
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
result = civitai_last_results.get(choices[0][1], "None")
md = result['md'] if result else ""
return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\
gr.update(visible=True), gr.update(visible=True)
def select_civitai_item(search_result):
if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True)
result = civitai_last_results.get(search_result, "None")
md = result['md'] if result else ""
return gr.update(value=search_result), gr.update(value=md, visible=True)