multimodalart HF Staff commited on
Commit
de3030d
·
verified ·
1 Parent(s): d1d8628

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -124
app.py CHANGED
@@ -7,7 +7,7 @@ import spaces
7
 
8
  import PIL
9
  from PIL import Image
10
- from typing import Tuple, List
11
 
12
  import diffusers
13
  from diffusers.utils import load_image
@@ -21,6 +21,8 @@ from insightface.app import FaceAnalysis
21
  from style_template import styles
22
  from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
23
 
 
 
24
  import gradio as gr
25
 
26
  from depth_anything.dpt import DepthAnything
@@ -56,6 +58,8 @@ app = FaceAnalysis(
56
  )
57
  app.prepare(ctx_id=0, det_size=(640, 640))
58
 
 
 
59
  depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_vitl14').to(device).eval()
60
 
61
  transform = Compose([
@@ -81,10 +85,14 @@ controlnet_identitynet = ControlNetModel.from_pretrained(
81
  controlnet_path, torch_dtype=dtype
82
  )
83
 
84
- # controlnet-canny/depth
 
85
  controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
86
  controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small"
87
 
 
 
 
88
  controlnet_canny = ControlNetModel.from_pretrained(
89
  controlnet_canny_model, torch_dtype=dtype
90
  ).to(device)
@@ -119,10 +127,12 @@ def get_canny_image(image, t1=100, t2=200):
119
  return Image.fromarray(edges, "L")
120
 
121
  controlnet_map = {
 
122
  "canny": controlnet_canny,
123
  "depth": controlnet_depth,
124
  }
125
  controlnet_map_fn = {
 
126
  "canny": get_canny_image,
127
  "depth": get_depth_map,
128
  }
@@ -170,6 +180,67 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
170
  def remove_tips():
171
  return gr.update(visible=False)
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
174
  return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
175
 
@@ -213,12 +284,9 @@ def apply_style(
213
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
214
  return p.replace("{prompt}", positive), n + " " + negative
215
 
216
- def update_face_gallery(files):
217
- return gr.update(value=files, visible=True)
218
-
219
  @spaces.GPU
220
  def generate_image(
221
- face_images_path, # Now accepts a list of image paths
222
  pose_image_path,
223
  prompt,
224
  negative_prompt,
@@ -226,6 +294,7 @@ def generate_image(
226
  num_steps,
227
  identitynet_strength_ratio,
228
  adapter_strength_ratio,
 
229
  canny_strength,
230
  depth_strength,
231
  controlnet_selection,
@@ -252,9 +321,9 @@ def generate_image(
252
  scheduler = getattr(diffusers, scheduler_class_name)
253
  pipe.scheduler = scheduler.from_config(pipe.scheduler.config, **add_kwargs)
254
 
255
- if face_images_path is None or len(face_images_path) == 0:
256
  raise gr.Error(
257
- f"Cannot find any input face images! Please upload at least one face image"
258
  )
259
 
260
  if prompt is None:
@@ -263,67 +332,28 @@ def generate_image(
263
  # apply the style template
264
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
265
 
266
- # Use the first face image for face keypoints and size reference
267
- reference_face_path = face_images_path[0] if isinstance(face_images_path, list) else face_images_path
268
- reference_face_image = load_image(reference_face_path)
269
- reference_face_image = resize_img(reference_face_image, max_side=1024)
270
- reference_face_cv2 = convert_from_image_to_cv2(reference_face_image)
271
- height, width, _ = reference_face_cv2.shape
272
 
273
- # Initialize a list to collect face embeddings
274
- face_embeddings = []
275
-
276
- # Process each face image if multiple images are provided
277
- face_image_paths = face_images_path if isinstance(face_images_path, list) else [face_images_path]
278
-
279
- for face_path in face_image_paths:
280
- face_img = load_image(face_path)
281
- face_img = resize_img(face_img, max_side=1024)
282
- face_img_cv2 = convert_from_image_to_cv2(face_img)
283
-
284
- # Extract face features
285
- face_info = app.get(face_img_cv2)
286
-
287
- if len(face_info) == 0:
288
- print(f"Warning: Unable to detect a face in {face_path}. Skipping this image.")
289
- continue
290
-
291
- # Use the largest face in each image
292
- face_info = sorted(
293
- face_info,
294
- key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
295
- )[-1]
296
-
297
- # Collect the embedding
298
- face_embeddings.append(torch.tensor(face_info["embedding"]).unsqueeze(0))
299
-
300
- if len(face_embeddings) == 0:
301
- raise gr.Error(
302
- f"Unable to detect a face in any of the uploaded images. Please upload different photos with clear faces."
303
- )
304
-
305
- # Average the face embeddings
306
- if len(face_embeddings) == 1:
307
- face_emb = face_embeddings[0].squeeze().numpy() # Use as is if only one image
308
- else:
309
- # Stack and compute mean along the batch dimension
310
- face_emb = torch.mean(torch.cat(face_embeddings, dim=0), dim=0).numpy()
311
- print(f"Averaged {len(face_embeddings)} face embeddings")
312
 
313
- # Extract keypoints from the reference face for ControlNet
314
- reference_face_info = app.get(reference_face_cv2)
315
- if len(reference_face_info) == 0:
316
  raise gr.Error(
317
- f"Unable to detect a face in the reference image for keypoints. Please upload a different photo with a clear face."
318
  )
319
- reference_face_info = sorted(
320
- reference_face_info,
 
321
  key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
322
- )[-1] # Use the largest face
323
-
324
- face_kps = draw_kps(convert_from_cv2_to_image(reference_face_cv2), reference_face_info["kps"])
325
- img_controlnet = reference_face_image
326
-
 
327
  if pose_image_path is not None:
328
  pose_image = load_image(pose_image_path)
329
  pose_image = resize_img(pose_image, max_side=1024)
@@ -353,6 +383,7 @@ def generate_image(
353
 
354
  if len(controlnet_selection) > 0:
355
  controlnet_scales = {
 
356
  "canny": canny_strength,
357
  "depth": depth_strength,
358
  }
@@ -394,42 +425,9 @@ def generate_image(
394
 
395
  return images[0], gr.update(visible=True)
396
 
397
- def get_example():
398
- case = [
399
- [
400
- "./examples/yann-lecun_resize.jpg",
401
- None,
402
- "a man",
403
- "Spring Festival",
404
- "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
405
- ],
406
- # Add more examples as needed
407
- ]
408
- return case
409
-
410
- def run_for_examples(face_file, pose_file, prompt, style, negative_prompt):
411
- return generate_image(
412
- face_file,
413
- pose_file,
414
- prompt,
415
- negative_prompt,
416
- style,
417
- 20, # num_steps
418
- 0.8, # identitynet_strength_ratio
419
- 0.8, # adapter_strength_ratio
420
- 0.3, # canny_strength
421
- 0.5, # depth_strength
422
- ["depth", "canny"], # controlnet_selection
423
- 5.0, # guidance_scale
424
- 42, # seed
425
- "EulerDiscreteScheduler", # scheduler
426
- False, # enable_LCM
427
- True, # enable_Face_Region
428
- )
429
-
430
  # Description
431
  title = r"""
432
- <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation with Multi-Face Averaging</h1>
433
  """
434
 
435
  article = r"""
@@ -451,12 +449,11 @@ If you have any questions, please feel free to open an issue or directly reach u
451
  """
452
 
453
  tips = r"""
454
- ### Usage tips of InstantID with Multi-Face Averaging
455
- 1. Upload multiple photos of the same person for better identity preservation through face embedding averaging.
456
- 2. If you're not satisfied with the similarity, try increasing the weight of "IdentityNet Strength" and "Adapter Strength."
457
- 3. If you feel that the saturation is too high, first decrease the Adapter strength. If it remains too high, then decrease the IdentityNet strength.
458
- 4. If you find that text control is not as expected, decrease Adapter strength.
459
- 5. If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model.
460
  """
461
 
462
  css = """
@@ -469,19 +466,10 @@ with gr.Blocks(css=css) as demo:
469
  with gr.Row():
470
  with gr.Column():
471
  with gr.Row(equal_height=True):
472
- # Change from single image to multiple files
473
- face_files = gr.Files(
474
- label="Upload photos of your face (1 or more)",
475
- file_types=["image"]
476
  )
477
-
478
- face_gallery = gr.Gallery(
479
- label="Your uploaded face images",
480
- visible=True,
481
- columns=5,
482
- rows=1,
483
- height=150
484
- )
485
 
486
  # prompt
487
  prompt = gr.Textbox(
@@ -526,21 +514,28 @@ with gr.Blocks(css=css) as demo:
526
  )
527
  controlnet_selection = gr.CheckboxGroup(
528
  ["canny", "depth"], label="Controlnet", value=[],
529
- info="Use canny for edge detection, and depth for depth map estimation to control the generation process"
530
  )
 
 
 
 
 
 
 
531
  canny_strength = gr.Slider(
532
  label="Canny strength",
533
  minimum=0,
534
  maximum=1.5,
535
  step=0.05,
536
- value=0.3,
537
  )
538
  depth_strength = gr.Slider(
539
  label="Depth strength",
540
  minimum=0,
541
  maximum=1.5,
542
  step=0.05,
543
- value=0.5,
544
  )
545
  with gr.Accordion(open=False, label="Advanced Options"):
546
  negative_prompt = gr.Textbox(
@@ -591,9 +586,6 @@ with gr.Blocks(css=css) as demo:
591
  label="InstantID Usage Tips", value=tips, visible=False
592
  )
593
 
594
- # Connect file uploads to update the gallery
595
- face_files.upload(fn=update_face_gallery, inputs=face_files, outputs=face_gallery)
596
-
597
  submit.click(
598
  fn=remove_tips,
599
  outputs=usage_tips,
@@ -606,7 +598,7 @@ with gr.Blocks(css=css) as demo:
606
  ).then(
607
  fn=generate_image,
608
  inputs=[
609
- face_files, # Changed from face_file to face_files
610
  pose_file,
611
  prompt,
612
  negative_prompt,
@@ -614,6 +606,7 @@ with gr.Blocks(css=css) as demo:
614
  num_steps,
615
  identitynet_strength_ratio,
616
  adapter_strength_ratio,
 
617
  canny_strength,
618
  depth_strength,
619
  controlnet_selection,
@@ -635,7 +628,7 @@ with gr.Blocks(css=css) as demo:
635
 
636
  gr.Examples(
637
  examples=get_example(),
638
- inputs=[face_files, pose_file, prompt, style, negative_prompt],
639
  fn=run_for_examples,
640
  outputs=[gallery, usage_tips],
641
  cache_examples=True,
 
7
 
8
  import PIL
9
  from PIL import Image
10
+ from typing import Tuple
11
 
12
  import diffusers
13
  from diffusers.utils import load_image
 
21
  from style_template import styles
22
  from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
23
 
24
+ # from controlnet_aux import OpenposeDetector
25
+
26
  import gradio as gr
27
 
28
  from depth_anything.dpt import DepthAnything
 
58
  )
59
  app.prepare(ctx_id=0, det_size=(640, 640))
60
 
61
+ # openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
62
+
63
  depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_vitl14').to(device).eval()
64
 
65
  transform = Compose([
 
85
  controlnet_path, torch_dtype=dtype
86
  )
87
 
88
+ # controlnet-pose/canny/depth
89
+ # controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
90
  controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
91
  controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small"
92
 
93
+ # controlnet_pose = ControlNetModel.from_pretrained(
94
+ # controlnet_pose_model, torch_dtype=dtype
95
+ # ).to(device)
96
  controlnet_canny = ControlNetModel.from_pretrained(
97
  controlnet_canny_model, torch_dtype=dtype
98
  ).to(device)
 
127
  return Image.fromarray(edges, "L")
128
 
129
  controlnet_map = {
130
+ #"pose": controlnet_pose,
131
  "canny": controlnet_canny,
132
  "depth": controlnet_depth,
133
  }
134
  controlnet_map_fn = {
135
+ #"pose": openpose,
136
  "canny": get_canny_image,
137
  "depth": get_depth_map,
138
  }
 
180
  def remove_tips():
181
  return gr.update(visible=False)
182
 
183
+ def get_example():
184
+ case = [
185
+ [
186
+ "./examples/yann-lecun_resize.jpg",
187
+ None,
188
+ "a man",
189
+ "Spring Festival",
190
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
191
+ ],
192
+ [
193
+ "./examples/musk_resize.jpeg",
194
+ "./examples/poses/pose2.jpg",
195
+ "a man flying in the sky in Mars",
196
+ "Mars",
197
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
198
+ ],
199
+ [
200
+ "./examples/sam_resize.png",
201
+ "./examples/poses/pose4.jpg",
202
+ "a man doing a silly pose wearing a suite",
203
+ "Jungle",
204
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
205
+ ],
206
+ [
207
+ "./examples/schmidhuber_resize.png",
208
+ "./examples/poses/pose3.jpg",
209
+ "a man sit on a chair",
210
+ "Neon",
211
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
212
+ ],
213
+ [
214
+ "./examples/kaifu_resize.png",
215
+ "./examples/poses/pose.jpg",
216
+ "a man",
217
+ "Vibrant Color",
218
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
219
+ ],
220
+ ]
221
+ return case
222
+
223
+ def run_for_examples(face_file, pose_file, prompt, style, negative_prompt):
224
+ return generate_image(
225
+ face_file,
226
+ pose_file,
227
+ prompt,
228
+ negative_prompt,
229
+ style,
230
+ 20, # num_steps
231
+ 0.8, # identitynet_strength_ratio
232
+ 0.8, # adapter_strength_ratio
233
+ #0.4, # pose_strength
234
+ 0.3, # canny_strength
235
+ 0.5, # depth_strength
236
+ ["depth", "canny"], # controlnet_selection
237
+ 5.0, # guidance_scale
238
+ 42, # seed
239
+ "EulerDiscreteScheduler", # scheduler
240
+ False, # enable_LCM
241
+ True, # enable_Face_Region
242
+ )
243
+
244
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
245
  return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
246
 
 
284
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
285
  return p.replace("{prompt}", positive), n + " " + negative
286
 
 
 
 
287
  @spaces.GPU
288
  def generate_image(
289
+ face_image_path,
290
  pose_image_path,
291
  prompt,
292
  negative_prompt,
 
294
  num_steps,
295
  identitynet_strength_ratio,
296
  adapter_strength_ratio,
297
+ #pose_strength,
298
  canny_strength,
299
  depth_strength,
300
  controlnet_selection,
 
321
  scheduler = getattr(diffusers, scheduler_class_name)
322
  pipe.scheduler = scheduler.from_config(pipe.scheduler.config, **add_kwargs)
323
 
324
+ if face_image_path is None:
325
  raise gr.Error(
326
+ f"Cannot find any input face image! Please upload the face image"
327
  )
328
 
329
  if prompt is None:
 
332
  # apply the style template
333
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
334
 
335
+ face_image = load_image(face_image_path)
336
+ face_image = resize_img(face_image, max_side=1024)
337
+ face_image_cv2 = convert_from_image_to_cv2(face_image)
338
+ height, width, _ = face_image_cv2.shape
 
 
339
 
340
+ # Extract face features
341
+ face_info = app.get(face_image_cv2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
+ if len(face_info) == 0:
 
 
344
  raise gr.Error(
345
+ f"Unable to detect a face in the image. Please upload a different photo with a clear face."
346
  )
347
+
348
+ face_info = sorted(
349
+ face_info,
350
  key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
351
+ )[
352
+ -1
353
+ ] # only use the maximum face
354
+ face_emb = face_info["embedding"]
355
+ face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
356
+ img_controlnet = face_image
357
  if pose_image_path is not None:
358
  pose_image = load_image(pose_image_path)
359
  pose_image = resize_img(pose_image, max_side=1024)
 
383
 
384
  if len(controlnet_selection) > 0:
385
  controlnet_scales = {
386
+ #"pose": pose_strength,
387
  "canny": canny_strength,
388
  "depth": depth_strength,
389
  }
 
425
 
426
  return images[0], gr.update(visible=True)
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  # Description
429
  title = r"""
430
+ <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
431
  """
432
 
433
  article = r"""
 
449
  """
450
 
451
  tips = r"""
452
+ ### Usage tips of InstantID
453
+ 1. If you're not satisfied with the similarity, try increasing the weight of "IdentityNet Strength" and "Adapter Strength."
454
+ 2. If you feel that the saturation is too high, first decrease the Adapter strength. If it remains too high, then decrease the IdentityNet strength.
455
+ 3. If you find that text control is not as expected, decrease Adapter strength.
456
+ 4. If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model.
 
457
  """
458
 
459
  css = """
 
466
  with gr.Row():
467
  with gr.Column():
468
  with gr.Row(equal_height=True):
469
+ # upload face image
470
+ face_file = gr.Image(
471
+ label="Upload a photo of your face", type="filepath"
 
472
  )
 
 
 
 
 
 
 
 
473
 
474
  # prompt
475
  prompt = gr.Textbox(
 
514
  )
515
  controlnet_selection = gr.CheckboxGroup(
516
  ["canny", "depth"], label="Controlnet", value=[],
517
+ info="Use pose for skeleton inference, canny for edge detection, and depth for depth map estimation. You can try all three to control the generation process"
518
  )
519
+ # pose_strength = gr.Slider(
520
+ # label="Pose strength",
521
+ # minimum=0,
522
+ # maximum=1.5,
523
+ # step=0.05,
524
+ # value=0.40,
525
+ # )
526
  canny_strength = gr.Slider(
527
  label="Canny strength",
528
  minimum=0,
529
  maximum=1.5,
530
  step=0.05,
531
+ value=0,
532
  )
533
  depth_strength = gr.Slider(
534
  label="Depth strength",
535
  minimum=0,
536
  maximum=1.5,
537
  step=0.05,
538
+ value=0,
539
  )
540
  with gr.Accordion(open=False, label="Advanced Options"):
541
  negative_prompt = gr.Textbox(
 
586
  label="InstantID Usage Tips", value=tips, visible=False
587
  )
588
 
 
 
 
589
  submit.click(
590
  fn=remove_tips,
591
  outputs=usage_tips,
 
598
  ).then(
599
  fn=generate_image,
600
  inputs=[
601
+ face_file,
602
  pose_file,
603
  prompt,
604
  negative_prompt,
 
606
  num_steps,
607
  identitynet_strength_ratio,
608
  adapter_strength_ratio,
609
+ #pose_strength,
610
  canny_strength,
611
  depth_strength,
612
  controlnet_selection,
 
628
 
629
  gr.Examples(
630
  examples=get_example(),
631
+ inputs=[face_file, pose_file, prompt, style, negative_prompt],
632
  fn=run_for_examples,
633
  outputs=[gallery, usage_tips],
634
  cache_examples=True,