adaface-neurips commited on
Commit
936cd75
·
1 Parent(s): 40ca865

update code

Browse files
.gitignore CHANGED
@@ -1,10 +1,4 @@
1
- models/awportrait/*
2
- models/awportrait
3
  __pycache__/*
4
  __pycache__
5
- samples-ada/*
6
- samples-ada
7
- models/ensemble/awp14-unet/*
8
- models/ensemble/awp14-unet
9
  .gradio/certificate.pem
10
-
 
 
 
1
  __pycache__/*
2
  __pycache__
 
 
 
 
3
  .gradio/certificate.pem
4
+ models/*
ConsistentID/app.py CHANGED
@@ -26,8 +26,8 @@ pipe = ConsistentIDPipeline.from_pretrained(
26
 
27
  ### Load consistentID_model checkpoint
28
  pipe.load_ConsistentID_model(
29
- consistentID_weight_path="./models/ConsistentID-v1.bin",
30
- bise_net_weight_path="./models/BiSeNet_pretrained_for_ConsistentID.pth",
31
  )
32
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
33
  pipe = pipe.to(device, torch.float16)
 
26
 
27
  ### Load consistentID_model checkpoint
28
  pipe.load_ConsistentID_model(
29
+ consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin",
30
+ bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth",
31
  )
32
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
33
  pipe = pipe.to(device, torch.float16)
adaface/adaface_infer.py CHANGED
@@ -45,8 +45,7 @@ def parse_args():
45
  help="Type of pipeline to use (default: txt2img)")
46
  parser.add_argument("--base_model_path", type=str, default=None,
47
  help="Type of checkpoints to use (default: None, using the official model)")
48
- parser.add_argument('--adaface_ckpt_paths', type=str, nargs="+",
49
- default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt'])
50
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
51
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
52
  parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
@@ -60,23 +59,18 @@ def parse_args():
60
  parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*",
61
  default=[],
62
  help="Extra paths to the checkpoints of the UNet models")
63
- parser.add_argument('--unet_weights', type=float, nargs="+", default=[1],
64
  help="Weights for the UNet models")
65
  parser.add_argument("--subject", type=str)
66
  parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
67
  parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
68
  parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
69
- parser.add_argument("--noise", dest='perturb_std', type=float, default=0)
70
  parser.add_argument("--randface", action="store_true")
71
  parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
72
  help="Guidance scale for the diffusion model")
73
- parser.add_argument("--subject_string",
74
- type=str, default="z",
75
- help="Subject placeholder string used in prompts to denote the concept.")
76
  parser.add_argument("--num_images_per_row", type=int, default=4,
77
  help="Number of images to display in a row in the output grid image.")
78
- parser.add_argument("--num_inference_steps", type=int, default=50,
79
- help="Number of inference steps")
80
  parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
81
  parser.add_argument("--seed", type=int, default=42,
82
  help="the seed (for reproducible sampling). Set to -1 to disable.")
@@ -95,16 +89,15 @@ if __name__ == "__main__":
95
 
96
  if args.pipeline not in ["text2img", "img2img"]:
97
  args.extra_unet_dirpaths = None
98
- args.unet_weights = None
99
 
100
  adaface = AdaFaceWrapper(args.pipeline, args.base_model_path,
101
- args.adaface_encoder_types, args.adaface_ckpt_paths,
102
  args.adaface_encoder_cfg_scales, args.enabled_encoders,
103
- args.subject_string, args.num_inference_steps,
104
  unet_types=None,
105
  main_unet_filepath=args.main_unet_filepath,
106
  extra_unet_dirpaths=args.extra_unet_dirpaths,
107
- unet_weights=args.unet_weights, device=args.device)
108
 
109
  if not args.randface:
110
  image_folder = args.subject
@@ -143,7 +136,7 @@ if __name__ == "__main__":
143
  rand_init_id_embs = torch.randn(1, 512)
144
 
145
  init_id_embs = rand_init_id_embs if args.randface else None
146
- noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
147
  # args.perturb_std: the *relative* std of the noise added to the face embeddings.
148
  # A noise level of 0.08 could change gender, but 0.06 is usually safe.
149
  # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
@@ -151,5 +144,7 @@ if __name__ == "__main__":
151
  adaface.prepare_adaface_embeddings(image_paths, init_id_embs,
152
  perturb_at_stage='img_prompt_emb',
153
  perturb_std=args.perturb_std, update_text_encoder=True)
154
- images = adaface(noise, args.prompt, None, 'append', args.guidance_scale, args.out_image_count, verbose=True)
 
 
155
  save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std)
 
45
  help="Type of pipeline to use (default: txt2img)")
46
  parser.add_argument("--base_model_path", type=str, default=None,
47
  help="Type of checkpoints to use (default: None, using the official model)")
48
+ parser.add_argument('--adaface_ckpt_path', type=str, required=True)
 
49
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
50
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
51
  parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
 
59
  parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*",
60
  default=[],
61
  help="Extra paths to the checkpoints of the UNet models")
62
+ parser.add_argument('--unet_weights_in_ensemble', type=float, nargs="+", default=[1],
63
  help="Weights for the UNet models")
64
  parser.add_argument("--subject", type=str)
65
  parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
66
  parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
67
  parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
68
+ parser.add_argument("--perturb_std", type=float, default=0)
69
  parser.add_argument("--randface", action="store_true")
70
  parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
71
  help="Guidance scale for the diffusion model")
 
 
 
72
  parser.add_argument("--num_images_per_row", type=int, default=4,
73
  help="Number of images to display in a row in the output grid image.")
 
 
74
  parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
75
  parser.add_argument("--seed", type=int, default=42,
76
  help="the seed (for reproducible sampling). Set to -1 to disable.")
 
89
 
90
  if args.pipeline not in ["text2img", "img2img"]:
91
  args.extra_unet_dirpaths = None
92
+ args.unet_weights_in_ensemble = None
93
 
94
  adaface = AdaFaceWrapper(args.pipeline, args.base_model_path,
95
+ args.adaface_encoder_types, args.adaface_ckpt_path,
96
  args.adaface_encoder_cfg_scales, args.enabled_encoders,
 
97
  unet_types=None,
98
  main_unet_filepath=args.main_unet_filepath,
99
  extra_unet_dirpaths=args.extra_unet_dirpaths,
100
+ unet_weights_in_ensemble=args.unet_weights_in_ensemble, device=args.device)
101
 
102
  if not args.randface:
103
  image_folder = args.subject
 
136
  rand_init_id_embs = torch.randn(1, 512)
137
 
138
  init_id_embs = rand_init_id_embs if args.randface else None
139
+ init_noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
140
  # args.perturb_std: the *relative* std of the noise added to the face embeddings.
141
  # A noise level of 0.08 could change gender, but 0.06 is usually safe.
142
  # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
 
144
  adaface.prepare_adaface_embeddings(image_paths, init_id_embs,
145
  perturb_at_stage='img_prompt_emb',
146
  perturb_std=args.perturb_std, update_text_encoder=True)
147
+ images = adaface(init_noise, args.prompt, None, None,
148
+ 'append', args.guidance_scale,
149
+ args.out_image_count, verbose=True)
150
  save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std)
adaface/adaface_translate.py CHANGED
@@ -25,10 +25,9 @@ def seed_everything(seed):
25
 
26
  def parse_args():
27
  parser = argparse.ArgumentParser()
28
- parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors',
29
- help="Path to the UNet checkpoint (default: RealisticVision 4.0)")
30
- parser.add_argument('--adaface_ckpt_paths', type=str, nargs="+",
31
- default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt'])
32
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
33
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
34
  parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
@@ -40,9 +39,11 @@ def parse_args():
40
  parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*",
41
  default=[],
42
  help="Extra paths to the checkpoints of the UNet models")
43
- parser.add_argument('--unet_weights', type=float, nargs="+", default=[1],
44
  help="Weights for the UNet models")
45
  parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
 
 
46
  # If True, the input folder contains images of mixed subjects.
47
  # If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
48
  parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
@@ -52,19 +53,14 @@ def parse_args():
52
  parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
53
  parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
54
  parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
55
- parser.add_argument("--noise", dest='perturb_std', type=float, default=0)
56
  parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
57
  help="Guidance scale for the diffusion model")
58
  parser.add_argument("--ref_img_strength", type=float, default=0.8,
59
  help="Strength of the reference image in the output image.")
60
- parser.add_argument("--subject_string",
61
- type=str, default="z",
62
- help="Subject placeholder string used in prompts to denote the concept.")
63
  parser.add_argument("--prompt", type=str, default="a person z")
64
  parser.add_argument("--num_images_per_row", type=int, default=4,
65
  help="Number of images to display in a row in the output grid image.")
66
- parser.add_argument("--num_inference_steps", type=int, default=50,
67
- help="Number of DDIM inference steps")
68
  parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
69
  parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
70
  parser.add_argument("--seed", type=int, default=42,
@@ -93,15 +89,16 @@ if __name__ == "__main__":
93
  process_index = 0
94
 
95
  adaface = AdaFaceWrapper("img2img", args.base_model_path,
96
- args.adaface_encoder_types, args.adaface_ckpt_paths,
97
  args.adaface_encoder_cfg_scales, args.enabled_encoders,
98
- args.subject_string, args.num_inference_steps,
99
  unet_types=None,
100
- extra_unet_dirpaths=args.extra_unet_dirpaths, unet_weights=args.unet_weights,
 
101
  device=args.device)
102
 
103
  in_folder = args.in_folder
104
  if os.path.isfile(in_folder):
 
105
  subject_folders = [ os.path.dirname(in_folder) ]
106
  images_by_subject = [[in_folder]]
107
  else:
@@ -157,6 +154,24 @@ if __name__ == "__main__":
157
  images_by_subject = images_by_subject[process_index::args.num_gpus]
158
  #subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
161
  # If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
162
  # Otherwise, we use the folder name as the signature of the images.
@@ -176,29 +191,32 @@ if __name__ == "__main__":
176
  os.makedirs(subject_out_folder)
177
  print(f"Output images will be saved to {subject_out_folder}")
178
 
179
- in_images = []
180
- for image_path in image_paths:
181
- image = Image.open(image_path).convert("RGB").resize((512, 512))
182
- # [512, 512, 3] -> [3, 512, 512].
183
- image = np.array(image).transpose(2, 0, 1)
184
- # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
185
- image = torch.tensor(image).unsqueeze(0).float().cuda()
186
- in_images.append(image)
187
-
188
- # Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
189
- # NOTE: For simplicity, we do not check overly large batch sizes.
190
- in_images = torch.cat(in_images, dim=0)
191
- # in_images: [5, 3, 512, 512].
192
- # Normalize the pixel values to [0, 1].
193
- in_images = in_images / 255.0
194
- num_out_images = len(in_images) * args.out_count_per_input_image
 
195
 
196
  with torch.no_grad():
197
  # args.perturb_std: the *relative* std of the noise added to the face embeddings.
198
  # A noise level of 0.08 could change gender, but 0.06 is usually safe.
199
  # The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
200
  # NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
201
- out_images = adaface(in_images, args.prompt, None, 'append', args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength)
 
 
202
 
203
  for img_i, img in enumerate(out_images):
204
  # out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
@@ -206,9 +224,11 @@ if __name__ == "__main__":
206
  copy_i = img_i // len(in_images)
207
  image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
208
  if copy_i == 0:
209
- img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}"))
210
  else:
211
- img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}"))
 
 
212
 
213
  if args.copy_masks:
214
  mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
 
25
 
26
  def parse_args():
27
  parser = argparse.ArgumentParser()
28
+ parser.add_argument("--base_model_path", type=str, default='models/sar/sar.safetensors',
29
+ help="Path to the UNet checkpoint (Default: SAR)")
30
+ parser.add_argument('--adaface_ckpt_path', type=str, required=True)
 
31
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
32
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
33
  parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
 
39
  parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*",
40
  default=[],
41
  help="Extra paths to the checkpoints of the UNet models")
42
+ parser.add_argument('--unet_weights_in_ensemble', type=float, nargs="+", default=[1],
43
  help="Weights for the UNet models")
44
  parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
45
+ parser.add_argument("--restore_image", type=str, default=None,
46
+ help="Path to the image to be restored")
47
  # If True, the input folder contains images of mixed subjects.
48
  # If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
49
  parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
 
53
  parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
54
  parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
55
  parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
56
+ parser.add_argument("--perturb_std", type=float, default=0)
57
  parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
58
  help="Guidance scale for the diffusion model")
59
  parser.add_argument("--ref_img_strength", type=float, default=0.8,
60
  help="Strength of the reference image in the output image.")
 
 
 
61
  parser.add_argument("--prompt", type=str, default="a person z")
62
  parser.add_argument("--num_images_per_row", type=int, default=4,
63
  help="Number of images to display in a row in the output grid image.")
 
 
64
  parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
65
  parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
66
  parser.add_argument("--seed", type=int, default=42,
 
89
  process_index = 0
90
 
91
  adaface = AdaFaceWrapper("img2img", args.base_model_path,
92
+ args.adaface_encoder_types, args.adaface_ckpt_path,
93
  args.adaface_encoder_cfg_scales, args.enabled_encoders,
 
94
  unet_types=None,
95
+ extra_unet_dirpaths=args.extra_unet_dirpaths,
96
+ unet_weights_in_ensemble=args.unet_weights_in_ensemble,
97
  device=args.device)
98
 
99
  in_folder = args.in_folder
100
  if os.path.isfile(in_folder):
101
+ args.in_folder = os.path.dirname(args.in_folder)
102
  subject_folders = [ os.path.dirname(in_folder) ]
103
  images_by_subject = [[in_folder]]
104
  else:
 
154
  images_by_subject = images_by_subject[process_index::args.num_gpus]
155
  #subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
156
 
157
+ if args.restore_image is not None:
158
+ in_images = []
159
+ for image_path in [args.restore_image]:
160
+ image = Image.open(image_path).convert("RGB").resize((512, 512))
161
+ # [512, 512, 3] -> [3, 512, 512].
162
+ image = np.array(image).transpose(2, 0, 1)
163
+ # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
164
+ image = torch.tensor(image).unsqueeze(0).float().cuda()
165
+ in_images.append(image)
166
+
167
+ # Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
168
+ # NOTE: For simplicity, we do not check overly large batch sizes.
169
+ in_images = torch.cat(in_images, dim=0)
170
+ # in_images: [5, 3, 512, 512].
171
+ # Normalize the pixel values to [0, 1].
172
+ in_images = in_images / 255.0
173
+ num_out_images = len(in_images) * args.out_count_per_input_image
174
+
175
  for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
176
  # If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
177
  # Otherwise, we use the folder name as the signature of the images.
 
191
  os.makedirs(subject_out_folder)
192
  print(f"Output images will be saved to {subject_out_folder}")
193
 
194
+ if args.restore_image is None:
195
+ in_images = []
196
+ for image_path in image_paths:
197
+ image = Image.open(image_path).convert("RGB").resize((512, 512))
198
+ # [512, 512, 3] -> [3, 512, 512].
199
+ image = np.array(image).transpose(2, 0, 1)
200
+ # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
201
+ image = torch.tensor(image).unsqueeze(0).float().cuda()
202
+ in_images.append(image)
203
+
204
+ # Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
205
+ # NOTE: For simplicity, we do not check overly large batch sizes.
206
+ in_images = torch.cat(in_images, dim=0)
207
+ # in_images: [5, 3, 512, 512].
208
+ # Normalize the pixel values to [0, 1].
209
+ in_images = in_images / 255.0
210
+ num_out_images = len(in_images) * args.out_count_per_input_image
211
 
212
  with torch.no_grad():
213
  # args.perturb_std: the *relative* std of the noise added to the face embeddings.
214
  # A noise level of 0.08 could change gender, but 0.06 is usually safe.
215
  # The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
216
  # NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
217
+ out_images = adaface(in_images, args.prompt, None, None,
218
+ 'append', args.guidance_scale, num_out_images,
219
+ ref_img_strength=args.ref_img_strength)
220
 
221
  for img_i, img in enumerate(out_images):
222
  # out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
 
224
  copy_i = img_i // len(in_images)
225
  image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
226
  if copy_i == 0:
227
+ save_path = os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}")
228
  else:
229
+ save_path = os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}")
230
+ img.save(save_path)
231
+ print(f"Saved {save_path}")
232
 
233
  if args.copy_masks:
234
  mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
adaface/adaface_wrapper.py CHANGED
@@ -8,22 +8,29 @@ from diffusers import (
8
  StableDiffusion3Pipeline,
9
  #FluxPipeline,
10
  DDIMScheduler,
 
 
11
  AutoencoderKL,
 
12
  )
13
  from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
14
  from adaface.util import UNetEnsemble
15
  from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
 
16
  from safetensors.torch import load_file as safetensors_load_file
17
  import re, os
18
  import numpy as np
 
19
 
20
  class AdaFaceWrapper(nn.Module):
21
  def __init__(self, pipeline_name, base_model_path, adaface_encoder_types,
22
  adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
23
- enabled_encoders=None,
24
- subject_string='z', num_inference_steps=50, negative_prompt=None,
25
  use_840k_vae=False, use_ds_text_encoder=False,
26
- main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights=None,
 
 
27
  device='cuda', is_training=False):
28
  '''
29
  pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
@@ -38,15 +45,23 @@ class AdaFaceWrapper(nn.Module):
38
  self.adaface_ckpt_paths = adaface_ckpt_paths
39
  self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
40
  self.enabled_encoders = enabled_encoders
 
 
 
 
 
 
41
  self.subject_string = subject_string
 
42
 
43
- self.num_inference_steps = num_inference_steps
 
44
  self.use_840k_vae = use_840k_vae
45
  self.use_ds_text_encoder = use_ds_text_encoder
46
  self.main_unet_filepath = main_unet_filepath
47
  self.unet_types = unet_types
48
  self.extra_unet_dirpaths = extra_unet_dirpaths
49
- self.unet_weights = unet_weights
50
  self.device = device
51
  self.is_training = is_training
52
 
@@ -62,7 +77,14 @@ class AdaFaceWrapper(nn.Module):
62
  self.initialize_pipeline()
63
  # During inference, we never use static image suffix embeddings.
64
  # So num_id_vecs is the length of the returned adaface embeddings for each encoder.
65
- self.encoders_num_id_vecs = self.id2ada_prompt_encoder.encoders_num_id_vecs
 
 
 
 
 
 
 
66
  self.extend_tokenizer_and_text_encoder()
67
 
68
  def to(self, device):
@@ -76,7 +98,8 @@ class AdaFaceWrapper(nn.Module):
76
  self.id2ada_prompt_encoder = create_id2ada_prompt_encoder(self.adaface_encoder_types,
77
  self.adaface_ckpt_paths,
78
  self.adaface_encoder_cfg_scales,
79
- self.enabled_encoders)
 
80
 
81
  self.id2ada_prompt_encoder.to(self.device)
82
  print(f"adaface_encoder_cfg_scales: {self.adaface_encoder_cfg_scales}")
@@ -118,10 +141,10 @@ class AdaFaceWrapper(nn.Module):
118
 
119
  if self.base_model_path is None:
120
  base_model_path_dict = {
121
- 'text2img': 'models/sd15-dste8-vae.safetensors',
122
- 'text2imgxl': 'stabilityai/stable-diffusion-xl-base-1.0',
123
- 'text2img3': 'stabilityai/stable-diffusion-3-medium-diffusers',
124
- 'flux': 'black-forest-labs/FLUX.1-schnell',
125
  }
126
  self.base_model_path = base_model_path_dict[self.pipeline_name]
127
 
@@ -137,6 +160,20 @@ class AdaFaceWrapper(nn.Module):
137
  safety_checker=None
138
  )
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  if self.main_unet_filepath is not None:
141
  print(f"Replacing the UNet with the UNet from {self.main_unet_filepath}.")
142
  ret = pipeline.unet.load_state_dict(self.load_unet_from_file(self.main_unet_filepath, device='cpu'))
@@ -147,12 +184,19 @@ class AdaFaceWrapper(nn.Module):
147
 
148
  if (self.unet_types is not None and len(self.unet_types) > 0) \
149
  or (self.extra_unet_dirpaths is not None and len(self.extra_unet_dirpaths) > 0):
150
- unet_ensemble = UNetEnsemble([pipeline.unet], self.unet_types, self.extra_unet_dirpaths, self.unet_weights,
151
  device=self.device, torch_dtype=torch.float16)
152
  pipeline.unet = unet_ensemble
153
 
154
  print(f"Loaded pipeline from {self.base_model_path}.")
155
-
 
 
 
 
 
 
 
156
  if self.use_840k_vae:
157
  pipeline.vae = vae
158
  print("Replaced the VAE with the 840k-step VAE.")
@@ -167,19 +211,56 @@ class AdaFaceWrapper(nn.Module):
167
  pipeline.vae = None
168
  print("Removed UNet and VAE from the pipeline.")
169
 
170
- if self.pipeline_name not in ["text2imgxl", "text2img3", "flux"]:
171
- noise_scheduler = DDIMScheduler(
172
- num_train_timesteps=1000,
173
- beta_start=0.00085,
174
- beta_end=0.012,
175
- beta_schedule="scaled_linear",
176
- clip_sample=False,
177
- set_alpha_to_one=False,
178
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  pipeline.scheduler = noise_scheduler
180
- # Otherwise, pipeline.scheduler == FlowMatchEulerDiscreteScheduler
 
181
  self.pipeline = pipeline.to(self.device)
182
 
 
 
 
 
183
  def load_unet_from_file(self, unet_path, device=None):
184
  if os.path.isfile(unet_path):
185
  if unet_path.endswith(".safetensors"):
@@ -208,7 +289,109 @@ class AdaFaceWrapper(nn.Module):
208
  else:
209
  raise ValueError(f"UNet path {unet_path} is not a file.")
210
  return unet_state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  def extend_tokenizer_and_text_encoder(self):
213
  if np.sum(self.encoders_num_id_vecs) < 1:
214
  raise ValueError(f"encoders_num_id_vecs has to be larger or equal to 1, but is {self.encoders_num_id_vecs}")
@@ -218,6 +401,7 @@ class AdaFaceWrapper(nn.Module):
218
  # We add z_0_0, z_0_1, z_0_2, ..., z_0_15, z_1_0, z_1_1, z_1_2, z_1_3 to the tokenizer.
219
  self.all_placeholder_tokens = []
220
  self.placeholder_tokens_strs = []
 
221
  for i in range(len(self.adaface_encoder_types)):
222
  placeholder_tokens = []
223
  for j in range(self.encoders_num_id_vecs[i]):
@@ -225,9 +409,11 @@ class AdaFaceWrapper(nn.Module):
225
  placeholder_tokens_str = " ".join(placeholder_tokens)
226
 
227
  self.all_placeholder_tokens.extend(placeholder_tokens)
 
228
  self.placeholder_tokens_strs.append(placeholder_tokens_str)
229
 
230
  self.all_placeholder_tokens_str = " ".join(self.placeholder_tokens_strs)
 
231
  # all_null_placeholder_tokens_str: ", , , , ..." (20 times).
232
  # It just contains the commas and spaces with the same length, but no actual tokens.
233
  self.all_null_placeholder_tokens_str = " ".join([", "] * len(self.all_placeholder_tokens))
@@ -241,7 +427,7 @@ class AdaFaceWrapper(nn.Module):
241
 
242
  print(f"Added {num_added_tokens} tokens ({self.all_placeholder_tokens_str}) to the tokenizer.")
243
 
244
- # placeholder_token_ids: [49408, ..., 49423].
245
  self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.all_placeholder_tokens)
246
  #print("New tokens:", self.placeholder_token_ids)
247
  # Resize the token embeddings as we are adding new special tokens to the tokenizer
@@ -252,24 +438,49 @@ class AdaFaceWrapper(nn.Module):
252
 
253
  # Extend pipeline.text_encoder with the adaface subject emeddings.
254
  # subj_embs: [16, 768].
255
- def update_text_encoder_subj_embeddings(self, subj_embs):
256
  # Initialise the newly added placeholder token with the embeddings of the initializer token
257
  # token_embeds: [49412, 768]
258
  token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
 
 
 
 
259
  with torch.no_grad():
260
- for i, token_id in enumerate(self.placeholder_token_ids):
261
- token_embeds[token_id] = subj_embs[i]
262
- print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.all_placeholder_tokens_str}) in the text encoder.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  def update_prompt(self, prompt, placeholder_tokens_pos='append',
 
265
  use_null_placeholders=False):
266
  if prompt is None:
267
  prompt = ""
268
 
269
  if use_null_placeholders:
270
  all_placeholder_tokens_str = self.all_null_placeholder_tokens_str
 
 
 
271
  else:
272
- all_placeholder_tokens_str = self.all_placeholder_tokens_str
273
 
274
  # Delete the subject_string from the prompt.
275
  prompt = re.sub(r'\b(a|an|the)\s+' + self.subject_string + r'\b,?', "", prompt)
@@ -279,15 +490,29 @@ class AdaFaceWrapper(nn.Module):
279
  # When we do joint training, seems both work better if they are appended to the prompt.
280
  # Therefore we simply appended all placeholder_tokens_str's to the prompt.
281
  # NOTE: Prepending them hurts compositional prompts.
282
- if placeholder_tokens_pos == 'prepend':
283
- prompt = all_placeholder_tokens_str + " " + prompt
284
- elif placeholder_tokens_pos == 'append':
285
- prompt = prompt + " " + all_placeholder_tokens_str
 
 
 
 
 
 
 
286
  else:
287
- breakpoint()
 
 
 
 
 
288
 
289
  return prompt
290
 
 
 
291
  # If face_id_embs is None, then it extracts face_id_embs from the images,
292
  # then map them to ada prompt embeddings.
293
  # avg_at_stage: 'id_emb', 'img_prompt_emb', or None.
@@ -298,27 +523,29 @@ class AdaFaceWrapper(nn.Module):
298
  perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
299
  perturb_std=0, update_text_encoder=True):
300
 
301
- all_adaface_subj_embs = \
302
  self.id2ada_prompt_encoder.generate_adaface_embeddings(\
303
  image_paths, face_id_embs=face_id_embs,
304
  img_prompt_embs=None,
305
  avg_at_stage=avg_at_stage,
306
  perturb_at_stage=perturb_at_stage,
307
  perturb_std=perturb_std,
308
- enable_static_img_suffix_embs=False)
309
 
310
  if all_adaface_subj_embs is None:
311
  return None
312
 
 
 
313
  if all_adaface_subj_embs.ndim == 4:
314
- # [1, 1, 16, 768] -> [16, 768]
315
  all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0).squeeze(0)
316
  elif all_adaface_subj_embs.ndim == 3:
317
- # [1, 16, 768] -> [16, 768]
318
  all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0)
319
 
320
  if update_text_encoder:
321
- self.update_text_encoder_subj_embeddings(all_adaface_subj_embs)
322
  return all_adaface_subj_embs
323
 
324
  def diffusers_encode_prompts(self, prompt, plain_prompt, negative_prompt, device):
@@ -368,6 +595,7 @@ class AdaFaceWrapper(nn.Module):
368
  else:
369
  breakpoint()
370
  else:
 
371
  # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
372
  prompt_embeds_, negative_prompt_embeds_ = \
373
  self.pipeline.encode_prompt(prompt, device=device,
@@ -378,9 +606,53 @@ class AdaFaceWrapper(nn.Module):
378
  return prompt_embeds_, negative_prompt_embeds_, \
379
  pooled_prompt_embeds_, negative_pooled_prompt_embeds_
380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  def encode_prompt(self, prompt, negative_prompt=None,
382
  placeholder_tokens_pos='append',
383
- do_neg_id_prompt_weight=0,
 
 
 
 
384
  device=None, verbose=False):
385
  if negative_prompt is None:
386
  negative_prompt = self.negative_prompt
@@ -389,59 +661,81 @@ class AdaFaceWrapper(nn.Module):
389
  device = self.device
390
 
391
  plain_prompt = prompt
392
- prompt = self.update_prompt(prompt, placeholder_tokens_pos=placeholder_tokens_pos)
 
 
 
 
 
 
393
  if verbose:
394
  print(f"Subject prompt:\n{prompt}")
395
 
396
- if do_neg_id_prompt_weight > 0:
397
- # Use 'prepend' for the negative prompt, since it's long and we want to make sure
398
- # the placeholder tokens are not cut off.
399
- negative_prompt0 = negative_prompt
400
- negative_prompt = self.update_prompt(negative_prompt0, placeholder_tokens_pos='prepend')
401
- null_negative_prompt = self.update_prompt(negative_prompt0, placeholder_tokens_pos='prepend',
402
- use_null_placeholders=True)
403
- ''' if verbose:
404
- print(f"Negative prompt:\n{negative_prompt}")
405
- print(f"Null negative prompt:\n{null_negative_prompt}")
406
-
407
- '''
408
- else:
409
- null_negative_prompt = None
410
-
411
  # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
412
  # So we manually move it to GPU here.
413
  self.pipeline.text_encoder.to(device)
414
 
415
  prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
416
  self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
417
-
418
- if 0 < do_neg_id_prompt_weight < 1:
419
- _, negative_prompt_embeds_null, _, _ = \
420
- self.diffusers_encode_prompts(prompt, plain_prompt, null_negative_prompt, device)
421
- negative_prompt_embeds_ = negative_prompt_embeds_ * do_neg_id_prompt_weight + \
422
- negative_prompt_embeds_null * (1 - do_neg_id_prompt_weight)
423
-
 
 
 
 
 
 
 
424
  return prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_
425
 
426
  # ref_img_strength is used only in the img2img pipeline.
427
- def forward(self, noise, prompt, negative_prompt=None,
428
  placeholder_tokens_pos='append',
429
- do_neg_id_prompt_weight=0,
430
  guidance_scale=6.0, out_image_count=4,
431
- ref_img_strength=0.8, generator=None, verbose=False):
 
 
 
 
 
 
432
  noise = noise.to(device=self.device, dtype=torch.float16)
 
 
433
 
434
  if negative_prompt is None:
435
  negative_prompt = self.negative_prompt
436
  # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
437
- prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \
438
- negative_pooled_prompt_embeds_ = \
439
- self.encode_prompt(prompt, negative_prompt,
440
- placeholder_tokens_pos=placeholder_tokens_pos,
441
- do_neg_id_prompt_weight=do_neg_id_prompt_weight,
442
- device=self.device, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  # Repeat the prompt embeddings for all images in the batch.
444
  prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
 
445
  if negative_prompt_embeds_ is not None:
446
  negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
447
 
 
8
  StableDiffusion3Pipeline,
9
  #FluxPipeline,
10
  DDIMScheduler,
11
+ PNDMScheduler,
12
+ DPMSolverSinglestepScheduler,
13
  AutoencoderKL,
14
+ LCMScheduler,
15
  )
16
  from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
17
  from adaface.util import UNetEnsemble
18
  from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
19
+ from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
20
  from safetensors.torch import load_file as safetensors_load_file
21
  import re, os
22
  import numpy as np
23
+ from peft.utils.constants import DUMMY_TARGET_MODULES
24
 
25
  class AdaFaceWrapper(nn.Module):
26
  def __init__(self, pipeline_name, base_model_path, adaface_encoder_types,
27
  adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
28
+ enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
29
+ num_inference_steps=50, subject_string='z', negative_prompt=None,
30
  use_840k_vae=False, use_ds_text_encoder=False,
31
+ main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
32
+ enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
33
+ attn_lora_layer_names=['q', 'k', 'v', 'out'], shrink_cross_attn=False, q_lora_updates_query=False,
34
  device='cuda', is_training=False):
35
  '''
36
  pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
 
45
  self.adaface_ckpt_paths = adaface_ckpt_paths
46
  self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
47
  self.enabled_encoders = enabled_encoders
48
+ # None, or a list of two bools for two encoders. If None, both are disabled.
49
+ self.enable_static_img_suffix_embs = enable_static_img_suffix_embs
50
+ self.unet_uses_attn_lora = unet_uses_attn_lora
51
+ self.attn_lora_layer_names = attn_lora_layer_names
52
+ self.q_lora_updates_query = q_lora_updates_query
53
+ self.use_lcm = use_lcm
54
  self.subject_string = subject_string
55
+ self.shrink_cross_attn = shrink_cross_attn
56
 
57
+ self.default_scheduler_name = default_scheduler_name
58
+ self.num_inference_steps = num_inference_steps if not use_lcm else 4
59
  self.use_840k_vae = use_840k_vae
60
  self.use_ds_text_encoder = use_ds_text_encoder
61
  self.main_unet_filepath = main_unet_filepath
62
  self.unet_types = unet_types
63
  self.extra_unet_dirpaths = extra_unet_dirpaths
64
+ self.unet_weights_in_ensemble = unet_weights_in_ensemble
65
  self.device = device
66
  self.is_training = is_training
67
 
 
77
  self.initialize_pipeline()
78
  # During inference, we never use static image suffix embeddings.
79
  # So num_id_vecs is the length of the returned adaface embeddings for each encoder.
80
+ self.encoders_num_id_vecs = np.array(self.id2ada_prompt_encoder.encoders_num_id_vecs)
81
+ self.encoders_num_static_img_suffix_embs = np.array(self.id2ada_prompt_encoder.encoders_num_static_img_suffix_embs)
82
+ if self.enable_static_img_suffix_embs is not None:
83
+ assert len(self.enable_static_img_suffix_embs) == len(self.encoders_num_id_vecs)
84
+ self.encoders_num_static_img_suffix_embs *= np.array(self.enable_static_img_suffix_embs)
85
+ self.encoders_num_id_vecs += self.encoders_num_static_img_suffix_embs
86
+
87
+ self.img_prompt_embs = None
88
  self.extend_tokenizer_and_text_encoder()
89
 
90
  def to(self, device):
 
98
  self.id2ada_prompt_encoder = create_id2ada_prompt_encoder(self.adaface_encoder_types,
99
  self.adaface_ckpt_paths,
100
  self.adaface_encoder_cfg_scales,
101
+ self.enabled_encoders,
102
+ num_static_img_suffix_embs=4)
103
 
104
  self.id2ada_prompt_encoder.to(self.device)
105
  print(f"adaface_encoder_cfg_scales: {self.adaface_encoder_cfg_scales}")
 
141
 
142
  if self.base_model_path is None:
143
  base_model_path_dict = {
144
+ 'text2img': 'models/sd15-dste8-vae.safetensors',
145
+ 'text2imgxl': 'stabilityai/stable-diffusion-xl-base-1.0',
146
+ 'text2img3': 'stabilityai/stable-diffusion-3-medium-diffusers',
147
+ 'flux': 'black-forest-labs/FLUX.1-schnell',
148
  }
149
  self.base_model_path = base_model_path_dict[self.pipeline_name]
150
 
 
160
  safety_checker=None
161
  )
162
 
163
+ if self.use_lcm:
164
+ lcm_path_dict = {
165
+ 'text2img': 'latent-consistency/lcm-lora-sdv1-5',
166
+ 'text2imgxl': 'latent-consistency/lcm-lora-sdxl',
167
+ }
168
+ if self.pipeline_name not in lcm_path_dict:
169
+ raise ValueError(f"Pipeline {self.pipeline_name} does not support LCM.")
170
+
171
+ lcm_path = lcm_path_dict[self.pipeline_name]
172
+ pipeline.load_lora_weights(lcm_path)
173
+ pipeline.fuse_lora()
174
+ print(f"Loaded LCM weights from {lcm_path}.")
175
+ pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
176
+
177
  if self.main_unet_filepath is not None:
178
  print(f"Replacing the UNet with the UNet from {self.main_unet_filepath}.")
179
  ret = pipeline.unet.load_state_dict(self.load_unet_from_file(self.main_unet_filepath, device='cpu'))
 
184
 
185
  if (self.unet_types is not None and len(self.unet_types) > 0) \
186
  or (self.extra_unet_dirpaths is not None and len(self.extra_unet_dirpaths) > 0):
187
+ unet_ensemble = UNetEnsemble([pipeline.unet], self.unet_types, self.extra_unet_dirpaths, self.unet_weights_in_ensemble,
188
  device=self.device, torch_dtype=torch.float16)
189
  pipeline.unet = unet_ensemble
190
 
191
  print(f"Loaded pipeline from {self.base_model_path}.")
192
+ if not remove_unet and (self.unet_uses_attn_lora or self.shrink_cross_attn):
193
+ unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
194
+ attn_lora_layer_names=self.attn_lora_layer_names,
195
+ shrink_cross_attn=self.shrink_cross_attn,
196
+ q_lora_updates_query=self.q_lora_updates_query)
197
+
198
+ pipeline.unet = unet2
199
+
200
  if self.use_840k_vae:
201
  pipeline.vae = vae
202
  print("Replaced the VAE with the 840k-step VAE.")
 
211
  pipeline.vae = None
212
  print("Removed UNet and VAE from the pipeline.")
213
 
214
+ if self.pipeline_name not in ["text2imgxl", "text2img3", "flux"] and not self.use_lcm:
215
+ if self.default_scheduler_name == 'ddim':
216
+ noise_scheduler = DDIMScheduler(
217
+ num_train_timesteps=1000,
218
+ beta_start=0.00085,
219
+ beta_end=0.012,
220
+ beta_schedule="scaled_linear",
221
+ clip_sample=False,
222
+ set_alpha_to_one=False,
223
+ steps_offset=1,
224
+ timestep_spacing="leading",
225
+ rescale_betas_zero_snr=False,
226
+ )
227
+ elif self.default_scheduler_name == 'pndm':
228
+ noise_scheduler = PNDMScheduler(
229
+ num_train_timesteps=1000,
230
+ beta_start=0.00085,
231
+ beta_end=0.012,
232
+ beta_schedule="scaled_linear",
233
+ set_alpha_to_one=False,
234
+ steps_offset=1,
235
+ timestep_spacing="leading",
236
+ skip_prk_steps=True,
237
+ )
238
+ elif self.default_scheduler_name == 'dpm++':
239
+ noise_scheduler = DPMSolverSinglestepScheduler(
240
+ beta_start=0.00085,
241
+ beta_end=0.012,
242
+ beta_schedule="scaled_linear",
243
+ prediction_type="epsilon",
244
+ num_train_timesteps=1000,
245
+ trained_betas=None,
246
+ thresholding=False,
247
+ algorithm_type="dpmsolver++",
248
+ solver_type="midpoint",
249
+ lower_order_final=True,
250
+ use_karras_sigmas=True,
251
+ )
252
+ else:
253
+ breakpoint()
254
+
255
  pipeline.scheduler = noise_scheduler
256
+ # Otherwise, if not use_lcm, pipeline.scheduler == FlowMatchEulerDiscreteScheduler
257
+ # if use_lcm, pipeline.scheduler == LCMScheduler
258
  self.pipeline = pipeline.to(self.device)
259
 
260
+ def set_adaface_encoder_cfg_scales(self, adaface_encoder_cfg_scales):
261
+ self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
262
+ self.id2ada_prompt_encoder.set_out_id_embs_cfg_scale(adaface_encoder_cfg_scales)
263
+
264
  def load_unet_from_file(self, unet_path, device=None):
265
  if os.path.isfile(unet_path):
266
  if unet_path.endswith(".safetensors"):
 
289
  else:
290
  raise ValueError(f"UNet path {unet_path} is not a file.")
291
  return unet_state_dict
292
+
293
+ # Adapted from ConsistentIDPipeline:set_ip_adapter().
294
+ def load_unet_loras(self, unet, unet_lora_modules_state_dict,
295
+ use_attn_lora=True, use_ffn_lora=False,
296
+ attn_lora_layer_names=['q', 'k', 'v', 'out'],
297
+ shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
298
+ q_lora_updates_query=False):
299
+ attn_capture_procs, attn_opt_modules = \
300
+ set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
301
+ lora_rank=192, lora_scale_down=8,
302
+ cross_attn_shrink_factor=cross_attn_shrink_factor,
303
+ q_lora_updates_query=q_lora_updates_query)
304
+ # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
305
+ if use_ffn_lora:
306
+ target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
307
+ else:
308
+ # A special pattern, "dummy-target-modules" tells PEFT to add loras on NONE of the layers.
309
+ # We couldn't simply skip PEFT initialization (converting unet to a PEFT model),
310
+ # otherwise the attn lora layers will cause nan quickly during a fp16 training.
311
+ target_modules_pat = DUMMY_TARGET_MODULES
312
+
313
+ unet, ffn_lora_layers, ffn_opt_modules = \
314
+ set_up_ffn_loras(unet, target_modules_pat=target_modules_pat, lora_uses_dora=True)
315
+
316
+ # self.attn_capture_procs and ffn_lora_layers will be used in set_lora_and_capture_flags().
317
+ self.attn_capture_procs = list(attn_capture_procs.values())
318
+ self.ffn_lora_layers = list(ffn_lora_layers.values())
319
+ # Combine attn_opt_modules and ffn_opt_modules into unet_lora_modules.
320
+ # unet_lora_modules is for optimization and loading/saving.
321
+ unet_lora_modules = {}
322
+ # attn_opt_modules and ffn_opt_modules have different depths of keys.
323
+ # attn_opt_modules:
324
+ # up_blocks_3_attentions_1_transformer_blocks_0_attn2_processor_std_shrink_factor,
325
+ # up_blocks_3_attentions_1_transformer_blocks_0_attn2_processor_to_q_lora_lora_A, ...
326
+ # ffn_opt_modules:
327
+ # base_model_model_up_blocks_3_resnets_1_conv1_lora_A, ...
328
+ # with the prefix 'base_model_model_'. Because ffn_opt_modules are extracted from the peft-wrapped model,
329
+ # and attn_opt_modules are extracted from the original unet model.
330
+ # To be compatible with old param keys, we append 'base_model_model_' to the keys of attn_opt_modules.
331
+ unet_lora_modules.update({ f'base_model_model_{k}': v for k, v in attn_opt_modules.items() })
332
+ unet_lora_modules.update(ffn_opt_modules)
333
+ # ParameterDict can contain both Parameter and nn.Module.
334
+ # TODO: maybe in the future, we couldn't put nn.Module in nn.ParameterDict.
335
+ self.unet_lora_modules = torch.nn.ParameterDict(unet_lora_modules)
336
+
337
+ missing, unexpected = self.unet_lora_modules.load_state_dict(unet_lora_modules_state_dict, strict=False)
338
+ if len(missing) > 0:
339
+ print(f"Missing Keys: {missing}")
340
+ if len(unexpected) > 0:
341
+ print(f"Unexpected Keys: {unexpected}")
342
+
343
+ print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
344
+ self.outfeat_capture_blocks.append(unet.up_blocks[3])
345
+
346
+ # If shrink_cross_attn is True and use_attn_lora is False, we load all these params from ckpt,
347
+ # but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
348
+ set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
349
+ use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
350
+ shrink_cross_attn=shrink_cross_attn)
351
+
352
+ return unet
353
+
354
+ def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
355
+ shrink_cross_attn=False, q_lora_updates_query=False):
356
+ unet_lora_weight_found = False
357
+ if isinstance(self.adaface_ckpt_paths, str):
358
+ adaface_ckpt_paths = [self.adaface_ckpt_paths]
359
+ else:
360
+ adaface_ckpt_paths = self.adaface_ckpt_paths
361
+
362
+ for adaface_ckpt_path in adaface_ckpt_paths:
363
+ ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu')
364
+ if 'unet_lora_modules' in ckpt_dict:
365
+ unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
366
+ print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
367
+ unet_lora_weight_found = True
368
+ break
369
+
370
+ # Since unet lora weights are not found in the adaface ckpt, we give up on loading unet attn processors.
371
+ if not unet_lora_weight_found:
372
+ print(f"LoRA weights not found in {self.adaface_ckpt_paths}.")
373
+ return unet
374
 
375
+ self.outfeat_capture_blocks = []
376
+
377
+ if isinstance(unet, UNetEnsemble):
378
+ for i, unet_ in enumerate(unet.unets):
379
+ unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
380
+ use_attn_lora=use_attn_lora,
381
+ attn_lora_layer_names=attn_lora_layer_names,
382
+ shrink_cross_attn=shrink_cross_attn,
383
+ q_lora_updates_query=q_lora_updates_query)
384
+ unet.unets[i] = unet_
385
+ print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
386
+ else:
387
+ unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
388
+ use_attn_lora=use_attn_lora,
389
+ attn_lora_layer_names=attn_lora_layer_names,
390
+ shrink_cross_attn=shrink_cross_attn,
391
+ q_lora_updates_query=q_lora_updates_query)
392
+
393
+ return unet
394
+
395
  def extend_tokenizer_and_text_encoder(self):
396
  if np.sum(self.encoders_num_id_vecs) < 1:
397
  raise ValueError(f"encoders_num_id_vecs has to be larger or equal to 1, but is {self.encoders_num_id_vecs}")
 
401
  # We add z_0_0, z_0_1, z_0_2, ..., z_0_15, z_1_0, z_1_1, z_1_2, z_1_3 to the tokenizer.
402
  self.all_placeholder_tokens = []
403
  self.placeholder_tokens_strs = []
404
+ self.encoder_placeholder_tokens = []
405
  for i in range(len(self.adaface_encoder_types)):
406
  placeholder_tokens = []
407
  for j in range(self.encoders_num_id_vecs[i]):
 
409
  placeholder_tokens_str = " ".join(placeholder_tokens)
410
 
411
  self.all_placeholder_tokens.extend(placeholder_tokens)
412
+ self.encoder_placeholder_tokens.append(placeholder_tokens)
413
  self.placeholder_tokens_strs.append(placeholder_tokens_str)
414
 
415
  self.all_placeholder_tokens_str = " ".join(self.placeholder_tokens_strs)
416
+ self.updated_tokens_str = self.all_placeholder_tokens_str
417
  # all_null_placeholder_tokens_str: ", , , , ..." (20 times).
418
  # It just contains the commas and spaces with the same length, but no actual tokens.
419
  self.all_null_placeholder_tokens_str = " ".join([", "] * len(self.all_placeholder_tokens))
 
427
 
428
  print(f"Added {num_added_tokens} tokens ({self.all_placeholder_tokens_str}) to the tokenizer.")
429
 
430
+ # placeholder_token_ids: [49408, ..., 49427].
431
  self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.all_placeholder_tokens)
432
  #print("New tokens:", self.placeholder_token_ids)
433
  # Resize the token embeddings as we are adding new special tokens to the tokenizer
 
438
 
439
  # Extend pipeline.text_encoder with the adaface subject emeddings.
440
  # subj_embs: [16, 768].
441
+ def update_text_encoder_subj_embeddings(self, subj_embs, lens_subj_emb_segments):
442
  # Initialise the newly added placeholder token with the embeddings of the initializer token
443
  # token_embeds: [49412, 768]
444
  token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
445
+ all_encoders_updated_tokens = []
446
+ all_encoders_updated_token_strs = []
447
+ idx = 0
448
+
449
  with torch.no_grad():
450
+ # sum of lens_subj_emb_segments are probably shorter than self.placeholder_token_ids,
451
+ # when some static_img_suffix_embs are disabled.
452
+ for i, encoder_type in enumerate(self.adaface_encoder_types):
453
+ encoder_updated_tokens = []
454
+ if (self.enabled_encoders is not None) and (encoder_type not in self.enabled_encoders):
455
+ idx += lens_subj_emb_segments[i]
456
+ continue
457
+ for j in range(lens_subj_emb_segments[i]):
458
+ placeholder_token = f"{self.subject_string}_{i}_{j}"
459
+ token_id = self.pipeline.tokenizer.convert_tokens_to_ids(placeholder_token)
460
+ token_embeds[token_id] = subj_embs[idx]
461
+ encoder_updated_tokens.append(placeholder_token)
462
+ idx += 1
463
+
464
+ all_encoders_updated_tokens.extend(encoder_updated_tokens)
465
+ all_encoders_updated_token_strs.append(" ".join(encoder_updated_tokens))
466
+
467
+ self.updated_tokens_str = " ".join(all_encoders_updated_token_strs)
468
+ self.all_encoders_updated_token_strs = all_encoders_updated_token_strs
469
+ print(f"Updated {len(all_encoders_updated_tokens)} tokens ({self.updated_tokens_str}) in the text encoder.")
470
 
471
  def update_prompt(self, prompt, placeholder_tokens_pos='append',
472
+ repeat_prompt_for_each_encoder=True,
473
  use_null_placeholders=False):
474
  if prompt is None:
475
  prompt = ""
476
 
477
  if use_null_placeholders:
478
  all_placeholder_tokens_str = self.all_null_placeholder_tokens_str
479
+ if not re.search(r"\b(man|woman|person|child|girl|boy)\b", prompt.lower()):
480
+ all_placeholder_tokens_str = "person " + all_placeholder_tokens_str
481
+ repeat_prompt_for_each_encoder = False
482
  else:
483
+ all_placeholder_tokens_str = self.updated_tokens_str
484
 
485
  # Delete the subject_string from the prompt.
486
  prompt = re.sub(r'\b(a|an|the)\s+' + self.subject_string + r'\b,?', "", prompt)
 
490
  # When we do joint training, seems both work better if they are appended to the prompt.
491
  # Therefore we simply appended all placeholder_tokens_str's to the prompt.
492
  # NOTE: Prepending them hurts compositional prompts.
493
+ if repeat_prompt_for_each_encoder:
494
+ encoder_prompts = []
495
+ for encoder_updated_token_strs in self.all_encoders_updated_token_strs:
496
+ if placeholder_tokens_pos == 'prepend':
497
+ encoder_prompt = encoder_updated_token_strs + " " + prompt
498
+ elif placeholder_tokens_pos == 'append':
499
+ encoder_prompt = prompt + " " + encoder_updated_token_strs
500
+ else:
501
+ breakpoint()
502
+ encoder_prompts.append(encoder_prompt)
503
+ prompt = ", ".join(encoder_prompts)
504
  else:
505
+ if placeholder_tokens_pos == 'prepend':
506
+ prompt = all_placeholder_tokens_str + " " + prompt
507
+ elif placeholder_tokens_pos == 'append':
508
+ prompt = prompt + " " + all_placeholder_tokens_str
509
+ else:
510
+ breakpoint()
511
 
512
  return prompt
513
 
514
+ # NOTE: all_adaface_subj_embs is the input to the CLIP text encoder.
515
+ # ** DO NOT use it as prompt_embeds in the forward() method.
516
  # If face_id_embs is None, then it extracts face_id_embs from the images,
517
  # then map them to ada prompt embeddings.
518
  # avg_at_stage: 'id_emb', 'img_prompt_emb', or None.
 
523
  perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
524
  perturb_std=0, update_text_encoder=True):
525
 
526
+ all_adaface_subj_embs, img_prompt_embs, lens_subj_emb_segments = \
527
  self.id2ada_prompt_encoder.generate_adaface_embeddings(\
528
  image_paths, face_id_embs=face_id_embs,
529
  img_prompt_embs=None,
530
  avg_at_stage=avg_at_stage,
531
  perturb_at_stage=perturb_at_stage,
532
  perturb_std=perturb_std,
533
+ enable_static_img_suffix_embs=self.enable_static_img_suffix_embs)
534
 
535
  if all_adaface_subj_embs is None:
536
  return None
537
 
538
+ self.img_prompt_embs = img_prompt_embs
539
+
540
  if all_adaface_subj_embs.ndim == 4:
541
+ # [1, 1, 20, 768] -> [20, 768]
542
  all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0).squeeze(0)
543
  elif all_adaface_subj_embs.ndim == 3:
544
+ # [1, 20, 768] -> [20, 768]
545
  all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0)
546
 
547
  if update_text_encoder:
548
+ self.update_text_encoder_subj_embeddings(all_adaface_subj_embs, lens_subj_emb_segments)
549
  return all_adaface_subj_embs
550
 
551
  def diffusers_encode_prompts(self, prompt, plain_prompt, negative_prompt, device):
 
595
  else:
596
  breakpoint()
597
  else:
598
+ # "text2img" and "img2img" pipelines.
599
  # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
600
  prompt_embeds_, negative_prompt_embeds_ = \
601
  self.pipeline.encode_prompt(prompt, device=device,
 
606
  return prompt_embeds_, negative_prompt_embeds_, \
607
  pooled_prompt_embeds_, negative_pooled_prompt_embeds_
608
 
609
+ # alt_prompt_embed_type: 'ada-nonmix', 'img'
610
+ def mix_ada_embs_with_other_embs(self, prompt, prompt_embeds,
611
+ alt_prompt_embed_type, alt_prompt_emb_weights):
612
+ # Scan prompt and replace tokens in self.placeholder_token_ids
613
+ # with the corresponding image embeddings.
614
+ prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
615
+ prompt_embeds2 = prompt_embeds.clone()
616
+ if alt_prompt_embed_type == 'img':
617
+ if self.img_prompt_embs is None:
618
+ print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
619
+ return prompt_embeds
620
+ # self.img_prompt_embs: [1, 20, 768]
621
+ repl_embeddings = self.img_prompt_embs
622
+ elif alt_prompt_embed_type == 'ada-nonmix':
623
+ repl_embeddings_, _, _, _ = self.encode_prompt(prompt, ablate_prompt_only_placeholders=True,
624
+ verbose=True)
625
+ # repl_embeddings_: [1, 77, 768] -> [1, 20, 768]
626
+ repl_embeddings = repl_embeddings_[:, 1:len(self.all_placeholder_tokens)+1]
627
+ else:
628
+ breakpoint()
629
+
630
+ repl_tokens = {}
631
+ for i in range(len(prompt_tokens)):
632
+ if prompt_tokens[i] in self.all_placeholder_tokens:
633
+ encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
634
+ if prompt_tokens[i] in sublist), 0)
635
+ alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx]
636
+ prompt_embeds2[:, i] = prompt_embeds2[:, i] * (1 - alt_prompt_emb_weight) \
637
+ + repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
638
+ repl_tokens[prompt_tokens[i]] = 1
639
+
640
+ repl_token_count = len(repl_tokens)
641
+ if np.all(np.array(alt_prompt_emb_weights) == 1):
642
+ print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
643
+ else:
644
+ print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
645
+
646
+ return prompt_embeds2
647
+
648
+
649
  def encode_prompt(self, prompt, negative_prompt=None,
650
  placeholder_tokens_pos='append',
651
+ ablate_prompt_only_placeholders=False,
652
+ ablate_prompt_no_placeholders=False,
653
+ ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
654
+ nonmix_prompt_emb_weight=0,
655
+ repeat_prompt_for_each_encoder=True,
656
  device=None, verbose=False):
657
  if negative_prompt is None:
658
  negative_prompt = self.negative_prompt
 
661
  device = self.device
662
 
663
  plain_prompt = prompt
664
+ if ablate_prompt_only_placeholders:
665
+ prompt = self.updated_tokens_str
666
+ else:
667
+ prompt = self.update_prompt(prompt, placeholder_tokens_pos=placeholder_tokens_pos,
668
+ repeat_prompt_for_each_encoder=repeat_prompt_for_each_encoder,
669
+ use_null_placeholders=ablate_prompt_no_placeholders)
670
+
671
  if verbose:
672
  print(f"Subject prompt:\n{prompt}")
673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
  # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
675
  # So we manually move it to GPU here.
676
  self.pipeline.text_encoder.to(device)
677
 
678
  prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
679
  self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
680
+
681
+ if ablate_prompt_embed_type != 'ada':
682
+ alt_prompt_embed_type = ablate_prompt_embed_type
683
+ alt_prompt_emb_weights = (1, 1)
684
+ elif nonmix_prompt_emb_weight > 0:
685
+ alt_prompt_embed_type = 'ada-nonmix'
686
+ alt_prompt_emb_weights = (nonmix_prompt_emb_weight, nonmix_prompt_emb_weight)
687
+ else:
688
+ alt_prompt_emb_weights = (0, 0)
689
+
690
+ if sum(alt_prompt_emb_weights) > 0:
691
+ prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
692
+ alt_prompt_embed_type, alt_prompt_emb_weights)
693
+
694
  return prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_
695
 
696
  # ref_img_strength is used only in the img2img pipeline.
697
+ def forward(self, noise, prompt, prompt_embeds=None, negative_prompt=None,
698
  placeholder_tokens_pos='append',
 
699
  guidance_scale=6.0, out_image_count=4,
700
+ ref_img_strength=0.8, generator=None,
701
+ ablate_prompt_only_placeholders=False,
702
+ ablate_prompt_no_placeholders=False,
703
+ ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
704
+ nonmix_prompt_emb_weight=0,
705
+ repeat_prompt_for_each_encoder=True,
706
+ verbose=False):
707
  noise = noise.to(device=self.device, dtype=torch.float16)
708
+ if self.use_lcm:
709
+ guidance_scale = 0
710
 
711
  if negative_prompt is None:
712
  negative_prompt = self.negative_prompt
713
  # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
714
+ if prompt_embeds is None:
715
+ prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \
716
+ negative_pooled_prompt_embeds_ = \
717
+ self.encode_prompt(prompt, negative_prompt,
718
+ placeholder_tokens_pos=placeholder_tokens_pos,
719
+ ablate_prompt_only_placeholders=ablate_prompt_only_placeholders,
720
+ ablate_prompt_no_placeholders=ablate_prompt_no_placeholders,
721
+ ablate_prompt_embed_type=ablate_prompt_embed_type,
722
+ nonmix_prompt_emb_weight=nonmix_prompt_emb_weight,
723
+ repeat_prompt_for_each_encoder=repeat_prompt_for_each_encoder,
724
+ device=self.device,
725
+ verbose=verbose)
726
+ else:
727
+ if len(prompt_embeds) == 2:
728
+ prompt_embeds_, negative_prompt_embeds_ = prompt_embeds
729
+ pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = None, None
730
+ elif len(prompt_embeds) == 4:
731
+ prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \
732
+ negative_pooled_prompt_embeds_ = prompt_embeds
733
+ else:
734
+ breakpoint()
735
+
736
  # Repeat the prompt embeddings for all images in the batch.
737
  prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
738
+
739
  if negative_prompt_embeds_ is not None:
740
  negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
741
 
adaface/diffusers_attn_lora_capture.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple, Dict, Any
5
+ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
6
+ from diffusers.utils import logging, is_torch_version, deprecate
7
+ from diffusers.utils.torch_utils import fourier_filter
8
+ # UNet is a diffusers PeftAdapterMixin instance.
9
+ from diffusers.loaders.peft import PeftAdapterMixin
10
+ from peft import LoraConfig, get_peft_model
11
+ import peft.tuners.lora as peft_lora
12
+ from peft.tuners.lora.dora import DoraLinearLayer
13
+ from einops import rearrange
14
+ import math, re
15
+ import numpy as np
16
+ from peft.tuners.tuners_utils import BaseTunerLayer
17
+
18
+
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
+
21
+ def dummy_func(*args, **kwargs):
22
+ pass
23
+
24
+ # Revised from RevGrad, by removing the grad negation.
25
+ class ScaleGrad(torch.autograd.Function):
26
+ @staticmethod
27
+ def forward(ctx, input_, alpha_, debug=False):
28
+ ctx.save_for_backward(alpha_, debug)
29
+ output = input_
30
+ if debug:
31
+ print(f"input: {input_.abs().mean().item()}")
32
+ return output
33
+
34
+ @staticmethod
35
+ def backward(ctx, grad_output): # pragma: no cover
36
+ # saved_tensors returns a tuple of tensors.
37
+ alpha_, debug = ctx.saved_tensors
38
+ if ctx.needs_input_grad[0]:
39
+ grad_output2 = grad_output * alpha_
40
+ if debug:
41
+ print(f"grad_output2: {grad_output2.abs().mean().item()}")
42
+ else:
43
+ grad_output2 = None
44
+ return grad_output2, None, None
45
+
46
+ class GradientScaler(nn.Module):
47
+ def __init__(self, alpha=1., debug=False, *args, **kwargs):
48
+ """
49
+ A gradient scaling layer.
50
+ This layer has no parameters, and simply scales the gradient in the backward pass.
51
+ """
52
+ super().__init__(*args, **kwargs)
53
+
54
+ self._alpha = torch.tensor(alpha, requires_grad=False)
55
+ self._debug = torch.tensor(debug, requires_grad=False)
56
+
57
+ def forward(self, input_):
58
+ _debug = self._debug if hasattr(self, '_debug') else False
59
+ return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
60
+
61
+ def gen_gradient_scaler(alpha, debug=False):
62
+ if alpha == 1:
63
+ return nn.Identity()
64
+ if alpha > 0:
65
+ return GradientScaler(alpha, debug=debug)
66
+ else:
67
+ assert alpha == 0
68
+ # Don't use lambda function here, otherwise the object can't be pickled.
69
+ return torch.detach
70
+
71
+ def split_indices_by_instance(indices, as_dict=False):
72
+ indices_B, indices_N = indices
73
+ unique_indices_B = torch.unique(indices_B)
74
+ if not as_dict:
75
+ indices_by_instance = [ (indices_B[indices_B == uib], indices_N[indices_B == uib]) for uib in unique_indices_B ]
76
+ else:
77
+ indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
78
+ return indices_by_instance
79
+
80
+ # If do_sum, returned emb_attns is 3D. Otherwise 4D.
81
+ # indices are applied on the first 2 dims of attn_mat.
82
+ def sel_emb_attns_by_indices(attn_mat, indices, all_token_weights=None, do_sum=True, do_mean=False):
83
+ indices_by_instance = split_indices_by_instance(indices)
84
+
85
+ # emb_attns[0]: [1, 9, 8, 64]
86
+ # 8: 8 attention heads. Last dim 64: number of image tokens.
87
+ emb_attns = [ attn_mat[inst_indices].unsqueeze(0) for inst_indices in indices_by_instance ]
88
+ if all_token_weights is not None:
89
+ # all_token_weights: [4, 77].
90
+ # token_weights_by_instance[0]: [1, 9, 1, 1].
91
+ token_weights = [ all_token_weights[inst_indices].reshape(1, -1, 1, 1) for inst_indices in indices_by_instance ]
92
+ else:
93
+ token_weights = [ 1 ] * len(indices_by_instance)
94
+
95
+ # Apply token weights.
96
+ emb_attns = [ emb_attns[i] * token_weights[i] for i in range(len(indices_by_instance)) ]
97
+
98
+ # sum among K_subj_i subj embeddings -> [1, 8, 64]
99
+ if do_sum:
100
+ emb_attns = [ emb_attns[i].sum(dim=1) for i in range(len(indices_by_instance)) ]
101
+ elif do_mean:
102
+ emb_attns = [ emb_attns[i].mean(dim=1) for i in range(len(indices_by_instance)) ]
103
+
104
+ emb_attns = torch.cat(emb_attns, dim=0)
105
+ return emb_attns
106
+
107
+ # Slow implementation equivalent to F.scaled_dot_product_attention.
108
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
109
+ shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
110
+ is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
111
+ B, L, S = query.size(0), query.size(-2), key.size(-2)
112
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
113
+ # 1: head (to be broadcasted). L: query length. S: key length.
114
+ attn_bias = torch.zeros(B, 1, L, S, device=query.device, dtype=query.dtype)
115
+ if is_causal:
116
+ assert attn_mask is None
117
+ temp_mask = torch.ones(B, 1, L, S, device=query.device, dtype=torch.bool).tril(diagonal=0)
118
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
119
+ attn_bias.to(query.dtype)
120
+
121
+ if attn_mask is not None:
122
+ if attn_mask.dtype == torch.bool:
123
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
124
+ else:
125
+ attn_bias += attn_mask
126
+
127
+ if enable_gqa:
128
+ key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
129
+ value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
130
+
131
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
132
+
133
+ if shrink_cross_attn:
134
+ cross_attn_scale = cross_attn_shrink_factor
135
+ else:
136
+ cross_attn_scale = 1
137
+
138
+ # attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_weight.
139
+ attn_weight += attn_bias
140
+ attn_score = attn_weight
141
+ attn_weight = torch.softmax(attn_weight, dim=-1)
142
+ # NOTE: After scaling, the "probabilities" of the subject embeddings will sum to < 1.
143
+ # But this is intended, as we want to scale down the impact of the subject embeddings
144
+ # in the computed attention output tensors.
145
+ attn_weight = attn_weight * cross_attn_scale
146
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
147
+ output = attn_weight @ value
148
+ return output, attn_score, attn_weight
149
+
150
+ # All layers share the same attention processor instance.
151
+ class AttnProcessor_LoRA_Capture(nn.Module):
152
+ r"""
153
+ Revised from AttnProcessor2_0
154
+ """
155
+ # lora_proj_layers is a dict of lora_layer_name -> lora_proj_layer.
156
+ def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
157
+ lora_uses_dora=True, lora_proj_layers=None,
158
+ lora_rank: int = 192, lora_alpha: float = 16,
159
+ cross_attn_shrink_factor: float = 0.5,
160
+ q_lora_updates_query=False, attn_proc_idx=-1):
161
+ super().__init__()
162
+
163
+ self.global_enable_lora = enable_lora
164
+ self.attn_proc_idx = attn_proc_idx
165
+ # reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
166
+ # By default, shrink_cross_attn is False. Later in layers 22, 23, 24 it will be set to True.
167
+ self.reset_attn_cache_and_flags(capture_ca_activations, False, enable_lora)
168
+ self.lora_rank = lora_rank
169
+ self.lora_alpha = lora_alpha
170
+ self.lora_scale = self.lora_alpha / self.lora_rank
171
+ self.cross_attn_shrink_factor = cross_attn_shrink_factor
172
+ self.q_lora_updates_query = q_lora_updates_query
173
+
174
+ self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
175
+ if self.global_enable_lora:
176
+ for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
177
+ if lora_layer_name == 'q':
178
+ self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
179
+ use_dora=lora_uses_dora, lora_dropout=0.1)
180
+ elif lora_layer_name == 'k':
181
+ self.to_k_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
182
+ use_dora=lora_uses_dora, lora_dropout=0.1)
183
+ elif lora_layer_name == 'v':
184
+ self.to_v_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
185
+ use_dora=lora_uses_dora, lora_dropout=0.1)
186
+ elif lora_layer_name == 'out':
187
+ self.to_out_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
188
+ use_dora=lora_uses_dora, lora_dropout=0.1)
189
+
190
+ # LoRA layers can be enabled/disabled dynamically.
191
+ def reset_attn_cache_and_flags(self, capture_ca_activations, shrink_cross_attn, enable_lora):
192
+ self.capture_ca_activations = capture_ca_activations
193
+ self.shrink_cross_attn = shrink_cross_attn
194
+ self.cached_activations = {}
195
+ # Only enable LoRA for the next call(s) if global_enable_lora is set to True.
196
+ self.enable_lora = enable_lora and self.global_enable_lora
197
+
198
+ def __call__(
199
+ self,
200
+ attn: Attention,
201
+ hidden_states: torch.Tensor,
202
+ encoder_hidden_states: Optional[torch.Tensor] = None,
203
+ attention_mask: Optional[torch.Tensor] = None,
204
+ temb: Optional[torch.Tensor] = None,
205
+ img_mask: Optional[torch.Tensor] = None,
206
+ subj_indices: Optional[Tuple[torch.IntTensor, torch.IntTensor]] = None,
207
+ debug: bool = False,
208
+ *args,
209
+ **kwargs,
210
+ ) -> torch.Tensor:
211
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
212
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
213
+ deprecate("scale", "1.0.0", deprecation_message)
214
+
215
+ # hidden_states: [1, 4096, 320]
216
+ residual = hidden_states
217
+ # attn.spatial_norm is None.
218
+ if attn.spatial_norm is not None:
219
+ hidden_states = attn.spatial_norm(hidden_states, temb)
220
+
221
+ input_ndim = hidden_states.ndim
222
+
223
+ if input_ndim == 4:
224
+ batch_size, channel, height, width = hidden_states.shape
225
+ # Collapse the spatial dimensions to a single token dimension.
226
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
227
+
228
+ batch_size, sequence_length, _ = (
229
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
230
+ )
231
+
232
+ if attention_mask is not None:
233
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
234
+ # scaled_dot_product_attention expects attention_mask shape to be
235
+ # (batch, heads, source_length, target_length)
236
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
237
+
238
+ if attn.group_norm is not None:
239
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
240
+
241
+ query = attn.to_q(hidden_states)
242
+ # NOTE: there's a inconsistency between q lora and k, v loras.
243
+ # k, v loras are directly applied to key and value (currently k, v loras are never enabled),
244
+ # while q lora is applied to query2, and we keep the query unchanged.
245
+ if self.enable_lora and self.to_q_lora is not None:
246
+ # query2 will be used in ldm/util.py:calc_elastic_matching_loss() to get more accurate
247
+ # cross attention scores between the latent images of the sc and mc instances.
248
+ query2 = self.to_q_lora(hidden_states)
249
+ # If not q_lora_updates_query, only query2 will be impacted by the LoRA layer.
250
+ # The query, and thus the attention score and attn_out, will be the same
251
+ # as the original ones.
252
+ if self.q_lora_updates_query:
253
+ query = query2
254
+ else:
255
+ query2 = query
256
+
257
+ scale = 1 / math.sqrt(query.size(-1))
258
+
259
+ is_cross_attn = (encoder_hidden_states is not None)
260
+ if (not is_cross_attn) and (img_mask is not None):
261
+ # NOTE: we assume the image is square. But this will fail if the image is not square.
262
+ # hidden_states: [BS, 4096, 320]. img_mask: [BS, 1, 64, 64]
263
+ # Scale the mask to the same size as hidden_states.
264
+ mask_size = int(math.sqrt(hidden_states.shape[-2]))
265
+ img_mask = F.interpolate(img_mask, size=(mask_size, mask_size), mode='nearest')
266
+ if (img_mask.sum(dim=(2, 3)) == 0).any():
267
+ img_mask = None
268
+ else:
269
+ # img_mask: [2, 1, 64, 64] -> [2, 4096]
270
+ img_mask = rearrange(img_mask, 'b ... -> b (...)').contiguous()
271
+ # max_neg_value = -torch.finfo(hidden_states.dtype).max
272
+ # img_mask: [2, 4096] -> [2, 1, 1, 4096]
273
+ img_mask = rearrange(img_mask.bool(), 'b j -> b () () j')
274
+ # attn_score: [16, 4096, 4096]. img_mask will be broadcasted to [16, 4096, 4096].
275
+ # So some rows in dim 1 (e.g. [0, :, 4095]) of attn_score will be masked out (all elements in [0, :, 4095] is -inf).
276
+ # But not all elements in [0, 4095, :] is -inf. Since the softmax is done along dim 2, this is fine.
277
+ # attn_score.masked_fill_(~img_mask, max_neg_value)
278
+ # NOTE: If there's an attention mask, it will be replaced by img_mask.
279
+ attention_mask = img_mask
280
+
281
+ if encoder_hidden_states is None:
282
+ encoder_hidden_states = hidden_states
283
+ elif attn.norm_cross:
284
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
285
+
286
+ if self.enable_lora and self.to_k_lora is not None:
287
+ key = self.to_k_lora(encoder_hidden_states)
288
+ else:
289
+ key = attn.to_k(encoder_hidden_states)
290
+
291
+ if self.enable_lora and self.to_v_lora is not None:
292
+ value = self.to_v_lora(encoder_hidden_states)
293
+ else:
294
+ value = attn.to_v(encoder_hidden_states)
295
+
296
+ if attn.norm_q is not None:
297
+ query = attn.norm_q(query)
298
+ query2 = attn.norm_q(query2)
299
+ if attn.norm_k is not None:
300
+ key = attn.norm_k(key)
301
+
302
+ inner_dim = key.shape[-1]
303
+ head_dim = inner_dim // attn.heads
304
+
305
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
306
+ query2 = query2.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
307
+
308
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
309
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
310
+
311
+ if debug and self.attn_proc_idx >= 0:
312
+ breakpoint()
313
+
314
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
315
+ if is_cross_attn and (self.capture_ca_activations or self.shrink_cross_attn):
316
+ hidden_states, attn_score, attn_prob = \
317
+ scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
318
+ dropout_p=0.0, shrink_cross_attn=self.shrink_cross_attn,
319
+ cross_attn_shrink_factor=self.cross_attn_shrink_factor)
320
+ else:
321
+ # Use the faster implementation of scaled_dot_product_attention
322
+ # when not capturing the activations or suppressing the subject attention.
323
+ hidden_states = \
324
+ F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
325
+ attn_prob = attn_score = None
326
+
327
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
328
+ hidden_states = hidden_states.to(query.dtype)
329
+
330
+ # linear proj
331
+ if self.enable_lora and self.to_out_lora is not None:
332
+ hidden_states = self.to_out_lora(hidden_states)
333
+ else:
334
+ hidden_states = attn.to_out[0](hidden_states)
335
+
336
+ # dropout
337
+ hidden_states = attn.to_out[1](hidden_states)
338
+
339
+ if input_ndim == 4:
340
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
341
+
342
+ if attn.residual_connection:
343
+ hidden_states = hidden_states + residual
344
+
345
+ hidden_states = hidden_states / attn.rescale_output_factor
346
+
347
+ if is_cross_attn and self.capture_ca_activations:
348
+ # cached q will be used in ddpm.py:calc_comp_fg_bg_preserve_loss(), in which two qs will multiply each other.
349
+ # So sqrt(scale) will scale the product of two qs by scale.
350
+ # ANCHOR[id=attention_caching]
351
+ # query: [2, 8, 4096, 40] -> [2, 320, 4096]
352
+ self.cached_activations['q'] = \
353
+ rearrange(query, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
354
+ self.cached_activations['q2'] = \
355
+ rearrange(query2, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
356
+ self.cached_activations['k'] = \
357
+ rearrange(key, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
358
+ self.cached_activations['v'] = \
359
+ rearrange(value, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
360
+ # attn_prob, attn_score: [2, 8, 4096, 77]
361
+ self.cached_activations['attn'] = attn_prob
362
+ self.cached_activations['attnscore'] = attn_score
363
+ # attn_out: [b, n, h * d] -> [b, h * d, n]
364
+ # [2, 4096, 320] -> [2, 320, 4096].
365
+ self.cached_activations['attn_out'] = hidden_states.permute(0, 2, 1).contiguous()
366
+
367
+ return hidden_states
368
+
369
+ def CrossAttnUpBlock2D_forward_capture(
370
+ self,
371
+ hidden_states: torch.Tensor,
372
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
373
+ temb: Optional[torch.Tensor] = None,
374
+ encoder_hidden_states: Optional[torch.Tensor] = None,
375
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
376
+ upsample_size: Optional[int] = None,
377
+ attention_mask: Optional[torch.Tensor] = None,
378
+ encoder_attention_mask: Optional[torch.Tensor] = None,
379
+ ) -> torch.Tensor:
380
+ if cross_attention_kwargs is not None:
381
+ if cross_attention_kwargs.get("scale", None) is not None:
382
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
383
+
384
+ self.cached_outfeats = {}
385
+ res_hidden_states_gradscale = getattr(self, "res_hidden_states_gradscale", 1)
386
+ capture_outfeats = getattr(self, "capture_outfeats", False)
387
+ layer_idx = 0
388
+ res_grad_scaler = gen_gradient_scaler(res_hidden_states_gradscale)
389
+
390
+ for resnet, attn in zip(self.resnets, self.attentions):
391
+ # pop res hidden states
392
+ res_hidden_states = res_hidden_states_tuple[-1]
393
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
394
+
395
+ # Scale down the magnitudes of gradients to res_hidden_states
396
+ # by res_hidden_states_gradscale=0.2, to match the scale of the cross-attn layer outputs.
397
+ res_hidden_states = res_grad_scaler(res_hidden_states)
398
+
399
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
400
+
401
+ if self.training and self.gradient_checkpointing:
402
+ def create_custom_forward(module, return_dict=None):
403
+ def custom_forward(*inputs):
404
+ if return_dict is not None:
405
+ return module(*inputs, return_dict=return_dict)
406
+ else:
407
+ return module(*inputs)
408
+
409
+ return custom_forward
410
+
411
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
412
+ hidden_states = torch.utils.checkpoint.checkpoint(
413
+ create_custom_forward(resnet),
414
+ hidden_states,
415
+ temb,
416
+ **ckpt_kwargs,
417
+ )
418
+ hidden_states = attn(
419
+ hidden_states,
420
+ encoder_hidden_states=encoder_hidden_states,
421
+ cross_attention_kwargs=cross_attention_kwargs,
422
+ attention_mask=attention_mask,
423
+ encoder_attention_mask=encoder_attention_mask,
424
+ return_dict=False,
425
+ )[0]
426
+ else:
427
+ # resnet: ResnetBlock2D instance.
428
+ #LINK diffusers.models.resnet.ResnetBlock2D
429
+ # up_blocks.3.resnets.2.conv_shortcut is a module within ResnetBlock2D,
430
+ # it's not transforming the UNet shortcut features.
431
+ hidden_states = resnet(hidden_states, temb)
432
+ hidden_states = attn(
433
+ hidden_states,
434
+ encoder_hidden_states=encoder_hidden_states,
435
+ cross_attention_kwargs=cross_attention_kwargs,
436
+ attention_mask=attention_mask,
437
+ encoder_attention_mask=encoder_attention_mask,
438
+ return_dict=False,
439
+ )[0]
440
+
441
+ if capture_outfeats:
442
+ self.cached_outfeats[layer_idx] = hidden_states
443
+ layer_idx += 1
444
+
445
+ if self.upsamplers is not None:
446
+ for upsampler in self.upsamplers:
447
+ hidden_states = upsampler(hidden_states, upsample_size)
448
+
449
+ return hidden_states
450
+
451
+
452
+ # Adapted from ConsistentIDPipeline:set_ip_adapter().
453
+ # attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
454
+ def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
455
+ lora_rank=192, lora_scale_down=8, cross_attn_shrink_factor=0.5,
456
+ q_lora_updates_query=False):
457
+ attn_procs = {}
458
+ attn_capture_procs = {}
459
+ unet_modules = dict(unet.named_modules())
460
+ attn_opt_modules = {}
461
+
462
+ attn_proc_idx = 0
463
+
464
+ for name, attn_proc in unet.attn_processors.items():
465
+ # Only capture the activations of the last 3 CA layers.
466
+ if not name.startswith("up_blocks.3"):
467
+ # Not the last 3 CA layers. Don't enable LoRA or capture activations.
468
+ # Then the layer falls back to the original attention mechanism.
469
+ # We still use AttnProcessor_LoRA_Capture, as it can handle img_mask.
470
+ attn_procs[name] = AttnProcessor_LoRA_Capture(
471
+ capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1)
472
+ continue
473
+ # cross_attention_dim: 768.
474
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
475
+ if cross_attention_dim is None:
476
+ # Self attention. Don't enable LoRA or capture activations.
477
+ # We replace the default attn_proc with AttnProcessor_LoRA_Capture,
478
+ # so that it can incorporate img_mask into self-attention.
479
+ attn_procs[name] = AttnProcessor_LoRA_Capture(
480
+ capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1)
481
+ continue
482
+
483
+ # block_id = 3
484
+ # hidden_size: 320
485
+ # hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
486
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor' ->
487
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q'
488
+ lora_layer_dict = {}
489
+ lora_layer_dict['q'] = unet_modules[name[:-9] + "to_q"]
490
+ lora_layer_dict['k'] = unet_modules[name[:-9] + "to_k"]
491
+ lora_layer_dict['v'] = unet_modules[name[:-9] + "to_v"]
492
+ # to_out is a ModuleList(Linear, Dropout).
493
+ lora_layer_dict['out'] = unet_modules[name[:-9] + "to_out"][0]
494
+
495
+ lora_proj_layers = {}
496
+ # Only apply LoRA to the specified layers.
497
+ for lora_layer_name in attn_lora_layer_names:
498
+ lora_proj_layers[lora_layer_name] = lora_layer_dict[lora_layer_name]
499
+
500
+ attn_capture_proc = AttnProcessor_LoRA_Capture(
501
+ capture_ca_activations=True, enable_lora=use_attn_lora,
502
+ lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
503
+ # LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
504
+ lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
505
+ cross_attn_shrink_factor=cross_attn_shrink_factor,
506
+ q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
507
+
508
+ attn_proc_idx += 1
509
+ # attn_procs has to use the original names.
510
+ attn_procs[name] = attn_capture_proc
511
+ # ModuleDict doesn't allow "." in the key.
512
+ name = name.replace(".", "_")
513
+ attn_capture_procs[name] = attn_capture_proc
514
+
515
+ if use_attn_lora:
516
+ for subname, module in attn_capture_proc.named_modules():
517
+ if isinstance(module, peft_lora.LoraLayer):
518
+ # ModuleDict doesn't allow "." in the key.
519
+ lora_path = name + "_" + subname.replace(".", "_")
520
+ attn_opt_modules[lora_path + "_lora_A"] = module.lora_A
521
+ attn_opt_modules[lora_path + "_lora_B"] = module.lora_B
522
+ # lora_uses_dora is always True, so we don't check it here.
523
+ attn_opt_modules[lora_path + "_lora_magnitude_vector"] = module.lora_magnitude_vector
524
+ # We will manage attn adapters directly. By default, LoraLayer is an instance of BaseTunerLayer,
525
+ # so according to the code logic in diffusers/loaders/peft.py,
526
+ # they will be managed by the diffusers PeftAdapterMixin instance, through the
527
+ # enable_adapters(), and set_adapter() methods.
528
+ # Therefore, we disable these calls on module.
529
+ # disable_adapters() is a property and changing it will cause exceptions.
530
+ module.enable_adapters = dummy_func
531
+ module.set_adapter = dummy_func
532
+
533
+ unet.set_attn_processor(attn_procs)
534
+
535
+ print(f"Set up {len(attn_capture_procs)} CrossAttn processors on {attn_capture_procs.keys()}.")
536
+ print(f"Set up {len(attn_opt_modules)} attn LoRA params: {attn_opt_modules.keys()}.")
537
+ return attn_capture_procs, attn_opt_modules
538
+
539
+ # NOTE: cross-attn layers are included in the returned lora_modules.
540
+ def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=192, lora_alpha=16):
541
+ # target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
542
+ # up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
543
+ # Cannot set to conv.+ as it will match added adapter module names, including
544
+ # up_blocks.3.resnets.1.conv1.base_layer, up_blocks.3.resnets.1.conv1.lora_dropout
545
+ if target_modules_pat is not None:
546
+ peft_config = LoraConfig(use_dora=lora_uses_dora, inference_mode=False, r=lora_rank,
547
+ lora_alpha=lora_alpha, lora_dropout=0.1,
548
+ target_modules=target_modules_pat)
549
+
550
+ # UNet is a diffusers PeftAdapterMixin instance. Using get_peft_model on it will
551
+ # cause weird errors. Instead, we directly use diffusers peft adapter methods.
552
+ unet.add_adapter(peft_config, "recon_loss")
553
+ unet.add_adapter(peft_config, "unet_distill")
554
+ unet.add_adapter(peft_config, "comp_distill")
555
+ unet.enable_adapters()
556
+
557
+ # lora_layers contain both the LoRA A and B matrices, as well as the original layers.
558
+ # lora_layers are used to set the flag, not used for optimization.
559
+ # lora_modules contain only the LoRA A and B matrices, so they are used for optimization.
560
+ # NOTE: lora_modules contain both ffn and cross-attn lora modules.
561
+ ffn_lora_layers = {}
562
+ ffn_opt_modules = {}
563
+ for name, module in unet.named_modules():
564
+ if isinstance(module, peft_lora.LoraLayer):
565
+ # We don't want to include cross-attn layers in ffn_lora_layers.
566
+ if target_modules_pat is not None and re.search(target_modules_pat, name):
567
+ ffn_lora_layers[name] = module
568
+ # ModuleDict doesn't allow "." in the key.
569
+ name = name.replace(".", "_")
570
+ # Since ModuleDict doesn't allow "." in the key, we manually collect
571
+ # the LoRA matrices in each module.
572
+ # NOTE: We cannot put every sub-module of module into lora_modules,
573
+ # as base_layer is also a sub-module of module, which we shouldn't optimize.
574
+ # Each value in ffn_opt_modules is a ModuleDict:
575
+ '''
576
+ (Pdb) ffn_opt_modules['up_blocks_3_resnets_1_conv1_lora_A']
577
+ ModuleDict(
578
+ (unet_distill): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
579
+ (recon_loss): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
580
+ )
581
+ '''
582
+ ffn_opt_modules[name + "_lora_A"] = module.lora_A
583
+ ffn_opt_modules[name + "_lora_B"] = module.lora_B
584
+ if lora_uses_dora:
585
+ ffn_opt_modules[name + "_lora_magnitude_vector"] = module.lora_magnitude_vector
586
+
587
+ print(f"Set up {len(ffn_lora_layers)} FFN LoRA layers: {ffn_lora_layers.keys()}.")
588
+ print(f"Set up {len(ffn_opt_modules)} FFN LoRA params: {ffn_opt_modules.keys()}.")
589
+
590
+ return ffn_lora_layers, ffn_opt_modules
591
+
592
+ def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
593
+ outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
594
+ use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
595
+ shrink_cross_attn, res_hidden_states_gradscale):
596
+ # For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
597
+ for attn_capture_proc in attn_capture_procs:
598
+ attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, shrink_cross_attn, enable_lora=use_attn_lora)
599
+ # outfeat_capture_blocks only contains the last up block, up_blocks[3].
600
+ # It contains 3 FFN layers. We want to capture their output features.
601
+ for block in outfeat_capture_blocks:
602
+ block.capture_outfeats = capture_ca_activations
603
+
604
+ for block in res_hidden_states_gradscale_blocks:
605
+ block.res_hidden_states_gradscale = res_hidden_states_gradscale
606
+
607
+ if not use_ffn_lora:
608
+ unet.disable_adapters()
609
+ else:
610
+ # ffn_lora_adapter_name: 'recon_loss', 'unet_distill', 'comp_distill'.
611
+ if ffn_lora_adapter_name is not None:
612
+ unet.set_adapter(ffn_lora_adapter_name)
613
+ # NOTE: Don't forget to enable_adapters().
614
+ # The adapters are not enabled by default after set_adapter().
615
+ unet.enable_adapters()
616
+ else:
617
+ breakpoint()
618
+
619
+ # During training, disable_adapters() and set_adapter() will set all/inactive adapters with requires_grad=False,
620
+ # which might cause issues during DDP training.
621
+ # So we restore them to requires_grad=True.
622
+ # During test, unet_lora_modules will be passed as None, so this block will be skipped.
623
+ if unet_lora_modules is not None:
624
+ for param in unet_lora_modules.parameters():
625
+ param.requires_grad = True
626
+
627
+ def get_captured_activations(capture_ca_activations, attn_capture_procs, outfeat_capture_blocks,
628
+ captured_layer_indices=[22, 23, 24], out_dtype=torch.float32):
629
+ captured_activations = { k: {} for k in ('outfeat', 'attn', 'attnscore',
630
+ 'q', 'q2', 'k', 'v', 'attn_out') }
631
+
632
+ if not capture_ca_activations:
633
+ return captured_activations
634
+
635
+ all_cached_outfeats = []
636
+ for block in outfeat_capture_blocks:
637
+ all_cached_outfeats.append(block.cached_outfeats)
638
+ # Clear the capture flag and cached outfeats.
639
+ block.cached_outfeats = {}
640
+ block.capture_outfeats = False
641
+
642
+ for layer_idx in captured_layer_indices:
643
+ # Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
644
+ # 23, 24 -> 1, 2 (!! not 0, 1 !!)
645
+ internal_idx = layer_idx - 22
646
+ for k in captured_activations.keys():
647
+ if k == 'outfeat':
648
+ # Currently we only capture one block, up_blocks.3. So we hard-code the index 0.
649
+ captured_activations['outfeat'][layer_idx] = all_cached_outfeats[0][internal_idx].to(out_dtype)
650
+ else:
651
+ # internal_idx is the index of layers in up_blocks.3.
652
+ # Layers 22, 23 and 24 map to 0, 1 and 2.
653
+ cached_activations = attn_capture_procs[internal_idx].cached_activations
654
+ captured_activations[k][layer_idx] = cached_activations[k].to(out_dtype)
655
+
656
+ return captured_activations
adaface/face_id_to_ada_prompt.py CHANGED
@@ -53,6 +53,8 @@ class FaceID2AdaPrompt(nn.Module):
53
  self.text_to_image_prompt_encoder = None
54
  self.tokenizer = None
55
  self.dtype = kwargs.get('dtype', torch.float16)
 
 
56
 
57
  # Load Img2Ada SubjectBasisGenerator.
58
  self.subject_string = kwargs.get('subject_string', 'z')
@@ -73,12 +75,16 @@ class FaceID2AdaPrompt(nn.Module):
73
 
74
  self.use_clip_embs = False
75
  self.do_contrast_clip_embs_on_bg_features = False
 
 
 
 
76
  # num_id_vecs is the output embeddings of the ID2ImgPrompt module.
77
  # If there's no static image suffix embeddings, then num_id_vecs is also
78
  # the number of ada embeddings returned by the subject basis generator.
79
  # num_id_vecs will be set in each derived class.
80
  self.num_static_img_suffix_embs = kwargs.get('num_static_img_suffix_embs', 0)
81
- print(f'{self.name} Adaface uses {self.num_id_vecs} ID image embeddings and {self.num_static_img_suffix_embs} fixed image embeddings as input.')
82
 
83
  self.id_img_prompt_max_length = 77
84
  self.face_id_dim = 512
@@ -87,36 +93,35 @@ class FaceID2AdaPrompt(nn.Module):
87
  self.clip_embedding_dim = 1024
88
  self.output_dim = 768
89
 
90
- def get_id2img_learnable_modules(self):
91
- raise NotImplementedError
92
-
93
- def load_id2img_learnable_modules(self, id2img_learnable_modules_state_dict_list):
94
- id2img_prompt_encoder_learnable_modules = self.get_id2img_learnable_modules()
95
- for module, state_dict in zip(id2img_prompt_encoder_learnable_modules, id2img_learnable_modules_state_dict_list):
96
- module.load_state_dict(state_dict)
97
- print(f'{len(id2img_prompt_encoder_learnable_modules)} ID2ImgPrompt encoder modules loaded.')
98
-
99
- # init_subj_basis_generator() can only be called after the derived class is initialized,
100
- # when self.num_id_vecs, self.num_static_img_suffix_embs and self.clip_embedding_dim have been set.
101
- def init_subj_basis_generator(self):
102
  self.subj_basis_generator = \
103
- SubjBasisGenerator(num_id_vecs = self.num_id_vecs,
 
104
  num_static_img_suffix_embs = self.num_static_img_suffix_embs,
105
  bg_image_embedding_dim = self.clip_embedding_dim,
106
  output_dim = self.output_dim,
107
  placeholder_is_bg = False,
108
- prompt2token_proj_grad_scale = 1,
109
  bg_prompt_translator_has_to_out_proj=False)
110
 
111
  def load_adaface_ckpt(self, adaface_ckpt_path):
112
- ckpt = torch.load(adaface_ckpt_path, map_location='cpu')
 
 
 
113
  string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
114
  if self.subject_string not in string_to_subj_basis_generator_dict:
115
  print(f"Subject '{self.subject_string}' not found in the embedding manager.")
116
  breakpoint()
117
 
118
  ckpt_subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
119
- ckpt_subj_basis_generator.N_ID = self.num_id_vecs
 
 
 
 
 
120
  # Since we directly use the subject basis generator object from the ckpt,
121
  # fixing the number of static image suffix embeddings is much simpler.
122
  # Otherwise if we want to load the subject basis generator from its state_dict,
@@ -129,7 +134,7 @@ class FaceID2AdaPrompt(nn.Module):
129
  ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.num_static_img_suffix_embs, img_prompt_dim=self.output_dim)
130
  # Fix missing variables in old ckpt.
131
  ckpt_subj_basis_generator.patch_old_subj_basis_generator_ckpt()
132
-
133
  self.subj_basis_generator.extend_prompt2token_proj_attention(\
134
  ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
135
  ret = self.subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict(), strict=False)
@@ -155,6 +160,11 @@ class FaceID2AdaPrompt(nn.Module):
155
 
156
  self.subj_basis_generator.freeze_prompt2token_proj()
157
 
 
 
 
 
 
158
  @torch.no_grad()
159
  def get_clip_neg_features(self, BS):
160
  if self.clip_neg_features is None:
@@ -220,6 +230,7 @@ class FaceID2AdaPrompt(nn.Module):
220
  image_obj, _, _ = pad_image_obj_to_square(image_obj)
221
  image_np = np.array(image_obj.resize(size, Image.NEAREST))
222
  face_info = self.face_app.get(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
 
223
  if len(face_info) > 0:
224
  face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
225
  # id_emb: [512,]
@@ -487,12 +498,20 @@ class FaceID2AdaPrompt(nn.Module):
487
  # avg_at_stage == ada_prompt_emb usually produces the worst results.
488
  # avg_at_stage == id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better.
489
  # p_dropout and return_zero_embs_for_dropped_encoders are only used by Joint_FaceID2AdaPrompt.
 
490
  def generate_adaface_embeddings(self, image_paths, face_id_embs=None, img_prompt_embs=None,
491
  p_dropout=0,
492
  return_zero_embs_for_dropped_encoders=True,
493
  avg_at_stage='id_emb', # id_emb, img_prompt_emb, or None.
494
  perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
495
- perturb_std=0, enable_static_img_suffix_embs=False):
 
 
 
 
 
 
 
496
  if (avg_at_stage is None) or avg_at_stage.lower() == 'none':
497
  img_prompt_avg_at_stage = None
498
  else:
@@ -509,7 +528,7 @@ class FaceID2AdaPrompt(nn.Module):
509
  id_batch_size = len(image_paths)
510
  else:
511
  id_batch_size = 1
512
-
513
  # faceid_embeds: [BS, 512] is a batch of extracted face analysis embeddings. NOT used later.
514
  # NOTE: If face_id_embs, image_paths and image_objs are all None,
515
  # then get_img_prompt_embs() generates random faceid_embeds/img_prompt_embs,
@@ -532,7 +551,7 @@ class FaceID2AdaPrompt(nn.Module):
532
  verbose=True)
533
 
534
  if face_image_count == 0:
535
- return None
536
 
537
  # No matter whether avg_at_stage is id_emb or img_prompt_emb, we average img_prompt_embs.
538
  elif avg_at_stage is not None and avg_at_stage.lower() != 'none':
@@ -545,19 +564,27 @@ class FaceID2AdaPrompt(nn.Module):
545
  out_id_embs_cfg_scale=self.out_id_embs_cfg_scale,
546
  is_face=True,
547
  enable_static_img_suffix_embs=enable_static_img_suffix_embs)
 
 
 
 
548
  # During training, img_prompt_avg_at_stage is None, and BS >= 1.
549
  # During inference, img_prompt_avg_at_stage is 'id_emb' or 'img_prompt_emb', and BS == 1.
550
  if img_prompt_avg_at_stage is not None:
551
  # adaface_subj_embs: [1, 16, 768] -> [16, 768]
552
  adaface_subj_embs = adaface_subj_embs.squeeze(0)
553
 
554
- return adaface_subj_embs
555
 
556
  class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
557
- def __init__(self, *args, **kwargs):
558
- self.name = 'arc2face'
559
- self.num_id_vecs = 16
 
 
 
560
 
 
561
  super().__init__(*args, **kwargs)
562
 
563
  self.clip_image_encoder = CLIPVisionModelWithMask.from_pretrained('openai/clip-vit-large-patch14')
@@ -583,7 +610,7 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
583
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
584
  providers=['CPUExecutionProvider'])
585
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
586
- print(f'Face encoder loaded on CPU.')
587
 
588
  self.text_to_image_prompt_encoder = CLIPTextModelWrapper.from_pretrained(
589
  'models/arc2face', subfolder="encoder",
@@ -594,21 +621,54 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
594
  if self.out_id_embs_cfg_scale == -1:
595
  self.out_id_embs_cfg_scale = 1
596
  #### Arc2Face pipeline specific configs ####
597
- self.gen_neg_img_prompt = False
598
  # bg CLIP features are used by the bg subject basis generator.
599
- self.use_clip_embs = True
600
  self.do_contrast_clip_embs_on_bg_features = True
601
  # self.num_static_img_suffix_embs is initialized in the parent class.
602
- self.id_img_prompt_max_length = 22
603
- self.clip_embedding_dim = 1024
604
 
605
- self.init_subj_basis_generator()
606
  if self.adaface_ckpt_path is not None:
607
  self.load_adaface_ckpt(self.adaface_ckpt_path)
608
 
609
- print(f"{self.name} ada prompt encoder initialized, "
610
- f"ID vecs: {self.num_id_vecs}, static suffix: {self.num_static_img_suffix_embs}.")
 
 
 
 
611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  # Arc2Face_ID2AdaPrompt never uses clip_features or called_for_neg_img_prompt.
613
  def map_init_id_to_img_prompt_embs(self, init_id_embs,
614
  clip_features=None,
@@ -656,16 +716,17 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
656
  # [N, 22, 768] -> [N, 16, 768]
657
  return prompt_embeds[:, 4:20]
658
 
659
- def get_id2img_learnable_modules(self):
660
- return [ self.text_to_image_prompt_encoder ]
661
-
662
  # ConsistentID_ID2AdaPrompt is just a wrapper of ConsistentIDPipeline, so it's not an nn.Module.
663
  class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
 
 
 
 
 
 
664
  def __init__(self, pipe=None, base_model_path="models/sd15-dste8-vae.safetensors",
665
  *args, **kwargs):
666
- self.name = 'consistentID'
667
- self.num_id_vecs = 4
668
-
669
  super().__init__(*args, **kwargs)
670
  if pipe is None:
671
  # The base_model_path is kind of arbitrary, as the UNet and VAE in the model
@@ -712,13 +773,47 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
712
  self.clip_embedding_dim = 1280
713
  self.s_scale = 1.0
714
  self.shortcut = False
715
-
716
- self.init_subj_basis_generator()
717
  if self.adaface_ckpt_path is not None:
718
  self.load_adaface_ckpt(self.adaface_ckpt_path)
719
 
 
 
 
 
 
 
 
720
  print(f"{self.name} ada prompt encoder initialized, "
721
- f"ID vecs: {self.num_id_vecs}, static suffix: {self.num_static_img_suffix_embs}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
 
723
  def map_init_id_to_img_prompt_embs(self, init_id_embs,
724
  clip_features=None,
@@ -757,26 +852,30 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
757
 
758
  return global_id_embeds
759
 
760
- def get_id2img_learnable_modules(self):
761
- return [ self.image_proj_model ]
762
-
763
  # A wrapper for combining multiple FaceID2AdaPrompt instances.
764
  class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
765
  def __init__(self, adaface_encoder_types, adaface_ckpt_paths,
766
  out_id_embs_cfg_scales=None, enabled_encoders=None,
767
  *args, **kwargs):
768
  self.name = 'jointIDs'
 
769
  assert len(adaface_encoder_types) > 0, "adaface_encoder_types should not be empty."
770
- adaface_encoder_types2num_id_vecs = { 'arc2face': 16, 'consistentID': 4 }
771
- self.encoders_num_id_vecs = [ adaface_encoder_types2num_id_vecs[encoder_type] \
 
 
772
  for encoder_type in adaface_encoder_types ]
773
- self.num_id_vecs = sum(self.encoders_num_id_vecs)
 
 
 
774
  # super() sets self.is_training.
775
  super().__init__(*args, **kwargs)
776
 
777
  self.num_sub_encoders = len(adaface_encoder_types)
778
  self.id2ada_prompt_encoders = nn.ModuleList()
779
  self.encoders_num_static_img_suffix_embs = []
 
780
 
781
  # TODO: apply adaface_encoder_cfg_scales to influence the final prompt embeddings.
782
  # Now they are just placeholders.
@@ -786,10 +885,12 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
786
  self.out_id_embs_cfg_scales = [-1] * self.num_sub_encoders
787
  else:
788
  # Do not normalize the weights, and just use them as is.
789
- self.out_id_embs_cfg_scales = out_id_embs_cfg_scales
790
 
791
  # Note we don't pass the adaface_ckpt_paths to the base class, but instead,
792
  # we load them once and for all in self.load_adaface_ckpt().
 
 
793
  for i, encoder_type in enumerate(adaface_encoder_types):
794
  kwargs['out_id_embs_cfg_scale'] = self.out_id_embs_cfg_scales[i]
795
  if encoder_type == 'arc2face':
@@ -798,8 +899,10 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
798
  encoder = ConsistentID_ID2AdaPrompt(*args, **kwargs)
799
  else:
800
  breakpoint()
 
801
  self.id2ada_prompt_encoders.append(encoder)
802
  self.encoders_num_static_img_suffix_embs.append(encoder.num_static_img_suffix_embs)
 
803
 
804
  self.num_static_img_suffix_embs = sum(self.encoders_num_static_img_suffix_embs)
805
  # No need to set gen_neg_img_prompt, as we don't access it in this class, but rather
@@ -829,7 +932,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
829
  self.load_adaface_ckpt(adaface_ckpt_paths)
830
 
831
  print(f"{self.name} ada prompt encoder initialized with {self.num_sub_encoders} sub-encoders. "
832
- f"ID vecs: {self.num_id_vecs}, static suffix embs: {self.num_static_img_suffix_embs}.")
833
 
834
  if enabled_encoders is not None:
835
  self.are_encoders_enabled = \
@@ -845,79 +948,79 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
845
  else:
846
  self.are_encoders_enabled = \
847
  torch.tensor([True] * self.num_sub_encoders)
848
-
849
- for i, encoder in enumerate(self.id2ada_prompt_encoders):
850
- if not (self.is_training and self.are_encoders_enabled[i]):
851
- for param in encoder.parameters():
852
- param.requires_grad = False
853
- else:
854
- for param in encoder.parameters():
855
- param.requires_grad = True
856
-
857
  def load_adaface_ckpt(self, adaface_ckpt_paths):
858
- # If only one adaface ckpt path is provided, then we assume it's the ckpt of the Joint_FaceID2AdaPrompt,
859
- # so we dereference the list to get the actual path and load the subj_basis_generators of all adaface encoders.
860
  if isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
861
- if len(adaface_ckpt_paths) == 1 and self.num_sub_encoders > 1:
 
 
 
 
 
 
 
 
862
  adaface_ckpt_paths = adaface_ckpt_paths[0]
863
-
864
- if isinstance(adaface_ckpt_paths, str):
865
- # This is only applicable to newest ckpts of Joint_FaceID2AdaPrompt, where
866
- # the ckpt_subj_basis_generator is an nn.ModuleList of multiple subj_basis_generators.
867
- # Therefore, no need to patch missing variables.
868
- ckpt = torch.load(adaface_ckpt_paths, map_location='cpu')
869
- string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
870
- if self.subject_string not in string_to_subj_basis_generator_dict:
871
- print(f"Subject '{self.subject_string}' not found in the embedding manager.")
872
  breakpoint()
873
 
874
- ckpt_subj_basis_generators = string_to_subj_basis_generator_dict[self.subject_string]
875
- if len(ckpt_subj_basis_generators) != self.num_sub_encoders:
876
- print(f"Number of subj_basis_generators in the ckpt ({len(ckpt_subj_basis_generators)}) "
877
- f"doesn't match the number of adaface encoders ({self.num_sub_encoders}).")
878
- breakpoint()
 
 
 
 
 
879
 
880
- for i, subj_basis_generator in enumerate(self.subj_basis_generator):
881
- ckpt_subj_basis_generator = ckpt_subj_basis_generators[i]
882
- # Handle differences in num_static_img_suffix_embs between the current model and the ckpt.
883
- ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.encoders_num_static_img_suffix_embs[i],
884
- img_prompt_dim=self.output_dim)
885
-
886
- if subj_basis_generator.prompt2token_proj_attention_multipliers \
887
- == [1] * 12:
888
- subj_basis_generator.extend_prompt2token_proj_attention(\
889
- ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
890
- elif subj_basis_generator.prompt2token_proj_attention_multipliers \
891
- != ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers:
892
- raise ValueError("Inconsistent prompt2token_proj_attention_multipliers.")
893
-
894
- assert subj_basis_generator.prompt2token_proj_attention_multipliers \
895
- == ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, \
896
- "Inconsistent prompt2token_proj_attention_multipliers."
897
- subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict())
898
-
899
- # extend_prompt2token_proj_attention_multiplier is an integer >= 1.
900
- # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
901
- # If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict,
902
- # extend subj_basis_generator again.
903
- if self.extend_prompt2token_proj_attention_multiplier > 1:
904
- # During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt.
905
- # During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1.
906
- # During inference, prompt2token_proj_ext_attention_perturb_ratio == 0.
907
- subj_basis_generator.extend_prompt2token_proj_attention(\
908
- None, -1, -1, self.extend_prompt2token_proj_attention_multiplier,
909
- perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio)
910
-
911
- subj_basis_generator.freeze_prompt2token_proj()
912
-
913
- print(f"{adaface_ckpt_paths}: {len(self.subj_basis_generator)} subj_basis_generators loaded for {self.name}.")
914
-
915
- elif isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
916
- for i, ckpt_path in enumerate(adaface_ckpt_paths):
917
- self.id2ada_prompt_encoders[i].load_adaface_ckpt(ckpt_path)
918
- else:
919
  breakpoint()
920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
921
  def extract_init_id_embeds_from_images(self, *args, **kwargs):
922
  total_faceless_img_count = 0
923
  all_id_embs = []
@@ -1055,7 +1158,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1055
 
1056
  N_ID = self.encoders_num_id_vecs[i]
1057
  if all_pos_prompt_embs[i] is None:
1058
- # Both pos_prompt_embs and neg_prompt_embs have N_ID == num_id_vecs embeddings.
1059
  all_pos_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
1060
  if all_neg_prompt_embs[i] is None:
1061
  all_neg_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
@@ -1077,6 +1180,13 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1077
  # So its .device is the device of its parameters.
1078
  device = self.id2ada_prompt_encoders[0].clip_image_encoder.device
1079
  is_emb_averaged = kwargs.get('avg_at_stage', None) is not None
 
 
 
 
 
 
 
1080
  BS = -1
1081
 
1082
  if face_id_embs is not None:
@@ -1084,13 +1194,17 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1084
  all_face_id_embs = face_id_embs.split(self.face_id_dims, dim=1)
1085
  else:
1086
  all_face_id_embs = [None] * self.num_sub_encoders
 
1087
  if img_prompt_embs is not None:
1088
  BS = img_prompt_embs.shape[0] if BS == -1 else BS
1089
- if img_prompt_embs.shape[1] != self.num_id_vecs:
1090
  breakpoint()
1091
- all_img_prompt_embs = img_prompt_embs.split(self.encoders_num_id_vecs, dim=1)
 
1092
  else:
1093
  all_img_prompt_embs = [None] * self.num_sub_encoders
 
 
1094
  if image_paths is not None:
1095
  BS = len(image_paths) if BS == -1 else BS
1096
  if BS == -1:
@@ -1116,23 +1230,29 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1116
  self.curr_are_encoders_enabled = are_encoders_enabled
1117
  all_adaface_subj_embs = []
1118
  num_available_id_vecs = 0
 
1119
 
1120
  for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders):
1121
  if not are_encoders_enabled[i]:
1122
  adaface_subj_embs = None
1123
- print(f"Encoder {id2ada_prompt_encoder.name} is dropped.")
 
 
1124
  else:
 
1125
  # ddpm.embedding_manager.train() -> id2ada_prompt_encoder.train() -> each sub-enconder's train().
1126
  # -> each sub-enconder's subj_basis_generator.train().
1127
  # Therefore grad for the following call is enabled.
1128
- adaface_subj_embs = \
1129
  id2ada_prompt_encoder.generate_adaface_embeddings(image_paths,
1130
  all_face_id_embs[i],
1131
  all_img_prompt_embs[i],
1132
  *args, **kwargs)
1133
 
1134
- # adaface_subj_embs: [16, 768] or [4, 768].
1135
- N_ID = self.encoders_num_id_vecs[i]
 
 
1136
  if adaface_subj_embs is None:
1137
  if not return_zero_embs_for_dropped_encoders:
1138
  continue
@@ -1143,12 +1263,16 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1143
  all_adaface_subj_embs.append(adaface_subj_embs)
1144
  else:
1145
  all_adaface_subj_embs.append(adaface_subj_embs)
 
 
1146
  num_available_id_vecs += N_ID
1147
 
 
 
1148
  # No faces are found in the images, so return None embeddings.
1149
  # We don't want to return an all-zero embedding, which is useless.
1150
  if num_available_id_vecs == 0:
1151
- return None
1152
 
1153
  # If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
1154
  # during inference, we average across the batch dim.
@@ -1158,7 +1282,12 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
1158
  # all_adaface_subj_embs[0]: [BS, 4, 768]. all_adaface_subj_embs[1]: [BS, 16, 768].
1159
  # all_adaface_subj_embs: [BS, 20, 768].
1160
  all_adaface_subj_embs = torch.cat(all_adaface_subj_embs, dim=-2)
1161
- return all_adaface_subj_embs
 
 
 
 
 
1162
 
1163
 
1164
  '''
 
53
  self.text_to_image_prompt_encoder = None
54
  self.tokenizer = None
55
  self.dtype = kwargs.get('dtype', torch.float16)
56
+ self.img2txt_dtype = kwargs.get('img2txt_dtype', torch.float16)
57
+ self.device = torch.device("cpu")
58
 
59
  # Load Img2Ada SubjectBasisGenerator.
60
  self.subject_string = kwargs.get('subject_string', 'z')
 
75
 
76
  self.use_clip_embs = False
77
  self.do_contrast_clip_embs_on_bg_features = False
78
+ # Override the default setting in derived classes.
79
+ if 'enable_static_img_suffix_embs' in kwargs:
80
+ self.default_enable_static_img_suffix_embs = kwargs['enable_static_img_suffix_embs']
81
+
82
  # num_id_vecs is the output embeddings of the ID2ImgPrompt module.
83
  # If there's no static image suffix embeddings, then num_id_vecs is also
84
  # the number of ada embeddings returned by the subject basis generator.
85
  # num_id_vecs will be set in each derived class.
86
  self.num_static_img_suffix_embs = kwargs.get('num_static_img_suffix_embs', 0)
87
+ print(f'{self.name} Adaface uses {self.num_id_vecs} ID image embeddings + {self.num_static_img_suffix_embs} fixed image embeddings as input.')
88
 
89
  self.id_img_prompt_max_length = 77
90
  self.face_id_dim = 512
 
93
  self.clip_embedding_dim = 1024
94
  self.output_dim = 768
95
 
96
+ # init_img2txt_projection() can only be called after the derived class is initialized,
97
+ # when self.num_id_vecs0, self.num_static_img_suffix_embs and self.clip_embedding_dim have been set.
98
+ def init_img2txt_projection(self):
 
 
 
 
 
 
 
 
 
99
  self.subj_basis_generator = \
100
+ SubjBasisGenerator(dtype=self.img2txt_dtype,
101
+ num_id_vecs = self.num_id_vecs0,
102
  num_static_img_suffix_embs = self.num_static_img_suffix_embs,
103
  bg_image_embedding_dim = self.clip_embedding_dim,
104
  output_dim = self.output_dim,
105
  placeholder_is_bg = False,
 
106
  bg_prompt_translator_has_to_out_proj=False)
107
 
108
  def load_adaface_ckpt(self, adaface_ckpt_path):
109
+ if isinstance(adaface_ckpt_path, (list, tuple, ListConfig)):
110
+ adaface_ckpt_path = adaface_ckpt_path[0]
111
+
112
+ ckpt = torch.load(adaface_ckpt_path, map_location='cpu', weights_only=False)
113
  string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
114
  if self.subject_string not in string_to_subj_basis_generator_dict:
115
  print(f"Subject '{self.subject_string}' not found in the embedding manager.")
116
  breakpoint()
117
 
118
  ckpt_subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
119
+ if isinstance(ckpt_subj_basis_generator, nn.ModuleList):
120
+ name2idx = { 'consistentID': 0, 'arc2face': 1 }
121
+ subj_basis_generator_idx = name2idx[self.name]
122
+ ckpt_subj_basis_generator = ckpt_subj_basis_generator[subj_basis_generator_idx]
123
+
124
+ ckpt_subj_basis_generator.N_ID = self.num_id_vecs0
125
  # Since we directly use the subject basis generator object from the ckpt,
126
  # fixing the number of static image suffix embeddings is much simpler.
127
  # Otherwise if we want to load the subject basis generator from its state_dict,
 
134
  ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.num_static_img_suffix_embs, img_prompt_dim=self.output_dim)
135
  # Fix missing variables in old ckpt.
136
  ckpt_subj_basis_generator.patch_old_subj_basis_generator_ckpt()
137
+
138
  self.subj_basis_generator.extend_prompt2token_proj_attention(\
139
  ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
140
  ret = self.subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict(), strict=False)
 
160
 
161
  self.subj_basis_generator.freeze_prompt2token_proj()
162
 
163
+ def set_out_id_embs_cfg_scale(self, out_id_embs_cfg_scale):
164
+ if isinstance(out_id_embs_cfg_scale, (list, tuple, ListConfig)):
165
+ out_id_embs_cfg_scale = out_id_embs_cfg_scale[0]
166
+ self.out_id_embs_cfg_scale = out_id_embs_cfg_scale
167
+
168
  @torch.no_grad()
169
  def get_clip_neg_features(self, BS):
170
  if self.clip_neg_features is None:
 
230
  image_obj, _, _ = pad_image_obj_to_square(image_obj)
231
  image_np = np.array(image_obj.resize(size, Image.NEAREST))
232
  face_info = self.face_app.get(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
233
+
234
  if len(face_info) > 0:
235
  face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
236
  # id_emb: [512,]
 
498
  # avg_at_stage == ada_prompt_emb usually produces the worst results.
499
  # avg_at_stage == id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better.
500
  # p_dropout and return_zero_embs_for_dropped_encoders are only used by Joint_FaceID2AdaPrompt.
501
+ # enable_static_img_suffix_embs=None: use the default setting.
502
  def generate_adaface_embeddings(self, image_paths, face_id_embs=None, img_prompt_embs=None,
503
  p_dropout=0,
504
  return_zero_embs_for_dropped_encoders=True,
505
  avg_at_stage='id_emb', # id_emb, img_prompt_emb, or None.
506
  perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
507
+ perturb_std=0, enable_static_img_suffix_embs=None):
508
+
509
+ if enable_static_img_suffix_embs is None:
510
+ enable_static_img_suffix_embs = self.default_enable_static_img_suffix_embs
511
+
512
+ lens_subj_emb_segments = [ self.num_id_vecs + enable_static_img_suffix_embs \
513
+ * self.num_static_img_suffix_embs ]
514
+
515
  if (avg_at_stage is None) or avg_at_stage.lower() == 'none':
516
  img_prompt_avg_at_stage = None
517
  else:
 
528
  id_batch_size = len(image_paths)
529
  else:
530
  id_batch_size = 1
531
+
532
  # faceid_embeds: [BS, 512] is a batch of extracted face analysis embeddings. NOT used later.
533
  # NOTE: If face_id_embs, image_paths and image_objs are all None,
534
  # then get_img_prompt_embs() generates random faceid_embeds/img_prompt_embs,
 
551
  verbose=True)
552
 
553
  if face_image_count == 0:
554
+ return None, None, lens_subj_emb_segments
555
 
556
  # No matter whether avg_at_stage is id_emb or img_prompt_emb, we average img_prompt_embs.
557
  elif avg_at_stage is not None and avg_at_stage.lower() != 'none':
 
564
  out_id_embs_cfg_scale=self.out_id_embs_cfg_scale,
565
  is_face=True,
566
  enable_static_img_suffix_embs=enable_static_img_suffix_embs)
567
+
568
+ if self.num_id_vecs < self.num_id_vecs0:
569
+ adaface_subj_embs = adaface_subj_embs[:, :self.num_id_vecs, :]
570
+
571
  # During training, img_prompt_avg_at_stage is None, and BS >= 1.
572
  # During inference, img_prompt_avg_at_stage is 'id_emb' or 'img_prompt_emb', and BS == 1.
573
  if img_prompt_avg_at_stage is not None:
574
  # adaface_subj_embs: [1, 16, 768] -> [16, 768]
575
  adaface_subj_embs = adaface_subj_embs.squeeze(0)
576
 
577
+ return adaface_subj_embs, img_prompt_embs, lens_subj_emb_segments
578
 
579
  class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
580
+ name = 'arc2face'
581
+ num_id_vecs0 = 16
582
+ # first 4 are kept, the rest 12 are averaged to another 4.
583
+ # Then concatenated to [8, 768].
584
+ num_id_vecs = 16
585
+ default_enable_static_img_suffix_embs = False
586
 
587
+ def __init__(self, *args, **kwargs):
588
  super().__init__(*args, **kwargs)
589
 
590
  self.clip_image_encoder = CLIPVisionModelWithMask.from_pretrained('openai/clip-vit-large-patch14')
 
610
  self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
611
  providers=['CPUExecutionProvider'])
612
  self.face_app.prepare(ctx_id=0, det_size=(512, 512))
613
+ print(f'Arc2Face Face encoder loaded on CPU.')
614
 
615
  self.text_to_image_prompt_encoder = CLIPTextModelWrapper.from_pretrained(
616
  'models/arc2face', subfolder="encoder",
 
621
  if self.out_id_embs_cfg_scale == -1:
622
  self.out_id_embs_cfg_scale = 1
623
  #### Arc2Face pipeline specific configs ####
624
+ self.gen_neg_img_prompt = False
625
  # bg CLIP features are used by the bg subject basis generator.
626
+ self.use_clip_embs = True
627
  self.do_contrast_clip_embs_on_bg_features = True
628
  # self.num_static_img_suffix_embs is initialized in the parent class.
629
+ self.id_img_prompt_max_length = 22
630
+ self.clip_embedding_dim = 1024
631
 
632
+ self.init_img2txt_projection()
633
  if self.adaface_ckpt_path is not None:
634
  self.load_adaface_ckpt(self.adaface_ckpt_path)
635
 
636
+ for param in self.clip_image_encoder.parameters():
637
+ param.requires_grad = False
638
+ for param in self.text_to_image_prompt_encoder.parameters():
639
+ param.requires_grad = False
640
+ for param in self.subj_basis_generator.parameters():
641
+ param.requires_grad = self.is_training
642
 
643
+ print(f"{self.name} ada prompt encoder initialized, "
644
+ f"ID vecs: {self.num_id_vecs0}, static suffix: {self.num_static_img_suffix_embs}.")
645
+
646
+ def _apply(self, fn):
647
+ super()._apply(fn) # Call the parent _apply to handle parameters and buffers
648
+ # A dirty hack to get the device of the model, passed from
649
+ # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
650
+ test_tensor = torch.zeros(1) # Create a test tensor
651
+ transformed_tensor = fn(test_tensor) # Apply `fn()` to test it
652
+ device = transformed_tensor.device # Get the device of the transformed tensor
653
+ # No need to reload face_app on the same device.
654
+ if device == self.device:
655
+ return
656
+
657
+ if str(device) == 'cpu':
658
+ self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
659
+ providers=['CPUExecutionProvider'])
660
+ self.face_app.prepare(ctx_id=0, det_size=(512, 512))
661
+ else:
662
+ device_id = device.index
663
+ self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
664
+ providers=['CUDAExecutionProvider'],
665
+ provider_options=[{"device_id": str(device_id)}])
666
+ self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
667
+
668
+ self.device = device
669
+ print(f'Arc2Face Face encoder reloaded on {device}.')
670
+ return
671
+
672
  # Arc2Face_ID2AdaPrompt never uses clip_features or called_for_neg_img_prompt.
673
  def map_init_id_to_img_prompt_embs(self, init_id_embs,
674
  clip_features=None,
 
716
  # [N, 22, 768] -> [N, 16, 768]
717
  return prompt_embeds[:, 4:20]
718
 
 
 
 
719
  # ConsistentID_ID2AdaPrompt is just a wrapper of ConsistentIDPipeline, so it's not an nn.Module.
720
  class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
721
+ name = 'consistentID'
722
+ num_id_vecs0 = 4
723
+ # No compression for ConsistentID.
724
+ num_id_vecs = 4
725
+ default_enable_static_img_suffix_embs = False
726
+
727
  def __init__(self, pipe=None, base_model_path="models/sd15-dste8-vae.safetensors",
728
  *args, **kwargs):
729
+
 
 
730
  super().__init__(*args, **kwargs)
731
  if pipe is None:
732
  # The base_model_path is kind of arbitrary, as the UNet and VAE in the model
 
773
  self.clip_embedding_dim = 1280
774
  self.s_scale = 1.0
775
  self.shortcut = False
776
+
777
+ self.init_img2txt_projection()
778
  if self.adaface_ckpt_path is not None:
779
  self.load_adaface_ckpt(self.adaface_ckpt_path)
780
 
781
+ for param in self.clip_image_encoder.parameters():
782
+ param.requires_grad = False
783
+ for param in self.image_proj_model.parameters():
784
+ param.requires_grad = False
785
+ for param in self.subj_basis_generator.parameters():
786
+ param.requires_grad = self.is_training
787
+
788
  print(f"{self.name} ada prompt encoder initialized, "
789
+ f"ID vecs: {self.num_id_vecs0}, static suffix: {self.num_static_img_suffix_embs}.")
790
+
791
+ def _apply(self, fn):
792
+ super()._apply(fn) # Call the parent _apply to handle parameters and buffers
793
+ # A dirty hack to get the device of the model, passed from
794
+ # parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
795
+ test_tensor = torch.zeros(1) # Create a test tensor
796
+ transformed_tensor = fn(test_tensor) # Apply `fn()` to test it
797
+ device = transformed_tensor.device # Get the device of the transformed tensor
798
+ # No need to reload face_app on the same device.
799
+ if device == self.device:
800
+ return
801
+
802
+ if str(device) == 'cpu':
803
+ self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
804
+ providers=['CPUExecutionProvider'])
805
+ self.face_app.prepare(ctx_id=0, det_size=(512, 512))
806
+ else:
807
+ device_id = device.index
808
+ self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
809
+ providers=['CUDAExecutionProvider'],
810
+ provider_options=[{"device_id": str(device_id)}])
811
+ self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
812
+
813
+ self.device = device
814
+ self.pipe.face_app = self.face_app
815
+ print(f'ConsistentID Face encoder reloaded on {device}.')
816
+
817
 
818
  def map_init_id_to_img_prompt_embs(self, init_id_embs,
819
  clip_features=None,
 
852
 
853
  return global_id_embeds
854
 
 
 
 
855
  # A wrapper for combining multiple FaceID2AdaPrompt instances.
856
  class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
857
  def __init__(self, adaface_encoder_types, adaface_ckpt_paths,
858
  out_id_embs_cfg_scales=None, enabled_encoders=None,
859
  *args, **kwargs):
860
  self.name = 'jointIDs'
861
+ name2class = { 'arc2face': Arc2Face_ID2AdaPrompt, 'consistentID': ConsistentID_ID2AdaPrompt }
862
  assert len(adaface_encoder_types) > 0, "adaface_encoder_types should not be empty."
863
+ adaface_encoder_types2num_id_vecs0 = { name: name2class[name].num_id_vecs0 for name in adaface_encoder_types }
864
+ adaface_encoder_types2num_id_vecs = { name: name2class[name].num_id_vecs for name in adaface_encoder_types }
865
+ # self.num_id_vecs0 is used in the parent class. So we need to initialize it here first.
866
+ self.encoders_num_id_vecs0 = [ adaface_encoder_types2num_id_vecs0[encoder_type] \
867
  for encoder_type in adaface_encoder_types ]
868
+ self.encoders_num_id_vecs = [ adaface_encoder_types2num_id_vecs[encoder_type] \
869
+ for encoder_type in adaface_encoder_types ]
870
+ self.num_id_vecs0 = sum(self.encoders_num_id_vecs0)
871
+ self.num_id_vecs = sum(self.encoders_num_id_vecs)
872
  # super() sets self.is_training.
873
  super().__init__(*args, **kwargs)
874
 
875
  self.num_sub_encoders = len(adaface_encoder_types)
876
  self.id2ada_prompt_encoders = nn.ModuleList()
877
  self.encoders_num_static_img_suffix_embs = []
878
+ self.default_enable_static_img_suffix_embs = []
879
 
880
  # TODO: apply adaface_encoder_cfg_scales to influence the final prompt embeddings.
881
  # Now they are just placeholders.
 
885
  self.out_id_embs_cfg_scales = [-1] * self.num_sub_encoders
886
  else:
887
  # Do not normalize the weights, and just use them as is.
888
+ self.out_id_embs_cfg_scales = list(out_id_embs_cfg_scales)
889
 
890
  # Note we don't pass the adaface_ckpt_paths to the base class, but instead,
891
  # we load them once and for all in self.load_adaface_ckpt().
892
+ # NOTE: during inference, num_static_img_suffix_embs is fixed to be 4 for each encoder.
893
+ # But we can still disable static_img_suffix_embs by setting enable_static_img_suffix_embs to False.
894
  for i, encoder_type in enumerate(adaface_encoder_types):
895
  kwargs['out_id_embs_cfg_scale'] = self.out_id_embs_cfg_scales[i]
896
  if encoder_type == 'arc2face':
 
899
  encoder = ConsistentID_ID2AdaPrompt(*args, **kwargs)
900
  else:
901
  breakpoint()
902
+
903
  self.id2ada_prompt_encoders.append(encoder)
904
  self.encoders_num_static_img_suffix_embs.append(encoder.num_static_img_suffix_embs)
905
+ self.default_enable_static_img_suffix_embs.append(encoder.default_enable_static_img_suffix_embs)
906
 
907
  self.num_static_img_suffix_embs = sum(self.encoders_num_static_img_suffix_embs)
908
  # No need to set gen_neg_img_prompt, as we don't access it in this class, but rather
 
932
  self.load_adaface_ckpt(adaface_ckpt_paths)
933
 
934
  print(f"{self.name} ada prompt encoder initialized with {self.num_sub_encoders} sub-encoders. "
935
+ f"ID vecs: {self.num_id_vecs0}, static suffix embs: {self.num_static_img_suffix_embs}.")
936
 
937
  if enabled_encoders is not None:
938
  self.are_encoders_enabled = \
 
948
  else:
949
  self.are_encoders_enabled = \
950
  torch.tensor([True] * self.num_sub_encoders)
951
+
 
 
 
 
 
 
 
 
952
  def load_adaface_ckpt(self, adaface_ckpt_paths):
 
 
953
  if isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
954
+ # If multiple adaface ckpt paths are provided, then we assume they are the
955
+ # ckpts of the sub-encoders.
956
+ if len(adaface_ckpt_paths) == self.num_sub_encoders:
957
+ for i, ckpt_path in enumerate(adaface_ckpt_paths):
958
+ self.id2ada_prompt_encoders[i].load_adaface_ckpt(ckpt_path)
959
+ return
960
+ # If only one adaface ckpt path is provided, then we assume it's the ckpt of the Joint_FaceID2AdaPrompt,
961
+ # so we dereference the list to get the actual path and load the subj_basis_generators of all adaface encoders.
962
+ elif len(adaface_ckpt_paths) == 1 and self.num_sub_encoders > 1:
963
  adaface_ckpt_paths = adaface_ckpt_paths[0]
964
+ else:
 
 
 
 
 
 
 
 
965
  breakpoint()
966
 
967
+ adaface_ckpt_path = adaface_ckpt_paths
968
+ assert isinstance(adaface_ckpt_path, str), "adaface_ckpt_path should be a string."
969
+ # This is only applicable to newest ckpts of Joint_FaceID2AdaPrompt, where
970
+ # the ckpt_subj_basis_generator is an nn.ModuleList of multiple subj_basis_generators.
971
+ # Therefore, no need to patch missing variables.
972
+ ckpt = torch.load(adaface_ckpt_paths, map_location='cpu', weights_only=False)
973
+ string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
974
+ if self.subject_string not in string_to_subj_basis_generator_dict:
975
+ print(f"Subject '{self.subject_string}' not found in the embedding manager.")
976
+ breakpoint()
977
 
978
+ ckpt_subj_basis_generators = string_to_subj_basis_generator_dict[self.subject_string]
979
+ if len(ckpt_subj_basis_generators) != self.num_sub_encoders:
980
+ print(f"Number of subj_basis_generators in the ckpt ({len(ckpt_subj_basis_generators)}) "
981
+ f"doesn't match the number of adaface encoders ({self.num_sub_encoders}).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982
  breakpoint()
983
 
984
+ for i, subj_basis_generator in enumerate(self.subj_basis_generator):
985
+ ckpt_subj_basis_generator = ckpt_subj_basis_generators[i]
986
+ # Handle differences in num_static_img_suffix_embs between the current model and the ckpt.
987
+ ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.encoders_num_static_img_suffix_embs[i],
988
+ img_prompt_dim=self.output_dim)
989
+
990
+ if subj_basis_generator.prompt2token_proj_attention_multipliers \
991
+ == [1] * 12:
992
+ subj_basis_generator.extend_prompt2token_proj_attention(\
993
+ ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
994
+ elif subj_basis_generator.prompt2token_proj_attention_multipliers \
995
+ != ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers:
996
+ raise ValueError("Inconsistent prompt2token_proj_attention_multipliers.")
997
+
998
+ assert subj_basis_generator.prompt2token_proj_attention_multipliers \
999
+ == ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, \
1000
+ "Inconsistent prompt2token_proj_attention_multipliers."
1001
+ subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict())
1002
+
1003
+ # extend_prompt2token_proj_attention_multiplier is an integer >= 1.
1004
+ # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
1005
+ # If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict,
1006
+ # extend subj_basis_generator again.
1007
+ if self.extend_prompt2token_proj_attention_multiplier > 1:
1008
+ # During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt.
1009
+ # During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1.
1010
+ # During inference, prompt2token_proj_ext_attention_perturb_ratio == 0.
1011
+ subj_basis_generator.extend_prompt2token_proj_attention(\
1012
+ None, -1, -1, self.extend_prompt2token_proj_attention_multiplier,
1013
+ perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio)
1014
+
1015
+ subj_basis_generator.freeze_prompt2token_proj()
1016
+
1017
+ print(f"{adaface_ckpt_paths}: {len(self.subj_basis_generator)} subj_basis_generators loaded for {self.name}.")
1018
+
1019
+ def set_out_id_embs_cfg_scale(self, out_id_embs_cfg_scales):
1020
+ self.out_id_embs_cfg_scales = list(out_id_embs_cfg_scales)
1021
+ for i, out_id_embs_cfg_scale in enumerate(out_id_embs_cfg_scales):
1022
+ self.id2ada_prompt_encoders[i].set_out_id_embs_cfg_scale(out_id_embs_cfg_scale)
1023
+
1024
  def extract_init_id_embeds_from_images(self, *args, **kwargs):
1025
  total_faceless_img_count = 0
1026
  all_id_embs = []
 
1158
 
1159
  N_ID = self.encoders_num_id_vecs[i]
1160
  if all_pos_prompt_embs[i] is None:
1161
+ # Both pos_prompt_embs and neg_prompt_embs have N_ID == num_id_vecs0 embeddings.
1162
  all_pos_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
1163
  if all_neg_prompt_embs[i] is None:
1164
  all_neg_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
 
1180
  # So its .device is the device of its parameters.
1181
  device = self.id2ada_prompt_encoders[0].clip_image_encoder.device
1182
  is_emb_averaged = kwargs.get('avg_at_stage', None) is not None
1183
+ if kwargs.get('enable_static_img_suffix_embs', None) is None:
1184
+ enable_static_img_suffix_embs = self.default_enable_static_img_suffix_embs
1185
+ else:
1186
+ enable_static_img_suffix_embs = kwargs['enable_static_img_suffix_embs']
1187
+ if isinstance(enable_static_img_suffix_embs, bool):
1188
+ enable_static_img_suffix_embs = [enable_static_img_suffix_embs] * self.num_sub_encoders
1189
+
1190
  BS = -1
1191
 
1192
  if face_id_embs is not None:
 
1194
  all_face_id_embs = face_id_embs.split(self.face_id_dims, dim=1)
1195
  else:
1196
  all_face_id_embs = [None] * self.num_sub_encoders
1197
+
1198
  if img_prompt_embs is not None:
1199
  BS = img_prompt_embs.shape[0] if BS == -1 else BS
1200
+ if img_prompt_embs.shape[1] != self.num_id_vecs0:
1201
  breakpoint()
1202
+ all_img_prompt_embs = img_prompt_embs.split(self.encoders_num_id_vecs0, dim=1)
1203
+ img_prompt_embs_provided = True
1204
  else:
1205
  all_img_prompt_embs = [None] * self.num_sub_encoders
1206
+ img_prompt_embs_provided = False
1207
+
1208
  if image_paths is not None:
1209
  BS = len(image_paths) if BS == -1 else BS
1210
  if BS == -1:
 
1230
  self.curr_are_encoders_enabled = are_encoders_enabled
1231
  all_adaface_subj_embs = []
1232
  num_available_id_vecs = 0
1233
+ lens_subj_emb_segments = []
1234
 
1235
  for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders):
1236
  if not are_encoders_enabled[i]:
1237
  adaface_subj_embs = None
1238
+ print(f"Encoder {id2ada_prompt_encoder.name} is disabled.")
1239
+ N_ID = id2ada_prompt_encoder.num_id_vecs + enable_static_img_suffix_embs[i] \
1240
+ * id2ada_prompt_encoder.num_static_img_suffix_embs
1241
  else:
1242
+ kwargs['enable_static_img_suffix_embs'] = enable_static_img_suffix_embs[i]
1243
  # ddpm.embedding_manager.train() -> id2ada_prompt_encoder.train() -> each sub-enconder's train().
1244
  # -> each sub-enconder's subj_basis_generator.train().
1245
  # Therefore grad for the following call is enabled.
1246
+ adaface_subj_embs, img_prompt_embs, encoder_lens_subj_emb_segments = \
1247
  id2ada_prompt_encoder.generate_adaface_embeddings(image_paths,
1248
  all_face_id_embs[i],
1249
  all_img_prompt_embs[i],
1250
  *args, **kwargs)
1251
 
1252
+ # adaface_subj_embs: arc2face [16, 768] or consistentID [4, 768],
1253
+ # or arc2face [20, 768] or consistentID [8, 768] if enable_static_img_suffix_embs=True.
1254
+ N_ID = encoder_lens_subj_emb_segments[0]
1255
+
1256
  if adaface_subj_embs is None:
1257
  if not return_zero_embs_for_dropped_encoders:
1258
  continue
 
1263
  all_adaface_subj_embs.append(adaface_subj_embs)
1264
  else:
1265
  all_adaface_subj_embs.append(adaface_subj_embs)
1266
+ if not img_prompt_embs_provided:
1267
+ all_img_prompt_embs[i] = img_prompt_embs
1268
  num_available_id_vecs += N_ID
1269
 
1270
+ lens_subj_emb_segments.append(N_ID)
1271
+
1272
  # No faces are found in the images, so return None embeddings.
1273
  # We don't want to return an all-zero embedding, which is useless.
1274
  if num_available_id_vecs == 0:
1275
+ return None, [0]
1276
 
1277
  # If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
1278
  # during inference, we average across the batch dim.
 
1282
  # all_adaface_subj_embs[0]: [BS, 4, 768]. all_adaface_subj_embs[1]: [BS, 16, 768].
1283
  # all_adaface_subj_embs: [BS, 20, 768].
1284
  all_adaface_subj_embs = torch.cat(all_adaface_subj_embs, dim=-2)
1285
+ # Check if some of the img_prompt_embs are None.
1286
+ if None in all_img_prompt_embs:
1287
+ all_img_prompt_embs = None
1288
+ else:
1289
+ all_img_prompt_embs = torch.cat(all_img_prompt_embs, dim=-2)
1290
+ return all_adaface_subj_embs, all_img_prompt_embs, lens_subj_emb_segments
1291
 
1292
 
1293
  '''
adaface/subj_basis_generator.py CHANGED
@@ -9,7 +9,7 @@ import torch
9
  from torch import nn
10
  from einops import rearrange
11
  from einops.layers.torch import Rearrange
12
- from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
13
 
14
  from torch import einsum
15
  from adaface.util import gen_gradient_scaler
@@ -57,7 +57,25 @@ class IP_MLPProjModel(nn.Module):
57
  x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
58
  x = self.norm(x)
59
  return x
60
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # group_dim: the tensor dimension that corresponds to the multiple groups.
62
  class LearnedSoftAggregate(nn.Module):
63
  def __init__(self, num_feat, group_dim, keepdim=False):
@@ -349,23 +367,26 @@ class CrossAttention(nn.Module):
349
  else:
350
  return out
351
 
 
352
  class ImgPrompt2TextPrompt(nn.Module):
353
- def __init__(self, placeholder_is_bg, num_id_vecs, dtype=torch.float32, *args, **kwargs):
 
354
  super().__init__()
355
  self.N_ID = num_id_vecs
356
  # If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components().
357
  self.N_SFX = 0
 
358
 
359
  if not placeholder_is_bg:
360
- self.initialize_text_components(*args, **kwargs)
 
361
 
362
  # prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
363
  # prompt2token_proj is with the same architecture as the original arc2face text encoder,
364
  # but retrained to do inverse mapping.
365
  # To be initialized in the subclass.
366
  self.prompt2token_proj = None
367
- self.dtype = dtype
368
-
369
  def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768):
370
  self.N_SFX = num_static_img_suffix_embs
371
  # We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs.
@@ -376,11 +397,11 @@ class ImgPrompt2TextPrompt(nn.Module):
376
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.")
377
  elif self.static_img_suffix_embs.shape[1] < self.N_SFX:
378
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.")
379
- self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim))
380
  elif self.N_SFX > 0:
381
  # self.static_img_suffix_embs.shape[1] > self.N_SFX > 0.
382
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.")
383
- self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX])
384
  else:
385
  # self.static_img_suffix_embs.shape[1] > self.N_SFX == 0.
386
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.")
@@ -391,7 +412,7 @@ class ImgPrompt2TextPrompt(nn.Module):
391
  # or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare,
392
  # so we don't consider to reuse and extend a shorter static_img_suffix_embs).
393
  # So we reinitialize it.
394
- self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim))
395
  else:
396
  # If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance.
397
  self.static_img_suffix_embs = None
@@ -399,9 +420,7 @@ class ImgPrompt2TextPrompt(nn.Module):
399
  # Implement a separate initialization function, so that it can be called from SubjBasisGenerator
400
  # after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator
401
  # ckpts which were not subclassed from ImgPrompt2TextPrompt.
402
- def initialize_text_components(self, max_prompt_length=77, num_id_vecs=16,
403
- num_static_img_suffix_embs=0, img_prompt_dim=768):
404
- self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim)
405
  self.max_prompt_length = max_prompt_length
406
  self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
407
  # clip_text_embeddings: CLIPTextEmbeddings instance.
@@ -416,7 +435,7 @@ class ImgPrompt2TextPrompt(nn.Module):
416
  # pad_embeddings is still on CPU. But should be moved to GPU automatically.
417
  # Note: detach pad_embeddings from the computation graph, otherwise
418
  # deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail.
419
- self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach()
420
 
421
  # image prompt space -> text prompt space.
422
  # return_emb_types: a list of strings, each string is among
@@ -439,7 +458,7 @@ class ImgPrompt2TextPrompt(nn.Module):
439
  else:
440
  breakpoint()
441
  else:
442
- # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_comp_prompt_distillation.
443
  # But list_extra_words always corresponds to the actual batch size. So we only take the first element.
444
  list_extra_words = list_extra_words[:1]
445
 
@@ -466,7 +485,7 @@ class ImgPrompt2TextPrompt(nn.Module):
466
  face_prompt_embs_orig_dtype = face_prompt_embs.dtype
467
  face_prompt_embs = face_prompt_embs.to(self.dtype)
468
 
469
- ID_END = 4 + self.N_ID
470
  PAD_BEGIN = ID_END + self.N_SFX + 2
471
 
472
  # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
@@ -545,6 +564,7 @@ class ImgPrompt2TextPrompt(nn.Module):
545
  class SubjBasisGenerator(ImgPrompt2TextPrompt):
546
  def __init__(
547
  self,
 
548
  # number of cross-attention heads of the bg prompt translator.
549
  # Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14:
550
  # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
@@ -553,22 +573,25 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
553
  # or number of background input identity vectors (no matter the subject is face or not).
554
  # 257: 257 CLIP tokens.
555
  num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 },
 
556
  num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4.
557
  num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings.
558
  bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above.
559
  obj_embedding_dim=384, # DINO object feature dimension for objects.
560
  output_dim=768, # CLIP text embedding input dimension.
 
561
  placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens.
562
- prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj.
563
  learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
564
- bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
565
  ):
566
 
567
  # If not placeholder_is_bg, then it calls initialize_text_components() in the superclass.
568
- super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs, max_prompt_length=77,
569
- num_static_img_suffix_embs=num_static_img_suffix_embs, img_prompt_dim=output_dim)
 
570
 
571
  self.placeholder_is_bg = placeholder_is_bg
 
572
  self.num_out_embs = self.N_ID + self.N_SFX
573
  self.output_dim = output_dim
574
  # num_nonface_in_id_vecs should be the number of core ID embs, 16.
@@ -586,14 +609,18 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
586
  # self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings).
587
  # If self.placeholder_is_bg: prompt2token_proj is set to None.
588
  # Use an attention dropout of 0.2 to increase robustness.
589
- clip_dropout_config = None #CLIPTextConfig.from_pretrained('openai/clip-vit-large-patch14', attention_dropout=0.05, dropout=0.05)
590
- self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14',
591
- config=clip_dropout_config)
592
- self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale
593
- self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale)
594
- print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.")
595
- # If prompt2token_proj_grad_scale is 0, freeze all params in prompt2token_proj.
596
- # Otherwise, only freeze token and positional embeddings of the original CLIPTextModel.
 
 
 
 
597
  self.freeze_prompt2token_proj()
598
 
599
  # These multipliers are relative to the original CLIPTextModel.
@@ -631,6 +658,9 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
631
  identity_to_out=identity_to_out,
632
  out_has_skip=out_has_skip)
633
 
 
 
 
634
  self.output_scale = output_dim ** -0.5
635
 
636
  '''
@@ -686,21 +716,20 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
686
  hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
687
 
688
  # faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space.
689
- with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0):
690
- # If list_extra_words is not None, then ada_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
691
- # and (at most) two extra words in adaface_prompt_embs, without BOS and EOS.
692
- # If list_extra_words is None, then ada_id_embs: [BS, 16, 768], the 16 identity tokens in adaface_prompt_embs.
693
- # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
694
- # ada_id_embs: [BS, 16, 768].
695
- # return_emb_types: a list of strings, each string is among
696
- # ['full', 'core', 'full_pad', 'full_half_pad'].
697
- ada_id_embs, = \
698
- self.inverse_img_prompt_embs(faceid2img_prompt_embs,
699
- list_extra_words=None,
700
- return_emb_types=['core'],
701
- hidden_state_layer_weights=hidden_state_layer_weights,
702
- enable_static_img_suffix_embs=enable_static_img_suffix_embs)
703
- ada_id_embs = self.prompt2token_proj_grad_scaler(ada_id_embs)
704
  elif raw_id_embs is not None:
705
  # id_embs: [BS, 384] -> [BS, 18, 768].
706
  # obj_proj_in is expected to project the DINO object features to
@@ -726,14 +755,15 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
726
 
727
  adaface_out_embs = id_embs_out * self.output_scale # * 0.036
728
  else:
729
- adaface_out_embs = ada_id_embs
 
730
  # If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings.
731
  if out_id_embs_cfg_scale != 1:
732
- # pad_embeddings: [77, 768] -> [16, 768] -> [1, 16, 768].
733
  # NOTE: Never do cfg on static image suffix embeddings.
734
  # So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX,
735
  # even if enable_static_img_suffix_embs=True.
736
- pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).to(ada_id_embs.device)
737
  adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \
738
  + pad_embeddings * (1 - out_id_embs_cfg_scale)
739
 
@@ -812,37 +842,37 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
812
  # Only applicable to fg basis generator.
813
  if self.placeholder_is_bg:
814
  return
815
- # If bg, then prompt2token_proj is set to None. Therefore no need to freeze it.
816
- # Then we don't have to check whether it's for subj or bg.
817
- if self.prompt2token_proj_grad_scale == 0:
818
- frozen_components_name = 'all'
819
- frozen_param_set = self.prompt2token_proj.named_parameters()
820
- else:
821
- frozen_components_name = 'token_pos_embeddings'
822
- frozen_param_set = self.prompt2token_proj.text_model.embeddings.named_parameters()
823
-
824
  if self.prompt2token_proj is not None:
825
  frozen_param_names = []
826
- for param_name, param in frozen_param_set:
827
  if param.requires_grad:
828
  param.requires_grad = False
829
  frozen_param_names.append(param_name)
830
  # If param is already frozen, then no need to freeze it again.
831
- print(f"{frozen_components_name} {len(frozen_param_names)} params in Subj prompt2token_proj is frozen.")
832
  #print(f"Frozen parameters:\n{frozen_param_names}")
833
 
834
  def patch_old_subj_basis_generator_ckpt(self):
835
  # Fix compatability with the previous version.
836
  if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
837
  self.bg_prompt_translator_has_to_out_proj = False
838
- if not hasattr(self, 'num_out_embs'):
839
- self.num_out_embs = -1
840
  if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'):
841
  self.N_ID = self.num_id_vecs
 
 
 
842
  if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'):
843
  self.num_nonface_in_id_vecs = self.N_ID
844
  if not hasattr(self, 'dtype'):
845
- self.dtype = torch.float32
 
 
 
 
 
 
 
846
 
847
  if self.placeholder_is_bg:
848
  if not hasattr(self, 'pos_embs') or self.pos_embs is None:
@@ -860,6 +890,14 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
860
  num_static_img_suffix_embs=self.N_SFX,
861
  img_prompt_dim=self.output_dim)
862
 
 
 
 
 
 
 
 
 
863
  def __repr__(self):
864
  type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
865
 
 
9
  from torch import nn
10
  from einops import rearrange
11
  from einops.layers.torch import Rearrange
12
+ from transformers import CLIPTokenizer, CLIPTextModel
13
 
14
  from torch import einsum
15
  from adaface.util import gen_gradient_scaler
 
57
  x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
58
  x = self.norm(x)
59
  return x
60
+
61
+ class LayerwiseMLPProjWithSkip(nn.Module):
62
+ def __init__(self, id_embeddings_dim=768, num_layers=16, dim_mult=2):
63
+ super().__init__()
64
+
65
+ self.proj = nn.Sequential(
66
+ nn.Linear(id_embeddings_dim, id_embeddings_dim*dim_mult*num_layers),
67
+ Rearrange('b n (l d) -> b n l d', l=num_layers, d=id_embeddings_dim*dim_mult),
68
+ nn.GELU(),
69
+ nn.Linear(id_embeddings_dim*dim_mult, id_embeddings_dim),
70
+ )
71
+ self.norm = nn.LayerNorm(id_embeddings_dim)
72
+
73
+ def forward(self, id_embeds):
74
+ # B N D -> B N L D + B N L D -> B N L D
75
+ x = self.proj(id_embeds) + id_embeds.unsqueeze(1)
76
+ x = self.norm(x)
77
+ return x
78
+
79
  # group_dim: the tensor dimension that corresponds to the multiple groups.
80
  class LearnedSoftAggregate(nn.Module):
81
  def __init__(self, num_feat, group_dim, keepdim=False):
 
367
  else:
368
  return out
369
 
370
+
371
  class ImgPrompt2TextPrompt(nn.Module):
372
+ def __init__(self, placeholder_is_bg, num_id_vecs, num_static_img_suffix_embs,
373
+ max_prompt_length=77, img_prompt_dim=768, dtype=torch.float16):
374
  super().__init__()
375
  self.N_ID = num_id_vecs
376
  # If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components().
377
  self.N_SFX = 0
378
+ self.dtype = dtype
379
 
380
  if not placeholder_is_bg:
381
+ self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim)
382
+ self.initialize_text_components(max_prompt_length)
383
 
384
  # prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
385
  # prompt2token_proj is with the same architecture as the original arc2face text encoder,
386
  # but retrained to do inverse mapping.
387
  # To be initialized in the subclass.
388
  self.prompt2token_proj = None
389
+
 
390
  def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768):
391
  self.N_SFX = num_static_img_suffix_embs
392
  # We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs.
 
397
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.")
398
  elif self.static_img_suffix_embs.shape[1] < self.N_SFX:
399
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.")
400
+ self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim, dtype=self.dtype))
401
  elif self.N_SFX > 0:
402
  # self.static_img_suffix_embs.shape[1] > self.N_SFX > 0.
403
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.")
404
+ self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX].to(dtype=self.dtype))
405
  else:
406
  # self.static_img_suffix_embs.shape[1] > self.N_SFX == 0.
407
  print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.")
 
412
  # or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare,
413
  # so we don't consider to reuse and extend a shorter static_img_suffix_embs).
414
  # So we reinitialize it.
415
+ self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim, dtype=self.dtype))
416
  else:
417
  # If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance.
418
  self.static_img_suffix_embs = None
 
420
  # Implement a separate initialization function, so that it can be called from SubjBasisGenerator
421
  # after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator
422
  # ckpts which were not subclassed from ImgPrompt2TextPrompt.
423
+ def initialize_text_components(self, max_prompt_length=77):
 
 
424
  self.max_prompt_length = max_prompt_length
425
  self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
426
  # clip_text_embeddings: CLIPTextEmbeddings instance.
 
435
  # pad_embeddings is still on CPU. But should be moved to GPU automatically.
436
  # Note: detach pad_embeddings from the computation graph, otherwise
437
  # deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail.
438
+ self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach().to(self.dtype)
439
 
440
  # image prompt space -> text prompt space.
441
  # return_emb_types: a list of strings, each string is among
 
458
  else:
459
  breakpoint()
460
  else:
461
+ # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_feat_distill_on_comp_prompt.
462
  # But list_extra_words always corresponds to the actual batch size. So we only take the first element.
463
  list_extra_words = list_extra_words[:1]
464
 
 
485
  face_prompt_embs_orig_dtype = face_prompt_embs.dtype
486
  face_prompt_embs = face_prompt_embs.to(self.dtype)
487
 
488
+ ID_END = 4 + self.N_ID
489
  PAD_BEGIN = ID_END + self.N_SFX + 2
490
 
491
  # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
 
564
  class SubjBasisGenerator(ImgPrompt2TextPrompt):
565
  def __init__(
566
  self,
567
+ dtype=torch.float16,
568
  # number of cross-attention heads of the bg prompt translator.
569
  # Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14:
570
  # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
 
573
  # or number of background input identity vectors (no matter the subject is face or not).
574
  # 257: 257 CLIP tokens.
575
  num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 },
576
+ num_ca_layers=16,
577
  num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4.
578
  num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings.
579
  bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above.
580
  obj_embedding_dim=384, # DINO object feature dimension for objects.
581
  output_dim=768, # CLIP text embedding input dimension.
582
+ use_layerwise_proj: bool = False, # Whether to use layerwise projection.
583
  placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens.
 
584
  learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
585
+ bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
586
  ):
587
 
588
  # If not placeholder_is_bg, then it calls initialize_text_components() in the superclass.
589
+ super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs,
590
+ num_static_img_suffix_embs=num_static_img_suffix_embs,
591
+ max_prompt_length=77, img_prompt_dim=output_dim, dtype=dtype)
592
 
593
  self.placeholder_is_bg = placeholder_is_bg
594
+ self.num_ca_layers = num_ca_layers
595
  self.num_out_embs = self.N_ID + self.N_SFX
596
  self.output_dim = output_dim
597
  # num_nonface_in_id_vecs should be the number of core ID embs, 16.
 
609
  # self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings).
610
  # If self.placeholder_is_bg: prompt2token_proj is set to None.
611
  # Use an attention dropout of 0.2 to increase robustness.
612
+ self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
613
+ self.prompt2token_proj.to(dtype=self.dtype)
614
+
615
+ if use_layerwise_proj:
616
+ # MLPProjWithSkip: MLP with skip connection.
617
+ # [BS, 4, 768] -> [BS, 16, 4, 768]. Extra 16: 16 layers.
618
+ self.layerwise_proj = LayerwiseMLPProjWithSkip(output_dim, dim_mult=2)
619
+ else:
620
+ self.layerwise_proj = nn.Identity() #Rearrange('b n d -> b l n d', l=16)
621
+
622
+ print(f"Subj prompt2token_proj initialized.")
623
+ # Only freeze token and positional embeddings of the original CLIPTextModel.
624
  self.freeze_prompt2token_proj()
625
 
626
  # These multipliers are relative to the original CLIPTextModel.
 
658
  identity_to_out=identity_to_out,
659
  out_has_skip=out_has_skip)
660
 
661
+ if self.dtype == torch.float16:
662
+ self.prompt_translator = self.prompt_translator.half()
663
+
664
  self.output_scale = output_dim ** -0.5
665
 
666
  '''
 
716
  hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
717
 
718
  # faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space.
719
+ # inverse_img_prompt_embs() applies self.prompt2token_proj to faceid2img_prompt_embs.
720
+ # If list_extra_words is not None, then ada_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
721
+ # and (at most) two extra words in adaface_prompt_embs, without BOS and EOS.
722
+ # If list_extra_words is None, then ada_id_embs: [BS, 16, 768], the 16 identity tokens in adaface_prompt_embs.
723
+ # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
724
+ # ada_id_embs: [BS, 16, 768].
725
+ # return_emb_types: a list of strings, each string is among
726
+ # ['full', 'core', 'full_pad', 'full_half_pad'].
727
+ ada_id_embs, = \
728
+ self.inverse_img_prompt_embs(faceid2img_prompt_embs,
729
+ list_extra_words=None,
730
+ return_emb_types=['core'],
731
+ hidden_state_layer_weights=hidden_state_layer_weights,
732
+ enable_static_img_suffix_embs=enable_static_img_suffix_embs)
 
733
  elif raw_id_embs is not None:
734
  # id_embs: [BS, 384] -> [BS, 18, 768].
735
  # obj_proj_in is expected to project the DINO object features to
 
755
 
756
  adaface_out_embs = id_embs_out * self.output_scale # * 0.036
757
  else:
758
+ # [BS, 16, 768] -> [BS, layers=16, tokens=16, 768]
759
+ adaface_out_embs = self.layerwise_proj(ada_id_embs)
760
  # If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings.
761
  if out_id_embs_cfg_scale != 1:
762
+ # pad_embeddings: [77, 768] -> [16, 768] -> [1, 1, 16, 768].
763
  # NOTE: Never do cfg on static image suffix embeddings.
764
  # So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX,
765
  # even if enable_static_img_suffix_embs=True.
766
+ pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).unsqueeze(1).to(ada_id_embs.device)
767
  adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \
768
  + pad_embeddings * (1 - out_id_embs_cfg_scale)
769
 
 
842
  # Only applicable to fg basis generator.
843
  if self.placeholder_is_bg:
844
  return
845
+
 
 
 
 
 
 
 
 
846
  if self.prompt2token_proj is not None:
847
  frozen_param_names = []
848
+ for param_name, param in self.prompt2token_proj.text_model.embeddings.named_parameters():
849
  if param.requires_grad:
850
  param.requires_grad = False
851
  frozen_param_names.append(param_name)
852
  # If param is already frozen, then no need to freeze it again.
853
+ print(f"{len(frozen_param_names)} params of token_pos_embeddings in Subj prompt2token_proj is frozen.")
854
  #print(f"Frozen parameters:\n{frozen_param_names}")
855
 
856
  def patch_old_subj_basis_generator_ckpt(self):
857
  # Fix compatability with the previous version.
858
  if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
859
  self.bg_prompt_translator_has_to_out_proj = False
 
 
860
  if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'):
861
  self.N_ID = self.num_id_vecs
862
+ # Update the number of output embeddings.
863
+ self.num_out_embs = self.N_ID + self.N_SFX
864
+
865
  if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'):
866
  self.num_nonface_in_id_vecs = self.N_ID
867
  if not hasattr(self, 'dtype'):
868
+ self.dtype = torch.float16
869
+ if not self.placeholder_is_bg:
870
+ self.prompt2token_proj.to(dtype=self.dtype)
871
+ else:
872
+ self.prompt_translator.half()
873
+
874
+ if not hasattr(self, 'num_ca_layers'):
875
+ self.num_ca_layers = 16
876
 
877
  if self.placeholder_is_bg:
878
  if not hasattr(self, 'pos_embs') or self.pos_embs is None:
 
890
  num_static_img_suffix_embs=self.N_SFX,
891
  img_prompt_dim=self.output_dim)
892
 
893
+ if not hasattr(self, 'use_layerwise_proj'):
894
+ self.use_layerwise_proj = False
895
+ if not hasattr(self, 'layerwise_proj'):
896
+ if self.use_layerwise_proj:
897
+ self.layerwise_proj = LayerwiseMLPProjWithSkip(self.output_dim, dim_mult=2)
898
+ else:
899
+ self.layerwise_proj = nn.Identity()
900
+
901
  def __repr__(self):
902
  type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
903
 
adaface/unet_teachers.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
 
2
  import numpy as np
3
- import pytorch_lightning as pl
4
  from diffusers import UNet2DConditionModel
5
  from adaface.util import UNetEnsemble, create_consistentid_pipeline
6
  from diffusers import UNet2DConditionModel
@@ -12,9 +12,9 @@ def create_unet_teacher(teacher_type, device='cpu', **kwargs):
12
  teacher_type = teacher_type[0]
13
 
14
  if teacher_type == "arc2face":
15
- return Arc2FaceTeacher(**kwargs)
16
  elif teacher_type == "unet_ensemble":
17
- # unet, extra_unet_dirpaths and unet_weights are passed in kwargs.
18
  # Even if we distill from unet_ensemble, we still need to load arc2face for generating
19
  # arc2face embeddings.
20
  # The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet,
@@ -22,20 +22,24 @@ def create_unet_teacher(teacher_type, device='cpu', **kwargs):
22
  # However, since the __call__ method of the ddpm unet takes different formats of params,
23
  # for simplicity, we still use the diffusers unet.
24
  # unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU.
25
- return UNetEnsembleTeacher(device=device, **kwargs)
26
  elif teacher_type == "consistentID":
27
- return ConsistentIDTeacher(**kwargs)
28
  elif teacher_type == "simple_unet":
29
- return SimpleUNetTeacher(**kwargs)
30
  # Since we've dereferenced the list if it has only one element,
31
  # this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher.
32
  elif isinstance(teacher_type, (tuple, list, ListConfig)):
33
  # teacher_type is a list of teacher types. So it's UNetEnsembleTeacher.
34
- return UNetEnsembleTeacher(unet_types=teacher_type, device=device, **kwargs)
35
  else:
36
  raise NotImplementedError(f"Teacher type {teacher_type} not implemented.")
37
 
38
- class UNetTeacher(pl.LightningModule):
 
 
 
 
39
  def __init__(self, **kwargs):
40
  super().__init__()
41
  self.name = None
@@ -56,9 +60,10 @@ class UNetTeacher(pl.LightningModule):
56
  # to be initialized, which will unnecessarily complicate the code.
57
  # noise: the initial noise for the first iteration.
58
  # t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
59
- # uses_same_t: when sampling t, use the same t for all instances.
60
- def forward(self, ddpm_model, x_start, noise, t, teacher_context,
61
- num_denoising_steps=1, uses_same_t=False):
 
62
  assert num_denoising_steps <= 10
63
 
64
  if self.p_uses_cfg > 0:
@@ -71,27 +76,22 @@ class UNetTeacher(pl.LightningModule):
71
 
72
  if self.uses_cfg:
73
  print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
 
 
 
 
 
 
74
  else:
75
  self.cfg_scale = 1
76
  print("Teacher does not use CFG.")
77
 
78
- # If p_uses_cfg > 0, we always pass both pos_context and neg_context to the teacher.
79
- # But the neg_context is only used when self.uses_cfg is True and cfg_scale > 1.
80
- # So we manually split the teacher_context into pos_context and neg_context, and only keep pos_context.
81
- if self.name == 'unet_ensemble':
82
- teacher_pos_contexts = []
83
- # teacher_context is a list of teacher contexts.
84
- for teacher_context_i in teacher_context:
85
- pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
86
- if pos_context.shape[0] != x_start.shape[0]:
87
- breakpoint()
88
- teacher_pos_contexts.append(pos_context)
89
- teacher_context = teacher_pos_contexts
90
- else:
91
- pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
92
- if pos_context.shape[0] != x_start.shape[0]:
93
- breakpoint()
94
- teacher_context = pos_context
95
  else:
96
  # p_uses_cfg = 0. Never use CFG.
97
  self.uses_cfg = False
@@ -102,15 +102,21 @@ class UNetTeacher(pl.LightningModule):
102
  # in case someday we want to switch from CFG to non-CFG during runtime.
103
  self.cfg_scale = 1
104
 
 
105
  if self.name == 'unet_ensemble':
106
  # teacher_context is a list of teacher contexts.
107
  for teacher_context_i in teacher_context:
108
- if teacher_context_i.shape[0] != x_start.shape[0] * (1 + self.uses_cfg):
109
  breakpoint()
110
  else:
111
- if teacher_context.shape[0] != x_start.shape[0] * (1 + self.uses_cfg):
112
  breakpoint()
113
-
 
 
 
 
 
114
  # Initially, x_starts only contains the original x_start.
115
  x_starts = [ x_start ]
116
  noises = [ noise ]
@@ -125,24 +131,35 @@ class UNetTeacher(pl.LightningModule):
125
  # sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise
126
  x_noisy = ddpm_model.q_sample(x_start, t, noise)
127
 
128
- if self.uses_cfg:
129
  x_noisy2 = x_noisy.repeat(2, 1, 1, 1)
130
  t2 = t.repeat(2)
131
  else:
132
  x_noisy2 = x_noisy
133
- t2 = t
134
 
135
  # If do_arc2face_distill, then pos_context is [BS=6, 21, 768].
136
  noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context,
137
  return_dict=False)[0]
138
  if self.uses_cfg and self.cfg_scale > 1:
139
- pos_noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
140
  noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1)
141
 
142
- # sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise
143
- pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred)
144
  noise_preds.append(noise_pred)
145
-
 
146
  # The predicted x0 is used as the x_start for the next denoising step.
147
  x_starts.append(pred_x0)
148
 
@@ -157,20 +174,43 @@ class UNetTeacher(pl.LightningModule):
157
  # of the current timestep.
158
  t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3))
159
  t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3))
 
 
160
  earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb
161
  earlier_timesteps = earlier_timesteps.long()
 
162
 
163
- if uses_same_t:
164
- # If uses_same_t, we use the same earlier_timesteps for all instances.
165
  earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0])
 
166
 
167
  # earlier_timesteps = ts[i+1] < ts[i].
168
  ts.append(earlier_timesteps)
169
-
170
- noise = torch.randn_like(pred_x0)
171
  noises.append(noise)
172
 
173
  return noise_preds, x_starts, noises, ts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  class Arc2FaceTeacher(UNetTeacher):
176
  def __init__(self, **kwargs):
@@ -185,11 +225,11 @@ class Arc2FaceTeacher(UNetTeacher):
185
  self.cfg_scale_range = [1, 1]
186
 
187
  class UNetEnsembleTeacher(UNetTeacher):
188
- # unet_weights are not model weights, but scalar weights for individual unets.
189
- def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', **kwargs):
190
  super().__init__(**kwargs)
191
  self.name = "unet_ensemble"
192
- self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths, unet_weights, device)
193
 
194
  class ConsistentIDTeacher(UNetTeacher):
195
  def __init__(self, base_model_path="models/sd15-dste8-vae.safetensors", **kwargs):
@@ -199,12 +239,9 @@ class ConsistentIDTeacher(UNetTeacher):
199
  # In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module.
200
  # We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU.
201
  # Instead, we have to initialize it to GPU directly.
202
- pipe = create_consistentid_pipeline(base_model_path)
203
- # Compatible with the UNetTeacher interface.
204
- self.unet = pipe.unet
205
- # Release VAE and text_encoder to save memory. UNet is still needed for denoising
206
  # (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet).
207
- pipe.release_components(["vae", "text_encoder"])
208
 
209
  # We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher.
210
  # Note p_uses_cfg=0.5 will also be passed in in kwargs.
 
1
  import torch
2
+ from torch import nn
3
  import numpy as np
 
4
  from diffusers import UNet2DConditionModel
5
  from adaface.util import UNetEnsemble, create_consistentid_pipeline
6
  from diffusers import UNet2DConditionModel
 
12
  teacher_type = teacher_type[0]
13
 
14
  if teacher_type == "arc2face":
15
+ teacher = Arc2FaceTeacher(**kwargs)
16
  elif teacher_type == "unet_ensemble":
17
+ # unet, extra_unet_dirpaths and unet_weights_in_ensemble are passed in kwargs.
18
  # Even if we distill from unet_ensemble, we still need to load arc2face for generating
19
  # arc2face embeddings.
20
  # The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet,
 
22
  # However, since the __call__ method of the ddpm unet takes different formats of params,
23
  # for simplicity, we still use the diffusers unet.
24
  # unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU.
25
+ teacher = UNetEnsembleTeacher(device=device, **kwargs)
26
  elif teacher_type == "consistentID":
27
+ teacher = ConsistentIDTeacher(**kwargs)
28
  elif teacher_type == "simple_unet":
29
+ teacher = SimpleUNetTeacher(**kwargs)
30
  # Since we've dereferenced the list if it has only one element,
31
  # this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher.
32
  elif isinstance(teacher_type, (tuple, list, ListConfig)):
33
  # teacher_type is a list of teacher types. So it's UNetEnsembleTeacher.
34
+ teacher = UNetEnsembleTeacher(unet_types=teacher_type, device=device, **kwargs)
35
  else:
36
  raise NotImplementedError(f"Teacher type {teacher_type} not implemented.")
37
 
38
+ for param in teacher.parameters():
39
+ param.requires_grad = False
40
+ return teacher
41
+
42
+ class UNetTeacher(nn.Module):
43
  def __init__(self, **kwargs):
44
  super().__init__()
45
  self.name = None
 
60
  # to be initialized, which will unnecessarily complicate the code.
61
  # noise: the initial noise for the first iteration.
62
  # t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
63
+ # same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
64
+ def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
65
+ num_denoising_steps=1, same_t_noise_across_instances=False,
66
+ global_t_lb=0, global_t_ub=1000):
67
  assert num_denoising_steps <= 10
68
 
69
  if self.p_uses_cfg > 0:
 
76
 
77
  if self.uses_cfg:
78
  print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
79
+ if negative_context is not None:
80
+ negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
81
+
82
+ # if negative_context is None, then teacher_context is a combination of
83
+ # (one or multiple if unet_ensemble) pos_context and neg_context.
84
+ # If negative_context is not None, then teacher_context is only pos_context.
85
  else:
86
  self.cfg_scale = 1
87
  print("Teacher does not use CFG.")
88
 
89
+ # If negative_context is None, then teacher_context is a combination of
90
+ # (one or multiple if unet_ensemble) pos_context and neg_context.
91
+ # Since not uses_cfg, we only need pos_context.
92
+ # If negative_context is not None, then teacher_context is only pos_context.
93
+ if negative_context is None:
94
+ teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
 
 
 
 
 
 
 
 
 
 
 
95
  else:
96
  # p_uses_cfg = 0. Never use CFG.
97
  self.uses_cfg = False
 
102
  # in case someday we want to switch from CFG to non-CFG during runtime.
103
  self.cfg_scale = 1
104
 
105
+ is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
106
  if self.name == 'unet_ensemble':
107
  # teacher_context is a list of teacher contexts.
108
  for teacher_context_i in teacher_context:
109
+ if teacher_context_i.shape[0] != x_start.shape[0] * is_context_doubled:
110
  breakpoint()
111
  else:
112
+ if teacher_context.shape[0] != x_start.shape[0] * is_context_doubled:
113
  breakpoint()
114
+
115
+ if same_t_noise_across_instances:
116
+ # If same_t_noise_across_instances, we use the same t and noise for all instances.
117
+ t = t[0].repeat(x_start.shape[0])
118
+ noise = noise[:1].repeat(x_start.shape[0], 1, 1, 1)
119
+
120
  # Initially, x_starts only contains the original x_start.
121
  x_starts = [ x_start ]
122
  noises = [ noise ]
 
131
  # sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise
132
  x_noisy = ddpm_model.q_sample(x_start, t, noise)
133
 
134
+ if self.uses_cfg and self.cfg_scale > 1 and negative_context is None:
135
  x_noisy2 = x_noisy.repeat(2, 1, 1, 1)
136
  t2 = t.repeat(2)
137
  else:
138
  x_noisy2 = x_noisy
139
+ t2 = t
140
 
141
  # If do_arc2face_distill, then pos_context is [BS=6, 21, 768].
142
  noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context,
143
  return_dict=False)[0]
144
  if self.uses_cfg and self.cfg_scale > 1:
145
+ if negative_context is None:
146
+ pos_noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0)
147
+ else:
148
+ # If negative_context is not None, then teacher_context is only pos_context.
149
+ pos_noise_pred = noise_pred
150
+ with torch.no_grad():
151
+ if self.name == 'unet_ensemble':
152
+ neg_noise_pred = self.unet.unets[0](sample=x_noisy, timestep=t,
153
+ encoder_hidden_states=negative_context, return_dict=False)[0]
154
+ else:
155
+ neg_noise_pred = self.unet(sample=x_noisy, timestep=t,
156
+ encoder_hidden_states=negative_context, return_dict=False)[0]
157
+
158
  noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1)
159
 
 
 
160
  noise_preds.append(noise_pred)
161
+ # sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise
162
+ pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred)
163
  # The predicted x0 is used as the x_start for the next denoising step.
164
  x_starts.append(pred_x0)
165
 
 
174
  # of the current timestep.
175
  t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3))
176
  t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3))
177
+ t_lb = torch.clamp(t_lb, min=global_t_lb)
178
+ t_ub = torch.clamp(t_ub, max=global_t_ub)
179
  earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb
180
  earlier_timesteps = earlier_timesteps.long()
181
+ noise = torch.randn_like(pred_x0)
182
 
183
+ if same_t_noise_across_instances:
184
+ # If same_t_noise_across_instances, we use the same earlier_timesteps and noise for all instances.
185
  earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0])
186
+ noise = noise[:1].repeat(x_start.shape[0], 1, 1, 1)
187
 
188
  # earlier_timesteps = ts[i+1] < ts[i].
189
  ts.append(earlier_timesteps)
 
 
190
  noises.append(noise)
191
 
192
  return noise_preds, x_starts, noises, ts
193
+
194
+ def extract_pos_context(self, teacher_context, BS):
195
+ # If p_uses_cfg > 0, we always pass both pos_context and neg_context to the teacher.
196
+ # But the neg_context is only used when self.uses_cfg is True and cfg_scale > 1.
197
+ # So we manually split the teacher_context into pos_context and neg_context, and only keep pos_context.
198
+ if self.name == 'unet_ensemble':
199
+ teacher_pos_contexts = []
200
+ # teacher_context is a list of teacher contexts.
201
+ for teacher_context_i in teacher_context:
202
+ pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
203
+ if pos_context.shape[0] != BS:
204
+ breakpoint()
205
+ teacher_pos_contexts.append(pos_context)
206
+ teacher_context = teacher_pos_contexts
207
+ else:
208
+ pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
209
+ if pos_context.shape[0] != BS:
210
+ breakpoint()
211
+ teacher_context = pos_context
212
+
213
+ return teacher_context
214
 
215
  class Arc2FaceTeacher(UNetTeacher):
216
  def __init__(self, **kwargs):
 
225
  self.cfg_scale_range = [1, 1]
226
 
227
  class UNetEnsembleTeacher(UNetTeacher):
228
+ # unet_weights_in_ensemble are not model weights, but scalar weights for individual unets.
229
+ def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble=None, device='cuda', **kwargs):
230
  super().__init__(**kwargs)
231
  self.name = "unet_ensemble"
232
+ self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble, device)
233
 
234
  class ConsistentIDTeacher(UNetTeacher):
235
  def __init__(self, base_model_path="models/sd15-dste8-vae.safetensors", **kwargs):
 
239
  # In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module.
240
  # We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU.
241
  # Instead, we have to initialize it to GPU directly.
242
+ # Release VAE and text_encoder to save memory. UNet is needed for denoising
 
 
 
243
  # (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet).
244
+ self.unet = create_consistentid_pipeline(base_model_path, unet_only=True)
245
 
246
  # We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher.
247
  # Note p_uses_cfg=0.5 will also be passed in in kwargs.
adaface/util.py CHANGED
@@ -57,7 +57,7 @@ def perturb_np_array(np_array, perturb_std, perturb_std_is_relative=True, std_di
57
  ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim)
58
  return ts.numpy().astype(np_array.dtype)
59
 
60
- def calc_stats(emb_name, embeddings, mean_dim=0):
61
  print("%s:" %emb_name)
62
  repeat_count = [1] * embeddings.ndim
63
  repeat_count[mean_dim] = embeddings.shape[mean_dim]
@@ -153,13 +153,14 @@ def pad_image_obj_to_square(image_obj, new_size=-1):
153
 
154
  class UNetEnsemble(nn.Module):
155
  # The first unet is the unet already loaded in a pipeline.
156
- def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', torch_dtype=torch.float16):
157
  super().__init__()
158
 
159
- self.unets = nn.ModuleList()
160
  if unets is not None:
161
- self.unets += [ unet.to(device) for unet in unets ]
162
-
 
 
163
  if unet_types is not None:
164
  for unet_type in unet_types:
165
  if unet_type == "arc2face":
@@ -169,25 +170,27 @@ class UNetEnsemble(nn.Module):
169
  unet = create_consistentid_pipeline(unet_only=True)
170
  else:
171
  breakpoint()
172
- self.unets.append(unet.to(device=device))
173
 
174
  if extra_unet_dirpaths is not None:
175
  for unet_path in extra_unet_dirpaths:
176
  unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype)
177
- self.unets.append(unet.to(device=device))
178
 
179
- if unet_weights is None:
180
- unet_weights = [1.] * len(self.unets)
181
- elif len(self.unets) < len(unet_weights):
182
- unet_weights = unet_weights[:len(self.unets)]
183
- elif len(self.unets) > len(unet_weights):
184
  breakpoint()
185
 
186
- unet_weights = torch.tensor(unet_weights, dtype=torch_dtype)
187
- unet_weights = unet_weights / unet_weights.sum()
188
- self.unet_weights = nn.Parameter(unet_weights, requires_grad=False)
189
 
190
- print(f"UNetEnsemble: {len(self.unets)} UNets loaded with weights: {self.unet_weights.data.cpu().numpy()}")
 
 
 
191
  # Set these fields to be compatible with diffusers.
192
  self.dtype = self.unets[0].dtype
193
  self.device = self.unets[0].device
@@ -215,8 +218,8 @@ class UNetEnsemble(nn.Module):
215
  samples.append(sample)
216
 
217
  samples = torch.stack(samples, dim=0)
218
- unet_weights = self.unet_weights.reshape(-1, *([1] * (samples.ndim - 1)))
219
- sample = (samples * unet_weights).sum(dim=0)
220
 
221
  if not return_dict:
222
  return (sample,)
 
57
  ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim)
58
  return ts.numpy().astype(np_array.dtype)
59
 
60
+ def calc_stats(emb_name, embeddings, mean_dim=-1):
61
  print("%s:" %emb_name)
62
  repeat_count = [1] * embeddings.ndim
63
  repeat_count[mean_dim] = embeddings.shape[mean_dim]
 
153
 
154
  class UNetEnsemble(nn.Module):
155
  # The first unet is the unet already loaded in a pipeline.
156
+ def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble=None, device='cuda', torch_dtype=torch.float16):
157
  super().__init__()
158
 
 
159
  if unets is not None:
160
+ unets = [ unet.to(device) for unet in unets ]
161
+ else:
162
+ unets = []
163
+
164
  if unet_types is not None:
165
  for unet_type in unet_types:
166
  if unet_type == "arc2face":
 
170
  unet = create_consistentid_pipeline(unet_only=True)
171
  else:
172
  breakpoint()
173
+ unets.append(unet.to(device=device))
174
 
175
  if extra_unet_dirpaths is not None:
176
  for unet_path in extra_unet_dirpaths:
177
  unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype)
178
+ unets.append(unet.to(device=device))
179
 
180
+ if unet_weights_in_ensemble is None:
181
+ unet_weights_in_ensemble = [1.] * len(unets)
182
+ elif len(unets) < len(unet_weights_in_ensemble):
183
+ unet_weights_in_ensemble = unet_weights_in_ensemble[:len(unets)]
184
+ elif len(unets) > len(unet_weights_in_ensemble):
185
  breakpoint()
186
 
187
+ unet_weights_in_ensemble = torch.tensor(unet_weights_in_ensemble, dtype=torch_dtype)
188
+ unet_weights_in_ensemble = unet_weights_in_ensemble / unet_weights_in_ensemble.sum()
 
189
 
190
+ self.unets = nn.ModuleList(unets)
191
+ # Put the weights in a Parameter so that they will be moved to the same device as the model.
192
+ self.unet_weights_in_ensemble = nn.Parameter(unet_weights_in_ensemble, requires_grad=False)
193
+ print(f"UNetEnsemble: {len(self.unets)} UNets loaded with weights: {self.unet_weights_in_ensemble.data.cpu().numpy()}")
194
  # Set these fields to be compatible with diffusers.
195
  self.dtype = self.unets[0].dtype
196
  self.device = self.unets[0].device
 
218
  samples.append(sample)
219
 
220
  samples = torch.stack(samples, dim=0)
221
+ unet_weights_in_ensemble = self.unet_weights_in_ensemble.reshape(-1, *([1] * (samples.ndim - 1)))
222
+ sample = (samples * unet_weights_in_ensemble).sum(dim=0)
223
 
224
  if not return_dict:
225
  return (sample,)
app.py CHANGED
@@ -5,40 +5,63 @@ from adaface.adaface_wrapper import AdaFaceWrapper
5
  import torch
6
  import numpy as np
7
  import random
8
-
 
9
  import gradio as gr
10
  import spaces
 
 
 
 
 
 
 
 
 
 
 
11
  import argparse
12
  parser = argparse.ArgumentParser()
13
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
14
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
15
- parser.add_argument('--adaface_ckpt_path', type=str, default='models/adaface/VGGface2_HQ_masks2024-10-14T16-09-24_zero3-ada-3500.pt',
16
- help="Paths to the checkpoints of the ID2Ada prompt encoders")
17
  # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
18
- parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None,
19
  help="Scales for the ID2Ada prompt encoders")
20
  parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
21
  choices=["arc2face", "consistentID"],
22
  help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)")
23
- parser.add_argument('--model_style_type', type=str, default='realistic',
24
  choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
25
- parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*", default=[],
26
- help="Extra paths to the checkpoints of the UNet models")
27
- parser.add_argument('--unet_weights', type=float, nargs="+", default=[1],
28
- help="Weights for the UNet models")
29
- parser.add_argument("--guidance_scale", type=float, default=8.0,
30
- help="The guidance scale for the diffusion model. Default: 8.0")
31
- parser.add_argument("--do_neg_id_prompt_weight", type=float, default=0.0,
32
- help="The weight of added ID prompt embeddings into the negative prompt. Default: 0, disabled.")
33
-
 
 
 
 
 
 
 
34
  parser.add_argument('--gpu', type=int, default=None)
35
  parser.add_argument('--ip', type=str, default="0.0.0.0")
36
  args = parser.parse_args()
37
 
 
 
 
 
38
  model_style_type2base_model_path = {
39
  "realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors",
40
  "anime": "models/aingdiffusion/aingdiffusion_v170_ar.safetensors",
41
- "photorealistic": "models/sar/sar.safetensors" # LDM format. Needs to be converted.
42
  }
43
  base_model_path = model_style_type2base_model_path[args.model_style_type]
44
 
@@ -48,13 +71,20 @@ device = "cuda" if args.gpu is None else f"cuda:{args.gpu}"
48
  print(f"Device: {device}")
49
 
50
  global adaface
51
- adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
52
- adaface_encoder_types=args.adaface_encoder_types,
53
- adaface_ckpt_paths=args.adaface_ckpt_path,
54
- adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
55
- enabled_encoders=args.enabled_encoders,
56
- unet_types=None, extra_unet_dirpaths=args.extra_unet_dirpaths,
57
- unet_weights=args.unet_weights, device='cpu')
 
 
 
 
 
 
 
58
 
59
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
60
  if randomize_seed:
@@ -71,12 +101,14 @@ def remove_back_to_files():
71
  # Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx
72
  # Or:
73
  # Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx
74
- return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True)
 
75
 
76
  @spaces.GPU
77
- def generate_image(image_paths, guidance_scale, do_neg_id_prompt_weight, perturb_std,
78
- num_images, prompt, negative_prompt, enhance_face,
79
- seed, progress=gr.Progress(track_tqdm=True)):
 
80
 
81
  global adaface
82
 
@@ -85,6 +117,9 @@ def generate_image(image_paths, guidance_scale, do_neg_id_prompt_weight, perturb
85
  if image_paths is None or len(image_paths) == 0:
86
  raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
87
 
 
 
 
88
  if prompt is None:
89
  prompt = ""
90
 
@@ -100,38 +135,128 @@ def generate_image(image_paths, guidance_scale, do_neg_id_prompt_weight, perturb
100
  # Sometimes the pipeline is on CPU, although we've put it on CUDA (due to some offloading mechanism).
101
  # Therefore we set the generator to the correct device.
102
  generator = torch.Generator(device=device).manual_seed(seed)
103
- print(f"Manual seed: {seed}. do_neg_id_prompt_weight: {do_neg_id_prompt_weight}.")
104
  # Generate two images each time for the user to select from.
105
  noise = torch.randn(num_images, 3, 512, 512, device=device, generator=generator)
106
  #print(noise.abs().sum())
107
  # samples: A list of PIL Image instances.
108
- if enhance_face and "face portrait" not in prompt:
109
- if "portrait" in prompt:
110
- # Enhance the face features by replacing "portrait" with "face portrait".
 
111
  prompt = prompt.replace("portrait", "face portrait")
 
 
 
 
 
 
 
112
  else:
113
- prompt = "face portrait, " + prompt
114
 
115
  generator = torch.Generator(device=adaface.pipeline._execution_device).manual_seed(seed)
116
- samples = adaface(noise, prompt, negative_prompt,
117
- do_neg_id_prompt_weight=do_neg_id_prompt_weight,
118
  guidance_scale=guidance_scale,
119
- out_image_count=num_images, generator=generator, verbose=True)
120
- return samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
 
122
 
123
- def check_prompt_and_model_type(prompt, model_style_type):
124
  global adaface
125
 
126
  model_style_type = model_style_type.lower()
127
- base_model_path = model_style_type2base_model_path[model_style_type]
128
  # If the base model type is changed, reload the model.
129
- if model_style_type != args.model_style_type:
130
- adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
131
- adaface_encoder_types=args.adaface_encoder_types,
132
- adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu')
133
- # Update base model type.
134
- args.model_style_type = model_style_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  if not prompt:
137
  raise gr.Error("Prompt cannot be blank")
@@ -145,13 +270,12 @@ description = r"""
145
  <b>Official demo</b> for our working paper <b>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</b>.<br>
146
 
147
  ❗️**What's New**❗️
148
- - Support switching between two model styles: **Realistic** and **Anime**.
149
  - If you just changed the model style, the first image/video generation will take extra 20~30 seconds for loading new model weight.
150
 
151
  ❗️**Tips**❗️
152
  1. Upload one or more images of a person. If multiple faces are detected, we use the largest one.
153
- 2. Check "Enhance Face" to highlight fine facial features.
154
- 3. If the face dominates the image, try increasing 'Weight of ID prompt in the negative prompt'.
155
  4. AdaFace Text-to-Video: <a href="https://huggingface.co/spaces/adaface-neurips/adaface-animate" style="display: inline-flex; align-items: center;">
156
  AdaFace-Animate
157
  <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;">
@@ -162,13 +286,18 @@ description = r"""
162
  """
163
 
164
  css = '''
165
- .gradio-container {width: 95% !important},
166
  .custom-gallery {
167
- height: 800px;
168
  width: 100%;
169
  margin: 10px auto;
170
- padding: 10px;
171
- overflow-y: auto;
 
 
 
 
 
172
  }
173
  '''
174
  with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
@@ -187,53 +316,108 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
187
  file_types=["image"],
188
  file_count="multiple"
189
  )
190
- uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=1, height=300)
 
191
  with gr.Column(visible=False) as clear_button_column:
192
- remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=img_files, size="sm")
193
-
194
- prompt = gr.Dropdown(label="Prompt",
195
- info="Try something like 'walking on the beach'. If the face is not in focus, try checking 'enhance face'.",
196
- value="portrait, ((best quality)), ((masterpiece)), ((realistic)), highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin",
197
- allow_custom_value=True,
198
- filterable=False,
199
- choices=[
200
- "portrait, ((best quality)), ((masterpiece)), ((realistic)), highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin",
201
- "portrait, walking on the beach, sunset, orange sky",
202
- "portrait, in a white apron and chef hat, garnishing a gourmet dish",
203
- "portrait, dancing pose among folks in a park, waving hands",
204
- "portrait, in iron man costume, the sky ablaze with hues of orange and purple",
205
- "portrait, jedi wielding a lightsaber, star wars, eye level shot",
206
- "portrait, playing guitar on a boat, ocean waves",
207
- "portrait, with a passion for reading, curled up with a book in a cozy nook near a window",
208
- "portrait, running pose in a park, eye level shot",
209
- "portrait, in superman costume, the sky ablaze with hues of orange and purple"
210
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- enhance_face = gr.Checkbox(label="Enhance face", value=False,
213
- info="Enhance the face features by prepending 'face portrait' to the prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  submit = gr.Button("Submit", variant="primary")
216
 
217
  negative_prompt = gr.Textbox(
218
  label="Negative Prompt",
219
- value="flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, nude, naked, nsfw, topless, bare breasts",
 
 
 
 
 
220
  )
221
 
222
  guidance_scale = gr.Slider(
223
  label="Guidance scale",
224
  minimum=1.0,
225
- maximum=12.0,
226
- step=1.0,
227
  value=args.guidance_scale,
228
  )
229
 
230
- do_neg_id_prompt_weight = gr.Slider(
231
- label="Weight of ID prompt in the negative prompt",
232
- minimum=0.0,
233
- maximum=0.3,
234
- step=0.1,
235
- value=args.do_neg_id_prompt_weight,
236
- visible=True,
237
  )
238
 
239
  model_style_type = gr.Dropdown(
@@ -256,7 +440,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
256
  num_images = gr.Slider(
257
  label="Number of output images",
258
  minimum=1,
259
- maximum=6,
260
  step=1,
261
  value=4,
262
  )
@@ -267,27 +451,41 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
267
  step=1,
268
  value=0,
269
  )
270
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True, info="Uncheck for reproducible results")
 
 
 
 
271
 
272
  with gr.Column():
273
- out_gallery = gr.Gallery(label="Generated Images", interactive=False, columns=2, rows=2, height=800,
274
  elem_classes="custom-gallery")
275
 
276
- img_files.upload(fn=swap_to_gallery, inputs=img_files, outputs=[uploaded_files_gallery, clear_button_column, img_files])
277
- remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery, clear_button_column, img_files])
278
-
279
- submit.click(fn=check_prompt_and_model_type,
280
- inputs=[prompt, model_style_type],outputs=None).success(
281
- fn=randomize_seed_fn,
282
- inputs=[seed, randomize_seed],
283
- outputs=seed,
284
- queue=False,
285
- api_name=False,
286
- ).then(
287
- fn=generate_image,
288
- inputs=[img_files, guidance_scale, do_neg_id_prompt_weight, perturb_std, num_images,
289
- prompt, negative_prompt, enhance_face, seed],
290
- outputs=[out_gallery]
291
- )
 
 
 
 
 
 
 
 
 
 
292
 
293
  demo.launch(share=True, server_name=args.ip, ssl_verify=False)
 
5
  import torch
6
  import numpy as np
7
  import random
8
+ import os, re
9
+ import time
10
  import gradio as gr
11
  import spaces
12
+
13
+ def str2bool(v):
14
+ if isinstance(v, bool):
15
+ return v
16
+ if v.lower() in ("yes", "true", "t", "y", "1"):
17
+ return True
18
+ elif v.lower() in ("no", "false", "f", "n", "0"):
19
+ return False
20
+ else:
21
+ raise argparse.ArgumentTypeError("Boolean value expected.")
22
+
23
  import argparse
24
  parser = argparse.ArgumentParser()
25
  parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
26
  choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
27
+ parser.add_argument('--adaface_ckpt_path', type=str, default='models/adaface/VGGface2_HQ_masks2025-03-06T03-31-21_zero3-ada-1000.pt',
28
+ help="Path to the checkpoint of the ID2Ada prompt encoders")
29
  # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
30
+ parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=[6.0, 1.0],
31
  help="Scales for the ID2Ada prompt encoders")
32
  parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
33
  choices=["arc2face", "consistentID"],
34
  help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)")
35
+ parser.add_argument('--model_style_type', type=str, default='photorealistic',
36
  choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
37
+ parser.add_argument("--guidance_scale", type=float, default=5.0,
38
+ help="The guidance scale for the diffusion model. Default: 5.0")
39
+ parser.add_argument("--unet_uses_attn_lora", type=str2bool, nargs="?", const=True, default=False,
40
+ help="Whether to use LoRA in the Diffusers UNet model")
41
+ # --attn_lora_layer_names and --q_lora_updates_query are only effective
42
+ # when --unet_uses_attn_lora is set to True.
43
+ parser.add_argument("--attn_lora_layer_names", type=str, nargs="*", default=['q', 'k', 'v', 'out'],
44
+ choices=['q', 'k', 'v', 'out'], help="Names of the cross-attn components to apply LoRA on")
45
+ parser.add_argument("--q_lora_updates_query", type=str2bool, nargs="?", const=True, default=False,
46
+ help="Whether the q LoRA updates the query in the Diffusers UNet model. "
47
+ "If False, the q lora only updates query2.")
48
+ parser.add_argument("--show_disable_adaface_checkbox", type=str2bool, nargs="?", const=True, default=False,
49
+ help="Whether to show the checkbox for disabling AdaFace")
50
+ parser.add_argument('--extra_save_dir', type=str, default=None, help="Directory to save the generated images")
51
+ parser.add_argument('--test_ui_only', type=str2bool, nargs="?", const=True, default=False,
52
+ help="Only test the UI layout, and skip loadding the adaface model")
53
  parser.add_argument('--gpu', type=int, default=None)
54
  parser.add_argument('--ip', type=str, default="0.0.0.0")
55
  args = parser.parse_args()
56
 
57
+ from huggingface_hub import snapshot_download
58
+ large_files = ["models/*", "models/**/*"]
59
+ snapshot_download(repo_id="adaface-neurips/adaface-models", repo_type="model", allow_patterns=large_files, local_dir=".")
60
+
61
  model_style_type2base_model_path = {
62
  "realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors",
63
  "anime": "models/aingdiffusion/aingdiffusion_v170_ar.safetensors",
64
+ "photorealistic": "models/sar/sar.safetensors", # LDM format. Needs to be converted.
65
  }
