Spaces:
Running
Running
Upload 2 files (#2)
Browse files- Upload 2 files (96498f17d1846904bb61dabd4d8a467a77f6ccf2)
Co-authored-by: John Smith <[email protected]>
- StableGR.py +356 -357
- app.py +116 -120
StableGR.py
CHANGED
@@ -1,357 +1,356 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from huggingface_hub import HfApi, hf_hub_url
|
3 |
-
import os
|
4 |
-
from pathlib import Path
|
5 |
-
import gc
|
6 |
-
import requests
|
7 |
-
from requests.adapters import HTTPAdapter
|
8 |
-
from urllib3.util import Retry
|
9 |
-
from utils import (get_user_agent, get_download_file,
|
10 |
-
list_uniq, list_sub, get_state, set_state)
|
11 |
-
import re
|
12 |
-
from PIL import Image
|
13 |
-
import json
|
14 |
-
import pandas as pd
|
15 |
-
import tempfile
|
16 |
-
import hashlib
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
return
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
return
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
if not
|
97 |
-
if not
|
98 |
-
if not
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
file
|
106 |
-
if
|
107 |
-
|
108 |
-
url
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
url
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
gr.
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
if
|
144 |
-
if
|
145 |
-
if
|
146 |
-
if
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
item['
|
191 |
-
item['
|
192 |
-
item['
|
193 |
-
item['
|
194 |
-
item['
|
195 |
-
item['
|
196 |
-
|
197 |
-
|
198 |
-
item['
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
i =
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
item
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
set_state(state, "
|
219 |
-
set_state(state, "
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
value
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
set_state(state, "
|
238 |
-
set_state(state, "
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
url = base_url +
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
url =
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
url =
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
df =
|
304 |
-
df.
|
305 |
-
tags =
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
json =
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
if
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
new_button_name =
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
return gr.update(value=selected)
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import HfApi, hf_hub_url
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
import gc
|
6 |
+
import requests
|
7 |
+
from requests.adapters import HTTPAdapter
|
8 |
+
from urllib3.util import Retry
|
9 |
+
from utils import (get_user_agent, get_download_file,
|
10 |
+
list_uniq, list_sub, get_state, set_state)
|
11 |
+
import re
|
12 |
+
from PIL import Image
|
13 |
+
import json
|
14 |
+
import pandas as pd
|
15 |
+
import tempfile
|
16 |
+
import hashlib
|
17 |
+
from stablepy import Model_Diffusers
|
18 |
+
|
19 |
+
|
20 |
+
TEMP_DIR = os.getcwd()
|
21 |
+
|
22 |
+
|
23 |
+
def parse_urls(s):
|
24 |
+
url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+"
|
25 |
+
try:
|
26 |
+
urls = re.findall(url_pattern, s)
|
27 |
+
return list(urls)
|
28 |
+
except Exception:
|
29 |
+
return []
|
30 |
+
|
31 |
+
|
32 |
+
def parse_repos(s):
|
33 |
+
repo_pattern = r'[^\w_\-\.]?([\w_\-\.]+/[\w_\-\.]+)[^\w_\-\.]?'
|
34 |
+
try:
|
35 |
+
s = re.sub("https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+", "", s)
|
36 |
+
repos = re.findall(repo_pattern, s)
|
37 |
+
return list(repos)
|
38 |
+
except Exception:
|
39 |
+
return []
|
40 |
+
|
41 |
+
|
42 |
+
def to_urls(l: list[str]):
|
43 |
+
return "\n".join(l)
|
44 |
+
|
45 |
+
|
46 |
+
def uniq_urls(s):
|
47 |
+
return to_urls(list_uniq(parse_urls(s) + parse_repos(s)))
|
48 |
+
|
49 |
+
|
50 |
+
def generate_image(model_id, prompt, lora_A, num_steps, guidance_scale, sampler, img_width, img_height):
|
51 |
+
model = Model_Diffusers(
|
52 |
+
base_model_id=model_id,
|
53 |
+
task_name='txt2img',
|
54 |
+
)
|
55 |
+
|
56 |
+
image, info_image = model(
|
57 |
+
prompt=prompt,
|
58 |
+
lora_A=lora_A,
|
59 |
+
num_steps=num_steps,
|
60 |
+
guidance_scale=guidance_scale,
|
61 |
+
sampler=sampler,
|
62 |
+
img_width=img_width,
|
63 |
+
img_height=img_height,
|
64 |
+
)
|
65 |
+
return image[0]
|
66 |
+
|
67 |
+
|
68 |
+
def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
|
69 |
+
download_dir = TEMP_DIR
|
70 |
+
progress(0, desc=f"Start downloading... {dl_url}")
|
71 |
+
output_filename = get_download_file(download_dir, dl_url, civitai_key)
|
72 |
+
return output_filename
|
73 |
+
|
74 |
+
|
75 |
+
def save_civitai_info(dl_url, filename, civitai_key="", progress=gr.Progress(track_tqdm=True)):
|
76 |
+
json_str, html_str, image_path = get_civitai_json(dl_url, True, filename, civitai_key)
|
77 |
+
if not json_str: return "", "", ""
|
78 |
+
json_path = str(Path(TEMP_DIR, Path(filename).stem + ".json"))
|
79 |
+
html_path = str(Path(TEMP_DIR, Path(filename).stem + ".html"))
|
80 |
+
try:
|
81 |
+
with open(json_path, 'w') as f:
|
82 |
+
json.dump(json_str, f, indent=2)
|
83 |
+
with open(html_path, mode='w', encoding="utf-8") as f:
|
84 |
+
f.write(html_str)
|
85 |
+
return json_path, html_path, image_path
|
86 |
+
except Exception as e:
|
87 |
+
print(f"Error: Failed to save info file {json_path}, {html_path} {e}")
|
88 |
+
return "", "", ""
|
89 |
+
|
90 |
+
|
91 |
+
def download_civitai(dl_url, civitai_key, hf_token, urls,
|
92 |
+
newrepo_id, repo_type="model", is_private=True, is_info=False, is_rename=True, progress=gr.Progress(track_tqdm=True)):
|
93 |
+
if hf_token: set_token(hf_token)
|
94 |
+
else: set_token(os.environ.get("HF_TOKEN")) # default huggingface write token
|
95 |
+
if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
|
96 |
+
if not newrepo_id: newrepo_id = os.environ.get("HF_REPO") # default repo to upload
|
97 |
+
if not get_token() or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.")
|
98 |
+
if not urls: urls = []
|
99 |
+
dl_urls = parse_urls(dl_url)
|
100 |
+
remain_urls = dl_urls.copy()
|
101 |
+
try:
|
102 |
+
md = f'### Your repo: [{newrepo_id}]({"https://huggingface.co/datasets/" if repo_type == "dataset" else "https://huggingface.co/"}{newrepo_id})\n'
|
103 |
+
for u in dl_urls:
|
104 |
+
file = download_file(u, civitai_key)
|
105 |
+
if not Path(file).exists() or not Path(file).is_file(): continue
|
106 |
+
if is_rename: file = get_safe_filename(file, newrepo_id, repo_type)
|
107 |
+
url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
|
108 |
+
if url:
|
109 |
+
if is_info: upload_info_to_repo(u, file, newrepo_id, repo_type, is_private, civitai_key)
|
110 |
+
urls.append(url)
|
111 |
+
remain_urls.remove(u)
|
112 |
+
md += f"- Uploaded [{str(u)}]({str(u)})\n"
|
113 |
+
dp_repos = parse_repos(dl_url)
|
114 |
+
for r in dp_repos:
|
115 |
+
url = duplicate_hf_repo(r, newrepo_id, "model", repo_type, is_private, HF_SUBFOLDER_NAME[1])
|
116 |
+
if url: urls.append(url)
|
117 |
+
return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=False)
|
118 |
+
except Exception as e:
|
119 |
+
gr.Info(f"Error occured: {e}")
|
120 |
+
return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=True)
|
121 |
+
finally:
|
122 |
+
gc.collect()
|
123 |
+
|
124 |
+
|
125 |
+
CIVITAI_TYPE = ["Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient", "LORA", "LoCon", "DoRA",
|
126 |
+
"Controlnet", "Upscaler", "MotionModule", "VAE", "Poses", "Wildcards", "Workflows", "Other"]
|
127 |
+
CIVITAI_FILETYPE = ["Model", "VAE", "Config", "Training Data"]
|
128 |
+
CIVITAI_BASEMODEL = ["Pony", "Illustrious", "SDXL 1.0", "SD 1.5", "Flux.1 D", "Flux.1 S", "SD 3.5"]
|
129 |
+
#CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Newest"]
|
130 |
+
CIVITAI_SORT = ["Highest Rated", "Most Downloaded", "Most Liked", "Most Discussed", "Most Collected", "Most Buzz", "Newest"]
|
131 |
+
CIVITAI_PERIOD = ["AllTime", "Year", "Month", "Week", "Day"]
|
132 |
+
|
133 |
+
|
134 |
+
def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100,
|
135 |
+
sort: str = "Highest Rated", period: str = "AllTime", tag: str = "", user: str = "", page: int = 1,
|
136 |
+
filetype: list[str] = [], api_key: str = "", progress=gr.Progress(track_tqdm=True)):
|
137 |
+
user_agent = get_user_agent()
|
138 |
+
headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
|
139 |
+
if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
|
140 |
+
base_url = 'https://civitai.com/api/v1/models'
|
141 |
+
params = {'sort': sort, 'period': period, 'limit': int(limit), 'nsfw': 'true'}
|
142 |
+
if len(types) != 0: params["types"] = types
|
143 |
+
if query: params["query"] = query
|
144 |
+
if tag: params["tag"] = tag
|
145 |
+
if user: params["username"] = user
|
146 |
+
if page != 0: params["page"] = int(page)
|
147 |
+
session = requests.Session()
|
148 |
+
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
|
149 |
+
session.mount("https://", HTTPAdapter(max_retries=retries))
|
150 |
+
rs = []
|
151 |
+
try:
|
152 |
+
if page == 0:
|
153 |
+
progress(0, desc="Searching page 1...")
|
154 |
+
print("Searching page 1...")
|
155 |
+
r = session.get(base_url, params=params | {'page': 1}, headers=headers, stream=True, timeout=(7.0, 30))
|
156 |
+
rs.append(r)
|
157 |
+
if r.ok:
|
158 |
+
json = r.json()
|
159 |
+
next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
|
160 |
+
i = 2
|
161 |
+
while(next_url is not None):
|
162 |
+
progress(0, desc=f"Searching page {i}...")
|
163 |
+
print(f"Searching page {i}...")
|
164 |
+
r = session.get(next_url, headers=headers, stream=True, timeout=(7.0, 30))
|
165 |
+
rs.append(r)
|
166 |
+
if r.ok:
|
167 |
+
json = r.json()
|
168 |
+
next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
|
169 |
+
else: next_url = None
|
170 |
+
i += 1
|
171 |
+
else:
|
172 |
+
progress(0, desc="Searching page 1...")
|
173 |
+
print("Searching page 1...")
|
174 |
+
r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(7.0, 30))
|
175 |
+
rs.append(r)
|
176 |
+
except requests.exceptions.ConnectTimeout:
|
177 |
+
print("Request timed out.")
|
178 |
+
except Exception as e:
|
179 |
+
print(e)
|
180 |
+
items = []
|
181 |
+
for r in rs:
|
182 |
+
if not r.ok: continue
|
183 |
+
json = r.json()
|
184 |
+
if 'items' not in json: continue
|
185 |
+
for j in json['items']:
|
186 |
+
for model in j['modelVersions']:
|
187 |
+
item = {}
|
188 |
+
if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue
|
189 |
+
item['name'] = j['name']
|
190 |
+
item['creator'] = j['creator']['username'] if 'creator' in j.keys() and 'username' in j['creator'].keys() else ""
|
191 |
+
item['tags'] = j['tags'] if 'tags' in j.keys() else []
|
192 |
+
item['model_name'] = model['name'] if 'name' in model.keys() else ""
|
193 |
+
item['base_model'] = model['baseModel'] if 'baseModel' in model.keys() else ""
|
194 |
+
item['description'] = model['description'] if 'description' in model.keys() else ""
|
195 |
+
item['md'] = ""
|
196 |
+
if 'images' in model.keys() and len(model["images"]) != 0:
|
197 |
+
item['img_url'] = model["images"][0]["url"]
|
198 |
+
item['md'] += f'<img src="{model["images"][0]["url"]}#float" alt="thumbnail" width="150" height="240"><br>'
|
199 |
+
else: item['img_url'] = "/home/user/app/null.png"
|
200 |
+
item['md'] += f'''Model URL: [https://civitai.com/models/{j["id"]}](https://civitai.com/models/{j["id"]})<br>Model Name: {item["name"]}<br>
|
201 |
+
Creator: {item["creator"]}<br>Tags: {", ".join(item["tags"])}<br>Base Model: {item["base_model"]}<br>Description: {item["description"]}'''
|
202 |
+
if 'files' in model.keys():
|
203 |
+
for f in model['files']:
|
204 |
+
i = item.copy()
|
205 |
+
i['dl_url'] = f['downloadUrl']
|
206 |
+
if len(filetype) != 0 and f['type'] not in set(filetype): continue
|
207 |
+
items.append(i)
|
208 |
+
else:
|
209 |
+
item['dl_url'] = model['downloadUrl']
|
210 |
+
items.append(item)
|
211 |
+
return items if len(items) > 0 else None
|
212 |
+
|
213 |
+
|
214 |
+
def search_civitai(query, types, base_model=[], sort=CIVITAI_SORT[0], period=CIVITAI_PERIOD[0], tag="", user="", limit=100, page=1,
|
215 |
+
filetype=[], api_key="", gallery=[], state={}, progress=gr.Progress(track_tqdm=True)):
|
216 |
+
civitai_last_results = {}
|
217 |
+
set_state(state, "civitai_last_choices", [("", "")])
|
218 |
+
set_state(state, "civitai_last_gallery", [])
|
219 |
+
set_state(state, "civitai_last_results", civitai_last_results)
|
220 |
+
results_info = "No item found."
|
221 |
+
items = search_on_civitai(query, types, base_model, int(limit), sort, period, tag, user, int(page), filetype, api_key)
|
222 |
+
if not items: return gr.update(choices=[("", "")], value=[], visible=True),\
|
223 |
+
gr.update(value="", visible=False), gr.update(), gr.update(), gr.update(), gr.update(), results_info, state
|
224 |
+
choices = []
|
225 |
+
gallery = []
|
226 |
+
for item in items:
|
227 |
+
base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
|
228 |
+
name = f"{item['name']} (for {base_model_name} / By: {item['creator']})"
|
229 |
+
value = item['dl_url']
|
230 |
+
choices.append((name, value))
|
231 |
+
gallery.append((item['img_url'], name))
|
232 |
+
civitai_last_results[value] = item
|
233 |
+
if len(choices) >= 1: results_info = f"{int(len(choices))} items found."
|
234 |
+
else: choices = [("", "")]
|
235 |
+
md = ""
|
236 |
+
set_state(state, "civitai_last_choices", choices)
|
237 |
+
set_state(state, "civitai_last_gallery", gallery)
|
238 |
+
set_state(state, "civitai_last_results", civitai_last_results)
|
239 |
+
return gr.update(choices=choices, value=[], visible=True), gr.update(value=md, visible=True),\
|
240 |
+
gr.update(), gr.update(), gr.update(value=gallery), gr.update(choices=choices, value=[]), results_info, state
|
241 |
+
|
242 |
+
|
243 |
+
def get_civitai_json(dl_url: str, is_html: bool=False, image_baseurl: str="", api_key=""):
|
244 |
+
if not image_baseurl: image_baseurl = dl_url
|
245 |
+
default = ("", "", "") if is_html else ""
|
246 |
+
if "https://civitai.com/api/download/models/" not in dl_url: return default
|
247 |
+
user_agent = get_user_agent()
|
248 |
+
headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
|
249 |
+
if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
|
250 |
+
base_url = 'https://civitai.com/api/v1/model-versions/'
|
251 |
+
params = {}
|
252 |
+
session = requests.Session()
|
253 |
+
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
|
254 |
+
session.mount("https://", HTTPAdapter(max_retries=retries))
|
255 |
+
model_id = re.sub('https://civitai.com/api/download/models/(\\d+)(?:.+)?', '\\1', dl_url)
|
256 |
+
url = base_url + model_id
|
257 |
+
#url = base_url + str(dl_url.split("/")[-1])
|
258 |
+
try:
|
259 |
+
r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
|
260 |
+
if not r.ok: return default
|
261 |
+
json = dict(r.json()).copy()
|
262 |
+
html = ""
|
263 |
+
image = ""
|
264 |
+
if "modelId" in json.keys():
|
265 |
+
url = f"https://civitai.com/models/{json['modelId']}"
|
266 |
+
r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
|
267 |
+
if not r.ok: return json, html, image
|
268 |
+
html = r.text
|
269 |
+
if 'images' in json.keys() and len(json["images"]) != 0:
|
270 |
+
url = json["images"][0]["url"]
|
271 |
+
r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
|
272 |
+
if not r.ok: return json, html, image
|
273 |
+
image_temp = str(Path(TEMP_DIR, "image" + Path(url.split("/")[-1]).suffix))
|
274 |
+
image = str(Path(TEMP_DIR, Path(image_baseurl.split("/")[-1]).stem + ".png"))
|
275 |
+
with open(image_temp, 'wb') as f:
|
276 |
+
f.write(r.content)
|
277 |
+
Image.open(image_temp).convert('RGBA').save(image)
|
278 |
+
return json, html, image
|
279 |
+
except Exception as e:
|
280 |
+
print(e)
|
281 |
+
return default
|
282 |
+
|
283 |
+
|
284 |
+
def get_civitai_tag():
|
285 |
+
default = [""]
|
286 |
+
user_agent = get_user_agent()
|
287 |
+
headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
|
288 |
+
base_url = 'https://civitai.com/api/v1/tags'
|
289 |
+
params = {'limit': 200}
|
290 |
+
session = requests.Session()
|
291 |
+
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
|
292 |
+
session.mount("https://", HTTPAdapter(max_retries=retries))
|
293 |
+
url = base_url
|
294 |
+
try:
|
295 |
+
r = session.get(url, params=params, headers=headers, stream=True, timeout=(7.0, 15))
|
296 |
+
if not r.ok: return default
|
297 |
+
j = dict(r.json()).copy()
|
298 |
+
if "items" not in j.keys(): return default
|
299 |
+
items = []
|
300 |
+
for item in j["items"]:
|
301 |
+
items.append([str(item.get("name", "")), int(item.get("modelCount", 0))])
|
302 |
+
df = pd.DataFrame(items)
|
303 |
+
df.sort_values(1, ascending=False)
|
304 |
+
tags = df.values.tolist()
|
305 |
+
tags = [""] + [l[0] for l in tags]
|
306 |
+
return tags
|
307 |
+
except Exception as e:
|
308 |
+
print(e)
|
309 |
+
return default
|
310 |
+
|
311 |
+
|
312 |
+
def select_civitai_item(results: list[str], state: dict):
|
313 |
+
json = {}
|
314 |
+
if "http" not in "".join(results) or len(results) == 0: return gr.update(value="", visible=True), gr.update(value=json, visible=False), state
|
315 |
+
result = get_state(state, "civitai_last_results")
|
316 |
+
last_selects = get_state(state, "civitai_last_selects")
|
317 |
+
selects = list_sub(results, last_selects if last_selects else [])
|
318 |
+
md = result.get(selects[-1]).get('md', "") if result and isinstance(result, dict) and len(selects) > 0 else ""
|
319 |
+
set_state(state, "civitai_last_selects", results)
|
320 |
+
return gr.update(value=md, visible=True), gr.update(value=json, visible=False), state
|
321 |
+
|
322 |
+
|
323 |
+
def add_civitai_item(results: list[str], dl_url: str):
|
324 |
+
if "http" not in "".join(results): return gr.update(value=dl_url)
|
325 |
+
new_url = dl_url if dl_url else ""
|
326 |
+
for result in results:
|
327 |
+
if "http" not in result: continue
|
328 |
+
new_url += f"\n{result}" if new_url else f"{result}"
|
329 |
+
new_url = uniq_urls(new_url)
|
330 |
+
return gr.update(value=new_url)
|
331 |
+
|
332 |
+
|
333 |
+
def select_civitai_all_item(button_name: str, state: dict):
|
334 |
+
if button_name not in ["Select All", "Deselect All"]: return gr.update(value=button_name), gr.Update(visible=True)
|
335 |
+
civitai_last_choices = get_state(state, "civitai_last_choices")
|
336 |
+
selected = [t[1] for t in civitai_last_choices if t[1] != ""] if button_name == "Select All" else []
|
337 |
+
new_button_name = "Select All" if button_name == "Deselect All" else "Deselect All"
|
338 |
+
return gr.update(value=new_button_name), gr.update(value=selected, choices=civitai_last_choices)
|
339 |
+
|
340 |
+
|
341 |
+
def update_civitai_selection(evt: gr.SelectData, value: list[str], state: dict):
|
342 |
+
try:
|
343 |
+
civitai_last_choices = get_state(state, "civitai_last_choices")
|
344 |
+
selected_index = evt.index
|
345 |
+
selected = list_uniq([v for v in value if v != ""] + [civitai_last_choices[selected_index][1]])
|
346 |
+
return gr.update(value=selected)
|
347 |
+
except Exception:
|
348 |
+
return gr.update()
|
349 |
+
|
350 |
+
|
351 |
+
def update_civitai_checkbox(selected: list[str]):
|
352 |
+
return gr.update(value=selected)
|
353 |
+
|
354 |
+
|
355 |
+
def from_civitai_checkbox(selected: list[str]):
|
356 |
+
return gr.update(value=selected)
|
|
app.py
CHANGED
@@ -1,120 +1,116 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
.
|
11 |
-
.
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
with gr.
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
with gr.Row():
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
)
|
112 |
-
|
113 |
-
.
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
search_civitai_gallery.select(update_civitai_selection, [search_civitai_result, state], [search_civitai_result], queue=False, show_api=False)
|
118 |
-
|
119 |
-
demo.queue()
|
120 |
-
demo.launch(debug=True, share=True)
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from StableGR import (generate_image, search_civitai, download_civitai, select_civitai_item, add_civitai_item, get_civitai_tag, select_civitai_all_item,
|
3 |
+
update_civitai_selection, update_civitai_checkbox, from_civitai_checkbox,
|
4 |
+
CIVITAI_TYPE, CIVITAI_BASEMODEL, CIVITAI_SORT, CIVITAI_PERIOD, CIVITAI_FILETYPE, download_file)
|
5 |
+
|
6 |
+
|
7 |
+
css = """
|
8 |
+
.title { font-size: 3em; align-items: center; text-align: center; }
|
9 |
+
.info { align-items: center; text-align: center; }
|
10 |
+
.block.result { margin: 1em 0; padding: 1em; box-shadow: 0 0 3px 3px #664422, 0 0 3px 2px #664422 inset; border-radius: 6px; background: #665544; }
|
11 |
+
.desc [src$='#float'] { float: right; margin: 20px; }
|
12 |
+
"""
|
13 |
+
|
14 |
+
|
15 |
+
with gr.Blocks(fill_width=True, css=css) as demo:
|
16 |
+
with gr.Column():
|
17 |
+
gr.Markdown("# StableGR", elem_classes="title")
|
18 |
+
state = gr.State(value={})
|
19 |
+
with gr.TabItem("Search Civitai"):
|
20 |
+
with gr.Row():
|
21 |
+
search_civitai_type = gr.CheckboxGroup(label="Type", choices=CIVITAI_TYPE, value=["Checkpoint", "LORA"])
|
22 |
+
search_civitai_basemodel = gr.CheckboxGroup(label="Base Model", choices=CIVITAI_BASEMODEL, value=[])
|
23 |
+
search_civitai_filetype = gr.CheckboxGroup(label="File type", choices=CIVITAI_FILETYPE, value=["Model"])
|
24 |
+
with gr.Row():
|
25 |
+
search_civitai_sort = gr.Radio(label="Sort", choices=CIVITAI_SORT, value=CIVITAI_SORT[0])
|
26 |
+
search_civitai_period = gr.Radio(label="Period", choices=CIVITAI_PERIOD, value="Month")
|
27 |
+
search_civitai_limit = gr.Number(label="Limit", minimum=1, maximum=100, step=1, value=100)
|
28 |
+
search_civitai_page = gr.Number(label="Page", info="If 0, retrieve all pages", minimum=0, maximum=10, step=1, value=1)
|
29 |
+
with gr.Row(equal_height=True):
|
30 |
+
search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
|
31 |
+
search_civitai_tag = gr.Dropdown(label="Tag", choices=get_civitai_tag(), value=get_civitai_tag()[0], allow_custom_value=True)
|
32 |
+
search_civitai_user = gr.Textbox(label="Username", lines=1)
|
33 |
+
search_civitai_submit = gr.Button("Search on Civitai")
|
34 |
+
with gr.TabItem("Results"):
|
35 |
+
with gr.Row():
|
36 |
+
search_civitai_desc = gr.Markdown(value="", visible=False, elem_classes="desc")
|
37 |
+
search_civitai_json = gr.JSON(value={}, visible=False)
|
38 |
+
with gr.Row(equal_height=True):
|
39 |
+
with gr.Column(scale=9):
|
40 |
+
with gr.TabItem("Select from Gallery"):
|
41 |
+
search_civitai_gallery = gr.Gallery([], label="Results", allow_preview=False, columns=5, elem_id="gallery", show_share_button=False, interactive=False)
|
42 |
+
with gr.TabItem("Select by Checkbox"):
|
43 |
+
search_civitai_result_checkbox = gr.CheckboxGroup(label="", choices=[], value=[])
|
44 |
+
search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value=[],
|
45 |
+
allow_custom_value=True, visible=True, multiselect=True)
|
46 |
+
search_civitai_result_info = gr.Markdown("Search result.", elem_classes="info")
|
47 |
+
with gr.Column(scale=1):
|
48 |
+
search_civitai_add = gr.Button("Add to download URLs")
|
49 |
+
search_civitai_select_all = gr.Button("Select All", variant="secondary", size="sm")
|
50 |
+
with gr.Group():
|
51 |
+
dl_url = gr.Textbox(label="Download URL(s)", placeholder="https://civitai.com/api/download/models/28907\n...", value="", lines=3, max_lines=255)
|
52 |
+
with gr.Column():
|
53 |
+
civitai_key = gr.Textbox(label="Your Civitai Key", value="", max_lines=1)
|
54 |
+
gr.Markdown("Your Civitai API key is available at [https://civitai.com/user/account](https://civitai.com/user/account).", elem_classes="info")
|
55 |
+
|
56 |
+
with gr.Row():
|
57 |
+
run_base = gr.Button(value="Download Base Model", variant="primary")
|
58 |
+
run_lora = gr.Button(value="Download Lora", variant="primary")
|
59 |
+
uploaded_urls = gr.CheckboxGroup(visible=False, choices=[], value=None) # hidden
|
60 |
+
#urls_md = gr.Markdown("<br><br>", elem_classes="result")
|
61 |
+
urls_remain = gr.Textbox("Remaining URLs", value="", show_copy_button=True, visible=False)
|
62 |
+
with gr.Column():
|
63 |
+
base_model = gr.File(label="Base Models")
|
64 |
+
lora_A = gr.File(label="Lora")
|
65 |
+
with gr.Row():
|
66 |
+
prompt = gr.Textbox(label="Prompt", value="A highly detailed portrait of an underwater city, with towering spires and domes rising up from the ocean floor")
|
67 |
+
|
68 |
+
num_steps = gr.Slider(label="Number of Steps", minimum=1, maximum=100, value=30, step=1)
|
69 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=20.0, value=7.5, step=0.1)
|
70 |
+
sampler = gr.Dropdown(label="Sampler", choices=["DPM++ 2M", "Euler", "TCD"], value="DPM++ 2M")
|
71 |
+
img_width = gr.Slider(label="Image Width", minimum=64, maximum=2048, value=512, step=64)
|
72 |
+
img_height = gr.Slider(label="Image Height", minimum=64, maximum=2048, value=1024, step=64)
|
73 |
+
generate_button = gr.Button("Generate Image")
|
74 |
+
output_image = gr.Image(label="output")
|
75 |
+
|
76 |
+
gr.on(
|
77 |
+
triggers=[run_base.click],
|
78 |
+
fn=download_file,
|
79 |
+
inputs=[dl_url, civitai_key],
|
80 |
+
outputs=base_model,
|
81 |
+
queue=True,
|
82 |
+
)
|
83 |
+
gr.on(
|
84 |
+
triggers=[run_lora.click],
|
85 |
+
fn=download_file,
|
86 |
+
inputs=[dl_url, civitai_key],
|
87 |
+
outputs=lora_A,
|
88 |
+
queue=True,
|
89 |
+
)
|
90 |
+
gr.on(
|
91 |
+
triggers=[generate_button.click],
|
92 |
+
fn=generate_image,
|
93 |
+
inputs=[base_model, prompt, lora_A, num_steps, guidance_scale, sampler, img_width, img_height],
|
94 |
+
outputs=output_image,
|
95 |
+
queue=True,
|
96 |
+
)
|
97 |
+
gr.on(
|
98 |
+
triggers=[search_civitai_submit.click, search_civitai_query.submit, search_civitai_user.submit],
|
99 |
+
fn=search_civitai,
|
100 |
+
inputs=[search_civitai_query, search_civitai_type, search_civitai_basemodel, search_civitai_sort,
|
101 |
+
search_civitai_period, search_civitai_tag, search_civitai_user, search_civitai_limit,
|
102 |
+
search_civitai_page, search_civitai_filetype, civitai_key, search_civitai_gallery, state],
|
103 |
+
outputs=[search_civitai_result, search_civitai_desc, search_civitai_submit, search_civitai_query, search_civitai_gallery,
|
104 |
+
search_civitai_result_checkbox, search_civitai_result_info, state],
|
105 |
+
queue=False,
|
106 |
+
show_api=False,
|
107 |
+
)
|
108 |
+
search_civitai_result.change(select_civitai_item, [search_civitai_result, state], [search_civitai_desc, search_civitai_json, state], queue=False, show_api=False)\
|
109 |
+
.success(update_civitai_checkbox, [search_civitai_result], [search_civitai_result_checkbox], queue=True, show_api=False)
|
110 |
+
search_civitai_result_checkbox.select(from_civitai_checkbox, [search_civitai_result_checkbox], [search_civitai_result], queue=False, show_api=False)
|
111 |
+
search_civitai_add.click(add_civitai_item, [search_civitai_result, dl_url], [dl_url], queue=False, show_api=False)
|
112 |
+
search_civitai_select_all.click(select_civitai_all_item, [search_civitai_select_all, state], [search_civitai_select_all, search_civitai_result], queue=False, show_api=False)
|
113 |
+
search_civitai_gallery.select(update_civitai_selection, [search_civitai_result, state], [search_civitai_result], queue=False, show_api=False)
|
114 |
+
|
115 |
+
demo.queue()
|
116 |
+
demo.launch(debug=True, share=True)
|
|
|
|
|
|
|
|