File size: 6,062 Bytes
c25cab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
from huggingface_hub import HfApi, hf_hub_url
import os
from pathlib import Path
import gc
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
from utils import get_token, set_token, is_repo_exists, get_user_agent, get_download_file


def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
    output_filename = Path(filename).name
    hf_token = get_token()
    api = HfApi(token=hf_token)
    try:
        if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
        progress(0, desc="Start uploading...")
        api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
        progress(1, desc="Uploaded.")
        url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
    except Exception as e:
        print(f"Error: Failed to upload to {repo_id}. {e}")
        gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
        return None
    return url


def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
    download_dir = "."
    progress(0, desc="Start downloading...")
    output_filename = get_download_file(download_dir, dl_url, civitai_key)
    return output_filename


def download_civitai(dl_url, civitai_key, hf_token, urls, 

                     newrepo_id, repo_type="model", is_private=True, progress=gr.Progress(track_tqdm=True)):
    if hf_token: set_token(hf_token)
    else: set_token(os.environ.get("HF_TOKEN"))
    if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY")
    if not hf_token or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.")
    file = download_file(dl_url, civitai_key)
    if not urls: urls = []
    url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
    progress(1, desc="Processing...")
    if url: urls.append(url)
    Path(file).unlink()
    md = ""
    for u in urls:
        md += f"[Uploaded {str(u)}]({str(u)})<br>"
    gc.collect()
    return gr.update(value=urls, choices=urls), gr.update(value=md)


CIVITAI_TYPE = ["Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient", "LORA", "Controlnet", "Poses"]
CIVITAI_BASEMODEL = ["Pony", "SD 1.5", "SDXL 1.0", "Flux.1 D", "Flux.1 S"]
CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Newest"]
CIVITAI_PERIOD = ["AllTime", "Year", "Month", "Week", "Day"]


def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100,

                           sort: str = "Highest Rated", period: str = "AllTime", tag: str = ""):

    user_agent = get_user_agent()
    headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
    base_url = 'https://civitai.com/api/v1/models'
    params = {'sort': sort, 'period': period, 'limit': limit, 'nsfw': 'true'}
    if len(types) != 0: params["types"] = types
    if query: params["query"] = query
    if tag: params["tag"] = tag
    session = requests.Session()
    retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
    session.mount("https://", HTTPAdapter(max_retries=retries))
    try:
        r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(3.0, 30))
    except Exception as e:
        print(e)
        return None
    else:
        if not r.ok: return None
        json = r.json()
        if 'items' not in json: return None
        items = []
        for j in json['items']:
            for model in j['modelVersions']:
                item = {}
                if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue
                item['name'] = j['name']
                item['creator'] = j['creator']['username']
                item['tags'] = j['tags']
                item['model_name'] = model['name']
                item['base_model'] = model['baseModel']
                item['dl_url'] = model['downloadUrl']
                item['md'] = f'<img src="{model["images"][0]["url"]}" alt="thumbnail" width="150" height="240"><br>[LoRA Model URL](https://civitai.com/models/{j["id"]})'
                items.append(item)
        return items


civitai_last_results = {}


def search_civitai(query, types, base_model=[], sort=CIVITAI_SORT[0], period=CIVITAI_PERIOD[0], tag=""):
    global civitai_last_results
    items = search_on_civitai(query, types, base_model, 100, sort, period, tag)
    if not items: return gr.update(choices=[("", "")], value="", visible=False),\
          gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
    civitai_last_results = {}
    choices = []
    for item in items:
        base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
        name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})"
        value = item['dl_url']
        choices.append((name, value))
        civitai_last_results[value] = item
    if not choices: return gr.update(choices=[("", "")], value="", visible=False),\
          gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
    result = civitai_last_results.get(choices[0][1], "None")
    md = result['md'] if result else ""
    return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\
          gr.update(visible=True), gr.update(visible=True)


def select_civitai_item(search_result):
    if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True)
    result = civitai_last_results.get(search_result, "None")
    md = result['md'] if result else ""
    return gr.update(value=search_result), gr.update(value=md, visible=True)