66
  base_model_path = model_style_type2base_model_path[args.model_style_type]
67
 
 
71
  print(f"Device: {device}")
72
 
73
  global adaface
74
+ adaface = None
75
+
76
+ if not args.test_ui_only:
77
+ adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
78
+ adaface_encoder_types=args.adaface_encoder_types,
79
+ adaface_ckpt_paths=args.adaface_ckpt_path,
80
+ adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
81
+ enabled_encoders=args.enabled_encoders,
82
+ unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
83
+ unet_uses_attn_lora=args.unet_uses_attn_lora,
84
+ attn_lora_layer_names=args.attn_lora_layer_names,
85
+ shrink_cross_attn=False,
86
+ q_lora_updates_query=args.q_lora_updates_query,
87
+ device='cpu')
88
 
89
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
90
  if randomize_seed:
 
101
  # Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx
102
  # Or:
103
  # Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx
104
+ return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True), \
105
+ gr.update(value=""), gr.update(value="(none)")
106
 
107
  @spaces.GPU
108
+ def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
109
+ num_images, prompt, negative_prompt, gender, highlight_face,
110
+ ablate_prompt_embed_type, nonmix_prompt_emb_weight,
111
+ composition_level, seed, disable_adaface, subj_name_sig, progress=gr.Progress(track_tqdm=True)):
112
 
