jbilcke-hf HF Staff commited on
Commit
52afd4e
·
verified ·
1 Parent(s): b5a40cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -173
app.py CHANGED
@@ -24,6 +24,9 @@ import cv2
24
  import torch
25
  import numpy as np
26
  from PIL import Image
 
 
 
27
 
28
  from insightface.app import FaceAnalysis
29
  from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
@@ -34,6 +37,26 @@ from compel import Compel, ReturnedEmbeddingsType
34
 
35
  #from gradio_imageslider import ImageSlider
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  with open("sdxl_loras.json", "r") as file:
38
  data = json.load(file)
39
  sdxl_loras_raw = [
@@ -56,7 +79,27 @@ with open("sdxl_loras.json", "r") as file:
56
 
57
  with open("defaults_data.json", "r") as file:
58
  lora_defaults = json.load(file)
59
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  device = "cuda"
62
 
@@ -148,36 +191,6 @@ pipe.to(device)
148
  last_lora = ""
149
  last_fused = False
150
 
151
- def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
152
- lora_repo = sdxl_loras[selected_state.index]["repo"]
153
- new_placeholder = "Type a prompt to use your selected LoRA"
154
- weight_name = sdxl_loras[selected_state.index]["weights"]
155
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[selected_state.index]['is_nc'] else '' }"
156
-
157
- for lora_list in lora_defaults:
158
- if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
159
- face_strength = lora_list.get("face_strength", 0.85)
160
- image_strength = lora_list.get("image_strength", 0.15)
161
- weight = lora_list.get("weight", 0.9)
162
- depth_control_scale = lora_list.get("depth_control_scale", 0.8)
163
- negative = lora_list.get("negative", "")
164
-
165
- if(is_new):
166
- if(selected_state.index == 0):
167
- selected_state.index = -9999
168
- else:
169
- selected_state.index *= -1
170
-
171
- return (
172
- updated_text,
173
- gr.update(placeholder=new_placeholder),
174
- face_strength,
175
- image_strength,
176
- weight,
177
- depth_control_scale,
178
- negative,
179
- selected_state
180
- )
181
 
182
  def center_crop_image_as_square(img):
183
  square_size = min(img.size)
@@ -216,13 +229,13 @@ def merge_incompatible_lora(full_path_lora, lora_scale):
216
  del weights_sd
217
  del lora_model
218
  #@spaces.GPU
219
- def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index, st):
220
  et = time.time()
221
  elapsed_time = et - st
222
  print('Getting into the decorated function took: ', elapsed_time, 'seconds')
223
  global last_fused, last_lora
224
  print("Last LoRA: ", last_lora)
225
- print("Current LoRA: ", repo_name)
226
  print("Last fused: ", last_fused)
227
  #prepare face zoe
228
  st = time.time()
@@ -233,7 +246,7 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
233
  et = time.time()
234
  elapsed_time = et - st
235
  print('Zoe Depth calculations took: ', elapsed_time, 'seconds')
236
- if last_lora != repo_name:
237
  if(last_fused):
238
  st = time.time()
239
  pipe.unfuse_lora()
@@ -242,17 +255,16 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
242
  elapsed_time = et - st
243
  print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
244
  st = time.time()
245
- pipe.load_lora_weights(loaded_state_dict)
246
  pipe.fuse_lora(lora_scale)
247
  et = time.time()
248
  elapsed_time = et - st
249
  print('Fuse and load LoRA took: ', elapsed_time, 'seconds')
250
  last_fused = True
251
- is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
252
- if(is_pivotal):
253
  #Add the textual inversion embeddings from pivotal tuning models
254
- text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
255
- embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
256
  state_dict_embedding = load_file(embedding_path)
257
  try:
258
  pipe.unload_textual_inversion()
@@ -293,12 +305,19 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
293
  et = time.time()
294
  elapsed_time = et - st
295
  print('Image processing took: ', elapsed_time, 'seconds')
296
- last_lora = repo_name
 
297
  return image
298
 
299
- def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, progress=gr.Progress(track_tqdm=True)):
300
- selected_state_index = selected_state.index
 
 
 
301
  st = time.time()
 
 
 
302
  face_image = center_crop_image_as_square(face_image)
303
  try:
304
  face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
@@ -312,160 +331,92 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
312
  print('Cropping and calculating face embeds took: ', elapsed_time, 'seconds')
313
 
314
  st = time.time()
315
- for lora_list in lora_defaults:
316
- if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]:
317
- prompt_full = lora_list.get("prompt", None)
318
- if(prompt_full):
319
- prompt = prompt_full.replace("<subject>", prompt)
320
 
321
 
322
  print("Prompt:", prompt)
323
  if(prompt == ""):
324
  prompt = "a person"
