wanghaofan hysts HF staff commited on
Commit
7dd7d9c
·
verified ·
1 Parent(s): 0add3fc

Replace gr.Files with gr.Image (#8)

Browse files

- Apply formatter (b106068ca93ab77524c9462ecb68227be875efb0)
- Clean up (d3fe5c54697abbb6981b1dc434863c55ca0fde3a)
- Use gr.Image instead of gr.Files (938be836b4c274cd120d4e408e793e7247a03342)
- Add cached examples (a075b7a9c7764bfe3c4e6ace40d3b12c06e08015)


Co-authored-by: hysts <[email protected]>

.gitattributes CHANGED
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  examples/kaifu_resize.png filter=lfs diff=lfs merge=lfs -text
37
  examples/sam_resize.png filter=lfs diff=lfs merge=lfs -text
38
  examples/schmidhuber_resize.png filter=lfs diff=lfs merge=lfs -text
 
 
36
  examples/kaifu_resize.png filter=lfs diff=lfs merge=lfs -text
37
  examples/sam_resize.png filter=lfs diff=lfs merge=lfs -text
38
  examples/schmidhuber_resize.png filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,25 +1,19 @@
1
- import os
2
- import cv2
3
  import math
4
- import spaces
5
- import torch
6
  import random
7
- import numpy as np
8
 
 
 
 
9
  import PIL
10
- from PIL import Image
11
-
12
- import diffusers
13
- from diffusers.utils import load_image
14
  from diffusers.models import ControlNetModel
15
-
16
- import insightface
17
  from insightface.app import FaceAnalysis
 
18
 
19
- from style_template import styles
20
  from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline
21
-
22
- import gradio as gr
23
 
24
  # global variable
25
  MAX_SEED = np.iinfo(np.int32).max
@@ -29,22 +23,27 @@ DEFAULT_STYLE_NAME = "Watercolor"
29
 
30
  # download checkpoints
31
  from huggingface_hub import hf_hub_download
 
32
  hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
33
- hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
 
 
 
 
34
  hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
35
 
36
  # Load face encoder
37
- app = FaceAnalysis(name='antelopev2', root='./', providers=['CPUExecutionProvider'])
38
  app.prepare(ctx_id=0, det_size=(640, 640))
39
 
40
  # Path to InstantID models
41
- face_adapter = f'./checkpoints/ip-adapter.bin'
42
- controlnet_path = f'./checkpoints/ControlNetModel'
43
 
44
  # Load pipeline
45
  controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
46
 
47
- base_model_path = 'wangqixun/YamerMIX_v8'
48
 
49
  pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
50
  base_model_path,
@@ -55,54 +54,48 @@ pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
55
  )
56
  pipe.cuda()
57
  pipe.load_ip_adapter_instantid(face_adapter)
58
- pipe.image_proj_model.to('cuda')
59
- pipe.unet.to('cuda')
 
60
 
61
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
62
  if randomize_seed:
63
  seed = random.randint(0, MAX_SEED)
64
  return seed
65
 
66
- def swap_to_gallery(images):
67
- return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
68
-
69
- def upload_example_to_gallery(images, prompt, style, negative_prompt):
70
- return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
71
-
72
- def remove_back_to_files():
73
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
74
 
75
  def remove_tips():
76
  return gr.update(visible=False)
77
 
 
78
  def get_example():