113
  global adaface
114
 
 
117
  if image_paths is None or len(image_paths) == 0:
118
  raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
119
 
120
+ if image_paths2 is not None and len(image_paths2) > 0:
121
+ image_paths = image_paths + image_paths2
122
+
123
  if prompt is None:
124
  prompt = ""
125
 
 
135
  # Sometimes the pipeline is on CPU, although we've put it on CUDA (due to some offloading mechanism).
136
  # Therefore we set the generator to the correct device.
137
  generator = torch.Generator(device=device).manual_seed(seed)
138
+ print(f"Manual seed: {seed}.")
139
  # Generate two images each time for the user to select from.
140
  noise = torch.randn(num_images, 3, 512, 512, device=device, generator=generator)
141
  #print(noise.abs().sum())
142
  # samples: A list of PIL Image instances.
143
+ if highlight_face:
144
+ if "portrait" not in prompt:
145
+ prompt = "face portrait, " + prompt
146
+ else:
147
  prompt = prompt.replace("portrait", "face portrait")
148
+ if composition_level >= 2:
149
+ if "full body" not in prompt:
150
+ prompt = prompt + ", full body view"
151
+
152
+ if gender != "(none)":
153
+ if "portrait" in prompt:
154
+ prompt = prompt.replace("portrait, ", f"portrait, {gender} ")
155
  else:
