multimodalart HF Staff commited on
Commit
d1d8628
·
verified ·
1 Parent(s): 2bf7ead

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -117
app.py CHANGED
@@ -7,7 +7,7 @@ import spaces
7
 
8
  import PIL
9
  from PIL import Image
10
- from typing import Tuple
11
 
12
  import diffusers
13
  from diffusers.utils import load_image
@@ -21,8 +21,6 @@ from insightface.app import FaceAnalysis
21
  from style_template import styles
22
  from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
23
 
24
- # from controlnet_aux import OpenposeDetector
25
-
26
  import gradio as gr
27
 
28
  from depth_anything.dpt import DepthAnything
@@ -58,8 +56,6 @@ app = FaceAnalysis(
58
  )
59
  app.prepare(ctx_id=0, det_size=(640, 640))
60
 
61
- # openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
62
-
63
  depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_vitl14').to(device).eval()
64
 
65
  transform = Compose([
@@ -85,14 +81,10 @@ controlnet_identitynet = ControlNetModel.from_pretrained(
85
  controlnet_path, torch_dtype=dtype
86
  )
87
 
88
- # controlnet-pose/canny/depth
89
- # controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
90
  controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
91
  controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small"
92
 
93
- # controlnet_pose = ControlNetModel.from_pretrained(
94
- # controlnet_pose_model, torch_dtype=dtype
95
- # ).to(device)
96
  controlnet_canny = ControlNetModel.from_pretrained(
97
  controlnet_canny_model, torch_dtype=dtype
98
  ).to(device)
@@ -127,12 +119,10 @@ def get_canny_image(image, t1=100, t2=200):
127
  return Image.fromarray(edges, "L")
128
 
129
  controlnet_map = {
130
- #"pose": controlnet_pose,
131
  "canny": controlnet_canny,
132
  "depth": controlnet_depth,
133
  }
134
  controlnet_map_fn = {
135
- #"pose": openpose,
136
  "canny": get_canny_image,
137
  "depth": get_depth_map,
138
  }
@@ -180,67 +170,6 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
180
  def remove_tips():
181
  return gr.update(visible=False)
182
 
183
- def get_example():
184
- case = [
185
- [
186
- "./examples/yann-lecun_resize.jpg",
187
- None,
188
- "a man",
189
- "Spring Festival",
190
- "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
191
- ],
192
- [
193
- "./examples/musk_resize.jpeg",
194
- "./examples/poses/pose2.jpg",
195
- "a man flying in the sky in Mars",
196
- "Mars",
197
- "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
198
- ],
199
- [
200
- "./examples/sam_resize.png",
201
- "./examples/poses/pose4.jpg",
202
- "a man doing a silly pose wearing a suite",
203
- "Jungle",
204
- "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
205
- ],
206
- [
207
- "./examples/schmidhuber_resize.png",
208
- "./examples/poses/pose3.jpg",
209
- "a man sit on a chair",
210
- "Neon",
211
- "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
212
- ],
213
- [
214
- "./examples/kaifu_resize.png",
215
- "./examples/poses/pose.jpg",
216
- "a man",
217
- "Vibrant Color",
218
- "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
219
- ],
220
- ]
221
- return case
222
-
223
- def run_for_examples(face_file, pose_file, prompt, style, negative_prompt):
224
- return generate_image(
225
- face_file,
226
- pose_file,
227
- prompt,
228
- negative_prompt,
229
- style,
230
- 20, # num_steps
231
- 0.8, # identitynet_strength_ratio
232
- 0.8, # adapter_strength_ratio
233
- #0.4, # pose_strength
234
- 0.3, # canny_strength
235
- 0.5, # depth_strength
236
- ["depth", "canny"], # controlnet_selection
237
- 5.0, # guidance_scale
238
- 42, # seed
239
- "EulerDiscreteScheduler", # scheduler
240
- False, # enable_LCM
241
- True, # enable_Face_Region
242
- )
243
-
244
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
245
  return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
246
 
@@ -284,9 +213,12 @@ def apply_style(
284
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
285
  return p.replace("{prompt}", positive), n + " " + negative
286
 
 
 
 
287
  @spaces.GPU
288
  def generate_image(
289
- face_image_path,
290
  pose_image_path,
291
  prompt,
292
  negative_prompt,
@@ -294,7 +226,6 @@ def generate_image(
294
  num_steps,
295
  identitynet_strength_ratio,
296
  adapter_strength_ratio,
297
- #pose_strength,
298
  canny_strength,
299
  depth_strength,
300
  controlnet_selection,
@@ -321,9 +252,9 @@ def generate_image(
321
  scheduler = getattr(diffusers, scheduler_class_name)
322
  pipe.scheduler = scheduler.from_config(pipe.scheduler.config, **add_kwargs)
323
 
324
- if face_image_path is None:
325
  raise gr.Error(
326
- f"Cannot find any input face image! Please upload the face image"
327
  )
328
 
329
  if prompt is None:
@@ -332,28 +263,67 @@ def generate_image(
332
  # apply the style template
333
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
334
 
335
- face_image = load_image(face_image_path)
336
- face_image = resize_img(face_image, max_side=1024)
337
- face_image_cv2 = convert_from_image_to_cv2(face_image)
338
- height, width, _ = face_image_cv2.shape
339
-
340
- # Extract face features
341
- face_info = app.get(face_image_cv2)
342
 
343
- if len(face_info) == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  raise gr.Error(
345
- f"Unable to detect a face in the image. Please upload a different photo with a clear face."
346
  )
 
 
 
 
 
 
 
 
347
 
348
- face_info = sorted(
349
- face_info,
 
 
 
 
 
 
350
  key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
351
- )[
352
- -1
353
- ] # only use the maximum face
354
- face_emb = face_info["embedding"]
355
- face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
356
- img_controlnet = face_image
357
  if pose_image_path is not None:
358
  pose_image = load_image(pose_image_path)
359
  pose_image = resize_img(pose_image, max_side=1024)
@@ -383,7 +353,6 @@ def generate_image(
383
 
384
  if len(controlnet_selection) > 0:
385
  controlnet_scales = {
386
- #"pose": pose_strength,
387
  "canny": canny_strength,
388
  "depth": depth_strength,
389
  }
@@ -425,9 +394,42 @@ def generate_image(
425
 
426
  return images[0], gr.update(visible=True)
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  # Description
429
  title = r"""
430
- <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
431
  """
432
 
433
  article = r"""
@@ -449,11 +451,12 @@ If you have any questions, please feel free to open an issue or directly reach u
449
  """
450
 
451
  tips = r"""
452
- ### Usage tips of InstantID
453
- 1. If you're not satisfied with the similarity, try increasing the weight of "IdentityNet Strength" and "Adapter Strength."
454
- 2. If you feel that the saturation is too high, first decrease the Adapter strength. If it remains too high, then decrease the IdentityNet strength.
455
- 3. If you find that text control is not as expected, decrease Adapter strength.
456
- 4. If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model.
 
457
  """
458
 
459
  css = """
@@ -466,10 +469,19 @@ with gr.Blocks(css=css) as demo:
466
  with gr.Row():
467
  with gr.Column():
468
  with gr.Row(equal_height=True):
469
- # upload face image
470
- face_file = gr.Image(
471
- label="Upload a photo of your face", type="filepath"
 
472
  )
 
 
 
 
 
 
 
 
473
 
474
  # prompt
475
  prompt = gr.Textbox(
@@ -514,28 +526,21 @@ with gr.Blocks(css=css) as demo:
514
  )
515
  controlnet_selection = gr.CheckboxGroup(
516
  ["canny", "depth"], label="Controlnet", value=[],
517
- info="Use pose for skeleton inference, canny for edge detection, and depth for depth map estimation. You can try all three to control the generation process"
518
  )
519
- # pose_strength = gr.Slider(
520
- # label="Pose strength",
521
- # minimum=0,
522
- # maximum=1.5,
523
- # step=0.05,
524
- # value=0.40,
525
- # )
526
  canny_strength = gr.Slider(
527
  label="Canny strength",
528
  minimum=0,
529
  maximum=1.5,
530
  step=0.05,
531
- value=0,
532
  )
533
  depth_strength = gr.Slider(
534
  label="Depth strength",
535
  minimum=0,
536
  maximum=1.5,
537
  step=0.05,
538
- value=0,
539
  )
540
  with gr.Accordion(open=False, label="Advanced Options"):
541
  negative_prompt = gr.Textbox(
@@ -586,6 +591,9 @@ with gr.Blocks(css=css) as demo:
586
  label="InstantID Usage Tips", value=tips, visible=False
587
  )
588
 
 
 
 
589
  submit.click(
590
  fn=remove_tips,
591
  outputs=usage_tips,
@@ -598,7 +606,7 @@ with gr.Blocks(css=css) as demo:
598
  ).then(
599
  fn=generate_image,
600
  inputs=[
601
- face_file,
602
  pose_file,
603
  prompt,
604
  negative_prompt,
@@ -606,7 +614,6 @@ with gr.Blocks(css=css) as demo:
606
  num_steps,
607
  identitynet_strength_ratio,
608
  adapter_strength_ratio,
609
- #pose_strength,
610
  canny_strength,
611
  depth_strength,
612
  controlnet_selection,
@@ -628,7 +635,7 @@ with gr.Blocks(css=css) as demo:
628
 
629
  gr.Examples(
630
  examples=get_example(),
631
- inputs=[face_file, pose_file, prompt, style, negative_prompt],
632
  fn=run_for_examples,
633
  outputs=[gallery, usage_tips],
634
  cache_examples=True,
 
7
 
8
  import PIL
9
  from PIL import Image
10
+ from typing import Tuple, List
11
 
12
  import diffusers
13
  from diffusers.utils import load_image
 
21
  from style_template import styles
22
  from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
23
 
 
 
24
  import gradio as gr
25
 
26
  from depth_anything.dpt import DepthAnything
 
56
  )
57
  app.prepare(ctx_id=0, det_size=(640, 640))
58
 
 
 
59
  depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_vitl14').to(device).eval()
60
 
61
  transform = Compose([
 
81
  controlnet_path, torch_dtype=dtype
82
  )
83
 
84
+ # controlnet-canny/depth
 
85
  controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
86
  controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small"
87
 
 
 
 
88
  controlnet_canny = ControlNetModel.from_pretrained(
89
  controlnet_canny_model, torch_dtype=dtype
90
  ).to(device)
 
119
  return Image.fromarray(edges, "L")
120
 
121
  controlnet_map = {
 
122
  "canny": controlnet_canny,
123
  "depth": controlnet_depth,
124
  }
125
  controlnet_map_fn = {
 
126
  "canny": get_canny_image,
127
  "depth": get_depth_map,
128
  }
 
170
  def remove_tips():
171
  return gr.update(visible=False)
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
174
  return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
175
 
 
213
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
214
  return p.replace("{prompt}", positive), n + " " + negative
215
 
216
+ def update_face_gallery(files):
217
+ return gr.update(value=files, visible=True)
218
+
219
  @spaces.GPU
220
  def generate_image(
221
+ face_images_path, # Now accepts a list of image paths
222
  pose_image_path,
223
  prompt,
224
  negative_prompt,
 
226
  num_steps,
227
  identitynet_strength_ratio,
228
  adapter_strength_ratio,
 
229
  canny_strength,
230
  depth_strength,
231
  controlnet_selection,
 
252
  scheduler = getattr(diffusers, scheduler_class_name)
253
  pipe.scheduler = scheduler.from_config(pipe.scheduler.config, **add_kwargs)
254
 
255
+ if face_images_path is None or len(face_images_path) == 0:
256
  raise gr.Error(
257
+ f"Cannot find any input face images! Please upload at least one face image"
258
  )
259
 
260
  if prompt is None:
 
263
  # apply the style template
264
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
265
 
266
+ # Use the first face image for face keypoints and size reference
267
+ reference_face_path = face_images_path[0] if isinstance(face_images_path, list) else face_images_path
268
+ reference_face_image = load_image(reference_face_path)
269
+ reference_face_image = resize_img(reference_face_image, max_side=1024)
270
+ reference_face_cv2 = convert_from_image_to_cv2(reference_face_image)
271
+ height, width, _ = reference_face_cv2.shape
 
272
 
273
+ # Initialize a list to collect face embeddings
274
+ face_embeddings = []
275
+
276
+ # Process each face image if multiple images are provided
277
+ face_image_paths = face_images_path if isinstance(face_images_path, list) else [face_images_path]
278
+
279
+ for face_path in face_image_paths:
280
+ face_img = load_image(face_path)
281
+ face_img = resize_img(face_img, max_side=1024)
282
+ face_img_cv2 = convert_from_image_to_cv2(face_img)
283
+
284
+ # Extract face features
285
+ face_info = app.get(face_img_cv2)
286
+
287
+ if len(face_info) == 0:
288
+ print(f"Warning: Unable to detect a face in {face_path}. Skipping this image.")
289
+ continue
290
+
291
+ # Use the largest face in each image
292
+ face_info = sorted(
293
+ face_info,
294
+ key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
295
+ )[-1]
296
+
297
+ # Collect the embedding
298
+ face_embeddings.append(torch.tensor(face_info["embedding"]).unsqueeze(0))
299
+
300
+ if len(face_embeddings) == 0:
301
  raise gr.Error(
302
+ f"Unable to detect a face in any of the uploaded images. Please upload different photos with clear faces."
303
  )
304
+
305
+ # Average the face embeddings
306
+ if len(face_embeddings) == 1:
307
+ face_emb = face_embeddings[0].squeeze().numpy() # Use as is if only one image
308
+ else:
309
+ # Stack and compute mean along the batch dimension
310
+ face_emb = torch.mean(torch.cat(face_embeddings, dim=0), dim=0).numpy()
311
+ print(f"Averaged {len(face_embeddings)} face embeddings")
312
 
313
+ # Extract keypoints from the reference face for ControlNet
314
+ reference_face_info = app.get(reference_face_cv2)
315
+ if len(reference_face_info) == 0:
316
+ raise gr.Error(
317
+ f"Unable to detect a face in the reference image for keypoints. Please upload a different photo with a clear face."
318
+ )
319
+ reference_face_info = sorted(
320
+ reference_face_info,
321
  key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
322
+ )[-1] # Use the largest face
323
+
324
+ face_kps = draw_kps(convert_from_cv2_to_image(reference_face_cv2), reference_face_info["kps"])
325
+ img_controlnet = reference_face_image
326
+
 
327
  if pose_image_path is not None:
328
  pose_image = load_image(pose_image_path)
329
  pose_image = resize_img(pose_image, max_side=1024)
 
353
 
354
  if len(controlnet_selection) > 0:
355
  controlnet_scales = {
 
356
  "canny": canny_strength,
357
  "depth": depth_strength,
358
  }
 
394
 
395
  return images[0], gr.update(visible=True)
396
 
397
+ def get_example():
398
+ case = [
399
+ [
400
+ "./examples/yann-lecun_resize.jpg",
401
+ None,
402
+ "a man",
403
+ "Spring Festival",
404
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
405
+ ],
406
+ # Add more examples as needed
407
+ ]
408
+ return case
409
+
410
+ def run_for_examples(face_file, pose_file, prompt, style, negative_prompt):
411
+ return generate_image(
412
+ face_file,
413
+ pose_file,
414
+ prompt,
415
+ negative_prompt,
416
+ style,
417
+ 20, # num_steps
418
+ 0.8, # identitynet_strength_ratio
419
+ 0.8, # adapter_strength_ratio
420
+ 0.3, # canny_strength
421
+ 0.5, # depth_strength
422
+ ["depth", "canny"], # controlnet_selection
423
+ 5.0, # guidance_scale
424
+ 42, # seed
425
+ "EulerDiscreteScheduler", # scheduler
426
+ False, # enable_LCM
427
+ True, # enable_Face_Region
428
+ )
429
+
430
  # Description
431
  title = r"""
432
+ <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation with Multi-Face Averaging</h1>
433
  """
434
 
435
  article = r"""
 
451
  """
452
 
453
  tips = r"""
454
+ ### Usage tips of InstantID with Multi-Face Averaging
455
+ 1. Upload multiple photos of the same person for better identity preservation through face embedding averaging.
456
+ 2. If you're not satisfied with the similarity, try increasing the weight of "IdentityNet Strength" and "Adapter Strength."
457
+ 3. If you feel that the saturation is too high, first decrease the Adapter strength. If it remains too high, then decrease the IdentityNet strength.
458
+ 4. If you find that text control is not as expected, decrease Adapter strength.
459
+ 5. If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model.
460
  """
461
 
462
  css = """
 
469
  with gr.Row():
470
  with gr.Column():
471
  with gr.Row(equal_height=True):
472
+ # Change from single image to multiple files
473
+ face_files = gr.Files(
474
+ label="Upload photos of your face (1 or more)",
475
+ file_types=["image"]
476
  )
477
+
478
+ face_gallery = gr.Gallery(
479
+ label="Your uploaded face images",
480
+ visible=True,
481
+ columns=5,
482
+ rows=1,
483
+ height=150
484
+ )
485
 
486
  # prompt
487
  prompt = gr.Textbox(
 
526
  )
527
  controlnet_selection = gr.CheckboxGroup(
528
  ["canny", "depth"], label="Controlnet", value=[],
529
+ info="Use canny for edge detection, and depth for depth map estimation to control the generation process"
530
  )
 
 
 
 
 
 
 
531
  canny_strength = gr.Slider(
532
  label="Canny strength",
533
  minimum=0,
534
  maximum=1.5,
535
  step=0.05,
536
+ value=0.3,
537
  )
538
  depth_strength = gr.Slider(
539
  label="Depth strength",
540
  minimum=0,
541
  maximum=1.5,
542
  step=0.05,
543
+ value=0.5,
544
  )
545
  with gr.Accordion(open=False, label="Advanced Options"):
546
  negative_prompt = gr.Textbox(
 
591
  label="InstantID Usage Tips", value=tips, visible=False
592
  )
593
 
594
+ # Connect file uploads to update the gallery
595
+ face_files.upload(fn=update_face_gallery, inputs=face_files, outputs=face_gallery)
596
+
597
  submit.click(
598
  fn=remove_tips,
599
  outputs=usage_tips,
 
606
  ).then(
607
  fn=generate_image,
608
  inputs=[
609
+ face_files, # Changed from face_file to face_files
610
  pose_file,
611
  prompt,
612
  negative_prompt,
 
614
  num_steps,
615
  identitynet_strength_ratio,
616
  adapter_strength_ratio,
 
617
  canny_strength,
618
  depth_strength,
619
  controlnet_selection,
 
635
 
636
  gr.Examples(
637
  examples=get_example(),
638
+ inputs=[face_files, pose_file, prompt, style, negative_prompt],
639
  fn=run_for_examples,
640
  outputs=[gallery, usage_tips],
641
  cache_examples=True,