79
  case = [
80
  [
81
- ['./examples/yann-lecun_resize.jpg'],
82
  "a man",
83
  "Snow",
84
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
85
  ],
86
  [
87
- ['./examples/musk_resize.jpeg'],
88
  "a man",
89
  "Mars",
90
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
91
  ],
92
  [
93
- ['./examples/sam_resize.png'],
94
  "a man",
95
  "Jungle",
96
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
97
  ],
98
  [
99
- ['./examples/schmidhuber_resize.png'],
100
  "a man",
101
  "Neon",
102
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
103
  ],
104
  [
105
- ['./examples/kaifu_resize.png'],
106
  "a man",
107
  "Vibrant Color",
108
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
@@ -110,16 +103,20 @@ def get_example():
110
  ]
111
  return case
112
 
113
- def run_for_examples(face_files, prompt, style, negative_prompt):
114
- return generate_image(face_files, None, prompt, negative_prompt, style, True, 30, 0.8, 0.8, 5, 42)
 
 
115
 
116
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
117
  return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
118
 
 
119
  def convert_from_image_to_cv2(img: Image) -> np.ndarray:
120
  return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
121
 
122
- def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
 
123
  stickwidth = 4
124
  limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
125
  kps = np.array(kps)
@@ -135,7 +132,9 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
135
  y = kps[index][:, 1]
136
  length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
137
  angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
138
- polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
 
 
139
  out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
140
  out_img = (out_img * 0.6).astype(np.uint8)
141
 
@@ -147,89 +146,114 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
147
  out_img_pil = Image.fromarray(out_img.astype(np.uint8))
148
  return out_img_pil
149
 
150
- def resize_img(input_image, max_side=1280, min_side=1024, size=None,
151
- pad_to_max_side=False, mode=PIL.Image.BILINEAR, base_pixel_number=64):
152
-
153
- w, h = input_image.size
154
- if size is not None:
155
- w_resize_new, h_resize_new = size
156
- else:
157
- ratio = min_side / min(h, w)
158
- w, h = round(ratio*w), round(ratio*h)
159
- ratio = max_side / max(h, w)
160
- input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
161
- w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
162
- h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
163
- input_image = input_image.resize([w_resize_new, h_resize_new], mode)
164
-
165
- if pad_to_max_side:
166
- res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
167
- offset_x = (max_side - w_resize_new) // 2
168
- offset_y = (max_side - h_resize_new) // 2
169
- res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
170
- input_image = Image.fromarray(res)
171
- return input_image
 
 
 
 
 
 
 
 
172
 
173
  def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
174
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
175
- return p.replace("{prompt}", positive), n + ' ' + negative
176
 
177
- @spaces.GPU
178
- def generate_image(face_image, pose_image, prompt, negative_prompt, style_name, enhance_face_region, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
179
 
 
180
  if face_image is None:
181
- raise gr.Error(f"Cannot find any input face image! Please upload the face image")
182
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  if prompt is None:
184
  prompt = "a person"
185
-
186
  # apply the style template
187
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
188
-
189
- face_image = load_image(face_image[0])
190
  face_image = resize_img(face_image)
191
  face_image_cv2 = convert_from_image_to_cv2(face_image)
192
  height, width, _ = face_image_cv2.shape
193
-
194
  # Extract face features
195
  face_info = app.get(face_image_cv2)
196
-
197
  if len(face_info) == 0:
198
- raise gr.Error(f"Cannot find any face in the image! Please upload another person image")
199
-
200
- face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
201
- face_emb = face_info['embedding']
202
- face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps'])
203
-
204
- if pose_image is not None:
205
- pose_image = load_image(pose_image[0])
 
 
206
  pose_image = resize_img(pose_image)
207
  pose_image_cv2 = convert_from_image_to_cv2(pose_image)
208
-
209
  face_info = app.get(pose_image_cv2)
210
-
211
  if len(face_info) == 0:
212
- raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image")
213
-
214
  face_info = face_info[-1]
215
- face_kps = draw_kps(pose_image, face_info['kps'])
216
-
217
  width, height = face_kps.size
218
-
219
  if enhance_face_region:
220
  control_mask = np.zeros([height, width, 3])
221
- x1, y1, x2, y2 = face_info['bbox']
222
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
223
  control_mask[y1:y2, x1:x2] = 255
224
  control_mask = Image.fromarray(control_mask.astype(np.uint8))
225
  else:
226
  control_mask = None
227
-
228
  generator = torch.Generator(device=device).manual_seed(seed)
229
-
230
  print("Start inference...")
231
  print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
232
-
233
  pipe.set_ip_adapter_scale(adapter_strength_ratio)
234
  images = pipe(
235
  prompt=prompt,
@@ -242,10 +266,11 @@ def generate_image(face_image, pose_image, prompt, negative_prompt, style_name,
242
  guidance_scale=guidance_scale,
243
  height=height,
244
  width=width,
245
- generator=generator
246
  ).images
247
 
248
- return images, gr.update(visible=True)
 
249
 
250
  ### Description
251
  title = r"""
@@ -289,46 +314,34 @@ tips = r"""
289
  4. Find a good base model always makes a difference.
290
  """
291
 
292
- css = '''
293
  .gradio-container {width: 85% !important}
294
- '''
295
  with gr.Blocks(css=css) as demo:
296
-
297
  # description
298
  gr.Markdown(title)
299
  gr.Markdown(description)
300
 
301
  with gr.Row():
302
  with gr.Column():
303
-
304
  # upload face image
305
- face_files = gr.Files(
306
- label="Upload a photo of your face",
307
- file_types=["image"]
308
- )
309
- uploaded_faces = gr.Gallery(label="Your images", visible=False, columns=1, rows=1, height=512)
310
- with gr.Column(visible=False) as clear_button_face:
311
- remove_and_reupload_faces = gr.ClearButton(value="Remove and upload new ones", components=face_files, size="sm")
312
-
313
  # optional: upload a reference pose image
314
- pose_files = gr.Files(
315
- label="Upload a reference pose image (optional)",
316
- file_types=["image"]
317
- )
318
- uploaded_poses = gr.Gallery(label="Your images", visible=False, columns=1, rows=1, height=512)
319
- with gr.Column(visible=False) as clear_button_pose:
320
- remove_and_reupload_poses = gr.ClearButton(value="Remove and upload new ones", components=pose_files, size="sm")
321
-
322
  # prompt
323
- prompt = gr.Textbox(label="Prompt",
324
- info="Give simple prompt is enough to achieve good face fedility",
325
- placeholder="A photo of a person",
326
- value="")
327
-
 
 
328
  submit = gr.Button("Submit", variant="primary")
329
-
330
  style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
331
-
332
  # strength
333
  identitynet_strength_ratio = gr.Slider(
334
  label="IdentityNet strength (for fedility)",
@@ -344,14 +357,14 @@ with gr.Blocks(css=css) as demo:
344
  step=0.05,
345
  value=0.80,
346
  )
347
-
348
  with gr.Accordion(open=False, label="Advanced Options"):
349
  negative_prompt = gr.Textbox(
350
- label="Negative Prompt",
351
  placeholder="low quality",
352
  value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
353
  )
354
- num_steps = gr.Slider(
355
  label="Number of sample steps",
356
  minimum=20,
357
  maximum=100,
@@ -376,18 +389,14 @@ with gr.Blocks(css=css) as demo:
376
  enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
377
 
378
  with gr.Column():
379
- gallery = gr.Gallery(label="Generated Images")
380
- usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips ,visible=False)
381
-
382
- face_files.upload(fn=swap_to_gallery, inputs=face_files, outputs=[uploaded_faces, clear_button_face, face_files])
383
- pose_files.upload(fn=swap_to_gallery, inputs=pose_files, outputs=[uploaded_poses, clear_button_pose, pose_files])
384
-
385
- remove_and_reupload_faces.click(fn=remove_back_to_files, outputs=[uploaded_faces, clear_button_face, face_files])
386
- remove_and_reupload_poses.click(fn=remove_back_to_files, outputs=[uploaded_poses, clear_button_pose, pose_files])
387
 
388
  submit.click(
389
  fn=remove_tips,
390
- outputs=usage_tips,
 
 
391
  ).then(
392
  fn=randomize_seed_fn,
393
  inputs=[seed, randomize_seed],
@@ -395,22 +404,37 @@ with gr.Blocks(css=css) as demo:
395
  queue=False,
396
  api_name=False,
397
  ).then(
 
 
 
 
 
398
  fn=generate_image,
399
- inputs=[face_files, pose_files, prompt, negative_prompt, style, enhance_face_region, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed],
400
- outputs=[gallery, usage_tips]
 
 
 
 
 
 
 
 
 
 
 
 
401
  )
402
-
403
  gr.Examples(
404
  examples=get_example(),
405
- inputs=[face_files, prompt, style, negative_prompt],
406
- run_on_click=True,
407
- fn=upload_example_to_gallery,
408
- outputs=[uploaded_faces, clear_button_face, face_files],
409
- cache_examples=True
410
  )
411
-
412
- gr.Markdown(article)
413
 
 
414
 
415
  demo.queue(api_open=False)
416
- demo.launch()
 
 
 
1
  import math
 
 
2
  import random
 
3
 
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
  import PIL
8
+ import spaces
9
+ import torch
 
 
10
  from diffusers.models import ControlNetModel
11
+ from diffusers.utils import load_image
 
12
  from insightface.app import FaceAnalysis
13
+ from PIL import Image
14
 
 
15
  from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline
16
+ from style_template import styles
 
17
 
18
  # global variable
19
  MAX_SEED = np.iinfo(np.int32).max
 
23
 
24
  # download checkpoints
25
  from huggingface_hub import hf_hub_download
26
+
27
  hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
28
+ hf_hub_download(
29
+ repo_id="InstantX/InstantID",
30
+ filename="ControlNetModel/diffusion_pytorch_model.safetensors",
31
+ local_dir="./checkpoints",
32
+ )
33
  hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
34
 
35
  # Load face encoder
36
+ app = FaceAnalysis(name="antelopev2", root="./", providers=["CPUExecutionProvider"])
37
  app.prepare(ctx_id=0, det_size=(640, 640))
38
 
39
  # Path to InstantID models
40
+ face_adapter = "./checkpoints/ip-adapter.bin"
41
+ controlnet_path = "./checkpoints/ControlNetModel"
42
 
43
  # Load pipeline
44
  controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
45
 
46
+ base_model_path = "wangqixun/YamerMIX_v8"
47
 
48
  pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
49
  base_model_path,
 
54
  )
55
  pipe.cuda()
56
  pipe.load_ip_adapter_instantid(face_adapter)
57
+ pipe.image_proj_model.to("cuda")
58
+ pipe.unet.to("cuda")
59
+
60
 
61
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
62
  if randomize_seed:
63
  seed = random.randint(0, MAX_SEED)
64
  return seed
65
 
 
 
 
 
 
 
 
 
66
 
67
  def remove_tips():
68
  return gr.update(visible=False)
69
 
70
+
71
  def get_example():
72
  case = [
73
  [
74
+ "./examples/yann-lecun_resize.jpg",
75
  "a man",
76
  "Snow",
77
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
78
  ],
79
  [
80
+ "./examples/musk_resize.jpeg",
81
  "a man",
82
  "Mars",
83
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
84
  ],
85
  [
86
+ "./examples/sam_resize.png",
87
  "a man",
88
  "Jungle",
89
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
90
  ],
91
  [
92
+ "./examples/schmidhuber_resize.png",
93
  "a man",
94
  "Neon",
95
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
96
  ],
97
  [
98
+ "./examples/kaifu_resize.png",
99
  "a man",
100
  "Vibrant Color",
101
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
 
103
  ]
104
  return case
105
 
106
+
107
+ def run_for_examples(face_file, prompt, style, negative_prompt):
108
+ return generate_image(face_file, None, prompt, negative_prompt, style, True, 30, 0.8, 0.8, 5, 42)
109
+
110
 
111
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
112
  return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
113
 
114
+
115
  def convert_from_image_to_cv2(img: Image) -> np.ndarray:
116
  return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
117
 
118
+
119
+ def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
120
  stickwidth = 4
121
  limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
122
  kps = np.array(kps)
 
132
  y = kps[index][:, 1]
133
  length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
134
  angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
135
+ polygon = cv2.ellipse2Poly(
136
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
137
+ )
138
  out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
139
  out_img = (out_img * 0.6).astype(np.uint8)
140
 
 
146
  out_img_pil = Image.fromarray(out_img.astype(np.uint8))
147
  return out_img_pil
148
 
149
+
150
+ def resize_img(
151
+ input_image,
152
+ max_side=1280,
153
+ min_side=1024,
154
+ size=None,
155
+ pad_to_max_side=False,
156
+ mode=PIL.Image.BILINEAR,
157
+ base_pixel_number=64,
158
+ ):
159
+ w, h = input_image.size
160
+ if size is not None:
161
+ w_resize_new, h_resize_new = size
162
+ else:
163
+ ratio = min_side / min(h, w)
164
+ w, h = round(ratio * w), round(ratio * h)
165
+ ratio = max_side / max(h, w)
166
+ input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
167
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
168
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
169
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
170
+
171
+ if pad_to_max_side:
172
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
173
+ offset_x = (max_side - w_resize_new) // 2
174
+ offset_y = (max_side - h_resize_new) // 2
175
+ res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = np.array(input_image)
176
+ input_image = Image.fromarray(res)
177
+ return input_image
178
+
179
 
180
  def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
181
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
182
+ return p.replace("{prompt}", positive), n + " " + negative
183
 
 
 
184
 
185
+ def check_input_image(face_image):
186
  if face_image is None:
187
+ raise gr.Error("Cannot find any input face image! Please upload the face image")
188
+
189
+
190
+ @spaces.GPU
191
+ def generate_image(
192
+ face_image_path,
193
+ pose_image_path,
194
+ prompt,
195
+ negative_prompt,
196
+ style_name,
197
+ enhance_face_region,
198
+ num_steps,
199
+ identitynet_strength_ratio,
200
+ adapter_strength_ratio,
201
+ guidance_scale,
202
+ seed,
203
+ progress=gr.Progress(track_tqdm=True),
204
+ ):
205
  if prompt is None:
206
  prompt = "a person"
207
+
208
  # apply the style template
209
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
210
+
211
+ face_image = load_image(face_image_path)
212
  face_image = resize_img(face_image)
213
  face_image_cv2 = convert_from_image_to_cv2(face_image)
214
  height, width, _ = face_image_cv2.shape
215
+
216
  # Extract face features
217
  face_info = app.get(face_image_cv2)
218
+
219
  if len(face_info) == 0:
220
+ raise gr.Error("Cannot find any face in the image! Please upload another person image")
221
+
222
+ face_info = sorted(face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1])[
223
+ -1
224
+ ] # only use the maximum face
225
+ face_emb = face_info["embedding"]
226
+ face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
227
+
228
+ if pose_image_path is not None:
229
+ pose_image = load_image(pose_image_path)
230
  pose_image = resize_img(pose_image)
231
  pose_image_cv2 = convert_from_image_to_cv2(pose_image)
232
+
233
  face_info = app.get(pose_image_cv2)
234
+
235
  if len(face_info) == 0:
236
+ raise gr.Error("Cannot find any face in the reference image! Please upload another person image")
237
+
238
  face_info = face_info[-1]
239
+ face_kps = draw_kps(pose_image, face_info["kps"])
240
+
241
  width, height = face_kps.size
242
+
243
  if enhance_face_region:
244
  control_mask = np.zeros([height, width, 3])
245
+ x1, y1, x2, y2 = face_info["bbox"]
246
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
247
  control_mask[y1:y2, x1:x2] = 255
248
  control_mask = Image.fromarray(control_mask.astype(np.uint8))
249
  else:
250
  control_mask = None
251
+
252
  generator = torch.Generator(device=device).manual_seed(seed)
253
+
254
  print("Start inference...")
255
  print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
256
+
257
  pipe.set_ip_adapter_scale(adapter_strength_ratio)
258
  images = pipe(
259
  prompt=prompt,
 
266
  guidance_scale=guidance_scale,
267
  height=height,
268
  width=width,
269
+ generator=generator,
270
  ).images
271
 
272
+ return images[0], gr.update(visible=True)
273
+
274
 
275
  ### Description
276
  title = r"""
 
314
  4. Find a good base model always makes a difference.
315
  """
316
 
317
+ css = """
318
  .gradio-container {width: 85% !important}
319
+ """
320
  with gr.Blocks(css=css) as demo:
 
321
  # description
322
  gr.Markdown(title)
323
  gr.Markdown(description)
324
 
325
  with gr.Row():
326
  with gr.Column():
 
327
  # upload face image
328
+ face_file = gr.Image(label="Upload a photo of your face", type="filepath")
329
+
 
 
 
 
 
 
330
  # optional: upload a reference pose image
331
+ pose_file = gr.Image(label="Upload a reference pose image (optional)", type="filepath")
332
+
 
 
 
 
 
 
333
  # prompt
334
+ prompt = gr.Textbox(
335
+ label="Prompt",
336
+ info="Give simple prompt is enough to achieve good face fedility",
337
+ placeholder="A photo of a person",
338
+ value="",
339
+ )
340
+
341
  submit = gr.Button("Submit", variant="primary")
342
+
343
  style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
344
+
345
  # strength
346
  identitynet_strength_ratio = gr.Slider(
347
  label="IdentityNet strength (for fedility)",
 
357
  step=0.05,
358
  value=0.80,
359
  )
360
+
361
  with gr.Accordion(open=False, label="Advanced Options"):
362
  negative_prompt = gr.Textbox(
363
+ label="Negative Prompt",
364
  placeholder="low quality",
365
  value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
366
  )
367
+ num_steps = gr.Slider(
368
  label="Number of sample steps",
369
  minimum=20,
370
  maximum=100,
 
389
  enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
390
 
391
  with gr.Column():
392
+ output_image = gr.Image(label="Generated Image")
393
+ usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips, visible=False)
 
 
 
 
 
 
394
 
395
  submit.click(
396
  fn=remove_tips,
397
+ outputs=usage_tips,
398
+ queue=False,
399
+ api_name=False,
400
  ).then(
401
  fn=randomize_seed_fn,
402
  inputs=[seed, randomize_seed],
 
404
  queue=False,
405
  api_name=False,
406
  ).then(
407
+ fn=check_input_image,
408
+ inputs=face_file,
409
+ queue=False,
410
+ api_name=False,
411
+ ).success(
412
  fn=generate_image,
413
+ inputs=[
414
+ face_file,
415
+ pose_file,
416
+ prompt,
417
+ negative_prompt,
418
+ style,
419
+ enhance_face_region,
420
+ num_steps,
421
+ identitynet_strength_ratio,
422
+ adapter_strength_ratio,
423
+ guidance_scale,
424
+ seed,
425
+ ],
426
+ outputs=[output_image, usage_tips],
427
  )
428
+
429
  gr.Examples(
430
  examples=get_example(),
431
+ inputs=[face_file, prompt, style, negative_prompt],
432
+ outputs=[output_image, usage_tips],
433
+ fn=run_for_examples,
434
+ cache_examples=True,
 
435
  )
 
 
436
 
437
+ gr.Markdown(article)
438
 
439
  demo.queue(api_open=False)
440
+ demo.launch()
gradio_cached_examples/25/Generated Image/2880a3d19ef9b42e3ed2/image.png ADDED

Git LFS Details

  • SHA256: 573444e88e4bf4ab7bf4a693cf53cea3988366f4ca35b41523bf7c027802d0a6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.57 MB
gradio_cached_examples/25/Generated Image/38dde1388c41109c5d39/image.png ADDED

Git LFS Details

  • SHA256: c4e80ada96212f1acd058324b25602dc611de920de9e710c26286e751a5f1a9a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.31 MB
gradio_cached_examples/25/Generated Image/6cb5a3af223906666bfd/image.png ADDED

Git LFS Details

  • SHA256: f4b7543b3b1fd8ae301ee77c8479a7b3170e2ee063cd501a04e5e8ecc38417f6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.28 MB
gradio_cached_examples/25/Generated Image/7e6ea85f77dd925d842c/image.png ADDED

Git LFS Details

  • SHA256: abde3e97dc47d0b95f9909b2575834d6792db49816e5b40b0c8474a48fe467b2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.38 MB
gradio_cached_examples/25/Generated Image/e56b029833685ff77e6a/image.png ADDED

Git LFS Details

  • SHA256: 2f0992f67fe4839bffb56cf3c527e08c652406c55729257374ee8d630ac21501
  • Pointer size: 132 Bytes
  • Size of remote file: 2.69 MB
gradio_cached_examples/25/log.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Generated Image,Usage tips of InstantID,flag,username,timestamp
2
+ "{""path"":""gradio_cached_examples/25/Generated Image/7e6ea85f77dd925d842c/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:55:38.846769
3
+ "{""path"":""gradio_cached_examples/25/Generated Image/2880a3d19ef9b42e3ed2/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:56:11.432078
4
+ "{""path"":""gradio_cached_examples/25/Generated Image/38dde1388c41109c5d39/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:56:45.563918
5
+ "{""path"":""gradio_cached_examples/25/Generated Image/e56b029833685ff77e6a/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:57:20.321876
6
+ "{""path"":""gradio_cached_examples/25/Generated Image/6cb5a3af223906666bfd/image.png"",""url"":null,""size"":null,""orig_name"":""image.png"",""mime_type"":null}","{'visible': True, '__type__': 'update'}",,,2024-01-24 08:57:53.871716