156
+ prompt = gender + ", " + prompt
157
 
158
  generator = torch.Generator(device=adaface.pipeline._execution_device).manual_seed(seed)
159
+ samples = adaface(noise, prompt, negative_prompt=negative_prompt,
 
160
  guidance_scale=guidance_scale,
161
+ out_image_count=num_images, generator=generator,
162
+ repeat_prompt_for_each_encoder=(composition_level >= 1),
163
+ ablate_prompt_no_placeholders=disable_adaface,
164
+ ablate_prompt_embed_type=ablate_prompt_embed_type,
165
+ nonmix_prompt_emb_weight=nonmix_prompt_emb_weight,
166
+ verbose=True)
167
+
168
+ session_signature = ",".join(image_paths + [prompt, str(seed)])
169
+ temp_folder = os.path.join("/tmp/gradio", f"{hash(session_signature)}")
170
+ os.makedirs(temp_folder, exist_ok=True)
171
+
172
+ saved_image_paths = []
173
+ if "models/adaface/" in args.adaface_ckpt_path:
174
+ # The model is loaded from within the project.
175
+ # models/adaface/VGGface2_HQ_masks2024-10-14T16-09-24_zero3-ada-3500.pt
176
+ matches = re.search(r"models/adaface/\w+\d{4}-(\d{2})-(\d{2})T(\d{2})-\d{2}-\d{2}_zero3-ada-(\d+).pt", args.adaface_ckpt_path)
177
+ else:
178
+ # The model is loaded from the adaprompt folder.
179
+ # adaface_ckpt_path = "VGGface2_HQ_masks2024-11-28T13-13-20_zero3-ada/checkpoints/embeddings_gs-2000.pt"
180
+ matches = re.search(r"\d{4}-(\d{2})-(\d{2})T(\d{2})-\d{2}-\d{2}_zero3-ada/checkpoints/embeddings_gs-(\d+).pt", args.adaface_ckpt_path)
181
+
182
+ # Extract the checkpoint signature as 112813-2000
183
+ ckpt_sig = f"{matches.group(1)}{matches.group(2)}{matches.group(3)}-{matches.group(4)}"
184
+
185
+ prompt_keywords = ['armor', 'beach', 'chef', 'dancing', 'iron man', 'jedi',
186
+ 'street', 'guitar', 'reading', 'running', 'superman', 'new year', 'mars']
187
+ keywords_reduction = { 'iron man': 'ironman', 'dancing': 'dance',
188
+ 'running': 'run', 'reading': 'read', 'new year': 'newyear' }
189
+
190
+ prompt_sig = None
191
+ for keyword in prompt_keywords:
192
+ if keyword in prompt.lower():
193
+ prompt_sig = keywords_reduction.get(keyword, keyword)
194
+ break
195
+
196
+ if prompt_sig is None:
197
+ prompt_parts = prompt.lower().split(",")
198
+ # Remove the view/shot parts (full body view, long shot, etc.) from the prompt.
199
+ prompt_parts = [ part for part in prompt_parts if not re.search(r"\W(view|shot)(\W|$)", part) ]
200
+ if len(prompt_parts) > 0:
201
+ # Use the last word of the prompt as the signature.
202
+ prompt_sig = prompt_parts[-1].split()[-1]
203
+ else:
204
+ prompt_sig = "person"
205
+
206
+ if len(prompt_sig) > 0:
207
+ prompt_sig = "-" + prompt_sig
208
+
209
+ extra_save_dir = args.extra_save_dir
210
+ if extra_save_dir is not None:
211
+ os.makedirs(extra_save_dir, exist_ok=True)
212
+
213
+ for i, sample in enumerate(samples):
214
+ filename = f"adaface{ckpt_sig}{prompt_sig}-{i+1}.png"
215
+ if len(subj_name_sig) > 0:
216
+ filename = f"{subj_name_sig.lower()}-{filename}"
217
+ filepath = os.path.join(temp_folder, filename)
218
+ # Save the image
219
+ sample.save(filepath) # Adjust to your image saving method
220
+ saved_image_paths.append(filepath)
221
+
222
+ if extra_save_dir is not None:
223
+ extra_filepath = os.path.join(extra_save_dir, filename)
224
+ sample.save(extra_filepath)
225
+ print(extra_filepath)
226
+
227
+ # Solution suggested by o1 to force the client browser to reload images
228
+ # when we change guidance scales only.
229
+ saved_image_paths = [f"{url}?t={int(time.time())}" for url in saved_image_paths]
230
 