325
-
326
- print("Selected State: ", selected_state_index)
327
- print(sdxl_loras[selected_state_index]["repo"])
328
  if negative == "":
329
  negative = None
330
 
331
  if not selected_state:
332
  raise gr.Error("You must select a LoRA")
333
- repo_name = sdxl_loras[selected_state_index]["repo"]
334
- weight_name = sdxl_loras[selected_state_index]["weights"]
335
 
336
- full_path_lora = state_dicts[repo_name]["saved_name"]
337
- #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
338
  cross_attention_kwargs = None
339
  et = time.time()
340
  elapsed_time = et - st
341
  print('Small content processing took: ', elapsed_time, 'seconds')
342
 
343
  st = time.time()
344
- image = generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, full_path_lora, lora_scale, sdxl_loras, selected_state_index, st)
345
- return image
346
 
347
- def shuffle_gallery(sdxl_loras):
348
- random.shuffle(sdxl_loras)
349
- return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
350
-
351
- def classify_gallery(sdxl_loras):
352
- sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True)
353
- return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
354
-
355
- def swap_gallery(order, sdxl_loras):
356
- if(order == "random"):
357
- return shuffle_gallery(sdxl_loras)
358
- else:
359
- return classify_gallery(sdxl_loras)
360
 
361
- def deselect():
362
- return gr.Gallery(selected_index=None)
 
 
363
 
364
  with gr.Blocks() as demo:
365
- gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
366
- title = gr.HTML(
367
- """<h1><img src="https://i.imgur.com/DVoGw04.png">
368
- <span>Face to All<br><small style="
369
- font-size: 13px;
370
- display: block;
371
- font-weight: normal;
372
- opacity: 0.75;
373
- ">🧨 diffusers InstantID + ControlNet<br> inspired by fofr's <a href="https://github.com/fofr/cog-face-to-many" target="_blank">face-to-many</a></small></span></h1>""",
374
- elem_id="title",
375
- )
376
- selected_state = gr.State()
377
- with gr.Row(elem_id="main_app"):
378
- with gr.Column(scale=4):
379
- with gr.Group(elem_id="gallery_box"):
380
- photo = gr.Image(label="Upload a picture of yourself", interactive=True, type="pil", height=300)
381
- selected_loras = gr.Gallery(label="Selected LoRAs", height=80, show_share_button=False, visible=False, elem_id="gallery_selected", )
382
- #order_gallery = gr.Radio(choices=["random", "likes"], value="random", label="Order by", elem_id="order_radio")
383
- #new_gallery = gr.Gallery(
384
- # label="New LoRAs",
385
- # elem_id="gallery_new",
386
- # columns=3,
387
- # value=[(item["image"], item["title"]) for item in sdxl_loras_raw_new], allow_preview=False, show_share_button=False)
388
- gallery = gr.Gallery(
389
- #value=[(item["image"], item["title"]) for item in sdxl_loras],
390
- label="Style gallery",
391
- allow_preview=False,
392
- columns=4,
393
- elem_id="gallery",
394
- show_share_button=False,
395
- height=550
396
- )
397
- custom_model = gr.Textbox(label="Enter a custom Hugging Face or CivitAI SDXL LoRA", interactive=False, placeholder="Coming soon...")
398
- with gr.Column(scale=5):
399
- with gr.Row():
400
- prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
401
- button = gr.Button("Run", elem_id="run_button")
402
- result = gr.Image(
403
- interactive=False, label="Generated Image", elem_id="result-image"
404
- )
405
-
406
- with gr.Accordion("Advanced options", open=False):
407
- negative = gr.Textbox(label="Negative Prompt")
408
-
409
- # initial value was 0.9
410
- weight = gr.Slider(0, 10, value=6, step=0.1, label="LoRA weight")
411
-
412
- # initial value was 0.85
413
- face_strength = gr.Slider(0, 1, value=0.75, step=0.01, label="Face strength", info="Higher values increase the face likeness but reduce the creative liberty of the models")
414
-
415
- # initial value was 0.15
416
- image_strength = gr.Slider(0, 1, value=0.15, step=0.01, label="Image strength", info="Higher values increase the similarity with the structure/colors of the original photo")
417
-
418
- # initial value was 7
419
- guidance_scale = gr.Slider(0, 50, value=7, step=0.1, label="Guidance Scale")
420
-
421
- # initial value was 1
422
- depth_control_scale = gr.Slider(0, 4, value=0.8, step=0.01, label="Zoe Depth ControlNet strenght")
423
- prompt_title = gr.Markdown(
424
- value="### Click on a LoRA in the gallery to select it",
425
- visible=True,
426
- elem_id="selected_lora",
427
- )
428
- #order_gallery.change(
429
- # fn=swap_gallery,
430
- # inputs=[order_gallery, gr_sdxl_loras],
431
- # outputs=[gallery, gr_sdxl_loras],
432
- # queue=False
433
- #)
434
- gallery.select(
435
- fn=update_selection,
436
- inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
437
- outputs=[prompt_title, prompt, face_strength, image_strength, weight, depth_control_scale, negative, selected_state],
438
- queue=False,
439
- show_progress=False
440
- )
441
- #new_gallery.select(
442
- # fn=update_selection,
443
- # inputs=[gr_sdxl_loras_new, gr.State(True)],
444
- # outputs=[prompt_title, prompt, prompt, selected_state, gallery],
445
- # queue=False,
446
- # show_progress=False
447
- #)
448
- prompt.submit(
449
- fn=check_selected,
450
- inputs=[selected_state],
451
- queue=False,
452
- show_progress=False
453
- ).success(
454
- fn=run_lora,
455
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras],
456
- outputs=[result],
457
- )
458
- button.click(
459
- fn=check_selected,
460
- inputs=[selected_state],
461
- queue=False,
462
- show_progress=False
463
- ).success(
464
  fn=run_lora,
465
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras],
466
- outputs=[result],
 
 
 
 
 
 
 
 
 
 
 