231
+ return saved_image_paths
232
 
233
+ def check_prompt_and_model_type(prompt, model_style_type, adaface_encoder_cfg_scale1):
234
  global adaface
235
 
236
  model_style_type = model_style_type.lower()
 
237
  # If the base model type is changed, reload the model.
238
+ if model_style_type != args.model_style_type or adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]:
239
+ if model_style_type != args.model_style_type:
240
+ # Update base model type.
241
+ args.model_style_type = model_style_type
242
+ print(f"Switching to the base model type: {model_style_type}.")
243
+
244
+ adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=model_style_type2base_model_path[model_style_type],
245
+ adaface_encoder_types=args.adaface_encoder_types,
246
+ adaface_ckpt_paths=args.adaface_ckpt_path,
247
+ adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
248
+ enabled_encoders=args.enabled_encoders,
249
+ unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
250
+ unet_uses_attn_lora=args.unet_uses_attn_lora,
251
+ attn_lora_layer_names=args.attn_lora_layer_names,
252
+ shrink_cross_attn=False,
253
+ q_lora_updates_query=args.q_lora_updates_query,
254
+ device='cpu')
255
+
256
+ if adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]:
257
+ args.adaface_encoder_cfg_scales[0] = adaface_encoder_cfg_scale1
258
+ adaface.set_adaface_encoder_cfg_scales(args.adaface_encoder_cfg_scales)
259
+ print(f"Updating the scale for consistentID encoder to {adaface_encoder_cfg_scale1}.")
260
 