467
  )
468
 
469
- demo.load(fn=classify_gallery, inputs=[gr_sdxl_loras], outputs=[gallery, gr_sdxl_loras], queue=False)
470
  demo.queue(max_size=20)
471
  demo.launch()
 
24
  import torch
25
  import numpy as np
26
  from PIL import Image
27
+ from io import BytesIO
28
+ import base64
29
+ import re
30
 
31
  from insightface.app import FaceAnalysis
32
  from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
 
37
 
38
  #from gradio_imageslider import ImageSlider
39
 
40
+
41
+ # Regex pattern to match data URI scheme
42
+ data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,')
43
+
44
+ def readb64(b64):
45
+ # Remove any data URI scheme prefix with regex
46
+ b64 = data_uri_pattern.sub("", b64)
47
+ # Decode and open the image with PIL
48
+ img = Image.open(BytesIO(base64.b64decode(b64)))
49
+ return img
50
+
51
+ # convert from PIL to base64
52
+ def writeb64(image):
53
+ buffered = BytesIO()
54
+ image.save(buffered, format="PNG")
55
+ b64image = base64.b64encode(buffered.getvalue())
56
+ b64image_str = b64image.decode("utf-8")
57
+ return b64image_str
58
+
59
+
60
  with open("sdxl_loras.json", "r") as file:
61
  data = json.load(file)
62
  sdxl_loras_raw = [
 
79
 
80
  with open("defaults_data.json", "r") as file:
81
  lora_defaults = json.load(file)
82
+
83
+
84
+ def getLoraByRepoName(repo_name):
85
+ # Loop through each lora in sdxl_loras_raw
86
+ for lora in sdxl_loras_raw:
87
+ if lora["repo"] == repo_name:
88
+ # Return the lora if the repo name matches
89
+ return lora
90
+ # If no match is found, return the first lora in the array
91
+ return sdxl_loras_raw[0] if sdxl_loras_raw else None
92
+
93
+ # Return the default values specific to this particular
94
+ def getLoraDefaultsByRepoName(repo_name):
95
+ # Loop through each lora in sdxl_loras_raw
96
+ for lora in lora_defaults:
97
+ if lora["repo"] == repo_name:
98
+ # Return the lora if the repo name matches
99
+ return lora
100
+ # If no match is found, return the first lora in the array
101
+ return lora_defaults[0] if lora_defaults else None
102
+
103
 
104
  device = "cuda"
105
 
 
191
  last_lora = ""
192
  last_fused = False
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  def center_crop_image_as_square(img):
196
  square_size = min(img.size)
 
229
  del weights_sd
230
  del lora_model
231
  #@spaces.GPU
232
+ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, lora, full_path_lora, lora_scale, st):
233
  et = time.time()
234
  elapsed_time = et - st
235
  print('Getting into the decorated function took: ', elapsed_time, 'seconds')
236
  global last_fused, last_lora
237
  print("Last LoRA: ", last_lora)
238
+ print("Current LoRA: ", lora["repo"])
239
  print("Last fused: ", last_fused)
240
  #prepare face zoe
241
  st = time.time()
 
246
  et = time.time()
247
  elapsed_time = et - st
248
  print('Zoe Depth calculations took: ', elapsed_time, 'seconds')
249
+ if last_lora != lora["repo"]:
250
  if(last_fused):
251
  st = time.time()
252
  pipe.unfuse_lora()
 
255
  elapsed_time = et - st
256
  print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
257
  st = time.time()
258
+ pipe.load_lora_weights(full_path_lora)
259
  pipe.fuse_lora(lora_scale)
260
  et = time.time()
261
  elapsed_time = et - st
262
  print('Fuse and load LoRA took: ', elapsed_time, 'seconds')
263
  last_fused = True
264
+ if(lora["is_pivotal"]):
 
265
  #Add the textual inversion embeddings from pivotal tuning models
266
+ text_embedding_name = lora["text_embedding_weights"]
267
+ embedding_path = hf_hub_download(repo_id=lora["repo"], filename=text_embedding_name, repo_type="model")
268
  state_dict_embedding = load_file(embedding_path)
269
  try:
270
  pipe.unload_textual_inversion()
 
305
  et = time.time()
306
  elapsed_time = et - st
307
  print('Image processing took: ', elapsed_time, 'seconds')
308
+ last_lora = lora["repo"]
309
+
310
  return image
311
 
312
+ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, lora_repo_name):
313
+ # get the lora and its defaulrt values
314
+ lora = getLoraByRepoName(lora_repo_name)
315
+ default_values = getLoraDefaultsByRepoName(lora_repo_name)
316
+
317
  st = time.time()
318
+
319
+ face_image = readb64(face_image)
320
+
321
  face_image = center_crop_image_as_square(face_image)
322
  try:
323
  face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
 
331
  print('Cropping and calculating face embeds took: ', elapsed_time, 'seconds')
332
 
333
  st = time.time()
334
+
335
+ if default_values:
336
+ prompt_full = default_values.get("prompt", None)
337
+ if(prompt_full):
338
+ prompt = prompt_full.replace("<subject>", prompt)
339
 
340
 
341
  print("Prompt:", prompt)
342
  if(prompt == ""):
343
  prompt = "a person"
344
+
 
 
345
  if negative == "":
346
  negative = None
347
 
348
  if not selected_state:
349
  raise gr.Error("You must select a LoRA")
350
+
351
+ weight_name = lora["weights"]
352
 
353
+ full_path_lora = state_dicts[lora["repo"]]["saved_name"]
354
+ #loaded_state_dict = copy.deepcopy(state_dicts[lora_repo_name]["state_dict"])
355
  cross_attention_kwargs = None
356
  et = time.time()
357
  elapsed_time = et - st
358
  print('Small content processing took: ', elapsed_time, 'seconds')
359
 
360
  st = time.time()
 
 
361
 
362
+ image = generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, lora, full_path_lora, lora_scale, st)
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
+ image_base64 = writeb64(image)
365
+
366
+ return image_base64
367
+
368
 
369
  with gr.Blocks() as demo:
370
+ gr.HTML("""
371
+ <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
372
+ <div style="text-align: center; color: black;">
373
+ <p style="color: black;">This space is a REST API to programmatically generate an image from a face.</p>
374
+ <p style="color: black;">Interested in using it through an UI? Please use the <a href="https://huggingface.co/spaces/multimodalart/face-to-all" target="_blank">original space</a>, thank you!</p>
375
+ </div>
376
+ </div>""")
377
+
378
+ input_image_base64 = gr.Text()
379
+
380
+ lora_repo_name = gr.Text(label="name of the LoRA repo nape on HF")
381
+
382
+ prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
383
+
384
+ negative = gr.Textbox(label="Negative Prompt")
385
+
386
+ # initial value was 0.9
387
+ weight = gr.Slider(0, 10, value=6, step=0.1, label="LoRA weight")
388
+
389
+ # initial value was 0.85
390
+ face_strength = gr.Slider(0, 1, value=0.75, step=0.01, label="Face strength", info="Higher values increase the face likeness but reduce the creative liberty of the models")
391
+
392
+ # initial value was 0.15
393
+ image_strength = gr.Slider(0, 1, value=0.15, step=0.01, label="Image strength", info="Higher values increase the similarity with the structure/colors of the original photo")
394
+
395
+ # initial value was 7
396
+ guidance_scale = gr.Slider(0, 50, value=7, step=0.1, label="Guidance Scale")
397
+
398
+ # initial value was 1
399
+ depth_control_scale = gr.Slider(0, 4, value=0.8, step=0.01, label="Zoe Depth ControlNet strenght")
400
+
401
+ button = gr.Button(value="Generate")
402
+ output_image_base64 = gr.Text()
403
+ text2image_predict.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  fn=run_lora,
405
+ inputs=[
406
+ input_image_base64,
407
+ prompt,
408
+ negative,
409
+ weight,
410
+ face_strength,
411
+ image_strength,
412
+ guidance_scale,
413
+ depth_control_scale,
414
+ lora_repo_name
415
+ ],
416
+ outputs=output_image_base64,
417
+ api_name='run',
418
  )
419
 
420
+
421
  demo.queue(max_size=20)
422
  demo.launch()