261
  if not prompt:
262
  raise gr.Error("Prompt cannot be blank")
 
270
  <b>Official demo</b> for our working paper <b>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</b>.<br>
271
 
272
  ❗️**What's New**❗️
273
+ - Support switching between three model styles: **Photorealistic**, **Realistic** and **Anime**.
274
  - If you just changed the model style, the first image/video generation will take extra 20~30 seconds for loading new model weight.
275
 
276
  ❗️**Tips**❗️
277
  1. Upload one or more images of a person. If multiple faces are detected, we use the largest one.
278
+ 2. Check "Highlight face" to highlight fine facial features.
 
279
  4. AdaFace Text-to-Video: <a href="https://huggingface.co/spaces/adaface-neurips/adaface-animate" style="display: inline-flex; align-items: center;">
280
  AdaFace-Animate
281
  <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;">
 
286
  """
287
 
288
  css = '''
289
+ .gradio-container {width: 95% !important}
290
  .custom-gallery {
291
+ height: 800px !important;
292
  width: 100%;
293
  margin: 10px auto;
294
+ padding: 0px;
295
+ overflow-y: auto !important;
296
+ }
297
+ .tight-row {
298
+ gap: 0 !important; /* removes the horizontal gap between columns */
299
+ margin: 0 !important; /* remove any extra margin if needed */
300
+ padding: 0 !important; /* remove any extra padding if needed */
301
  }
302
  '''
303
  with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
 
316
  file_types=["image"],
317
  file_count="multiple"
318
  )
319
+ # When files are uploaded, show the images in the gallery and hide the file uploader.
320
+ uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=1, height=300)
321
  with gr.Column(visible=False) as clear_button_column:
322
+ remove_and_reupload = gr.ClearButton(value="Remove and upload subject images",
323
+ components=img_files, size="sm")
324
+
325
+ with gr.Accordion("Second Subject (Optional)", open=False):
326
+ img_files2 = gr.File(
327
+ label="Drag / Select 1 or more photos of second subject's face (optional)",
328
+ file_types=["image"],
329
+ file_count="multiple"
330
+ )
331
+
332
+ uploaded_files_gallery2 = gr.Gallery(label="2nd Subject images (optional)", visible=False, columns=3, rows=1, height=300)
333
+ with gr.Column(visible=False) as clear_button_column2:
334
+ remove_and_reupload2 = gr.ClearButton(value="Remove and upload 2nd Subject images",
335
+ components=img_files2, size="sm")
336
+
337
+ with gr.Row(elem_classes="tight-row"):
338
+ with gr.Column(scale=1, min_width=100):
339
+ gender = gr.Dropdown(label="Gender", value="(none)",
340
+ info="Gender prefix. Select only when the model errs.",
341
+ container=False,
342
+ choices=[ "(none)", "person", "man", "woman", "girl", "boy" ])
343
+
344
+ with gr.Column(scale=100):
345
+ prompt = gr.Dropdown(label="Prompt",
346
+ info="Try something like 'walking on the beach'. If the face is not in focus, try checking 'Highlight face'.",
347
+ value="portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
348
+ allow_custom_value=True,
349
+ choices=[
350
+ "portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
351
+ "portrait, walking on the beach, sunset, orange sky, front view",
352
+ "portrait, in a white apron and chef hat, garnishing a gourmet dish",
353
+ "portrait, waving hands, dancing pose among folks in a park",
354
+ "portrait, in iron man costume, the sky ablaze with hues of orange and purple",
355
+ "portrait, jedi wielding a lightsaber, star wars",
356
+ "portrait, night view of tokyo street, neon light",
357
+ "portrait, playing guitar on a boat, ocean waves",
358
+ "portrait, with a passion for reading, curled up with a book in a cozy nook near a window, front view",
359
+ "portrait, celebrating new year, fireworks",
360
+ "portrait, running pose in a park",
361
+ "portrait, in space suit, space helmet, walking on mars",
362
+ "portrait, in superman costume, the sky ablaze with hues of orange and purple",
363
+ "in a wheelchair",
364
+ "on a horse"
365
+ ])
366
 
367
+ highlight_face = gr.Checkbox(label="Highlight face", value=False,
368
+ info="Enhance the facial features by prepending 'face portrait' to the prompt")
369
+ composition_level = \
370
+ gr.Slider(label="Composition Level", visible=True,
371
+ info="The degree of overall composition, 0~2. Challenging prompts like 'In a wheelchair' and 'on a horse' need level 2",
372
+ minimum=0, maximum=2, step=1, value=0)
373
+
374
+ ablate_prompt_embed_type = gr.Dropdown(label="Ablate prompt embeddings type",
375
+ choices=["ada", "ada-nonmix", "img"], value="ada", visible=False,
376
+ info="Use this type of prompt embeddings for ablation study")
377
+
378
+ nonmix_prompt_emb_weight = gr.Slider(label="Weight of ada-nonmix ID embeddings",
379
+ minimum=0.0, maximum=0.5, step=0.1, value=0,
380
+ info="Weight of ada-nonmix ID embeddings in the prompt embeddings",
381
+ visible=False)
382
+
383
+
384
+ subj_name_sig = gr.Textbox(
385
+ label="Nickname of Subject (optional; used to name saved images)",
386
+ value="",
387
+ )
388
+ subj_name_sig2 = gr.Textbox(
389
+ label="Nickname of 2nd Subject (optional; used to name saved images)",
390
+ value="",
391
+ visible=False,
392
+ )
393
 
394
  submit = gr.Button("Submit", variant="primary")
395
 
396
  negative_prompt = gr.Textbox(
397
  label="Negative Prompt",
398
+ value="sagging face, sagging cheeks, wrinkles, flaws in the eyes, flaws in the face, lowres, "
399
+ "non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, "
400
+ "mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, "
401
+ "deformed eyeballs, cross-eyed, extra legs, extra arms, blurry, mutation, duplicate, "
402
+ "out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, "
403
+ "nude, naked, nsfw, topless, bare breasts",
404
  )
405
 
406
  guidance_scale = gr.Slider(
407
  label="Guidance scale",
408
  minimum=1.0,
409
+ maximum=8.0,
410
+ step=0.5,
411
  value=args.guidance_scale,
412
  )
413
 
414
+ adaface_encoder_cfg_scale1 = gr.Slider(
415
+ label="Scale for consistentID encoder",
416
+ minimum=1.0,
417
+ maximum=12.0,
418
+ step=1.0,
419
+ value=args.adaface_encoder_cfg_scales[0],
420
+ visible=False,
421
  )
422
 
423
  model_style_type = gr.Dropdown(
 
440
  num_images = gr.Slider(
441
  label="Number of output images",
442
  minimum=1,
443
+ maximum=8,
444
  step=1,
445
  value=4,
446
  )
 
451
  step=1,
452
  value=0,
453
  )
454
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True,
455
+ info="Uncheck for reproducible results")
456
+ disable_adaface = gr.Checkbox(label="Disable AdaFace", value=False,
457
+ info="Disable AdaFace for ablation. If checked, the results are no longer personalized.",
458
+ visible=args.show_disable_adaface_checkbox)
459
 
460
  with gr.Column():
461
+ out_gallery = gr.Gallery(label="Generated Images", interactive=False, columns=2, rows=4, height=800,
462
  elem_classes="custom-gallery")
463
 
464
+ img_files.upload(fn=swap_to_gallery, inputs=img_files, outputs=[uploaded_files_gallery, clear_button_column, img_files])
465
+ img_files2.upload(fn=swap_to_gallery, inputs=img_files2, outputs=[uploaded_files_gallery2, clear_button_column2, img_files2])
466
+ remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery, clear_button_column,
467
+ img_files, subj_name_sig, gender])
468
+ remove_and_reupload2.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery2, clear_button_column2,
469
+ img_files2, subj_name_sig2, gender])
470
+
471
+ check_prompt_and_model_type_call_dict = {
472
+ 'fn': check_prompt_and_model_type,
473
+ 'inputs': [prompt, model_style_type, adaface_encoder_cfg_scale1],
474
+ 'outputs': None
475
+ }
476
+ randomize_seed_fn_call_dict = {
477
+ 'fn': randomize_seed_fn,
478
+ 'inputs': [seed, randomize_seed],
479
+ 'outputs': seed
480
+ }
481
+ generate_image_call_dict = {
482
+ 'fn': generate_image,
483
+ 'inputs': [img_files, img_files2, guidance_scale, perturb_std, num_images, prompt,
484
+ negative_prompt, gender, highlight_face, ablate_prompt_embed_type,
485
+ nonmix_prompt_emb_weight, composition_level, seed, disable_adaface, subj_name_sig],
486
+ 'outputs': [out_gallery]
487
+ }
488
+ submit.click(**check_prompt_and_model_type_call_dict).success(**randomize_seed_fn_call_dict).then(**generate_image_call_dict)
489
+ subj_name_sig.submit(**check_prompt_and_model_type_call_dict).success(**randomize_seed_fn_call_dict).then(**generate_image_call_dict)
490
 
491
  demo.launch(share=True, server_name=args.ip, ssl_verify=False)