Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
936cd75
1
Parent(s):
40ca865
update code
Browse files- .gitignore +1 -7
- ConsistentID/app.py +2 -2
- adaface/adaface_infer.py +10 -15
- adaface/adaface_translate.py +53 -33
- adaface/adaface_wrapper.py +366 -72
- adaface/diffusers_attn_lora_capture.py +656 -0
- adaface/face_id_to_ada_prompt.py +253 -124
- adaface/subj_basis_generator.py +97 -59
- adaface/unet_teachers.py +86 -49
- adaface/util.py +21 -18
- app.py +299 -101
.gitignore
CHANGED
@@ -1,10 +1,4 @@
|
|
1 |
-
models/awportrait/*
|
2 |
-
models/awportrait
|
3 |
__pycache__/*
|
4 |
__pycache__
|
5 |
-
samples-ada/*
|
6 |
-
samples-ada
|
7 |
-
models/ensemble/awp14-unet/*
|
8 |
-
models/ensemble/awp14-unet
|
9 |
.gradio/certificate.pem
|
10 |
-
|
|
|
|
|
|
|
1 |
__pycache__/*
|
2 |
__pycache__
|
|
|
|
|
|
|
|
|
3 |
.gradio/certificate.pem
|
4 |
+
models/*
|
ConsistentID/app.py
CHANGED
@@ -26,8 +26,8 @@ pipe = ConsistentIDPipeline.from_pretrained(
|
|
26 |
|
27 |
### Load consistentID_model checkpoint
|
28 |
pipe.load_ConsistentID_model(
|
29 |
-
consistentID_weight_path="./models/ConsistentID-v1.bin",
|
30 |
-
bise_net_weight_path="./models/BiSeNet_pretrained_for_ConsistentID.pth",
|
31 |
)
|
32 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
33 |
pipe = pipe.to(device, torch.float16)
|
|
|
26 |
|
27 |
### Load consistentID_model checkpoint
|
28 |
pipe.load_ConsistentID_model(
|
29 |
+
consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin",
|
30 |
+
bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth",
|
31 |
)
|
32 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
33 |
pipe = pipe.to(device, torch.float16)
|
adaface/adaface_infer.py
CHANGED
@@ -45,8 +45,7 @@ def parse_args():
|
|
45 |
help="Type of pipeline to use (default: txt2img)")
|
46 |
parser.add_argument("--base_model_path", type=str, default=None,
|
47 |
help="Type of checkpoints to use (default: None, using the official model)")
|
48 |
-
parser.add_argument('--
|
49 |
-
default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt'])
|
50 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
51 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
52 |
parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
|
@@ -60,23 +59,18 @@ def parse_args():
|
|
60 |
parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*",
|
61 |
default=[],
|
62 |
help="Extra paths to the checkpoints of the UNet models")
|
63 |
-
parser.add_argument('--
|
64 |
help="Weights for the UNet models")
|
65 |
parser.add_argument("--subject", type=str)
|
66 |
parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
|
67 |
parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
|
68 |
parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
|
69 |
-
parser.add_argument("--
|
70 |
parser.add_argument("--randface", action="store_true")
|
71 |
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
72 |
help="Guidance scale for the diffusion model")
|
73 |
-
parser.add_argument("--subject_string",
|
74 |
-
type=str, default="z",
|
75 |
-
help="Subject placeholder string used in prompts to denote the concept.")
|
76 |
parser.add_argument("--num_images_per_row", type=int, default=4,
|
77 |
help="Number of images to display in a row in the output grid image.")
|
78 |
-
parser.add_argument("--num_inference_steps", type=int, default=50,
|
79 |
-
help="Number of inference steps")
|
80 |
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
81 |
parser.add_argument("--seed", type=int, default=42,
|
82 |
help="the seed (for reproducible sampling). Set to -1 to disable.")
|
@@ -95,16 +89,15 @@ if __name__ == "__main__":
|
|
95 |
|
96 |
if args.pipeline not in ["text2img", "img2img"]:
|
97 |
args.extra_unet_dirpaths = None
|
98 |
-
args.
|
99 |
|
100 |
adaface = AdaFaceWrapper(args.pipeline, args.base_model_path,
|
101 |
-
args.adaface_encoder_types, args.
|
102 |
args.adaface_encoder_cfg_scales, args.enabled_encoders,
|
103 |
-
args.subject_string, args.num_inference_steps,
|
104 |
unet_types=None,
|
105 |
main_unet_filepath=args.main_unet_filepath,
|
106 |
extra_unet_dirpaths=args.extra_unet_dirpaths,
|
107 |
-
|
108 |
|
109 |
if not args.randface:
|
110 |
image_folder = args.subject
|
@@ -143,7 +136,7 @@ if __name__ == "__main__":
|
|
143 |
rand_init_id_embs = torch.randn(1, 512)
|
144 |
|
145 |
init_id_embs = rand_init_id_embs if args.randface else None
|
146 |
-
|
147 |
# args.perturb_std: the *relative* std of the noise added to the face embeddings.
|
148 |
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
149 |
# adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
|
@@ -151,5 +144,7 @@ if __name__ == "__main__":
|
|
151 |
adaface.prepare_adaface_embeddings(image_paths, init_id_embs,
|
152 |
perturb_at_stage='img_prompt_emb',
|
153 |
perturb_std=args.perturb_std, update_text_encoder=True)
|
154 |
-
images = adaface(
|
|
|
|
|
155 |
save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std)
|
|
|
45 |
help="Type of pipeline to use (default: txt2img)")
|
46 |
parser.add_argument("--base_model_path", type=str, default=None,
|
47 |
help="Type of checkpoints to use (default: None, using the official model)")
|
48 |
+
parser.add_argument('--adaface_ckpt_path', type=str, required=True)
|
|
|
49 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
50 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
51 |
parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
|
|
|
59 |
parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*",
|
60 |
default=[],
|
61 |
help="Extra paths to the checkpoints of the UNet models")
|
62 |
+
parser.add_argument('--unet_weights_in_ensemble', type=float, nargs="+", default=[1],
|
63 |
help="Weights for the UNet models")
|
64 |
parser.add_argument("--subject", type=str)
|
65 |
parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
|
66 |
parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
|
67 |
parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
|
68 |
+
parser.add_argument("--perturb_std", type=float, default=0)
|
69 |
parser.add_argument("--randface", action="store_true")
|
70 |
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
71 |
help="Guidance scale for the diffusion model")
|
|
|
|
|
|
|
72 |
parser.add_argument("--num_images_per_row", type=int, default=4,
|
73 |
help="Number of images to display in a row in the output grid image.")
|
|
|
|
|
74 |
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
75 |
parser.add_argument("--seed", type=int, default=42,
|
76 |
help="the seed (for reproducible sampling). Set to -1 to disable.")
|
|
|
89 |
|
90 |
if args.pipeline not in ["text2img", "img2img"]:
|
91 |
args.extra_unet_dirpaths = None
|
92 |
+
args.unet_weights_in_ensemble = None
|
93 |
|
94 |
adaface = AdaFaceWrapper(args.pipeline, args.base_model_path,
|
95 |
+
args.adaface_encoder_types, args.adaface_ckpt_path,
|
96 |
args.adaface_encoder_cfg_scales, args.enabled_encoders,
|
|
|
97 |
unet_types=None,
|
98 |
main_unet_filepath=args.main_unet_filepath,
|
99 |
extra_unet_dirpaths=args.extra_unet_dirpaths,
|
100 |
+
unet_weights_in_ensemble=args.unet_weights_in_ensemble, device=args.device)
|
101 |
|
102 |
if not args.randface:
|
103 |
image_folder = args.subject
|
|
|
136 |
rand_init_id_embs = torch.randn(1, 512)
|
137 |
|
138 |
init_id_embs = rand_init_id_embs if args.randface else None
|
139 |
+
init_noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
|
140 |
# args.perturb_std: the *relative* std of the noise added to the face embeddings.
|
141 |
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
142 |
# adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
|
|
|
144 |
adaface.prepare_adaface_embeddings(image_paths, init_id_embs,
|
145 |
perturb_at_stage='img_prompt_emb',
|
146 |
perturb_std=args.perturb_std, update_text_encoder=True)
|
147 |
+
images = adaface(init_noise, args.prompt, None, None,
|
148 |
+
'append', args.guidance_scale,
|
149 |
+
args.out_image_count, verbose=True)
|
150 |
save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std)
|
adaface/adaface_translate.py
CHANGED
@@ -25,10 +25,9 @@ def seed_everything(seed):
|
|
25 |
|
26 |
def parse_args():
|
27 |
parser = argparse.ArgumentParser()
|
28 |
-
parser.add_argument("--base_model_path", type=str, default='models/
|
29 |
-
help="Path to the UNet checkpoint (
|
30 |
-
parser.add_argument('--
|
31 |
-
default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt'])
|
32 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
33 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
34 |
parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
|
@@ -40,9 +39,11 @@ def parse_args():
|
|
40 |
parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*",
|
41 |
default=[],
|
42 |
help="Extra paths to the checkpoints of the UNet models")
|
43 |
-
parser.add_argument('--
|
44 |
help="Weights for the UNet models")
|
45 |
parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
|
|
|
|
|
46 |
# If True, the input folder contains images of mixed subjects.
|
47 |
# If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
|
48 |
parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
|
@@ -52,19 +53,14 @@ def parse_args():
|
|
52 |
parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
|
53 |
parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
|
54 |
parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
|
55 |
-
parser.add_argument("--
|
56 |
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
57 |
help="Guidance scale for the diffusion model")
|
58 |
parser.add_argument("--ref_img_strength", type=float, default=0.8,
|
59 |
help="Strength of the reference image in the output image.")
|
60 |
-
parser.add_argument("--subject_string",
|
61 |
-
type=str, default="z",
|
62 |
-
help="Subject placeholder string used in prompts to denote the concept.")
|
63 |
parser.add_argument("--prompt", type=str, default="a person z")
|
64 |
parser.add_argument("--num_images_per_row", type=int, default=4,
|
65 |
help="Number of images to display in a row in the output grid image.")
|
66 |
-
parser.add_argument("--num_inference_steps", type=int, default=50,
|
67 |
-
help="Number of DDIM inference steps")
|
68 |
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
|
69 |
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
70 |
parser.add_argument("--seed", type=int, default=42,
|
@@ -93,15 +89,16 @@ if __name__ == "__main__":
|
|
93 |
process_index = 0
|
94 |
|
95 |
adaface = AdaFaceWrapper("img2img", args.base_model_path,
|
96 |
-
args.adaface_encoder_types, args.
|
97 |
args.adaface_encoder_cfg_scales, args.enabled_encoders,
|
98 |
-
args.subject_string, args.num_inference_steps,
|
99 |
unet_types=None,
|
100 |
-
extra_unet_dirpaths=args.extra_unet_dirpaths,
|
|
|
101 |
device=args.device)
|
102 |
|
103 |
in_folder = args.in_folder
|
104 |
if os.path.isfile(in_folder):
|
|
|
105 |
subject_folders = [ os.path.dirname(in_folder) ]
|
106 |
images_by_subject = [[in_folder]]
|
107 |
else:
|
@@ -157,6 +154,24 @@ if __name__ == "__main__":
|
|
157 |
images_by_subject = images_by_subject[process_index::args.num_gpus]
|
158 |
#subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
|
161 |
# If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
|
162 |
# Otherwise, we use the folder name as the signature of the images.
|
@@ -176,29 +191,32 @@ if __name__ == "__main__":
|
|
176 |
os.makedirs(subject_out_folder)
|
177 |
print(f"Output images will be saved to {subject_out_folder}")
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
195 |
|
196 |
with torch.no_grad():
|
197 |
# args.perturb_std: the *relative* std of the noise added to the face embeddings.
|
198 |
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
199 |
# The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
|
200 |
# NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
|
201 |
-
out_images = adaface(in_images, args.prompt, None,
|
|
|
|
|
202 |
|
203 |
for img_i, img in enumerate(out_images):
|
204 |
# out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
|
@@ -206,9 +224,11 @@ if __name__ == "__main__":
|
|
206 |
copy_i = img_i // len(in_images)
|
207 |
image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
|
208 |
if copy_i == 0:
|
209 |
-
|
210 |
else:
|
211 |
-
|
|
|
|
|
212 |
|
213 |
if args.copy_masks:
|
214 |
mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
|
|
|
25 |
|
26 |
def parse_args():
|
27 |
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument("--base_model_path", type=str, default='models/sar/sar.safetensors',
|
29 |
+
help="Path to the UNet checkpoint (Default: SAR)")
|
30 |
+
parser.add_argument('--adaface_ckpt_path', type=str, required=True)
|
|
|
31 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
32 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
33 |
parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
|
|
|
39 |
parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*",
|
40 |
default=[],
|
41 |
help="Extra paths to the checkpoints of the UNet models")
|
42 |
+
parser.add_argument('--unet_weights_in_ensemble', type=float, nargs="+", default=[1],
|
43 |
help="Weights for the UNet models")
|
44 |
parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
|
45 |
+
parser.add_argument("--restore_image", type=str, default=None,
|
46 |
+
help="Path to the image to be restored")
|
47 |
# If True, the input folder contains images of mixed subjects.
|
48 |
# If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
|
49 |
parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
|
|
|
53 |
parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
|
54 |
parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
|
55 |
parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
|
56 |
+
parser.add_argument("--perturb_std", type=float, default=0)
|
57 |
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
58 |
help="Guidance scale for the diffusion model")
|
59 |
parser.add_argument("--ref_img_strength", type=float, default=0.8,
|
60 |
help="Strength of the reference image in the output image.")
|
|
|
|
|
|
|
61 |
parser.add_argument("--prompt", type=str, default="a person z")
|
62 |
parser.add_argument("--num_images_per_row", type=int, default=4,
|
63 |
help="Number of images to display in a row in the output grid image.")
|
|
|
|
|
64 |
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
|
65 |
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
66 |
parser.add_argument("--seed", type=int, default=42,
|
|
|
89 |
process_index = 0
|
90 |
|
91 |
adaface = AdaFaceWrapper("img2img", args.base_model_path,
|
92 |
+
args.adaface_encoder_types, args.adaface_ckpt_path,
|
93 |
args.adaface_encoder_cfg_scales, args.enabled_encoders,
|
|
|
94 |
unet_types=None,
|
95 |
+
extra_unet_dirpaths=args.extra_unet_dirpaths,
|
96 |
+
unet_weights_in_ensemble=args.unet_weights_in_ensemble,
|
97 |
device=args.device)
|
98 |
|
99 |
in_folder = args.in_folder
|
100 |
if os.path.isfile(in_folder):
|
101 |
+
args.in_folder = os.path.dirname(args.in_folder)
|
102 |
subject_folders = [ os.path.dirname(in_folder) ]
|
103 |
images_by_subject = [[in_folder]]
|
104 |
else:
|
|
|
154 |
images_by_subject = images_by_subject[process_index::args.num_gpus]
|
155 |
#subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
|
156 |
|
157 |
+
if args.restore_image is not None:
|
158 |
+
in_images = []
|
159 |
+
for image_path in [args.restore_image]:
|
160 |
+
image = Image.open(image_path).convert("RGB").resize((512, 512))
|
161 |
+
# [512, 512, 3] -> [3, 512, 512].
|
162 |
+
image = np.array(image).transpose(2, 0, 1)
|
163 |
+
# Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
|
164 |
+
image = torch.tensor(image).unsqueeze(0).float().cuda()
|
165 |
+
in_images.append(image)
|
166 |
+
|
167 |
+
# Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
|
168 |
+
# NOTE: For simplicity, we do not check overly large batch sizes.
|
169 |
+
in_images = torch.cat(in_images, dim=0)
|
170 |
+
# in_images: [5, 3, 512, 512].
|
171 |
+
# Normalize the pixel values to [0, 1].
|
172 |
+
in_images = in_images / 255.0
|
173 |
+
num_out_images = len(in_images) * args.out_count_per_input_image
|
174 |
+
|
175 |
for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
|
176 |
# If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
|
177 |
# Otherwise, we use the folder name as the signature of the images.
|
|
|
191 |
os.makedirs(subject_out_folder)
|
192 |
print(f"Output images will be saved to {subject_out_folder}")
|
193 |
|
194 |
+
if args.restore_image is None:
|
195 |
+
in_images = []
|
196 |
+
for image_path in image_paths:
|
197 |
+
image = Image.open(image_path).convert("RGB").resize((512, 512))
|
198 |
+
# [512, 512, 3] -> [3, 512, 512].
|
199 |
+
image = np.array(image).transpose(2, 0, 1)
|
200 |
+
# Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
|
201 |
+
image = torch.tensor(image).unsqueeze(0).float().cuda()
|
202 |
+
in_images.append(image)
|
203 |
+
|
204 |
+
# Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
|
205 |
+
# NOTE: For simplicity, we do not check overly large batch sizes.
|
206 |
+
in_images = torch.cat(in_images, dim=0)
|
207 |
+
# in_images: [5, 3, 512, 512].
|
208 |
+
# Normalize the pixel values to [0, 1].
|
209 |
+
in_images = in_images / 255.0
|
210 |
+
num_out_images = len(in_images) * args.out_count_per_input_image
|
211 |
|
212 |
with torch.no_grad():
|
213 |
# args.perturb_std: the *relative* std of the noise added to the face embeddings.
|
214 |
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
215 |
# The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
|
216 |
# NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
|
217 |
+
out_images = adaface(in_images, args.prompt, None, None,
|
218 |
+
'append', args.guidance_scale, num_out_images,
|
219 |
+
ref_img_strength=args.ref_img_strength)
|
220 |
|
221 |
for img_i, img in enumerate(out_images):
|
222 |
# out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
|
|
|
224 |
copy_i = img_i // len(in_images)
|
225 |
image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
|
226 |
if copy_i == 0:
|
227 |
+
save_path = os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}")
|
228 |
else:
|
229 |
+
save_path = os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}")
|
230 |
+
img.save(save_path)
|
231 |
+
print(f"Saved {save_path}")
|
232 |
|
233 |
if args.copy_masks:
|
234 |
mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
|
adaface/adaface_wrapper.py
CHANGED
@@ -8,22 +8,29 @@ from diffusers import (
|
|
8 |
StableDiffusion3Pipeline,
|
9 |
#FluxPipeline,
|
10 |
DDIMScheduler,
|
|
|
|
|
11 |
AutoencoderKL,
|
|
|
12 |
)
|
13 |
from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
|
14 |
from adaface.util import UNetEnsemble
|
15 |
from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
|
|
|
16 |
from safetensors.torch import load_file as safetensors_load_file
|
17 |
import re, os
|
18 |
import numpy as np
|
|
|
19 |
|
20 |
class AdaFaceWrapper(nn.Module):
|
21 |
def __init__(self, pipeline_name, base_model_path, adaface_encoder_types,
|
22 |
adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
|
23 |
-
enabled_encoders=None,
|
24 |
-
subject_string='z',
|
25 |
use_840k_vae=False, use_ds_text_encoder=False,
|
26 |
-
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None,
|
|
|
|
|
27 |
device='cuda', is_training=False):
|
28 |
'''
|
29 |
pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
|
@@ -38,15 +45,23 @@ class AdaFaceWrapper(nn.Module):
|
|
38 |
self.adaface_ckpt_paths = adaface_ckpt_paths
|
39 |
self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
|
40 |
self.enabled_encoders = enabled_encoders
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
self.subject_string = subject_string
|
|
|
42 |
|
43 |
-
self.
|
|
|
44 |
self.use_840k_vae = use_840k_vae
|
45 |
self.use_ds_text_encoder = use_ds_text_encoder
|
46 |
self.main_unet_filepath = main_unet_filepath
|
47 |
self.unet_types = unet_types
|
48 |
self.extra_unet_dirpaths = extra_unet_dirpaths
|
49 |
-
self.
|
50 |
self.device = device
|
51 |
self.is_training = is_training
|
52 |
|
@@ -62,7 +77,14 @@ class AdaFaceWrapper(nn.Module):
|
|
62 |
self.initialize_pipeline()
|
63 |
# During inference, we never use static image suffix embeddings.
|
64 |
# So num_id_vecs is the length of the returned adaface embeddings for each encoder.
|
65 |
-
self.encoders_num_id_vecs = self.id2ada_prompt_encoder.encoders_num_id_vecs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
self.extend_tokenizer_and_text_encoder()
|
67 |
|
68 |
def to(self, device):
|
@@ -76,7 +98,8 @@ class AdaFaceWrapper(nn.Module):
|
|
76 |
self.id2ada_prompt_encoder = create_id2ada_prompt_encoder(self.adaface_encoder_types,
|
77 |
self.adaface_ckpt_paths,
|
78 |
self.adaface_encoder_cfg_scales,
|
79 |
-
self.enabled_encoders
|
|
|
80 |
|
81 |
self.id2ada_prompt_encoder.to(self.device)
|
82 |
print(f"adaface_encoder_cfg_scales: {self.adaface_encoder_cfg_scales}")
|
@@ -118,10 +141,10 @@ class AdaFaceWrapper(nn.Module):
|
|
118 |
|
119 |
if self.base_model_path is None:
|
120 |
base_model_path_dict = {
|
121 |
-
'text2img':
|
122 |
-
'text2imgxl':
|
123 |
-
'text2img3':
|
124 |
-
'flux':
|
125 |
}
|
126 |
self.base_model_path = base_model_path_dict[self.pipeline_name]
|
127 |
|
@@ -137,6 +160,20 @@ class AdaFaceWrapper(nn.Module):
|
|
137 |
safety_checker=None
|
138 |
)
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
if self.main_unet_filepath is not None:
|
141 |
print(f"Replacing the UNet with the UNet from {self.main_unet_filepath}.")
|
142 |
ret = pipeline.unet.load_state_dict(self.load_unet_from_file(self.main_unet_filepath, device='cpu'))
|
@@ -147,12 +184,19 @@ class AdaFaceWrapper(nn.Module):
|
|
147 |
|
148 |
if (self.unet_types is not None and len(self.unet_types) > 0) \
|
149 |
or (self.extra_unet_dirpaths is not None and len(self.extra_unet_dirpaths) > 0):
|
150 |
-
unet_ensemble = UNetEnsemble([pipeline.unet], self.unet_types, self.extra_unet_dirpaths, self.
|
151 |
device=self.device, torch_dtype=torch.float16)
|
152 |
pipeline.unet = unet_ensemble
|
153 |
|
154 |
print(f"Loaded pipeline from {self.base_model_path}.")
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
if self.use_840k_vae:
|
157 |
pipeline.vae = vae
|
158 |
print("Replaced the VAE with the 840k-step VAE.")
|
@@ -167,19 +211,56 @@ class AdaFaceWrapper(nn.Module):
|
|
167 |
pipeline.vae = None
|
168 |
print("Removed UNet and VAE from the pipeline.")
|
169 |
|
170 |
-
if self.pipeline_name not in ["text2imgxl", "text2img3", "flux"]:
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
pipeline.scheduler = noise_scheduler
|
180 |
-
# Otherwise, pipeline.scheduler == FlowMatchEulerDiscreteScheduler
|
|
|
181 |
self.pipeline = pipeline.to(self.device)
|
182 |
|
|
|
|
|
|
|
|
|
183 |
def load_unet_from_file(self, unet_path, device=None):
|
184 |
if os.path.isfile(unet_path):
|
185 |
if unet_path.endswith(".safetensors"):
|
@@ -208,7 +289,109 @@ class AdaFaceWrapper(nn.Module):
|
|
208 |
else:
|
209 |
raise ValueError(f"UNet path {unet_path} is not a file.")
|
210 |
return unet_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
def extend_tokenizer_and_text_encoder(self):
|
213 |
if np.sum(self.encoders_num_id_vecs) < 1:
|
214 |
raise ValueError(f"encoders_num_id_vecs has to be larger or equal to 1, but is {self.encoders_num_id_vecs}")
|
@@ -218,6 +401,7 @@ class AdaFaceWrapper(nn.Module):
|
|
218 |
# We add z_0_0, z_0_1, z_0_2, ..., z_0_15, z_1_0, z_1_1, z_1_2, z_1_3 to the tokenizer.
|
219 |
self.all_placeholder_tokens = []
|
220 |
self.placeholder_tokens_strs = []
|
|
|
221 |
for i in range(len(self.adaface_encoder_types)):
|
222 |
placeholder_tokens = []
|
223 |
for j in range(self.encoders_num_id_vecs[i]):
|
@@ -225,9 +409,11 @@ class AdaFaceWrapper(nn.Module):
|
|
225 |
placeholder_tokens_str = " ".join(placeholder_tokens)
|
226 |
|
227 |
self.all_placeholder_tokens.extend(placeholder_tokens)
|
|
|
228 |
self.placeholder_tokens_strs.append(placeholder_tokens_str)
|
229 |
|
230 |
self.all_placeholder_tokens_str = " ".join(self.placeholder_tokens_strs)
|
|
|
231 |
# all_null_placeholder_tokens_str: ", , , , ..." (20 times).
|
232 |
# It just contains the commas and spaces with the same length, but no actual tokens.
|
233 |
self.all_null_placeholder_tokens_str = " ".join([", "] * len(self.all_placeholder_tokens))
|
@@ -241,7 +427,7 @@ class AdaFaceWrapper(nn.Module):
|
|
241 |
|
242 |
print(f"Added {num_added_tokens} tokens ({self.all_placeholder_tokens_str}) to the tokenizer.")
|
243 |
|
244 |
-
# placeholder_token_ids: [49408, ...,
|
245 |
self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.all_placeholder_tokens)
|
246 |
#print("New tokens:", self.placeholder_token_ids)
|
247 |
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
@@ -252,24 +438,49 @@ class AdaFaceWrapper(nn.Module):
|
|
252 |
|
253 |
# Extend pipeline.text_encoder with the adaface subject emeddings.
|
254 |
# subj_embs: [16, 768].
|
255 |
-
def update_text_encoder_subj_embeddings(self, subj_embs):
|
256 |
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
257 |
# token_embeds: [49412, 768]
|
258 |
token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
|
|
|
|
|
|
|
|
|
259 |
with torch.no_grad():
|
260 |
-
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
def update_prompt(self, prompt, placeholder_tokens_pos='append',
|
|
|
265 |
use_null_placeholders=False):
|
266 |
if prompt is None:
|
267 |
prompt = ""
|
268 |
|
269 |
if use_null_placeholders:
|
270 |
all_placeholder_tokens_str = self.all_null_placeholder_tokens_str
|
|
|
|
|
|
|
271 |
else:
|
272 |
-
all_placeholder_tokens_str = self.
|
273 |
|
274 |
# Delete the subject_string from the prompt.
|
275 |
prompt = re.sub(r'\b(a|an|the)\s+' + self.subject_string + r'\b,?', "", prompt)
|
@@ -279,15 +490,29 @@ class AdaFaceWrapper(nn.Module):
|
|
279 |
# When we do joint training, seems both work better if they are appended to the prompt.
|
280 |
# Therefore we simply appended all placeholder_tokens_str's to the prompt.
|
281 |
# NOTE: Prepending them hurts compositional prompts.
|
282 |
-
if
|
283 |
-
|
284 |
-
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
else:
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
return prompt
|
290 |
|
|
|
|
|
291 |
# If face_id_embs is None, then it extracts face_id_embs from the images,
|
292 |
# then map them to ada prompt embeddings.
|
293 |
# avg_at_stage: 'id_emb', 'img_prompt_emb', or None.
|
@@ -298,27 +523,29 @@ class AdaFaceWrapper(nn.Module):
|
|
298 |
perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
|
299 |
perturb_std=0, update_text_encoder=True):
|
300 |
|
301 |
-
all_adaface_subj_embs = \
|
302 |
self.id2ada_prompt_encoder.generate_adaface_embeddings(\
|
303 |
image_paths, face_id_embs=face_id_embs,
|
304 |
img_prompt_embs=None,
|
305 |
avg_at_stage=avg_at_stage,
|
306 |
perturb_at_stage=perturb_at_stage,
|
307 |
perturb_std=perturb_std,
|
308 |
-
enable_static_img_suffix_embs=
|
309 |
|
310 |
if all_adaface_subj_embs is None:
|
311 |
return None
|
312 |
|
|
|
|
|
313 |
if all_adaface_subj_embs.ndim == 4:
|
314 |
-
# [1, 1,
|
315 |
all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0).squeeze(0)
|
316 |
elif all_adaface_subj_embs.ndim == 3:
|
317 |
-
# [1,
|
318 |
all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0)
|
319 |
|
320 |
if update_text_encoder:
|
321 |
-
self.update_text_encoder_subj_embeddings(all_adaface_subj_embs)
|
322 |
return all_adaface_subj_embs
|
323 |
|
324 |
def diffusers_encode_prompts(self, prompt, plain_prompt, negative_prompt, device):
|
@@ -368,6 +595,7 @@ class AdaFaceWrapper(nn.Module):
|
|
368 |
else:
|
369 |
breakpoint()
|
370 |
else:
|
|
|
371 |
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
372 |
prompt_embeds_, negative_prompt_embeds_ = \
|
373 |
self.pipeline.encode_prompt(prompt, device=device,
|
@@ -378,9 +606,53 @@ class AdaFaceWrapper(nn.Module):
|
|
378 |
return prompt_embeds_, negative_prompt_embeds_, \
|
379 |
pooled_prompt_embeds_, negative_pooled_prompt_embeds_
|
380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
def encode_prompt(self, prompt, negative_prompt=None,
|
382 |
placeholder_tokens_pos='append',
|
383 |
-
|
|
|
|
|
|
|
|
|
384 |
device=None, verbose=False):
|
385 |
if negative_prompt is None:
|
386 |
negative_prompt = self.negative_prompt
|
@@ -389,59 +661,81 @@ class AdaFaceWrapper(nn.Module):
|
|
389 |
device = self.device
|
390 |
|
391 |
plain_prompt = prompt
|
392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
if verbose:
|
394 |
print(f"Subject prompt:\n{prompt}")
|
395 |
|
396 |
-
if do_neg_id_prompt_weight > 0:
|
397 |
-
# Use 'prepend' for the negative prompt, since it's long and we want to make sure
|
398 |
-
# the placeholder tokens are not cut off.
|
399 |
-
negative_prompt0 = negative_prompt
|
400 |
-
negative_prompt = self.update_prompt(negative_prompt0, placeholder_tokens_pos='prepend')
|
401 |
-
null_negative_prompt = self.update_prompt(negative_prompt0, placeholder_tokens_pos='prepend',
|
402 |
-
use_null_placeholders=True)
|
403 |
-
''' if verbose:
|
404 |
-
print(f"Negative prompt:\n{negative_prompt}")
|
405 |
-
print(f"Null negative prompt:\n{null_negative_prompt}")
|
406 |
-
|
407 |
-
'''
|
408 |
-
else:
|
409 |
-
null_negative_prompt = None
|
410 |
-
|
411 |
# For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
|
412 |
# So we manually move it to GPU here.
|
413 |
self.pipeline.text_encoder.to(device)
|
414 |
|
415 |
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
|
416 |
self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
|
417 |
-
|
418 |
-
if
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
return prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_
|
425 |
|
426 |
# ref_img_strength is used only in the img2img pipeline.
|
427 |
-
def forward(self, noise, prompt, negative_prompt=None,
|
428 |
placeholder_tokens_pos='append',
|
429 |
-
do_neg_id_prompt_weight=0,
|
430 |
guidance_scale=6.0, out_image_count=4,
|
431 |
-
ref_img_strength=0.8, generator=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
noise = noise.to(device=self.device, dtype=torch.float16)
|
|
|
|
|
433 |
|
434 |
if negative_prompt is None:
|
435 |
negative_prompt = self.negative_prompt
|
436 |
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
# Repeat the prompt embeddings for all images in the batch.
|
444 |
prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
|
|
|
445 |
if negative_prompt_embeds_ is not None:
|
446 |
negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
|
447 |
|
|
|
8 |
StableDiffusion3Pipeline,
|
9 |
#FluxPipeline,
|
10 |
DDIMScheduler,
|
11 |
+
PNDMScheduler,
|
12 |
+
DPMSolverSinglestepScheduler,
|
13 |
AutoencoderKL,
|
14 |
+
LCMScheduler,
|
15 |
)
|
16 |
from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
|
17 |
from adaface.util import UNetEnsemble
|
18 |
from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
|
19 |
+
from adaface.diffusers_attn_lora_capture import set_up_attn_processors, set_up_ffn_loras, set_lora_and_capture_flags
|
20 |
from safetensors.torch import load_file as safetensors_load_file
|
21 |
import re, os
|
22 |
import numpy as np
|
23 |
+
from peft.utils.constants import DUMMY_TARGET_MODULES
|
24 |
|
25 |
class AdaFaceWrapper(nn.Module):
|
26 |
def __init__(self, pipeline_name, base_model_path, adaface_encoder_types,
|
27 |
adaface_ckpt_paths, adaface_encoder_cfg_scales=None,
|
28 |
+
enabled_encoders=None, use_lcm=False, default_scheduler_name='ddim',
|
29 |
+
num_inference_steps=50, subject_string='z', negative_prompt=None,
|
30 |
use_840k_vae=False, use_ds_text_encoder=False,
|
31 |
+
main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
32 |
+
enable_static_img_suffix_embs=None, unet_uses_attn_lora=False,
|
33 |
+
attn_lora_layer_names=['q', 'k', 'v', 'out'], shrink_cross_attn=False, q_lora_updates_query=False,
|
34 |
device='cuda', is_training=False):
|
35 |
'''
|
36 |
pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None.
|
|
|
45 |
self.adaface_ckpt_paths = adaface_ckpt_paths
|
46 |
self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
|
47 |
self.enabled_encoders = enabled_encoders
|
48 |
+
# None, or a list of two bools for two encoders. If None, both are disabled.
|
49 |
+
self.enable_static_img_suffix_embs = enable_static_img_suffix_embs
|
50 |
+
self.unet_uses_attn_lora = unet_uses_attn_lora
|
51 |
+
self.attn_lora_layer_names = attn_lora_layer_names
|
52 |
+
self.q_lora_updates_query = q_lora_updates_query
|
53 |
+
self.use_lcm = use_lcm
|
54 |
self.subject_string = subject_string
|
55 |
+
self.shrink_cross_attn = shrink_cross_attn
|
56 |
|
57 |
+
self.default_scheduler_name = default_scheduler_name
|
58 |
+
self.num_inference_steps = num_inference_steps if not use_lcm else 4
|
59 |
self.use_840k_vae = use_840k_vae
|
60 |
self.use_ds_text_encoder = use_ds_text_encoder
|
61 |
self.main_unet_filepath = main_unet_filepath
|
62 |
self.unet_types = unet_types
|
63 |
self.extra_unet_dirpaths = extra_unet_dirpaths
|
64 |
+
self.unet_weights_in_ensemble = unet_weights_in_ensemble
|
65 |
self.device = device
|
66 |
self.is_training = is_training
|
67 |
|
|
|
77 |
self.initialize_pipeline()
|
78 |
# During inference, we never use static image suffix embeddings.
|
79 |
# So num_id_vecs is the length of the returned adaface embeddings for each encoder.
|
80 |
+
self.encoders_num_id_vecs = np.array(self.id2ada_prompt_encoder.encoders_num_id_vecs)
|
81 |
+
self.encoders_num_static_img_suffix_embs = np.array(self.id2ada_prompt_encoder.encoders_num_static_img_suffix_embs)
|
82 |
+
if self.enable_static_img_suffix_embs is not None:
|
83 |
+
assert len(self.enable_static_img_suffix_embs) == len(self.encoders_num_id_vecs)
|
84 |
+
self.encoders_num_static_img_suffix_embs *= np.array(self.enable_static_img_suffix_embs)
|
85 |
+
self.encoders_num_id_vecs += self.encoders_num_static_img_suffix_embs
|
86 |
+
|
87 |
+
self.img_prompt_embs = None
|
88 |
self.extend_tokenizer_and_text_encoder()
|
89 |
|
90 |
def to(self, device):
|
|
|
98 |
self.id2ada_prompt_encoder = create_id2ada_prompt_encoder(self.adaface_encoder_types,
|
99 |
self.adaface_ckpt_paths,
|
100 |
self.adaface_encoder_cfg_scales,
|
101 |
+
self.enabled_encoders,
|
102 |
+
num_static_img_suffix_embs=4)
|
103 |
|
104 |
self.id2ada_prompt_encoder.to(self.device)
|
105 |
print(f"adaface_encoder_cfg_scales: {self.adaface_encoder_cfg_scales}")
|
|
|
141 |
|
142 |
if self.base_model_path is None:
|
143 |
base_model_path_dict = {
|
144 |
+
'text2img': 'models/sd15-dste8-vae.safetensors',
|
145 |
+
'text2imgxl': 'stabilityai/stable-diffusion-xl-base-1.0',
|
146 |
+
'text2img3': 'stabilityai/stable-diffusion-3-medium-diffusers',
|
147 |
+
'flux': 'black-forest-labs/FLUX.1-schnell',
|
148 |
}
|
149 |
self.base_model_path = base_model_path_dict[self.pipeline_name]
|
150 |
|
|
|
160 |
safety_checker=None
|
161 |
)
|
162 |
|
163 |
+
if self.use_lcm:
|
164 |
+
lcm_path_dict = {
|
165 |
+
'text2img': 'latent-consistency/lcm-lora-sdv1-5',
|
166 |
+
'text2imgxl': 'latent-consistency/lcm-lora-sdxl',
|
167 |
+
}
|
168 |
+
if self.pipeline_name not in lcm_path_dict:
|
169 |
+
raise ValueError(f"Pipeline {self.pipeline_name} does not support LCM.")
|
170 |
+
|
171 |
+
lcm_path = lcm_path_dict[self.pipeline_name]
|
172 |
+
pipeline.load_lora_weights(lcm_path)
|
173 |
+
pipeline.fuse_lora()
|
174 |
+
print(f"Loaded LCM weights from {lcm_path}.")
|
175 |
+
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
|
176 |
+
|
177 |
if self.main_unet_filepath is not None:
|
178 |
print(f"Replacing the UNet with the UNet from {self.main_unet_filepath}.")
|
179 |
ret = pipeline.unet.load_state_dict(self.load_unet_from_file(self.main_unet_filepath, device='cpu'))
|
|
|
184 |
|
185 |
if (self.unet_types is not None and len(self.unet_types) > 0) \
|
186 |
or (self.extra_unet_dirpaths is not None and len(self.extra_unet_dirpaths) > 0):
|
187 |
+
unet_ensemble = UNetEnsemble([pipeline.unet], self.unet_types, self.extra_unet_dirpaths, self.unet_weights_in_ensemble,
|
188 |
device=self.device, torch_dtype=torch.float16)
|
189 |
pipeline.unet = unet_ensemble
|
190 |
|
191 |
print(f"Loaded pipeline from {self.base_model_path}.")
|
192 |
+
if not remove_unet and (self.unet_uses_attn_lora or self.shrink_cross_attn):
|
193 |
+
unet2 = self.load_unet_lora_weights(pipeline.unet, use_attn_lora=self.unet_uses_attn_lora,
|
194 |
+
attn_lora_layer_names=self.attn_lora_layer_names,
|
195 |
+
shrink_cross_attn=self.shrink_cross_attn,
|
196 |
+
q_lora_updates_query=self.q_lora_updates_query)
|
197 |
+
|
198 |
+
pipeline.unet = unet2
|
199 |
+
|
200 |
if self.use_840k_vae:
|
201 |
pipeline.vae = vae
|
202 |
print("Replaced the VAE with the 840k-step VAE.")
|
|
|
211 |
pipeline.vae = None
|
212 |
print("Removed UNet and VAE from the pipeline.")
|
213 |
|
214 |
+
if self.pipeline_name not in ["text2imgxl", "text2img3", "flux"] and not self.use_lcm:
|
215 |
+
if self.default_scheduler_name == 'ddim':
|
216 |
+
noise_scheduler = DDIMScheduler(
|
217 |
+
num_train_timesteps=1000,
|
218 |
+
beta_start=0.00085,
|
219 |
+
beta_end=0.012,
|
220 |
+
beta_schedule="scaled_linear",
|
221 |
+
clip_sample=False,
|
222 |
+
set_alpha_to_one=False,
|
223 |
+
steps_offset=1,
|
224 |
+
timestep_spacing="leading",
|
225 |
+
rescale_betas_zero_snr=False,
|
226 |
+
)
|
227 |
+
elif self.default_scheduler_name == 'pndm':
|
228 |
+
noise_scheduler = PNDMScheduler(
|
229 |
+
num_train_timesteps=1000,
|
230 |
+
beta_start=0.00085,
|
231 |
+
beta_end=0.012,
|
232 |
+
beta_schedule="scaled_linear",
|
233 |
+
set_alpha_to_one=False,
|
234 |
+
steps_offset=1,
|
235 |
+
timestep_spacing="leading",
|
236 |
+
skip_prk_steps=True,
|
237 |
+
)
|
238 |
+
elif self.default_scheduler_name == 'dpm++':
|
239 |
+
noise_scheduler = DPMSolverSinglestepScheduler(
|
240 |
+
beta_start=0.00085,
|
241 |
+
beta_end=0.012,
|
242 |
+
beta_schedule="scaled_linear",
|
243 |
+
prediction_type="epsilon",
|
244 |
+
num_train_timesteps=1000,
|
245 |
+
trained_betas=None,
|
246 |
+
thresholding=False,
|
247 |
+
algorithm_type="dpmsolver++",
|
248 |
+
solver_type="midpoint",
|
249 |
+
lower_order_final=True,
|
250 |
+
use_karras_sigmas=True,
|
251 |
+
)
|
252 |
+
else:
|
253 |
+
breakpoint()
|
254 |
+
|
255 |
pipeline.scheduler = noise_scheduler
|
256 |
+
# Otherwise, if not use_lcm, pipeline.scheduler == FlowMatchEulerDiscreteScheduler
|
257 |
+
# if use_lcm, pipeline.scheduler == LCMScheduler
|
258 |
self.pipeline = pipeline.to(self.device)
|
259 |
|
260 |
+
def set_adaface_encoder_cfg_scales(self, adaface_encoder_cfg_scales):
|
261 |
+
self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
|
262 |
+
self.id2ada_prompt_encoder.set_out_id_embs_cfg_scale(adaface_encoder_cfg_scales)
|
263 |
+
|
264 |
def load_unet_from_file(self, unet_path, device=None):
|
265 |
if os.path.isfile(unet_path):
|
266 |
if unet_path.endswith(".safetensors"):
|
|
|
289 |
else:
|
290 |
raise ValueError(f"UNet path {unet_path} is not a file.")
|
291 |
return unet_state_dict
|
292 |
+
|
293 |
+
# Adapted from ConsistentIDPipeline:set_ip_adapter().
|
294 |
+
def load_unet_loras(self, unet, unet_lora_modules_state_dict,
|
295 |
+
use_attn_lora=True, use_ffn_lora=False,
|
296 |
+
attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
297 |
+
shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
|
298 |
+
q_lora_updates_query=False):
|
299 |
+
attn_capture_procs, attn_opt_modules = \
|
300 |
+
set_up_attn_processors(unet, use_attn_lora=True, attn_lora_layer_names=attn_lora_layer_names,
|
301 |
+
lora_rank=192, lora_scale_down=8,
|
302 |
+
cross_attn_shrink_factor=cross_attn_shrink_factor,
|
303 |
+
q_lora_updates_query=q_lora_updates_query)
|
304 |
+
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut. [12] matches 1 or 2.
|
305 |
+
if use_ffn_lora:
|
306 |
+
target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
|
307 |
+
else:
|
308 |
+
# A special pattern, "dummy-target-modules" tells PEFT to add loras on NONE of the layers.
|
309 |
+
# We couldn't simply skip PEFT initialization (converting unet to a PEFT model),
|
310 |
+
# otherwise the attn lora layers will cause nan quickly during a fp16 training.
|
311 |
+
target_modules_pat = DUMMY_TARGET_MODULES
|
312 |
+
|
313 |
+
unet, ffn_lora_layers, ffn_opt_modules = \
|
314 |
+
set_up_ffn_loras(unet, target_modules_pat=target_modules_pat, lora_uses_dora=True)
|
315 |
+
|
316 |
+
# self.attn_capture_procs and ffn_lora_layers will be used in set_lora_and_capture_flags().
|
317 |
+
self.attn_capture_procs = list(attn_capture_procs.values())
|
318 |
+
self.ffn_lora_layers = list(ffn_lora_layers.values())
|
319 |
+
# Combine attn_opt_modules and ffn_opt_modules into unet_lora_modules.
|
320 |
+
# unet_lora_modules is for optimization and loading/saving.
|
321 |
+
unet_lora_modules = {}
|
322 |
+
# attn_opt_modules and ffn_opt_modules have different depths of keys.
|
323 |
+
# attn_opt_modules:
|
324 |
+
# up_blocks_3_attentions_1_transformer_blocks_0_attn2_processor_std_shrink_factor,
|
325 |
+
# up_blocks_3_attentions_1_transformer_blocks_0_attn2_processor_to_q_lora_lora_A, ...
|
326 |
+
# ffn_opt_modules:
|
327 |
+
# base_model_model_up_blocks_3_resnets_1_conv1_lora_A, ...
|
328 |
+
# with the prefix 'base_model_model_'. Because ffn_opt_modules are extracted from the peft-wrapped model,
|
329 |
+
# and attn_opt_modules are extracted from the original unet model.
|
330 |
+
# To be compatible with old param keys, we append 'base_model_model_' to the keys of attn_opt_modules.
|
331 |
+
unet_lora_modules.update({ f'base_model_model_{k}': v for k, v in attn_opt_modules.items() })
|
332 |
+
unet_lora_modules.update(ffn_opt_modules)
|
333 |
+
# ParameterDict can contain both Parameter and nn.Module.
|
334 |
+
# TODO: maybe in the future, we couldn't put nn.Module in nn.ParameterDict.
|
335 |
+
self.unet_lora_modules = torch.nn.ParameterDict(unet_lora_modules)
|
336 |
+
|
337 |
+
missing, unexpected = self.unet_lora_modules.load_state_dict(unet_lora_modules_state_dict, strict=False)
|
338 |
+
if len(missing) > 0:
|
339 |
+
print(f"Missing Keys: {missing}")
|
340 |
+
if len(unexpected) > 0:
|
341 |
+
print(f"Unexpected Keys: {unexpected}")
|
342 |
+
|
343 |
+
print(f"Loaded {len(unet_lora_modules_state_dict)} LoRA weights on the UNet:\n{unet_lora_modules.keys()}")
|
344 |
+
self.outfeat_capture_blocks.append(unet.up_blocks[3])
|
345 |
+
|
346 |
+
# If shrink_cross_attn is True and use_attn_lora is False, we load all these params from ckpt,
|
347 |
+
# but since we set use_attn_lora to False, attn loras won't be used during inference nonetheless.
|
348 |
+
set_lora_and_capture_flags(unet, None, self.attn_capture_procs, self.outfeat_capture_blocks,
|
349 |
+
use_attn_lora, use_ffn_lora, 'recon_loss', capture_ca_activations=False,
|
350 |
+
shrink_cross_attn=shrink_cross_attn)
|
351 |
+
|
352 |
+
return unet
|
353 |
+
|
354 |
+
def load_unet_lora_weights(self, unet, use_attn_lora=True, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
355 |
+
shrink_cross_attn=False, q_lora_updates_query=False):
|
356 |
+
unet_lora_weight_found = False
|
357 |
+
if isinstance(self.adaface_ckpt_paths, str):
|
358 |
+
adaface_ckpt_paths = [self.adaface_ckpt_paths]
|
359 |
+
else:
|
360 |
+
adaface_ckpt_paths = self.adaface_ckpt_paths
|
361 |
+
|
362 |
+
for adaface_ckpt_path in adaface_ckpt_paths:
|
363 |
+
ckpt_dict = torch.load(adaface_ckpt_path, map_location='cpu')
|
364 |
+
if 'unet_lora_modules' in ckpt_dict:
|
365 |
+
unet_lora_modules_state_dict = ckpt_dict['unet_lora_modules']
|
366 |
+
print(f"{len(unet_lora_modules_state_dict)} LoRA weights found in {adaface_ckpt_path}.")
|
367 |
+
unet_lora_weight_found = True
|
368 |
+
break
|
369 |
+
|
370 |
+
# Since unet lora weights are not found in the adaface ckpt, we give up on loading unet attn processors.
|
371 |
+
if not unet_lora_weight_found:
|
372 |
+
print(f"LoRA weights not found in {self.adaface_ckpt_paths}.")
|
373 |
+
return unet
|
374 |
|
375 |
+
self.outfeat_capture_blocks = []
|
376 |
+
|
377 |
+
if isinstance(unet, UNetEnsemble):
|
378 |
+
for i, unet_ in enumerate(unet.unets):
|
379 |
+
unet_ = self.load_unet_loras(unet_, unet_lora_modules_state_dict,
|
380 |
+
use_attn_lora=use_attn_lora,
|
381 |
+
attn_lora_layer_names=attn_lora_layer_names,
|
382 |
+
shrink_cross_attn=shrink_cross_attn,
|
383 |
+
q_lora_updates_query=q_lora_updates_query)
|
384 |
+
unet.unets[i] = unet_
|
385 |
+
print(f"Loaded LoRA processors on UNetEnsemble of {len(unet.unets)} UNets.")
|
386 |
+
else:
|
387 |
+
unet = self.load_unet_loras(unet, unet_lora_modules_state_dict,
|
388 |
+
use_attn_lora=use_attn_lora,
|
389 |
+
attn_lora_layer_names=attn_lora_layer_names,
|
390 |
+
shrink_cross_attn=shrink_cross_attn,
|
391 |
+
q_lora_updates_query=q_lora_updates_query)
|
392 |
+
|
393 |
+
return unet
|
394 |
+
|
395 |
def extend_tokenizer_and_text_encoder(self):
|
396 |
if np.sum(self.encoders_num_id_vecs) < 1:
|
397 |
raise ValueError(f"encoders_num_id_vecs has to be larger or equal to 1, but is {self.encoders_num_id_vecs}")
|
|
|
401 |
# We add z_0_0, z_0_1, z_0_2, ..., z_0_15, z_1_0, z_1_1, z_1_2, z_1_3 to the tokenizer.
|
402 |
self.all_placeholder_tokens = []
|
403 |
self.placeholder_tokens_strs = []
|
404 |
+
self.encoder_placeholder_tokens = []
|
405 |
for i in range(len(self.adaface_encoder_types)):
|
406 |
placeholder_tokens = []
|
407 |
for j in range(self.encoders_num_id_vecs[i]):
|
|
|
409 |
placeholder_tokens_str = " ".join(placeholder_tokens)
|
410 |
|
411 |
self.all_placeholder_tokens.extend(placeholder_tokens)
|
412 |
+
self.encoder_placeholder_tokens.append(placeholder_tokens)
|
413 |
self.placeholder_tokens_strs.append(placeholder_tokens_str)
|
414 |
|
415 |
self.all_placeholder_tokens_str = " ".join(self.placeholder_tokens_strs)
|
416 |
+
self.updated_tokens_str = self.all_placeholder_tokens_str
|
417 |
# all_null_placeholder_tokens_str: ", , , , ..." (20 times).
|
418 |
# It just contains the commas and spaces with the same length, but no actual tokens.
|
419 |
self.all_null_placeholder_tokens_str = " ".join([", "] * len(self.all_placeholder_tokens))
|
|
|
427 |
|
428 |
print(f"Added {num_added_tokens} tokens ({self.all_placeholder_tokens_str}) to the tokenizer.")
|
429 |
|
430 |
+
# placeholder_token_ids: [49408, ..., 49427].
|
431 |
self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.all_placeholder_tokens)
|
432 |
#print("New tokens:", self.placeholder_token_ids)
|
433 |
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
|
|
438 |
|
439 |
# Extend pipeline.text_encoder with the adaface subject emeddings.
|
440 |
# subj_embs: [16, 768].
|
441 |
+
def update_text_encoder_subj_embeddings(self, subj_embs, lens_subj_emb_segments):
|
442 |
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
443 |
# token_embeds: [49412, 768]
|
444 |
token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
|
445 |
+
all_encoders_updated_tokens = []
|
446 |
+
all_encoders_updated_token_strs = []
|
447 |
+
idx = 0
|
448 |
+
|
449 |
with torch.no_grad():
|
450 |
+
# sum of lens_subj_emb_segments are probably shorter than self.placeholder_token_ids,
|
451 |
+
# when some static_img_suffix_embs are disabled.
|
452 |
+
for i, encoder_type in enumerate(self.adaface_encoder_types):
|
453 |
+
encoder_updated_tokens = []
|
454 |
+
if (self.enabled_encoders is not None) and (encoder_type not in self.enabled_encoders):
|
455 |
+
idx += lens_subj_emb_segments[i]
|
456 |
+
continue
|
457 |
+
for j in range(lens_subj_emb_segments[i]):
|
458 |
+
placeholder_token = f"{self.subject_string}_{i}_{j}"
|
459 |
+
token_id = self.pipeline.tokenizer.convert_tokens_to_ids(placeholder_token)
|
460 |
+
token_embeds[token_id] = subj_embs[idx]
|
461 |
+
encoder_updated_tokens.append(placeholder_token)
|
462 |
+
idx += 1
|
463 |
+
|
464 |
+
all_encoders_updated_tokens.extend(encoder_updated_tokens)
|
465 |
+
all_encoders_updated_token_strs.append(" ".join(encoder_updated_tokens))
|
466 |
+
|
467 |
+
self.updated_tokens_str = " ".join(all_encoders_updated_token_strs)
|
468 |
+
self.all_encoders_updated_token_strs = all_encoders_updated_token_strs
|
469 |
+
print(f"Updated {len(all_encoders_updated_tokens)} tokens ({self.updated_tokens_str}) in the text encoder.")
|
470 |
|
471 |
def update_prompt(self, prompt, placeholder_tokens_pos='append',
|
472 |
+
repeat_prompt_for_each_encoder=True,
|
473 |
use_null_placeholders=False):
|
474 |
if prompt is None:
|
475 |
prompt = ""
|
476 |
|
477 |
if use_null_placeholders:
|
478 |
all_placeholder_tokens_str = self.all_null_placeholder_tokens_str
|
479 |
+
if not re.search(r"\b(man|woman|person|child|girl|boy)\b", prompt.lower()):
|
480 |
+
all_placeholder_tokens_str = "person " + all_placeholder_tokens_str
|
481 |
+
repeat_prompt_for_each_encoder = False
|
482 |
else:
|
483 |
+
all_placeholder_tokens_str = self.updated_tokens_str
|
484 |
|
485 |
# Delete the subject_string from the prompt.
|
486 |
prompt = re.sub(r'\b(a|an|the)\s+' + self.subject_string + r'\b,?', "", prompt)
|
|
|
490 |
# When we do joint training, seems both work better if they are appended to the prompt.
|
491 |
# Therefore we simply appended all placeholder_tokens_str's to the prompt.
|
492 |
# NOTE: Prepending them hurts compositional prompts.
|
493 |
+
if repeat_prompt_for_each_encoder:
|
494 |
+
encoder_prompts = []
|
495 |
+
for encoder_updated_token_strs in self.all_encoders_updated_token_strs:
|
496 |
+
if placeholder_tokens_pos == 'prepend':
|
497 |
+
encoder_prompt = encoder_updated_token_strs + " " + prompt
|
498 |
+
elif placeholder_tokens_pos == 'append':
|
499 |
+
encoder_prompt = prompt + " " + encoder_updated_token_strs
|
500 |
+
else:
|
501 |
+
breakpoint()
|
502 |
+
encoder_prompts.append(encoder_prompt)
|
503 |
+
prompt = ", ".join(encoder_prompts)
|
504 |
else:
|
505 |
+
if placeholder_tokens_pos == 'prepend':
|
506 |
+
prompt = all_placeholder_tokens_str + " " + prompt
|
507 |
+
elif placeholder_tokens_pos == 'append':
|
508 |
+
prompt = prompt + " " + all_placeholder_tokens_str
|
509 |
+
else:
|
510 |
+
breakpoint()
|
511 |
|
512 |
return prompt
|
513 |
|
514 |
+
# NOTE: all_adaface_subj_embs is the input to the CLIP text encoder.
|
515 |
+
# ** DO NOT use it as prompt_embeds in the forward() method.
|
516 |
# If face_id_embs is None, then it extracts face_id_embs from the images,
|
517 |
# then map them to ada prompt embeddings.
|
518 |
# avg_at_stage: 'id_emb', 'img_prompt_emb', or None.
|
|
|
523 |
perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
|
524 |
perturb_std=0, update_text_encoder=True):
|
525 |
|
526 |
+
all_adaface_subj_embs, img_prompt_embs, lens_subj_emb_segments = \
|
527 |
self.id2ada_prompt_encoder.generate_adaface_embeddings(\
|
528 |
image_paths, face_id_embs=face_id_embs,
|
529 |
img_prompt_embs=None,
|
530 |
avg_at_stage=avg_at_stage,
|
531 |
perturb_at_stage=perturb_at_stage,
|
532 |
perturb_std=perturb_std,
|
533 |
+
enable_static_img_suffix_embs=self.enable_static_img_suffix_embs)
|
534 |
|
535 |
if all_adaface_subj_embs is None:
|
536 |
return None
|
537 |
|
538 |
+
self.img_prompt_embs = img_prompt_embs
|
539 |
+
|
540 |
if all_adaface_subj_embs.ndim == 4:
|
541 |
+
# [1, 1, 20, 768] -> [20, 768]
|
542 |
all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0).squeeze(0)
|
543 |
elif all_adaface_subj_embs.ndim == 3:
|
544 |
+
# [1, 20, 768] -> [20, 768]
|
545 |
all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0)
|
546 |
|
547 |
if update_text_encoder:
|
548 |
+
self.update_text_encoder_subj_embeddings(all_adaface_subj_embs, lens_subj_emb_segments)
|
549 |
return all_adaface_subj_embs
|
550 |
|
551 |
def diffusers_encode_prompts(self, prompt, plain_prompt, negative_prompt, device):
|
|
|
595 |
else:
|
596 |
breakpoint()
|
597 |
else:
|
598 |
+
# "text2img" and "img2img" pipelines.
|
599 |
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
600 |
prompt_embeds_, negative_prompt_embeds_ = \
|
601 |
self.pipeline.encode_prompt(prompt, device=device,
|
|
|
606 |
return prompt_embeds_, negative_prompt_embeds_, \
|
607 |
pooled_prompt_embeds_, negative_pooled_prompt_embeds_
|
608 |
|
609 |
+
# alt_prompt_embed_type: 'ada-nonmix', 'img'
|
610 |
+
def mix_ada_embs_with_other_embs(self, prompt, prompt_embeds,
|
611 |
+
alt_prompt_embed_type, alt_prompt_emb_weights):
|
612 |
+
# Scan prompt and replace tokens in self.placeholder_token_ids
|
613 |
+
# with the corresponding image embeddings.
|
614 |
+
prompt_tokens = self.pipeline.tokenizer.tokenize(prompt)
|
615 |
+
prompt_embeds2 = prompt_embeds.clone()
|
616 |
+
if alt_prompt_embed_type == 'img':
|
617 |
+
if self.img_prompt_embs is None:
|
618 |
+
print("Unable to find img_prompt_embs. Either prepare_adaface_embeddings() hasn't been called, or faceless images were used.")
|
619 |
+
return prompt_embeds
|
620 |
+
# self.img_prompt_embs: [1, 20, 768]
|
621 |
+
repl_embeddings = self.img_prompt_embs
|
622 |
+
elif alt_prompt_embed_type == 'ada-nonmix':
|
623 |
+
repl_embeddings_, _, _, _ = self.encode_prompt(prompt, ablate_prompt_only_placeholders=True,
|
624 |
+
verbose=True)
|
625 |
+
# repl_embeddings_: [1, 77, 768] -> [1, 20, 768]
|
626 |
+
repl_embeddings = repl_embeddings_[:, 1:len(self.all_placeholder_tokens)+1]
|
627 |
+
else:
|
628 |
+
breakpoint()
|
629 |
+
|
630 |
+
repl_tokens = {}
|
631 |
+
for i in range(len(prompt_tokens)):
|
632 |
+
if prompt_tokens[i] in self.all_placeholder_tokens:
|
633 |
+
encoder_idx = next((i for i, sublist in enumerate(self.encoder_placeholder_tokens) \
|
634 |
+
if prompt_tokens[i] in sublist), 0)
|
635 |
+
alt_prompt_emb_weight = alt_prompt_emb_weights[encoder_idx]
|
636 |
+
prompt_embeds2[:, i] = prompt_embeds2[:, i] * (1 - alt_prompt_emb_weight) \
|
637 |
+
+ repl_embeddings[:, self.all_placeholder_tokens.index(prompt_tokens[i])] * alt_prompt_emb_weight
|
638 |
+
repl_tokens[prompt_tokens[i]] = 1
|
639 |
+
|
640 |
+
repl_token_count = len(repl_tokens)
|
641 |
+
if np.all(np.array(alt_prompt_emb_weights) == 1):
|
642 |
+
print(f"Replaced {repl_token_count} tokens with {alt_prompt_embed_type} embeddings.")
|
643 |
+
else:
|
644 |
+
print(f"Mixed {repl_token_count} tokens with {alt_prompt_embed_type} embeddings, weight {alt_prompt_emb_weights}.")
|
645 |
+
|
646 |
+
return prompt_embeds2
|
647 |
+
|
648 |
+
|
649 |
def encode_prompt(self, prompt, negative_prompt=None,
|
650 |
placeholder_tokens_pos='append',
|
651 |
+
ablate_prompt_only_placeholders=False,
|
652 |
+
ablate_prompt_no_placeholders=False,
|
653 |
+
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
|
654 |
+
nonmix_prompt_emb_weight=0,
|
655 |
+
repeat_prompt_for_each_encoder=True,
|
656 |
device=None, verbose=False):
|
657 |
if negative_prompt is None:
|
658 |
negative_prompt = self.negative_prompt
|
|
|
661 |
device = self.device
|
662 |
|
663 |
plain_prompt = prompt
|
664 |
+
if ablate_prompt_only_placeholders:
|
665 |
+
prompt = self.updated_tokens_str
|
666 |
+
else:
|
667 |
+
prompt = self.update_prompt(prompt, placeholder_tokens_pos=placeholder_tokens_pos,
|
668 |
+
repeat_prompt_for_each_encoder=repeat_prompt_for_each_encoder,
|
669 |
+
use_null_placeholders=ablate_prompt_no_placeholders)
|
670 |
+
|
671 |
if verbose:
|
672 |
print(f"Subject prompt:\n{prompt}")
|
673 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
674 |
# For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
|
675 |
# So we manually move it to GPU here.
|
676 |
self.pipeline.text_encoder.to(device)
|
677 |
|
678 |
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
|
679 |
self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device)
|
680 |
+
|
681 |
+
if ablate_prompt_embed_type != 'ada':
|
682 |
+
alt_prompt_embed_type = ablate_prompt_embed_type
|
683 |
+
alt_prompt_emb_weights = (1, 1)
|
684 |
+
elif nonmix_prompt_emb_weight > 0:
|
685 |
+
alt_prompt_embed_type = 'ada-nonmix'
|
686 |
+
alt_prompt_emb_weights = (nonmix_prompt_emb_weight, nonmix_prompt_emb_weight)
|
687 |
+
else:
|
688 |
+
alt_prompt_emb_weights = (0, 0)
|
689 |
+
|
690 |
+
if sum(alt_prompt_emb_weights) > 0:
|
691 |
+
prompt_embeds_ = self.mix_ada_embs_with_other_embs(prompt, prompt_embeds_,
|
692 |
+
alt_prompt_embed_type, alt_prompt_emb_weights)
|
693 |
+
|
694 |
return prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_
|
695 |
|
696 |
# ref_img_strength is used only in the img2img pipeline.
|
697 |
+
def forward(self, noise, prompt, prompt_embeds=None, negative_prompt=None,
|
698 |
placeholder_tokens_pos='append',
|
|
|
699 |
guidance_scale=6.0, out_image_count=4,
|
700 |
+
ref_img_strength=0.8, generator=None,
|
701 |
+
ablate_prompt_only_placeholders=False,
|
702 |
+
ablate_prompt_no_placeholders=False,
|
703 |
+
ablate_prompt_embed_type='ada', # 'ada', 'ada-nonmix', 'img'
|
704 |
+
nonmix_prompt_emb_weight=0,
|
705 |
+
repeat_prompt_for_each_encoder=True,
|
706 |
+
verbose=False):
|
707 |
noise = noise.to(device=self.device, dtype=torch.float16)
|
708 |
+
if self.use_lcm:
|
709 |
+
guidance_scale = 0
|
710 |
|
711 |
if negative_prompt is None:
|
712 |
negative_prompt = self.negative_prompt
|
713 |
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
714 |
+
if prompt_embeds is None:
|
715 |
+
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \
|
716 |
+
negative_pooled_prompt_embeds_ = \
|
717 |
+
self.encode_prompt(prompt, negative_prompt,
|
718 |
+
placeholder_tokens_pos=placeholder_tokens_pos,
|
719 |
+
ablate_prompt_only_placeholders=ablate_prompt_only_placeholders,
|
720 |
+
ablate_prompt_no_placeholders=ablate_prompt_no_placeholders,
|
721 |
+
ablate_prompt_embed_type=ablate_prompt_embed_type,
|
722 |
+
nonmix_prompt_emb_weight=nonmix_prompt_emb_weight,
|
723 |
+
repeat_prompt_for_each_encoder=repeat_prompt_for_each_encoder,
|
724 |
+
device=self.device,
|
725 |
+
verbose=verbose)
|
726 |
+
else:
|
727 |
+
if len(prompt_embeds) == 2:
|
728 |
+
prompt_embeds_, negative_prompt_embeds_ = prompt_embeds
|
729 |
+
pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = None, None
|
730 |
+
elif len(prompt_embeds) == 4:
|
731 |
+
prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \
|
732 |
+
negative_pooled_prompt_embeds_ = prompt_embeds
|
733 |
+
else:
|
734 |
+
breakpoint()
|
735 |
+
|
736 |
# Repeat the prompt embeddings for all images in the batch.
|
737 |
prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
|
738 |
+
|
739 |
if negative_prompt_embeds_ is not None:
|
740 |
negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
|
741 |
|
adaface/diffusers_attn_lora_capture.py
ADDED
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import Optional, Tuple, Dict, Any
|
5 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
6 |
+
from diffusers.utils import logging, is_torch_version, deprecate
|
7 |
+
from diffusers.utils.torch_utils import fourier_filter
|
8 |
+
# UNet is a diffusers PeftAdapterMixin instance.
|
9 |
+
from diffusers.loaders.peft import PeftAdapterMixin
|
10 |
+
from peft import LoraConfig, get_peft_model
|
11 |
+
import peft.tuners.lora as peft_lora
|
12 |
+
from peft.tuners.lora.dora import DoraLinearLayer
|
13 |
+
from einops import rearrange
|
14 |
+
import math, re
|
15 |
+
import numpy as np
|
16 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
20 |
+
|
21 |
+
def dummy_func(*args, **kwargs):
|
22 |
+
pass
|
23 |
+
|
24 |
+
# Revised from RevGrad, by removing the grad negation.
|
25 |
+
class ScaleGrad(torch.autograd.Function):
|
26 |
+
@staticmethod
|
27 |
+
def forward(ctx, input_, alpha_, debug=False):
|
28 |
+
ctx.save_for_backward(alpha_, debug)
|
29 |
+
output = input_
|
30 |
+
if debug:
|
31 |
+
print(f"input: {input_.abs().mean().item()}")
|
32 |
+
return output
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def backward(ctx, grad_output): # pragma: no cover
|
36 |
+
# saved_tensors returns a tuple of tensors.
|
37 |
+
alpha_, debug = ctx.saved_tensors
|
38 |
+
if ctx.needs_input_grad[0]:
|
39 |
+
grad_output2 = grad_output * alpha_
|
40 |
+
if debug:
|
41 |
+
print(f"grad_output2: {grad_output2.abs().mean().item()}")
|
42 |
+
else:
|
43 |
+
grad_output2 = None
|
44 |
+
return grad_output2, None, None
|
45 |
+
|
46 |
+
class GradientScaler(nn.Module):
|
47 |
+
def __init__(self, alpha=1., debug=False, *args, **kwargs):
|
48 |
+
"""
|
49 |
+
A gradient scaling layer.
|
50 |
+
This layer has no parameters, and simply scales the gradient in the backward pass.
|
51 |
+
"""
|
52 |
+
super().__init__(*args, **kwargs)
|
53 |
+
|
54 |
+
self._alpha = torch.tensor(alpha, requires_grad=False)
|
55 |
+
self._debug = torch.tensor(debug, requires_grad=False)
|
56 |
+
|
57 |
+
def forward(self, input_):
|
58 |
+
_debug = self._debug if hasattr(self, '_debug') else False
|
59 |
+
return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
|
60 |
+
|
61 |
+
def gen_gradient_scaler(alpha, debug=False):
|
62 |
+
if alpha == 1:
|
63 |
+
return nn.Identity()
|
64 |
+
if alpha > 0:
|
65 |
+
return GradientScaler(alpha, debug=debug)
|
66 |
+
else:
|
67 |
+
assert alpha == 0
|
68 |
+
# Don't use lambda function here, otherwise the object can't be pickled.
|
69 |
+
return torch.detach
|
70 |
+
|
71 |
+
def split_indices_by_instance(indices, as_dict=False):
|
72 |
+
indices_B, indices_N = indices
|
73 |
+
unique_indices_B = torch.unique(indices_B)
|
74 |
+
if not as_dict:
|
75 |
+
indices_by_instance = [ (indices_B[indices_B == uib], indices_N[indices_B == uib]) for uib in unique_indices_B ]
|
76 |
+
else:
|
77 |
+
indices_by_instance = { uib.item(): indices_N[indices_B == uib] for uib in unique_indices_B }
|
78 |
+
return indices_by_instance
|
79 |
+
|
80 |
+
# If do_sum, returned emb_attns is 3D. Otherwise 4D.
|
81 |
+
# indices are applied on the first 2 dims of attn_mat.
|
82 |
+
def sel_emb_attns_by_indices(attn_mat, indices, all_token_weights=None, do_sum=True, do_mean=False):
|
83 |
+
indices_by_instance = split_indices_by_instance(indices)
|
84 |
+
|
85 |
+
# emb_attns[0]: [1, 9, 8, 64]
|
86 |
+
# 8: 8 attention heads. Last dim 64: number of image tokens.
|
87 |
+
emb_attns = [ attn_mat[inst_indices].unsqueeze(0) for inst_indices in indices_by_instance ]
|
88 |
+
if all_token_weights is not None:
|
89 |
+
# all_token_weights: [4, 77].
|
90 |
+
# token_weights_by_instance[0]: [1, 9, 1, 1].
|
91 |
+
token_weights = [ all_token_weights[inst_indices].reshape(1, -1, 1, 1) for inst_indices in indices_by_instance ]
|
92 |
+
else:
|
93 |
+
token_weights = [ 1 ] * len(indices_by_instance)
|
94 |
+
|
95 |
+
# Apply token weights.
|
96 |
+
emb_attns = [ emb_attns[i] * token_weights[i] for i in range(len(indices_by_instance)) ]
|
97 |
+
|
98 |
+
# sum among K_subj_i subj embeddings -> [1, 8, 64]
|
99 |
+
if do_sum:
|
100 |
+
emb_attns = [ emb_attns[i].sum(dim=1) for i in range(len(indices_by_instance)) ]
|
101 |
+
elif do_mean:
|
102 |
+
emb_attns = [ emb_attns[i].mean(dim=1) for i in range(len(indices_by_instance)) ]
|
103 |
+
|
104 |
+
emb_attns = torch.cat(emb_attns, dim=0)
|
105 |
+
return emb_attns
|
106 |
+
|
107 |
+
# Slow implementation equivalent to F.scaled_dot_product_attention.
|
108 |
+
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
|
109 |
+
shrink_cross_attn=False, cross_attn_shrink_factor=0.5,
|
110 |
+
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
|
111 |
+
B, L, S = query.size(0), query.size(-2), key.size(-2)
|
112 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
113 |
+
# 1: head (to be broadcasted). L: query length. S: key length.
|
114 |
+
attn_bias = torch.zeros(B, 1, L, S, device=query.device, dtype=query.dtype)
|
115 |
+
if is_causal:
|
116 |
+
assert attn_mask is None
|
117 |
+
temp_mask = torch.ones(B, 1, L, S, device=query.device, dtype=torch.bool).tril(diagonal=0)
|
118 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
119 |
+
attn_bias.to(query.dtype)
|
120 |
+
|
121 |
+
if attn_mask is not None:
|
122 |
+
if attn_mask.dtype == torch.bool:
|
123 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
124 |
+
else:
|
125 |
+
attn_bias += attn_mask
|
126 |
+
|
127 |
+
if enable_gqa:
|
128 |
+
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
129 |
+
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
130 |
+
|
131 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
132 |
+
|
133 |
+
if shrink_cross_attn:
|
134 |
+
cross_attn_scale = cross_attn_shrink_factor
|
135 |
+
else:
|
136 |
+
cross_attn_scale = 1
|
137 |
+
|
138 |
+
# attn_bias: [1, 1, 4096, 77], the same size as a single-head attn_weight.
|
139 |
+
attn_weight += attn_bias
|
140 |
+
attn_score = attn_weight
|
141 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
142 |
+
# NOTE: After scaling, the "probabilities" of the subject embeddings will sum to < 1.
|
143 |
+
# But this is intended, as we want to scale down the impact of the subject embeddings
|
144 |
+
# in the computed attention output tensors.
|
145 |
+
attn_weight = attn_weight * cross_attn_scale
|
146 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
147 |
+
output = attn_weight @ value
|
148 |
+
return output, attn_score, attn_weight
|
149 |
+
|
150 |
+
# All layers share the same attention processor instance.
|
151 |
+
class AttnProcessor_LoRA_Capture(nn.Module):
|
152 |
+
r"""
|
153 |
+
Revised from AttnProcessor2_0
|
154 |
+
"""
|
155 |
+
# lora_proj_layers is a dict of lora_layer_name -> lora_proj_layer.
|
156 |
+
def __init__(self, capture_ca_activations: bool = False, enable_lora: bool = False,
|
157 |
+
lora_uses_dora=True, lora_proj_layers=None,
|
158 |
+
lora_rank: int = 192, lora_alpha: float = 16,
|
159 |
+
cross_attn_shrink_factor: float = 0.5,
|
160 |
+
q_lora_updates_query=False, attn_proc_idx=-1):
|
161 |
+
super().__init__()
|
162 |
+
|
163 |
+
self.global_enable_lora = enable_lora
|
164 |
+
self.attn_proc_idx = attn_proc_idx
|
165 |
+
# reset_attn_cache_and_flags() sets the local (call-specific) self.enable_lora flag.
|
166 |
+
# By default, shrink_cross_attn is False. Later in layers 22, 23, 24 it will be set to True.
|
167 |
+
self.reset_attn_cache_and_flags(capture_ca_activations, False, enable_lora)
|
168 |
+
self.lora_rank = lora_rank
|
169 |
+
self.lora_alpha = lora_alpha
|
170 |
+
self.lora_scale = self.lora_alpha / self.lora_rank
|
171 |
+
self.cross_attn_shrink_factor = cross_attn_shrink_factor
|
172 |
+
self.q_lora_updates_query = q_lora_updates_query
|
173 |
+
|
174 |
+
self.to_q_lora = self.to_k_lora = self.to_v_lora = self.to_out_lora = None
|
175 |
+
if self.global_enable_lora:
|
176 |
+
for lora_layer_name, lora_proj_layer in lora_proj_layers.items():
|
177 |
+
if lora_layer_name == 'q':
|
178 |
+
self.to_q_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
179 |
+
use_dora=lora_uses_dora, lora_dropout=0.1)
|
180 |
+
elif lora_layer_name == 'k':
|
181 |
+
self.to_k_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
182 |
+
use_dora=lora_uses_dora, lora_dropout=0.1)
|
183 |
+
elif lora_layer_name == 'v':
|
184 |
+
self.to_v_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
185 |
+
use_dora=lora_uses_dora, lora_dropout=0.1)
|
186 |
+
elif lora_layer_name == 'out':
|
187 |
+
self.to_out_lora = peft_lora.Linear(lora_proj_layer, 'default', r=lora_rank, lora_alpha=lora_alpha,
|
188 |
+
use_dora=lora_uses_dora, lora_dropout=0.1)
|
189 |
+
|
190 |
+
# LoRA layers can be enabled/disabled dynamically.
|
191 |
+
def reset_attn_cache_and_flags(self, capture_ca_activations, shrink_cross_attn, enable_lora):
|
192 |
+
self.capture_ca_activations = capture_ca_activations
|
193 |
+
self.shrink_cross_attn = shrink_cross_attn
|
194 |
+
self.cached_activations = {}
|
195 |
+
# Only enable LoRA for the next call(s) if global_enable_lora is set to True.
|
196 |
+
self.enable_lora = enable_lora and self.global_enable_lora
|
197 |
+
|
198 |
+
def __call__(
|
199 |
+
self,
|
200 |
+
attn: Attention,
|
201 |
+
hidden_states: torch.Tensor,
|
202 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
203 |
+
attention_mask: Optional[torch.Tensor] = None,
|
204 |
+
temb: Optional[torch.Tensor] = None,
|
205 |
+
img_mask: Optional[torch.Tensor] = None,
|
206 |
+
subj_indices: Optional[Tuple[torch.IntTensor, torch.IntTensor]] = None,
|
207 |
+
debug: bool = False,
|
208 |
+
*args,
|
209 |
+
**kwargs,
|
210 |
+
) -> torch.Tensor:
|
211 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
212 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
213 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
214 |
+
|
215 |
+
# hidden_states: [1, 4096, 320]
|
216 |
+
residual = hidden_states
|
217 |
+
# attn.spatial_norm is None.
|
218 |
+
if attn.spatial_norm is not None:
|
219 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
220 |
+
|
221 |
+
input_ndim = hidden_states.ndim
|
222 |
+
|
223 |
+
if input_ndim == 4:
|
224 |
+
batch_size, channel, height, width = hidden_states.shape
|
225 |
+
# Collapse the spatial dimensions to a single token dimension.
|
226 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
227 |
+
|
228 |
+
batch_size, sequence_length, _ = (
|
229 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
230 |
+
)
|
231 |
+
|
232 |
+
if attention_mask is not None:
|
233 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
234 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
235 |
+
# (batch, heads, source_length, target_length)
|
236 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
237 |
+
|
238 |
+
if attn.group_norm is not None:
|
239 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
240 |
+
|
241 |
+
query = attn.to_q(hidden_states)
|
242 |
+
# NOTE: there's a inconsistency between q lora and k, v loras.
|
243 |
+
# k, v loras are directly applied to key and value (currently k, v loras are never enabled),
|
244 |
+
# while q lora is applied to query2, and we keep the query unchanged.
|
245 |
+
if self.enable_lora and self.to_q_lora is not None:
|
246 |
+
# query2 will be used in ldm/util.py:calc_elastic_matching_loss() to get more accurate
|
247 |
+
# cross attention scores between the latent images of the sc and mc instances.
|
248 |
+
query2 = self.to_q_lora(hidden_states)
|
249 |
+
# If not q_lora_updates_query, only query2 will be impacted by the LoRA layer.
|
250 |
+
# The query, and thus the attention score and attn_out, will be the same
|
251 |
+
# as the original ones.
|
252 |
+
if self.q_lora_updates_query:
|
253 |
+
query = query2
|
254 |
+
else:
|
255 |
+
query2 = query
|
256 |
+
|
257 |
+
scale = 1 / math.sqrt(query.size(-1))
|
258 |
+
|
259 |
+
is_cross_attn = (encoder_hidden_states is not None)
|
260 |
+
if (not is_cross_attn) and (img_mask is not None):
|
261 |
+
# NOTE: we assume the image is square. But this will fail if the image is not square.
|
262 |
+
# hidden_states: [BS, 4096, 320]. img_mask: [BS, 1, 64, 64]
|
263 |
+
# Scale the mask to the same size as hidden_states.
|
264 |
+
mask_size = int(math.sqrt(hidden_states.shape[-2]))
|
265 |
+
img_mask = F.interpolate(img_mask, size=(mask_size, mask_size), mode='nearest')
|
266 |
+
if (img_mask.sum(dim=(2, 3)) == 0).any():
|
267 |
+
img_mask = None
|
268 |
+
else:
|
269 |
+
# img_mask: [2, 1, 64, 64] -> [2, 4096]
|
270 |
+
img_mask = rearrange(img_mask, 'b ... -> b (...)').contiguous()
|
271 |
+
# max_neg_value = -torch.finfo(hidden_states.dtype).max
|
272 |
+
# img_mask: [2, 4096] -> [2, 1, 1, 4096]
|
273 |
+
img_mask = rearrange(img_mask.bool(), 'b j -> b () () j')
|
274 |
+
# attn_score: [16, 4096, 4096]. img_mask will be broadcasted to [16, 4096, 4096].
|
275 |
+
# So some rows in dim 1 (e.g. [0, :, 4095]) of attn_score will be masked out (all elements in [0, :, 4095] is -inf).
|
276 |
+
# But not all elements in [0, 4095, :] is -inf. Since the softmax is done along dim 2, this is fine.
|
277 |
+
# attn_score.masked_fill_(~img_mask, max_neg_value)
|
278 |
+
# NOTE: If there's an attention mask, it will be replaced by img_mask.
|
279 |
+
attention_mask = img_mask
|
280 |
+
|
281 |
+
if encoder_hidden_states is None:
|
282 |
+
encoder_hidden_states = hidden_states
|
283 |
+
elif attn.norm_cross:
|
284 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
285 |
+
|
286 |
+
if self.enable_lora and self.to_k_lora is not None:
|
287 |
+
key = self.to_k_lora(encoder_hidden_states)
|
288 |
+
else:
|
289 |
+
key = attn.to_k(encoder_hidden_states)
|
290 |
+
|
291 |
+
if self.enable_lora and self.to_v_lora is not None:
|
292 |
+
value = self.to_v_lora(encoder_hidden_states)
|
293 |
+
else:
|
294 |
+
value = attn.to_v(encoder_hidden_states)
|
295 |
+
|
296 |
+
if attn.norm_q is not None:
|
297 |
+
query = attn.norm_q(query)
|
298 |
+
query2 = attn.norm_q(query2)
|
299 |
+
if attn.norm_k is not None:
|
300 |
+
key = attn.norm_k(key)
|
301 |
+
|
302 |
+
inner_dim = key.shape[-1]
|
303 |
+
head_dim = inner_dim // attn.heads
|
304 |
+
|
305 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
306 |
+
query2 = query2.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
307 |
+
|
308 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
309 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
310 |
+
|
311 |
+
if debug and self.attn_proc_idx >= 0:
|
312 |
+
breakpoint()
|
313 |
+
|
314 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
315 |
+
if is_cross_attn and (self.capture_ca_activations or self.shrink_cross_attn):
|
316 |
+
hidden_states, attn_score, attn_prob = \
|
317 |
+
scaled_dot_product_attention(query, key, value, attn_mask=attention_mask,
|
318 |
+
dropout_p=0.0, shrink_cross_attn=self.shrink_cross_attn,
|
319 |
+
cross_attn_shrink_factor=self.cross_attn_shrink_factor)
|
320 |
+
else:
|
321 |
+
# Use the faster implementation of scaled_dot_product_attention
|
322 |
+
# when not capturing the activations or suppressing the subject attention.
|
323 |
+
hidden_states = \
|
324 |
+
F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
325 |
+
attn_prob = attn_score = None
|
326 |
+
|
327 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
328 |
+
hidden_states = hidden_states.to(query.dtype)
|
329 |
+
|
330 |
+
# linear proj
|
331 |
+
if self.enable_lora and self.to_out_lora is not None:
|
332 |
+
hidden_states = self.to_out_lora(hidden_states)
|
333 |
+
else:
|
334 |
+
hidden_states = attn.to_out[0](hidden_states)
|
335 |
+
|
336 |
+
# dropout
|
337 |
+
hidden_states = attn.to_out[1](hidden_states)
|
338 |
+
|
339 |
+
if input_ndim == 4:
|
340 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
341 |
+
|
342 |
+
if attn.residual_connection:
|
343 |
+
hidden_states = hidden_states + residual
|
344 |
+
|
345 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
346 |
+
|
347 |
+
if is_cross_attn and self.capture_ca_activations:
|
348 |
+
# cached q will be used in ddpm.py:calc_comp_fg_bg_preserve_loss(), in which two qs will multiply each other.
|
349 |
+
# So sqrt(scale) will scale the product of two qs by scale.
|
350 |
+
# ANCHOR[id=attention_caching]
|
351 |
+
# query: [2, 8, 4096, 40] -> [2, 320, 4096]
|
352 |
+
self.cached_activations['q'] = \
|
353 |
+
rearrange(query, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
|
354 |
+
self.cached_activations['q2'] = \
|
355 |
+
rearrange(query2, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
|
356 |
+
self.cached_activations['k'] = \
|
357 |
+
rearrange(key, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
|
358 |
+
self.cached_activations['v'] = \
|
359 |
+
rearrange(value, 'b h n d -> b (h d) n').contiguous() * math.sqrt(scale)
|
360 |
+
# attn_prob, attn_score: [2, 8, 4096, 77]
|
361 |
+
self.cached_activations['attn'] = attn_prob
|
362 |
+
self.cached_activations['attnscore'] = attn_score
|
363 |
+
# attn_out: [b, n, h * d] -> [b, h * d, n]
|
364 |
+
# [2, 4096, 320] -> [2, 320, 4096].
|
365 |
+
self.cached_activations['attn_out'] = hidden_states.permute(0, 2, 1).contiguous()
|
366 |
+
|
367 |
+
return hidden_states
|
368 |
+
|
369 |
+
def CrossAttnUpBlock2D_forward_capture(
|
370 |
+
self,
|
371 |
+
hidden_states: torch.Tensor,
|
372 |
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
373 |
+
temb: Optional[torch.Tensor] = None,
|
374 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
375 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
376 |
+
upsample_size: Optional[int] = None,
|
377 |
+
attention_mask: Optional[torch.Tensor] = None,
|
378 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
379 |
+
) -> torch.Tensor:
|
380 |
+
if cross_attention_kwargs is not None:
|
381 |
+
if cross_attention_kwargs.get("scale", None) is not None:
|
382 |
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
383 |
+
|
384 |
+
self.cached_outfeats = {}
|
385 |
+
res_hidden_states_gradscale = getattr(self, "res_hidden_states_gradscale", 1)
|
386 |
+
capture_outfeats = getattr(self, "capture_outfeats", False)
|
387 |
+
layer_idx = 0
|
388 |
+
res_grad_scaler = gen_gradient_scaler(res_hidden_states_gradscale)
|
389 |
+
|
390 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
391 |
+
# pop res hidden states
|
392 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
393 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
394 |
+
|
395 |
+
# Scale down the magnitudes of gradients to res_hidden_states
|
396 |
+
# by res_hidden_states_gradscale=0.2, to match the scale of the cross-attn layer outputs.
|
397 |
+
res_hidden_states = res_grad_scaler(res_hidden_states)
|
398 |
+
|
399 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
400 |
+
|
401 |
+
if self.training and self.gradient_checkpointing:
|
402 |
+
def create_custom_forward(module, return_dict=None):
|
403 |
+
def custom_forward(*inputs):
|
404 |
+
if return_dict is not None:
|
405 |
+
return module(*inputs, return_dict=return_dict)
|
406 |
+
else:
|
407 |
+
return module(*inputs)
|
408 |
+
|
409 |
+
return custom_forward
|
410 |
+
|
411 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
412 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
413 |
+
create_custom_forward(resnet),
|
414 |
+
hidden_states,
|
415 |
+
temb,
|
416 |
+
**ckpt_kwargs,
|
417 |
+
)
|
418 |
+
hidden_states = attn(
|
419 |
+
hidden_states,
|
420 |
+
encoder_hidden_states=encoder_hidden_states,
|
421 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
422 |
+
attention_mask=attention_mask,
|
423 |
+
encoder_attention_mask=encoder_attention_mask,
|
424 |
+
return_dict=False,
|
425 |
+
)[0]
|
426 |
+
else:
|
427 |
+
# resnet: ResnetBlock2D instance.
|
428 |
+
#LINK diffusers.models.resnet.ResnetBlock2D
|
429 |
+
# up_blocks.3.resnets.2.conv_shortcut is a module within ResnetBlock2D,
|
430 |
+
# it's not transforming the UNet shortcut features.
|
431 |
+
hidden_states = resnet(hidden_states, temb)
|
432 |
+
hidden_states = attn(
|
433 |
+
hidden_states,
|
434 |
+
encoder_hidden_states=encoder_hidden_states,
|
435 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
436 |
+
attention_mask=attention_mask,
|
437 |
+
encoder_attention_mask=encoder_attention_mask,
|
438 |
+
return_dict=False,
|
439 |
+
)[0]
|
440 |
+
|
441 |
+
if capture_outfeats:
|
442 |
+
self.cached_outfeats[layer_idx] = hidden_states
|
443 |
+
layer_idx += 1
|
444 |
+
|
445 |
+
if self.upsamplers is not None:
|
446 |
+
for upsampler in self.upsamplers:
|
447 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
448 |
+
|
449 |
+
return hidden_states
|
450 |
+
|
451 |
+
|
452 |
+
# Adapted from ConsistentIDPipeline:set_ip_adapter().
|
453 |
+
# attn_lora_layer_names: candidates are subsets of ['q', 'k', 'v', 'out'].
|
454 |
+
def set_up_attn_processors(unet, use_attn_lora, attn_lora_layer_names=['q', 'k', 'v', 'out'],
|
455 |
+
lora_rank=192, lora_scale_down=8, cross_attn_shrink_factor=0.5,
|
456 |
+
q_lora_updates_query=False):
|
457 |
+
attn_procs = {}
|
458 |
+
attn_capture_procs = {}
|
459 |
+
unet_modules = dict(unet.named_modules())
|
460 |
+
attn_opt_modules = {}
|
461 |
+
|
462 |
+
attn_proc_idx = 0
|
463 |
+
|
464 |
+
for name, attn_proc in unet.attn_processors.items():
|
465 |
+
# Only capture the activations of the last 3 CA layers.
|
466 |
+
if not name.startswith("up_blocks.3"):
|
467 |
+
# Not the last 3 CA layers. Don't enable LoRA or capture activations.
|
468 |
+
# Then the layer falls back to the original attention mechanism.
|
469 |
+
# We still use AttnProcessor_LoRA_Capture, as it can handle img_mask.
|
470 |
+
attn_procs[name] = AttnProcessor_LoRA_Capture(
|
471 |
+
capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1)
|
472 |
+
continue
|
473 |
+
# cross_attention_dim: 768.
|
474 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
475 |
+
if cross_attention_dim is None:
|
476 |
+
# Self attention. Don't enable LoRA or capture activations.
|
477 |
+
# We replace the default attn_proc with AttnProcessor_LoRA_Capture,
|
478 |
+
# so that it can incorporate img_mask into self-attention.
|
479 |
+
attn_procs[name] = AttnProcessor_LoRA_Capture(
|
480 |
+
capture_ca_activations=False, enable_lora=False, attn_proc_idx=-1)
|
481 |
+
continue
|
482 |
+
|
483 |
+
# block_id = 3
|
484 |
+
# hidden_size: 320
|
485 |
+
# hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
486 |
+
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor' ->
|
487 |
+
# 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q'
|
488 |
+
lora_layer_dict = {}
|
489 |
+
lora_layer_dict['q'] = unet_modules[name[:-9] + "to_q"]
|
490 |
+
lora_layer_dict['k'] = unet_modules[name[:-9] + "to_k"]
|
491 |
+
lora_layer_dict['v'] = unet_modules[name[:-9] + "to_v"]
|
492 |
+
# to_out is a ModuleList(Linear, Dropout).
|
493 |
+
lora_layer_dict['out'] = unet_modules[name[:-9] + "to_out"][0]
|
494 |
+
|
495 |
+
lora_proj_layers = {}
|
496 |
+
# Only apply LoRA to the specified layers.
|
497 |
+
for lora_layer_name in attn_lora_layer_names:
|
498 |
+
lora_proj_layers[lora_layer_name] = lora_layer_dict[lora_layer_name]
|
499 |
+
|
500 |
+
attn_capture_proc = AttnProcessor_LoRA_Capture(
|
501 |
+
capture_ca_activations=True, enable_lora=use_attn_lora,
|
502 |
+
lora_uses_dora=True, lora_proj_layers=lora_proj_layers,
|
503 |
+
# LoRA up is initialized to 0. So no need to worry that the LoRA output may be too large.
|
504 |
+
lora_rank=lora_rank, lora_alpha=lora_rank // lora_scale_down,
|
505 |
+
cross_attn_shrink_factor=cross_attn_shrink_factor,
|
506 |
+
q_lora_updates_query=q_lora_updates_query, attn_proc_idx=attn_proc_idx)
|
507 |
+
|
508 |
+
attn_proc_idx += 1
|
509 |
+
# attn_procs has to use the original names.
|
510 |
+
attn_procs[name] = attn_capture_proc
|
511 |
+
# ModuleDict doesn't allow "." in the key.
|
512 |
+
name = name.replace(".", "_")
|
513 |
+
attn_capture_procs[name] = attn_capture_proc
|
514 |
+
|
515 |
+
if use_attn_lora:
|
516 |
+
for subname, module in attn_capture_proc.named_modules():
|
517 |
+
if isinstance(module, peft_lora.LoraLayer):
|
518 |
+
# ModuleDict doesn't allow "." in the key.
|
519 |
+
lora_path = name + "_" + subname.replace(".", "_")
|
520 |
+
attn_opt_modules[lora_path + "_lora_A"] = module.lora_A
|
521 |
+
attn_opt_modules[lora_path + "_lora_B"] = module.lora_B
|
522 |
+
# lora_uses_dora is always True, so we don't check it here.
|
523 |
+
attn_opt_modules[lora_path + "_lora_magnitude_vector"] = module.lora_magnitude_vector
|
524 |
+
# We will manage attn adapters directly. By default, LoraLayer is an instance of BaseTunerLayer,
|
525 |
+
# so according to the code logic in diffusers/loaders/peft.py,
|
526 |
+
# they will be managed by the diffusers PeftAdapterMixin instance, through the
|
527 |
+
# enable_adapters(), and set_adapter() methods.
|
528 |
+
# Therefore, we disable these calls on module.
|
529 |
+
# disable_adapters() is a property and changing it will cause exceptions.
|
530 |
+
module.enable_adapters = dummy_func
|
531 |
+
module.set_adapter = dummy_func
|
532 |
+
|
533 |
+
unet.set_attn_processor(attn_procs)
|
534 |
+
|
535 |
+
print(f"Set up {len(attn_capture_procs)} CrossAttn processors on {attn_capture_procs.keys()}.")
|
536 |
+
print(f"Set up {len(attn_opt_modules)} attn LoRA params: {attn_opt_modules.keys()}.")
|
537 |
+
return attn_capture_procs, attn_opt_modules
|
538 |
+
|
539 |
+
# NOTE: cross-attn layers are included in the returned lora_modules.
|
540 |
+
def set_up_ffn_loras(unet, target_modules_pat, lora_uses_dora=False, lora_rank=192, lora_alpha=16):
|
541 |
+
# target_modules_pat = 'up_blocks.3.resnets.[12].conv[a-z0-9_]+'
|
542 |
+
# up_blocks.3.resnets.[1~2].conv1, conv2, conv_shortcut
|
543 |
+
# Cannot set to conv.+ as it will match added adapter module names, including
|
544 |
+
# up_blocks.3.resnets.1.conv1.base_layer, up_blocks.3.resnets.1.conv1.lora_dropout
|
545 |
+
if target_modules_pat is not None:
|
546 |
+
peft_config = LoraConfig(use_dora=lora_uses_dora, inference_mode=False, r=lora_rank,
|
547 |
+
lora_alpha=lora_alpha, lora_dropout=0.1,
|
548 |
+
target_modules=target_modules_pat)
|
549 |
+
|
550 |
+
# UNet is a diffusers PeftAdapterMixin instance. Using get_peft_model on it will
|
551 |
+
# cause weird errors. Instead, we directly use diffusers peft adapter methods.
|
552 |
+
unet.add_adapter(peft_config, "recon_loss")
|
553 |
+
unet.add_adapter(peft_config, "unet_distill")
|
554 |
+
unet.add_adapter(peft_config, "comp_distill")
|
555 |
+
unet.enable_adapters()
|
556 |
+
|
557 |
+
# lora_layers contain both the LoRA A and B matrices, as well as the original layers.
|
558 |
+
# lora_layers are used to set the flag, not used for optimization.
|
559 |
+
# lora_modules contain only the LoRA A and B matrices, so they are used for optimization.
|
560 |
+
# NOTE: lora_modules contain both ffn and cross-attn lora modules.
|
561 |
+
ffn_lora_layers = {}
|
562 |
+
ffn_opt_modules = {}
|
563 |
+
for name, module in unet.named_modules():
|
564 |
+
if isinstance(module, peft_lora.LoraLayer):
|
565 |
+
# We don't want to include cross-attn layers in ffn_lora_layers.
|
566 |
+
if target_modules_pat is not None and re.search(target_modules_pat, name):
|
567 |
+
ffn_lora_layers[name] = module
|
568 |
+
# ModuleDict doesn't allow "." in the key.
|
569 |
+
name = name.replace(".", "_")
|
570 |
+
# Since ModuleDict doesn't allow "." in the key, we manually collect
|
571 |
+
# the LoRA matrices in each module.
|
572 |
+
# NOTE: We cannot put every sub-module of module into lora_modules,
|
573 |
+
# as base_layer is also a sub-module of module, which we shouldn't optimize.
|
574 |
+
# Each value in ffn_opt_modules is a ModuleDict:
|
575 |
+
'''
|
576 |
+
(Pdb) ffn_opt_modules['up_blocks_3_resnets_1_conv1_lora_A']
|
577 |
+
ModuleDict(
|
578 |
+
(unet_distill): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
579 |
+
(recon_loss): Conv2d(640, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
580 |
+
)
|
581 |
+
'''
|
582 |
+
ffn_opt_modules[name + "_lora_A"] = module.lora_A
|
583 |
+
ffn_opt_modules[name + "_lora_B"] = module.lora_B
|
584 |
+
if lora_uses_dora:
|
585 |
+
ffn_opt_modules[name + "_lora_magnitude_vector"] = module.lora_magnitude_vector
|
586 |
+
|
587 |
+
print(f"Set up {len(ffn_lora_layers)} FFN LoRA layers: {ffn_lora_layers.keys()}.")
|
588 |
+
print(f"Set up {len(ffn_opt_modules)} FFN LoRA params: {ffn_opt_modules.keys()}.")
|
589 |
+
|
590 |
+
return ffn_lora_layers, ffn_opt_modules
|
591 |
+
|
592 |
+
def set_lora_and_capture_flags(unet, unet_lora_modules, attn_capture_procs,
|
593 |
+
outfeat_capture_blocks, res_hidden_states_gradscale_blocks,
|
594 |
+
use_attn_lora, use_ffn_lora, ffn_lora_adapter_name, capture_ca_activations,
|
595 |
+
shrink_cross_attn, res_hidden_states_gradscale):
|
596 |
+
# For attn capture procs, capture_ca_activations and use_attn_lora are set in reset_attn_cache_and_flags().
|
597 |
+
for attn_capture_proc in attn_capture_procs:
|
598 |
+
attn_capture_proc.reset_attn_cache_and_flags(capture_ca_activations, shrink_cross_attn, enable_lora=use_attn_lora)
|
599 |
+
# outfeat_capture_blocks only contains the last up block, up_blocks[3].
|
600 |
+
# It contains 3 FFN layers. We want to capture their output features.
|
601 |
+
for block in outfeat_capture_blocks:
|
602 |
+
block.capture_outfeats = capture_ca_activations
|
603 |
+
|
604 |
+
for block in res_hidden_states_gradscale_blocks:
|
605 |
+
block.res_hidden_states_gradscale = res_hidden_states_gradscale
|
606 |
+
|
607 |
+
if not use_ffn_lora:
|
608 |
+
unet.disable_adapters()
|
609 |
+
else:
|
610 |
+
# ffn_lora_adapter_name: 'recon_loss', 'unet_distill', 'comp_distill'.
|
611 |
+
if ffn_lora_adapter_name is not None:
|
612 |
+
unet.set_adapter(ffn_lora_adapter_name)
|
613 |
+
# NOTE: Don't forget to enable_adapters().
|
614 |
+
# The adapters are not enabled by default after set_adapter().
|
615 |
+
unet.enable_adapters()
|
616 |
+
else:
|
617 |
+
breakpoint()
|
618 |
+
|
619 |
+
# During training, disable_adapters() and set_adapter() will set all/inactive adapters with requires_grad=False,
|
620 |
+
# which might cause issues during DDP training.
|
621 |
+
# So we restore them to requires_grad=True.
|
622 |
+
# During test, unet_lora_modules will be passed as None, so this block will be skipped.
|
623 |
+
if unet_lora_modules is not None:
|
624 |
+
for param in unet_lora_modules.parameters():
|
625 |
+
param.requires_grad = True
|
626 |
+
|
627 |
+
def get_captured_activations(capture_ca_activations, attn_capture_procs, outfeat_capture_blocks,
|
628 |
+
captured_layer_indices=[22, 23, 24], out_dtype=torch.float32):
|
629 |
+
captured_activations = { k: {} for k in ('outfeat', 'attn', 'attnscore',
|
630 |
+
'q', 'q2', 'k', 'v', 'attn_out') }
|
631 |
+
|
632 |
+
if not capture_ca_activations:
|
633 |
+
return captured_activations
|
634 |
+
|
635 |
+
all_cached_outfeats = []
|
636 |
+
for block in outfeat_capture_blocks:
|
637 |
+
all_cached_outfeats.append(block.cached_outfeats)
|
638 |
+
# Clear the capture flag and cached outfeats.
|
639 |
+
block.cached_outfeats = {}
|
640 |
+
block.capture_outfeats = False
|
641 |
+
|
642 |
+
for layer_idx in captured_layer_indices:
|
643 |
+
# Subtract 22 to ca_layer_idx to match the layer index in up_blocks[3].cached_outfeats.
|
644 |
+
# 23, 24 -> 1, 2 (!! not 0, 1 !!)
|
645 |
+
internal_idx = layer_idx - 22
|
646 |
+
for k in captured_activations.keys():
|
647 |
+
if k == 'outfeat':
|
648 |
+
# Currently we only capture one block, up_blocks.3. So we hard-code the index 0.
|
649 |
+
captured_activations['outfeat'][layer_idx] = all_cached_outfeats[0][internal_idx].to(out_dtype)
|
650 |
+
else:
|
651 |
+
# internal_idx is the index of layers in up_blocks.3.
|
652 |
+
# Layers 22, 23 and 24 map to 0, 1 and 2.
|
653 |
+
cached_activations = attn_capture_procs[internal_idx].cached_activations
|
654 |
+
captured_activations[k][layer_idx] = cached_activations[k].to(out_dtype)
|
655 |
+
|
656 |
+
return captured_activations
|
adaface/face_id_to_ada_prompt.py
CHANGED
@@ -53,6 +53,8 @@ class FaceID2AdaPrompt(nn.Module):
|
|
53 |
self.text_to_image_prompt_encoder = None
|
54 |
self.tokenizer = None
|
55 |
self.dtype = kwargs.get('dtype', torch.float16)
|
|
|
|
|
56 |
|
57 |
# Load Img2Ada SubjectBasisGenerator.
|
58 |
self.subject_string = kwargs.get('subject_string', 'z')
|
@@ -73,12 +75,16 @@ class FaceID2AdaPrompt(nn.Module):
|
|
73 |
|
74 |
self.use_clip_embs = False
|
75 |
self.do_contrast_clip_embs_on_bg_features = False
|
|
|
|
|
|
|
|
|
76 |
# num_id_vecs is the output embeddings of the ID2ImgPrompt module.
|
77 |
# If there's no static image suffix embeddings, then num_id_vecs is also
|
78 |
# the number of ada embeddings returned by the subject basis generator.
|
79 |
# num_id_vecs will be set in each derived class.
|
80 |
self.num_static_img_suffix_embs = kwargs.get('num_static_img_suffix_embs', 0)
|
81 |
-
print(f'{self.name} Adaface uses {self.num_id_vecs} ID image embeddings
|
82 |
|
83 |
self.id_img_prompt_max_length = 77
|
84 |
self.face_id_dim = 512
|
@@ -87,36 +93,35 @@ class FaceID2AdaPrompt(nn.Module):
|
|
87 |
self.clip_embedding_dim = 1024
|
88 |
self.output_dim = 768
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
def load_id2img_learnable_modules(self, id2img_learnable_modules_state_dict_list):
|
94 |
-
id2img_prompt_encoder_learnable_modules = self.get_id2img_learnable_modules()
|
95 |
-
for module, state_dict in zip(id2img_prompt_encoder_learnable_modules, id2img_learnable_modules_state_dict_list):
|
96 |
-
module.load_state_dict(state_dict)
|
97 |
-
print(f'{len(id2img_prompt_encoder_learnable_modules)} ID2ImgPrompt encoder modules loaded.')
|
98 |
-
|
99 |
-
# init_subj_basis_generator() can only be called after the derived class is initialized,
|
100 |
-
# when self.num_id_vecs, self.num_static_img_suffix_embs and self.clip_embedding_dim have been set.
|
101 |
-
def init_subj_basis_generator(self):
|
102 |
self.subj_basis_generator = \
|
103 |
-
SubjBasisGenerator(
|
|
|
104 |
num_static_img_suffix_embs = self.num_static_img_suffix_embs,
|
105 |
bg_image_embedding_dim = self.clip_embedding_dim,
|
106 |
output_dim = self.output_dim,
|
107 |
placeholder_is_bg = False,
|
108 |
-
prompt2token_proj_grad_scale = 1,
|
109 |
bg_prompt_translator_has_to_out_proj=False)
|
110 |
|
111 |
def load_adaface_ckpt(self, adaface_ckpt_path):
|
112 |
-
|
|
|
|
|
|
|
113 |
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
|
114 |
if self.subject_string not in string_to_subj_basis_generator_dict:
|
115 |
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
|
116 |
breakpoint()
|
117 |
|
118 |
ckpt_subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
|
119 |
-
ckpt_subj_basis_generator
|
|
|
|
|
|
|
|
|
|
|
120 |
# Since we directly use the subject basis generator object from the ckpt,
|
121 |
# fixing the number of static image suffix embeddings is much simpler.
|
122 |
# Otherwise if we want to load the subject basis generator from its state_dict,
|
@@ -129,7 +134,7 @@ class FaceID2AdaPrompt(nn.Module):
|
|
129 |
ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.num_static_img_suffix_embs, img_prompt_dim=self.output_dim)
|
130 |
# Fix missing variables in old ckpt.
|
131 |
ckpt_subj_basis_generator.patch_old_subj_basis_generator_ckpt()
|
132 |
-
|
133 |
self.subj_basis_generator.extend_prompt2token_proj_attention(\
|
134 |
ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
|
135 |
ret = self.subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict(), strict=False)
|
@@ -155,6 +160,11 @@ class FaceID2AdaPrompt(nn.Module):
|
|
155 |
|
156 |
self.subj_basis_generator.freeze_prompt2token_proj()
|
157 |
|
|
|
|
|
|
|
|
|
|
|
158 |
@torch.no_grad()
|
159 |
def get_clip_neg_features(self, BS):
|
160 |
if self.clip_neg_features is None:
|
@@ -220,6 +230,7 @@ class FaceID2AdaPrompt(nn.Module):
|
|
220 |
image_obj, _, _ = pad_image_obj_to_square(image_obj)
|
221 |
image_np = np.array(image_obj.resize(size, Image.NEAREST))
|
222 |
face_info = self.face_app.get(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
|
|
|
223 |
if len(face_info) > 0:
|
224 |
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
|
225 |
# id_emb: [512,]
|
@@ -487,12 +498,20 @@ class FaceID2AdaPrompt(nn.Module):
|
|
487 |
# avg_at_stage == ada_prompt_emb usually produces the worst results.
|
488 |
# avg_at_stage == id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better.
|
489 |
# p_dropout and return_zero_embs_for_dropped_encoders are only used by Joint_FaceID2AdaPrompt.
|
|
|
490 |
def generate_adaface_embeddings(self, image_paths, face_id_embs=None, img_prompt_embs=None,
|
491 |
p_dropout=0,
|
492 |
return_zero_embs_for_dropped_encoders=True,
|
493 |
avg_at_stage='id_emb', # id_emb, img_prompt_emb, or None.
|
494 |
perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
|
495 |
-
perturb_std=0, enable_static_img_suffix_embs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
if (avg_at_stage is None) or avg_at_stage.lower() == 'none':
|
497 |
img_prompt_avg_at_stage = None
|
498 |
else:
|
@@ -509,7 +528,7 @@ class FaceID2AdaPrompt(nn.Module):
|
|
509 |
id_batch_size = len(image_paths)
|
510 |
else:
|
511 |
id_batch_size = 1
|
512 |
-
|
513 |
# faceid_embeds: [BS, 512] is a batch of extracted face analysis embeddings. NOT used later.
|
514 |
# NOTE: If face_id_embs, image_paths and image_objs are all None,
|
515 |
# then get_img_prompt_embs() generates random faceid_embeds/img_prompt_embs,
|
@@ -532,7 +551,7 @@ class FaceID2AdaPrompt(nn.Module):
|
|
532 |
verbose=True)
|
533 |
|
534 |
if face_image_count == 0:
|
535 |
-
return None
|
536 |
|
537 |
# No matter whether avg_at_stage is id_emb or img_prompt_emb, we average img_prompt_embs.
|
538 |
elif avg_at_stage is not None and avg_at_stage.lower() != 'none':
|
@@ -545,19 +564,27 @@ class FaceID2AdaPrompt(nn.Module):
|
|
545 |
out_id_embs_cfg_scale=self.out_id_embs_cfg_scale,
|
546 |
is_face=True,
|
547 |
enable_static_img_suffix_embs=enable_static_img_suffix_embs)
|
|
|
|
|
|
|
|
|
548 |
# During training, img_prompt_avg_at_stage is None, and BS >= 1.
|
549 |
# During inference, img_prompt_avg_at_stage is 'id_emb' or 'img_prompt_emb', and BS == 1.
|
550 |
if img_prompt_avg_at_stage is not None:
|
551 |
# adaface_subj_embs: [1, 16, 768] -> [16, 768]
|
552 |
adaface_subj_embs = adaface_subj_embs.squeeze(0)
|
553 |
|
554 |
-
return adaface_subj_embs
|
555 |
|
556 |
class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
557 |
-
|
558 |
-
|
559 |
-
|
|
|
|
|
|
|
560 |
|
|
|
561 |
super().__init__(*args, **kwargs)
|
562 |
|
563 |
self.clip_image_encoder = CLIPVisionModelWithMask.from_pretrained('openai/clip-vit-large-patch14')
|
@@ -583,7 +610,7 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
583 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
584 |
providers=['CPUExecutionProvider'])
|
585 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
586 |
-
print(f'Face encoder loaded on CPU.')
|
587 |
|
588 |
self.text_to_image_prompt_encoder = CLIPTextModelWrapper.from_pretrained(
|
589 |
'models/arc2face', subfolder="encoder",
|
@@ -594,21 +621,54 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
594 |
if self.out_id_embs_cfg_scale == -1:
|
595 |
self.out_id_embs_cfg_scale = 1
|
596 |
#### Arc2Face pipeline specific configs ####
|
597 |
-
self.gen_neg_img_prompt
|
598 |
# bg CLIP features are used by the bg subject basis generator.
|
599 |
-
self.use_clip_embs
|
600 |
self.do_contrast_clip_embs_on_bg_features = True
|
601 |
# self.num_static_img_suffix_embs is initialized in the parent class.
|
602 |
-
self.id_img_prompt_max_length
|
603 |
-
self.clip_embedding_dim
|
604 |
|
605 |
-
self.
|
606 |
if self.adaface_ckpt_path is not None:
|
607 |
self.load_adaface_ckpt(self.adaface_ckpt_path)
|
608 |
|
609 |
-
|
610 |
-
|
|
|
|
|
|
|
|
|
611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
612 |
# Arc2Face_ID2AdaPrompt never uses clip_features or called_for_neg_img_prompt.
|
613 |
def map_init_id_to_img_prompt_embs(self, init_id_embs,
|
614 |
clip_features=None,
|
@@ -656,16 +716,17 @@ class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
656 |
# [N, 22, 768] -> [N, 16, 768]
|
657 |
return prompt_embeds[:, 4:20]
|
658 |
|
659 |
-
def get_id2img_learnable_modules(self):
|
660 |
-
return [ self.text_to_image_prompt_encoder ]
|
661 |
-
|
662 |
# ConsistentID_ID2AdaPrompt is just a wrapper of ConsistentIDPipeline, so it's not an nn.Module.
|
663 |
class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
664 |
def __init__(self, pipe=None, base_model_path="models/sd15-dste8-vae.safetensors",
|
665 |
*args, **kwargs):
|
666 |
-
|
667 |
-
self.num_id_vecs = 4
|
668 |
-
|
669 |
super().__init__(*args, **kwargs)
|
670 |
if pipe is None:
|
671 |
# The base_model_path is kind of arbitrary, as the UNet and VAE in the model
|
@@ -712,13 +773,47 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
712 |
self.clip_embedding_dim = 1280
|
713 |
self.s_scale = 1.0
|
714 |
self.shortcut = False
|
715 |
-
|
716 |
-
self.
|
717 |
if self.adaface_ckpt_path is not None:
|
718 |
self.load_adaface_ckpt(self.adaface_ckpt_path)
|
719 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
720 |
print(f"{self.name} ada prompt encoder initialized, "
|
721 |
-
f"ID vecs: {self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
722 |
|
723 |
def map_init_id_to_img_prompt_embs(self, init_id_embs,
|
724 |
clip_features=None,
|
@@ -757,26 +852,30 @@ class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
|
757 |
|
758 |
return global_id_embeds
|
759 |
|
760 |
-
def get_id2img_learnable_modules(self):
|
761 |
-
return [ self.image_proj_model ]
|
762 |
-
|
763 |
# A wrapper for combining multiple FaceID2AdaPrompt instances.
|
764 |
class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
765 |
def __init__(self, adaface_encoder_types, adaface_ckpt_paths,
|
766 |
out_id_embs_cfg_scales=None, enabled_encoders=None,
|
767 |
*args, **kwargs):
|
768 |
self.name = 'jointIDs'
|
|
|
769 |
assert len(adaface_encoder_types) > 0, "adaface_encoder_types should not be empty."
|
770 |
-
|
771 |
-
|
|
|
|
|
772 |
for encoder_type in adaface_encoder_types ]
|
773 |
-
self.
|
|
|
|
|
|
|
774 |
# super() sets self.is_training.
|
775 |
super().__init__(*args, **kwargs)
|
776 |
|
777 |
self.num_sub_encoders = len(adaface_encoder_types)
|
778 |
self.id2ada_prompt_encoders = nn.ModuleList()
|
779 |
self.encoders_num_static_img_suffix_embs = []
|
|
|
780 |
|
781 |
# TODO: apply adaface_encoder_cfg_scales to influence the final prompt embeddings.
|
782 |
# Now they are just placeholders.
|
@@ -786,10 +885,12 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
786 |
self.out_id_embs_cfg_scales = [-1] * self.num_sub_encoders
|
787 |
else:
|
788 |
# Do not normalize the weights, and just use them as is.
|
789 |
-
self.out_id_embs_cfg_scales = out_id_embs_cfg_scales
|
790 |
|
791 |
# Note we don't pass the adaface_ckpt_paths to the base class, but instead,
|
792 |
# we load them once and for all in self.load_adaface_ckpt().
|
|
|
|
|
793 |
for i, encoder_type in enumerate(adaface_encoder_types):
|
794 |
kwargs['out_id_embs_cfg_scale'] = self.out_id_embs_cfg_scales[i]
|
795 |
if encoder_type == 'arc2face':
|
@@ -798,8 +899,10 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
798 |
encoder = ConsistentID_ID2AdaPrompt(*args, **kwargs)
|
799 |
else:
|
800 |
breakpoint()
|
|
|
801 |
self.id2ada_prompt_encoders.append(encoder)
|
802 |
self.encoders_num_static_img_suffix_embs.append(encoder.num_static_img_suffix_embs)
|
|
|
803 |
|
804 |
self.num_static_img_suffix_embs = sum(self.encoders_num_static_img_suffix_embs)
|
805 |
# No need to set gen_neg_img_prompt, as we don't access it in this class, but rather
|
@@ -829,7 +932,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
829 |
self.load_adaface_ckpt(adaface_ckpt_paths)
|
830 |
|
831 |
print(f"{self.name} ada prompt encoder initialized with {self.num_sub_encoders} sub-encoders. "
|
832 |
-
f"ID vecs: {self.
|
833 |
|
834 |
if enabled_encoders is not None:
|
835 |
self.are_encoders_enabled = \
|
@@ -845,79 +948,79 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
845 |
else:
|
846 |
self.are_encoders_enabled = \
|
847 |
torch.tensor([True] * self.num_sub_encoders)
|
848 |
-
|
849 |
-
for i, encoder in enumerate(self.id2ada_prompt_encoders):
|
850 |
-
if not (self.is_training and self.are_encoders_enabled[i]):
|
851 |
-
for param in encoder.parameters():
|
852 |
-
param.requires_grad = False
|
853 |
-
else:
|
854 |
-
for param in encoder.parameters():
|
855 |
-
param.requires_grad = True
|
856 |
-
|
857 |
def load_adaface_ckpt(self, adaface_ckpt_paths):
|
858 |
-
# If only one adaface ckpt path is provided, then we assume it's the ckpt of the Joint_FaceID2AdaPrompt,
|
859 |
-
# so we dereference the list to get the actual path and load the subj_basis_generators of all adaface encoders.
|
860 |
if isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
|
861 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
862 |
adaface_ckpt_paths = adaface_ckpt_paths[0]
|
863 |
-
|
864 |
-
if isinstance(adaface_ckpt_paths, str):
|
865 |
-
# This is only applicable to newest ckpts of Joint_FaceID2AdaPrompt, where
|
866 |
-
# the ckpt_subj_basis_generator is an nn.ModuleList of multiple subj_basis_generators.
|
867 |
-
# Therefore, no need to patch missing variables.
|
868 |
-
ckpt = torch.load(adaface_ckpt_paths, map_location='cpu')
|
869 |
-
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
|
870 |
-
if self.subject_string not in string_to_subj_basis_generator_dict:
|
871 |
-
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
|
872 |
breakpoint()
|
873 |
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
|
|
|
|
|
|
|
|
|
|
879 |
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
img_prompt_dim=self.output_dim)
|
885 |
-
|
886 |
-
if subj_basis_generator.prompt2token_proj_attention_multipliers \
|
887 |
-
== [1] * 12:
|
888 |
-
subj_basis_generator.extend_prompt2token_proj_attention(\
|
889 |
-
ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
|
890 |
-
elif subj_basis_generator.prompt2token_proj_attention_multipliers \
|
891 |
-
!= ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers:
|
892 |
-
raise ValueError("Inconsistent prompt2token_proj_attention_multipliers.")
|
893 |
-
|
894 |
-
assert subj_basis_generator.prompt2token_proj_attention_multipliers \
|
895 |
-
== ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, \
|
896 |
-
"Inconsistent prompt2token_proj_attention_multipliers."
|
897 |
-
subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict())
|
898 |
-
|
899 |
-
# extend_prompt2token_proj_attention_multiplier is an integer >= 1.
|
900 |
-
# TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
|
901 |
-
# If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict,
|
902 |
-
# extend subj_basis_generator again.
|
903 |
-
if self.extend_prompt2token_proj_attention_multiplier > 1:
|
904 |
-
# During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt.
|
905 |
-
# During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1.
|
906 |
-
# During inference, prompt2token_proj_ext_attention_perturb_ratio == 0.
|
907 |
-
subj_basis_generator.extend_prompt2token_proj_attention(\
|
908 |
-
None, -1, -1, self.extend_prompt2token_proj_attention_multiplier,
|
909 |
-
perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio)
|
910 |
-
|
911 |
-
subj_basis_generator.freeze_prompt2token_proj()
|
912 |
-
|
913 |
-
print(f"{adaface_ckpt_paths}: {len(self.subj_basis_generator)} subj_basis_generators loaded for {self.name}.")
|
914 |
-
|
915 |
-
elif isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
|
916 |
-
for i, ckpt_path in enumerate(adaface_ckpt_paths):
|
917 |
-
self.id2ada_prompt_encoders[i].load_adaface_ckpt(ckpt_path)
|
918 |
-
else:
|
919 |
breakpoint()
|
920 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
921 |
def extract_init_id_embeds_from_images(self, *args, **kwargs):
|
922 |
total_faceless_img_count = 0
|
923 |
all_id_embs = []
|
@@ -1055,7 +1158,7 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1055 |
|
1056 |
N_ID = self.encoders_num_id_vecs[i]
|
1057 |
if all_pos_prompt_embs[i] is None:
|
1058 |
-
# Both pos_prompt_embs and neg_prompt_embs have N_ID ==
|
1059 |
all_pos_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
|
1060 |
if all_neg_prompt_embs[i] is None:
|
1061 |
all_neg_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
|
@@ -1077,6 +1180,13 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1077 |
# So its .device is the device of its parameters.
|
1078 |
device = self.id2ada_prompt_encoders[0].clip_image_encoder.device
|
1079 |
is_emb_averaged = kwargs.get('avg_at_stage', None) is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1080 |
BS = -1
|
1081 |
|
1082 |
if face_id_embs is not None:
|
@@ -1084,13 +1194,17 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1084 |
all_face_id_embs = face_id_embs.split(self.face_id_dims, dim=1)
|
1085 |
else:
|
1086 |
all_face_id_embs = [None] * self.num_sub_encoders
|
|
|
1087 |
if img_prompt_embs is not None:
|
1088 |
BS = img_prompt_embs.shape[0] if BS == -1 else BS
|
1089 |
-
if img_prompt_embs.shape[1] != self.
|
1090 |
breakpoint()
|
1091 |
-
all_img_prompt_embs = img_prompt_embs.split(self.
|
|
|
1092 |
else:
|
1093 |
all_img_prompt_embs = [None] * self.num_sub_encoders
|
|
|
|
|
1094 |
if image_paths is not None:
|
1095 |
BS = len(image_paths) if BS == -1 else BS
|
1096 |
if BS == -1:
|
@@ -1116,23 +1230,29 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1116 |
self.curr_are_encoders_enabled = are_encoders_enabled
|
1117 |
all_adaface_subj_embs = []
|
1118 |
num_available_id_vecs = 0
|
|
|
1119 |
|
1120 |
for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders):
|
1121 |
if not are_encoders_enabled[i]:
|
1122 |
adaface_subj_embs = None
|
1123 |
-
print(f"Encoder {id2ada_prompt_encoder.name} is
|
|
|
|
|
1124 |
else:
|
|
|
1125 |
# ddpm.embedding_manager.train() -> id2ada_prompt_encoder.train() -> each sub-enconder's train().
|
1126 |
# -> each sub-enconder's subj_basis_generator.train().
|
1127 |
# Therefore grad for the following call is enabled.
|
1128 |
-
adaface_subj_embs = \
|
1129 |
id2ada_prompt_encoder.generate_adaface_embeddings(image_paths,
|
1130 |
all_face_id_embs[i],
|
1131 |
all_img_prompt_embs[i],
|
1132 |
*args, **kwargs)
|
1133 |
|
1134 |
-
|
1135 |
-
|
|
|
|
|
1136 |
if adaface_subj_embs is None:
|
1137 |
if not return_zero_embs_for_dropped_encoders:
|
1138 |
continue
|
@@ -1143,12 +1263,16 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1143 |
all_adaface_subj_embs.append(adaface_subj_embs)
|
1144 |
else:
|
1145 |
all_adaface_subj_embs.append(adaface_subj_embs)
|
|
|
|
|
1146 |
num_available_id_vecs += N_ID
|
1147 |
|
|
|
|
|
1148 |
# No faces are found in the images, so return None embeddings.
|
1149 |
# We don't want to return an all-zero embedding, which is useless.
|
1150 |
if num_available_id_vecs == 0:
|
1151 |
-
return None
|
1152 |
|
1153 |
# If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
|
1154 |
# during inference, we average across the batch dim.
|
@@ -1158,7 +1282,12 @@ class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
|
1158 |
# all_adaface_subj_embs[0]: [BS, 4, 768]. all_adaface_subj_embs[1]: [BS, 16, 768].
|
1159 |
# all_adaface_subj_embs: [BS, 20, 768].
|
1160 |
all_adaface_subj_embs = torch.cat(all_adaface_subj_embs, dim=-2)
|
1161 |
-
|
|
|
|
|
|
|
|
|
|
|
1162 |
|
1163 |
|
1164 |
'''
|
|
|
53 |
self.text_to_image_prompt_encoder = None
|
54 |
self.tokenizer = None
|
55 |
self.dtype = kwargs.get('dtype', torch.float16)
|
56 |
+
self.img2txt_dtype = kwargs.get('img2txt_dtype', torch.float16)
|
57 |
+
self.device = torch.device("cpu")
|
58 |
|
59 |
# Load Img2Ada SubjectBasisGenerator.
|
60 |
self.subject_string = kwargs.get('subject_string', 'z')
|
|
|
75 |
|
76 |
self.use_clip_embs = False
|
77 |
self.do_contrast_clip_embs_on_bg_features = False
|
78 |
+
# Override the default setting in derived classes.
|
79 |
+
if 'enable_static_img_suffix_embs' in kwargs:
|
80 |
+
self.default_enable_static_img_suffix_embs = kwargs['enable_static_img_suffix_embs']
|
81 |
+
|
82 |
# num_id_vecs is the output embeddings of the ID2ImgPrompt module.
|
83 |
# If there's no static image suffix embeddings, then num_id_vecs is also
|
84 |
# the number of ada embeddings returned by the subject basis generator.
|
85 |
# num_id_vecs will be set in each derived class.
|
86 |
self.num_static_img_suffix_embs = kwargs.get('num_static_img_suffix_embs', 0)
|
87 |
+
print(f'{self.name} Adaface uses {self.num_id_vecs} ID image embeddings + {self.num_static_img_suffix_embs} fixed image embeddings as input.')
|
88 |
|
89 |
self.id_img_prompt_max_length = 77
|
90 |
self.face_id_dim = 512
|
|
|
93 |
self.clip_embedding_dim = 1024
|
94 |
self.output_dim = 768
|
95 |
|
96 |
+
# init_img2txt_projection() can only be called after the derived class is initialized,
|
97 |
+
# when self.num_id_vecs0, self.num_static_img_suffix_embs and self.clip_embedding_dim have been set.
|
98 |
+
def init_img2txt_projection(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
self.subj_basis_generator = \
|
100 |
+
SubjBasisGenerator(dtype=self.img2txt_dtype,
|
101 |
+
num_id_vecs = self.num_id_vecs0,
|
102 |
num_static_img_suffix_embs = self.num_static_img_suffix_embs,
|
103 |
bg_image_embedding_dim = self.clip_embedding_dim,
|
104 |
output_dim = self.output_dim,
|
105 |
placeholder_is_bg = False,
|
|
|
106 |
bg_prompt_translator_has_to_out_proj=False)
|
107 |
|
108 |
def load_adaface_ckpt(self, adaface_ckpt_path):
|
109 |
+
if isinstance(adaface_ckpt_path, (list, tuple, ListConfig)):
|
110 |
+
adaface_ckpt_path = adaface_ckpt_path[0]
|
111 |
+
|
112 |
+
ckpt = torch.load(adaface_ckpt_path, map_location='cpu', weights_only=False)
|
113 |
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
|
114 |
if self.subject_string not in string_to_subj_basis_generator_dict:
|
115 |
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
|
116 |
breakpoint()
|
117 |
|
118 |
ckpt_subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
|
119 |
+
if isinstance(ckpt_subj_basis_generator, nn.ModuleList):
|
120 |
+
name2idx = { 'consistentID': 0, 'arc2face': 1 }
|
121 |
+
subj_basis_generator_idx = name2idx[self.name]
|
122 |
+
ckpt_subj_basis_generator = ckpt_subj_basis_generator[subj_basis_generator_idx]
|
123 |
+
|
124 |
+
ckpt_subj_basis_generator.N_ID = self.num_id_vecs0
|
125 |
# Since we directly use the subject basis generator object from the ckpt,
|
126 |
# fixing the number of static image suffix embeddings is much simpler.
|
127 |
# Otherwise if we want to load the subject basis generator from its state_dict,
|
|
|
134 |
ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.num_static_img_suffix_embs, img_prompt_dim=self.output_dim)
|
135 |
# Fix missing variables in old ckpt.
|
136 |
ckpt_subj_basis_generator.patch_old_subj_basis_generator_ckpt()
|
137 |
+
|
138 |
self.subj_basis_generator.extend_prompt2token_proj_attention(\
|
139 |
ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
|
140 |
ret = self.subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict(), strict=False)
|
|
|
160 |
|
161 |
self.subj_basis_generator.freeze_prompt2token_proj()
|
162 |
|
163 |
+
def set_out_id_embs_cfg_scale(self, out_id_embs_cfg_scale):
|
164 |
+
if isinstance(out_id_embs_cfg_scale, (list, tuple, ListConfig)):
|
165 |
+
out_id_embs_cfg_scale = out_id_embs_cfg_scale[0]
|
166 |
+
self.out_id_embs_cfg_scale = out_id_embs_cfg_scale
|
167 |
+
|
168 |
@torch.no_grad()
|
169 |
def get_clip_neg_features(self, BS):
|
170 |
if self.clip_neg_features is None:
|
|
|
230 |
image_obj, _, _ = pad_image_obj_to_square(image_obj)
|
231 |
image_np = np.array(image_obj.resize(size, Image.NEAREST))
|
232 |
face_info = self.face_app.get(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
|
233 |
+
|
234 |
if len(face_info) > 0:
|
235 |
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
|
236 |
# id_emb: [512,]
|
|
|
498 |
# avg_at_stage == ada_prompt_emb usually produces the worst results.
|
499 |
# avg_at_stage == id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better.
|
500 |
# p_dropout and return_zero_embs_for_dropped_encoders are only used by Joint_FaceID2AdaPrompt.
|
501 |
+
# enable_static_img_suffix_embs=None: use the default setting.
|
502 |
def generate_adaface_embeddings(self, image_paths, face_id_embs=None, img_prompt_embs=None,
|
503 |
p_dropout=0,
|
504 |
return_zero_embs_for_dropped_encoders=True,
|
505 |
avg_at_stage='id_emb', # id_emb, img_prompt_emb, or None.
|
506 |
perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
|
507 |
+
perturb_std=0, enable_static_img_suffix_embs=None):
|
508 |
+
|
509 |
+
if enable_static_img_suffix_embs is None:
|
510 |
+
enable_static_img_suffix_embs = self.default_enable_static_img_suffix_embs
|
511 |
+
|
512 |
+
lens_subj_emb_segments = [ self.num_id_vecs + enable_static_img_suffix_embs \
|
513 |
+
* self.num_static_img_suffix_embs ]
|
514 |
+
|
515 |
if (avg_at_stage is None) or avg_at_stage.lower() == 'none':
|
516 |
img_prompt_avg_at_stage = None
|
517 |
else:
|
|
|
528 |
id_batch_size = len(image_paths)
|
529 |
else:
|
530 |
id_batch_size = 1
|
531 |
+
|
532 |
# faceid_embeds: [BS, 512] is a batch of extracted face analysis embeddings. NOT used later.
|
533 |
# NOTE: If face_id_embs, image_paths and image_objs are all None,
|
534 |
# then get_img_prompt_embs() generates random faceid_embeds/img_prompt_embs,
|
|
|
551 |
verbose=True)
|
552 |
|
553 |
if face_image_count == 0:
|
554 |
+
return None, None, lens_subj_emb_segments
|
555 |
|
556 |
# No matter whether avg_at_stage is id_emb or img_prompt_emb, we average img_prompt_embs.
|
557 |
elif avg_at_stage is not None and avg_at_stage.lower() != 'none':
|
|
|
564 |
out_id_embs_cfg_scale=self.out_id_embs_cfg_scale,
|
565 |
is_face=True,
|
566 |
enable_static_img_suffix_embs=enable_static_img_suffix_embs)
|
567 |
+
|
568 |
+
if self.num_id_vecs < self.num_id_vecs0:
|
569 |
+
adaface_subj_embs = adaface_subj_embs[:, :self.num_id_vecs, :]
|
570 |
+
|
571 |
# During training, img_prompt_avg_at_stage is None, and BS >= 1.
|
572 |
# During inference, img_prompt_avg_at_stage is 'id_emb' or 'img_prompt_emb', and BS == 1.
|
573 |
if img_prompt_avg_at_stage is not None:
|
574 |
# adaface_subj_embs: [1, 16, 768] -> [16, 768]
|
575 |
adaface_subj_embs = adaface_subj_embs.squeeze(0)
|
576 |
|
577 |
+
return adaface_subj_embs, img_prompt_embs, lens_subj_emb_segments
|
578 |
|
579 |
class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt):
|
580 |
+
name = 'arc2face'
|
581 |
+
num_id_vecs0 = 16
|
582 |
+
# first 4 are kept, the rest 12 are averaged to another 4.
|
583 |
+
# Then concatenated to [8, 768].
|
584 |
+
num_id_vecs = 16
|
585 |
+
default_enable_static_img_suffix_embs = False
|
586 |
|
587 |
+
def __init__(self, *args, **kwargs):
|
588 |
super().__init__(*args, **kwargs)
|
589 |
|
590 |
self.clip_image_encoder = CLIPVisionModelWithMask.from_pretrained('openai/clip-vit-large-patch14')
|
|
|
610 |
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
611 |
providers=['CPUExecutionProvider'])
|
612 |
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
613 |
+
print(f'Arc2Face Face encoder loaded on CPU.')
|
614 |
|
615 |
self.text_to_image_prompt_encoder = CLIPTextModelWrapper.from_pretrained(
|
616 |
'models/arc2face', subfolder="encoder",
|
|
|
621 |
if self.out_id_embs_cfg_scale == -1:
|
622 |
self.out_id_embs_cfg_scale = 1
|
623 |
#### Arc2Face pipeline specific configs ####
|
624 |
+
self.gen_neg_img_prompt = False
|
625 |
# bg CLIP features are used by the bg subject basis generator.
|
626 |
+
self.use_clip_embs = True
|
627 |
self.do_contrast_clip_embs_on_bg_features = True
|
628 |
# self.num_static_img_suffix_embs is initialized in the parent class.
|
629 |
+
self.id_img_prompt_max_length = 22
|
630 |
+
self.clip_embedding_dim = 1024
|
631 |
|
632 |
+
self.init_img2txt_projection()
|
633 |
if self.adaface_ckpt_path is not None:
|
634 |
self.load_adaface_ckpt(self.adaface_ckpt_path)
|
635 |
|
636 |
+
for param in self.clip_image_encoder.parameters():
|
637 |
+
param.requires_grad = False
|
638 |
+
for param in self.text_to_image_prompt_encoder.parameters():
|
639 |
+
param.requires_grad = False
|
640 |
+
for param in self.subj_basis_generator.parameters():
|
641 |
+
param.requires_grad = self.is_training
|
642 |
|
643 |
+
print(f"{self.name} ada prompt encoder initialized, "
|
644 |
+
f"ID vecs: {self.num_id_vecs0}, static suffix: {self.num_static_img_suffix_embs}.")
|
645 |
+
|
646 |
+
def _apply(self, fn):
|
647 |
+
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
648 |
+
# A dirty hack to get the device of the model, passed from
|
649 |
+
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
650 |
+
test_tensor = torch.zeros(1) # Create a test tensor
|
651 |
+
transformed_tensor = fn(test_tensor) # Apply `fn()` to test it
|
652 |
+
device = transformed_tensor.device # Get the device of the transformed tensor
|
653 |
+
# No need to reload face_app on the same device.
|
654 |
+
if device == self.device:
|
655 |
+
return
|
656 |
+
|
657 |
+
if str(device) == 'cpu':
|
658 |
+
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
659 |
+
providers=['CPUExecutionProvider'])
|
660 |
+
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
661 |
+
else:
|
662 |
+
device_id = device.index
|
663 |
+
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface',
|
664 |
+
providers=['CUDAExecutionProvider'],
|
665 |
+
provider_options=[{"device_id": str(device_id)}])
|
666 |
+
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
667 |
+
|
668 |
+
self.device = device
|
669 |
+
print(f'Arc2Face Face encoder reloaded on {device}.')
|
670 |
+
return
|
671 |
+
|
672 |
# Arc2Face_ID2AdaPrompt never uses clip_features or called_for_neg_img_prompt.
|
673 |
def map_init_id_to_img_prompt_embs(self, init_id_embs,
|
674 |
clip_features=None,
|
|
|
716 |
# [N, 22, 768] -> [N, 16, 768]
|
717 |
return prompt_embeds[:, 4:20]
|
718 |
|
|
|
|
|
|
|
719 |
# ConsistentID_ID2AdaPrompt is just a wrapper of ConsistentIDPipeline, so it's not an nn.Module.
|
720 |
class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt):
|
721 |
+
name = 'consistentID'
|
722 |
+
num_id_vecs0 = 4
|
723 |
+
# No compression for ConsistentID.
|
724 |
+
num_id_vecs = 4
|
725 |
+
default_enable_static_img_suffix_embs = False
|
726 |
+
|
727 |
def __init__(self, pipe=None, base_model_path="models/sd15-dste8-vae.safetensors",
|
728 |
*args, **kwargs):
|
729 |
+
|
|
|
|
|
730 |
super().__init__(*args, **kwargs)
|
731 |
if pipe is None:
|
732 |
# The base_model_path is kind of arbitrary, as the UNet and VAE in the model
|
|
|
773 |
self.clip_embedding_dim = 1280
|
774 |
self.s_scale = 1.0
|
775 |
self.shortcut = False
|
776 |
+
|
777 |
+
self.init_img2txt_projection()
|
778 |
if self.adaface_ckpt_path is not None:
|
779 |
self.load_adaface_ckpt(self.adaface_ckpt_path)
|
780 |
|
781 |
+
for param in self.clip_image_encoder.parameters():
|
782 |
+
param.requires_grad = False
|
783 |
+
for param in self.image_proj_model.parameters():
|
784 |
+
param.requires_grad = False
|
785 |
+
for param in self.subj_basis_generator.parameters():
|
786 |
+
param.requires_grad = self.is_training
|
787 |
+
|
788 |
print(f"{self.name} ada prompt encoder initialized, "
|
789 |
+
f"ID vecs: {self.num_id_vecs0}, static suffix: {self.num_static_img_suffix_embs}.")
|
790 |
+
|
791 |
+
def _apply(self, fn):
|
792 |
+
super()._apply(fn) # Call the parent _apply to handle parameters and buffers
|
793 |
+
# A dirty hack to get the device of the model, passed from
|
794 |
+
# parent.model.to(self.root_device) => parent._apply(convert) => module._apply(fn)
|
795 |
+
test_tensor = torch.zeros(1) # Create a test tensor
|
796 |
+
transformed_tensor = fn(test_tensor) # Apply `fn()` to test it
|
797 |
+
device = transformed_tensor.device # Get the device of the transformed tensor
|
798 |
+
# No need to reload face_app on the same device.
|
799 |
+
if device == self.device:
|
800 |
+
return
|
801 |
+
|
802 |
+
if str(device) == 'cpu':
|
803 |
+
self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
|
804 |
+
providers=['CPUExecutionProvider'])
|
805 |
+
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
806 |
+
else:
|
807 |
+
device_id = device.index
|
808 |
+
self.face_app = FaceAnalysis(name='buffalo_l', root='models/insightface',
|
809 |
+
providers=['CUDAExecutionProvider'],
|
810 |
+
provider_options=[{"device_id": str(device_id)}])
|
811 |
+
self.face_app.prepare(ctx_id=device_id, det_size=(512, 512))
|
812 |
+
|
813 |
+
self.device = device
|
814 |
+
self.pipe.face_app = self.face_app
|
815 |
+
print(f'ConsistentID Face encoder reloaded on {device}.')
|
816 |
+
|
817 |
|
818 |
def map_init_id_to_img_prompt_embs(self, init_id_embs,
|
819 |
clip_features=None,
|
|
|
852 |
|
853 |
return global_id_embeds
|
854 |
|
|
|
|
|
|
|
855 |
# A wrapper for combining multiple FaceID2AdaPrompt instances.
|
856 |
class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt):
|
857 |
def __init__(self, adaface_encoder_types, adaface_ckpt_paths,
|
858 |
out_id_embs_cfg_scales=None, enabled_encoders=None,
|
859 |
*args, **kwargs):
|
860 |
self.name = 'jointIDs'
|
861 |
+
name2class = { 'arc2face': Arc2Face_ID2AdaPrompt, 'consistentID': ConsistentID_ID2AdaPrompt }
|
862 |
assert len(adaface_encoder_types) > 0, "adaface_encoder_types should not be empty."
|
863 |
+
adaface_encoder_types2num_id_vecs0 = { name: name2class[name].num_id_vecs0 for name in adaface_encoder_types }
|
864 |
+
adaface_encoder_types2num_id_vecs = { name: name2class[name].num_id_vecs for name in adaface_encoder_types }
|
865 |
+
# self.num_id_vecs0 is used in the parent class. So we need to initialize it here first.
|
866 |
+
self.encoders_num_id_vecs0 = [ adaface_encoder_types2num_id_vecs0[encoder_type] \
|
867 |
for encoder_type in adaface_encoder_types ]
|
868 |
+
self.encoders_num_id_vecs = [ adaface_encoder_types2num_id_vecs[encoder_type] \
|
869 |
+
for encoder_type in adaface_encoder_types ]
|
870 |
+
self.num_id_vecs0 = sum(self.encoders_num_id_vecs0)
|
871 |
+
self.num_id_vecs = sum(self.encoders_num_id_vecs)
|
872 |
# super() sets self.is_training.
|
873 |
super().__init__(*args, **kwargs)
|
874 |
|
875 |
self.num_sub_encoders = len(adaface_encoder_types)
|
876 |
self.id2ada_prompt_encoders = nn.ModuleList()
|
877 |
self.encoders_num_static_img_suffix_embs = []
|
878 |
+
self.default_enable_static_img_suffix_embs = []
|
879 |
|
880 |
# TODO: apply adaface_encoder_cfg_scales to influence the final prompt embeddings.
|
881 |
# Now they are just placeholders.
|
|
|
885 |
self.out_id_embs_cfg_scales = [-1] * self.num_sub_encoders
|
886 |
else:
|
887 |
# Do not normalize the weights, and just use them as is.
|
888 |
+
self.out_id_embs_cfg_scales = list(out_id_embs_cfg_scales)
|
889 |
|
890 |
# Note we don't pass the adaface_ckpt_paths to the base class, but instead,
|
891 |
# we load them once and for all in self.load_adaface_ckpt().
|
892 |
+
# NOTE: during inference, num_static_img_suffix_embs is fixed to be 4 for each encoder.
|
893 |
+
# But we can still disable static_img_suffix_embs by setting enable_static_img_suffix_embs to False.
|
894 |
for i, encoder_type in enumerate(adaface_encoder_types):
|
895 |
kwargs['out_id_embs_cfg_scale'] = self.out_id_embs_cfg_scales[i]
|
896 |
if encoder_type == 'arc2face':
|
|
|
899 |
encoder = ConsistentID_ID2AdaPrompt(*args, **kwargs)
|
900 |
else:
|
901 |
breakpoint()
|
902 |
+
|
903 |
self.id2ada_prompt_encoders.append(encoder)
|
904 |
self.encoders_num_static_img_suffix_embs.append(encoder.num_static_img_suffix_embs)
|
905 |
+
self.default_enable_static_img_suffix_embs.append(encoder.default_enable_static_img_suffix_embs)
|
906 |
|
907 |
self.num_static_img_suffix_embs = sum(self.encoders_num_static_img_suffix_embs)
|
908 |
# No need to set gen_neg_img_prompt, as we don't access it in this class, but rather
|
|
|
932 |
self.load_adaface_ckpt(adaface_ckpt_paths)
|
933 |
|
934 |
print(f"{self.name} ada prompt encoder initialized with {self.num_sub_encoders} sub-encoders. "
|
935 |
+
f"ID vecs: {self.num_id_vecs0}, static suffix embs: {self.num_static_img_suffix_embs}.")
|
936 |
|
937 |
if enabled_encoders is not None:
|
938 |
self.are_encoders_enabled = \
|
|
|
948 |
else:
|
949 |
self.are_encoders_enabled = \
|
950 |
torch.tensor([True] * self.num_sub_encoders)
|
951 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
952 |
def load_adaface_ckpt(self, adaface_ckpt_paths):
|
|
|
|
|
953 |
if isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)):
|
954 |
+
# If multiple adaface ckpt paths are provided, then we assume they are the
|
955 |
+
# ckpts of the sub-encoders.
|
956 |
+
if len(adaface_ckpt_paths) == self.num_sub_encoders:
|
957 |
+
for i, ckpt_path in enumerate(adaface_ckpt_paths):
|
958 |
+
self.id2ada_prompt_encoders[i].load_adaface_ckpt(ckpt_path)
|
959 |
+
return
|
960 |
+
# If only one adaface ckpt path is provided, then we assume it's the ckpt of the Joint_FaceID2AdaPrompt,
|
961 |
+
# so we dereference the list to get the actual path and load the subj_basis_generators of all adaface encoders.
|
962 |
+
elif len(adaface_ckpt_paths) == 1 and self.num_sub_encoders > 1:
|
963 |
adaface_ckpt_paths = adaface_ckpt_paths[0]
|
964 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
965 |
breakpoint()
|
966 |
|
967 |
+
adaface_ckpt_path = adaface_ckpt_paths
|
968 |
+
assert isinstance(adaface_ckpt_path, str), "adaface_ckpt_path should be a string."
|
969 |
+
# This is only applicable to newest ckpts of Joint_FaceID2AdaPrompt, where
|
970 |
+
# the ckpt_subj_basis_generator is an nn.ModuleList of multiple subj_basis_generators.
|
971 |
+
# Therefore, no need to patch missing variables.
|
972 |
+
ckpt = torch.load(adaface_ckpt_paths, map_location='cpu', weights_only=False)
|
973 |
+
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
|
974 |
+
if self.subject_string not in string_to_subj_basis_generator_dict:
|
975 |
+
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
|
976 |
+
breakpoint()
|
977 |
|
978 |
+
ckpt_subj_basis_generators = string_to_subj_basis_generator_dict[self.subject_string]
|
979 |
+
if len(ckpt_subj_basis_generators) != self.num_sub_encoders:
|
980 |
+
print(f"Number of subj_basis_generators in the ckpt ({len(ckpt_subj_basis_generators)}) "
|
981 |
+
f"doesn't match the number of adaface encoders ({self.num_sub_encoders}).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
982 |
breakpoint()
|
983 |
|
984 |
+
for i, subj_basis_generator in enumerate(self.subj_basis_generator):
|
985 |
+
ckpt_subj_basis_generator = ckpt_subj_basis_generators[i]
|
986 |
+
# Handle differences in num_static_img_suffix_embs between the current model and the ckpt.
|
987 |
+
ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.encoders_num_static_img_suffix_embs[i],
|
988 |
+
img_prompt_dim=self.output_dim)
|
989 |
+
|
990 |
+
if subj_basis_generator.prompt2token_proj_attention_multipliers \
|
991 |
+
== [1] * 12:
|
992 |
+
subj_basis_generator.extend_prompt2token_proj_attention(\
|
993 |
+
ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0)
|
994 |
+
elif subj_basis_generator.prompt2token_proj_attention_multipliers \
|
995 |
+
!= ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers:
|
996 |
+
raise ValueError("Inconsistent prompt2token_proj_attention_multipliers.")
|
997 |
+
|
998 |
+
assert subj_basis_generator.prompt2token_proj_attention_multipliers \
|
999 |
+
== ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, \
|
1000 |
+
"Inconsistent prompt2token_proj_attention_multipliers."
|
1001 |
+
subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict())
|
1002 |
+
|
1003 |
+
# extend_prompt2token_proj_attention_multiplier is an integer >= 1.
|
1004 |
+
# TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers.
|
1005 |
+
# If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict,
|
1006 |
+
# extend subj_basis_generator again.
|
1007 |
+
if self.extend_prompt2token_proj_attention_multiplier > 1:
|
1008 |
+
# During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt.
|
1009 |
+
# During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1.
|
1010 |
+
# During inference, prompt2token_proj_ext_attention_perturb_ratio == 0.
|
1011 |
+
subj_basis_generator.extend_prompt2token_proj_attention(\
|
1012 |
+
None, -1, -1, self.extend_prompt2token_proj_attention_multiplier,
|
1013 |
+
perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio)
|
1014 |
+
|
1015 |
+
subj_basis_generator.freeze_prompt2token_proj()
|
1016 |
+
|
1017 |
+
print(f"{adaface_ckpt_paths}: {len(self.subj_basis_generator)} subj_basis_generators loaded for {self.name}.")
|
1018 |
+
|
1019 |
+
def set_out_id_embs_cfg_scale(self, out_id_embs_cfg_scales):
|
1020 |
+
self.out_id_embs_cfg_scales = list(out_id_embs_cfg_scales)
|
1021 |
+
for i, out_id_embs_cfg_scale in enumerate(out_id_embs_cfg_scales):
|
1022 |
+
self.id2ada_prompt_encoders[i].set_out_id_embs_cfg_scale(out_id_embs_cfg_scale)
|
1023 |
+
|
1024 |
def extract_init_id_embeds_from_images(self, *args, **kwargs):
|
1025 |
total_faceless_img_count = 0
|
1026 |
all_id_embs = []
|
|
|
1158 |
|
1159 |
N_ID = self.encoders_num_id_vecs[i]
|
1160 |
if all_pos_prompt_embs[i] is None:
|
1161 |
+
# Both pos_prompt_embs and neg_prompt_embs have N_ID == num_id_vecs0 embeddings.
|
1162 |
all_pos_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
|
1163 |
if all_neg_prompt_embs[i] is None:
|
1164 |
all_neg_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device)
|
|
|
1180 |
# So its .device is the device of its parameters.
|
1181 |
device = self.id2ada_prompt_encoders[0].clip_image_encoder.device
|
1182 |
is_emb_averaged = kwargs.get('avg_at_stage', None) is not None
|
1183 |
+
if kwargs.get('enable_static_img_suffix_embs', None) is None:
|
1184 |
+
enable_static_img_suffix_embs = self.default_enable_static_img_suffix_embs
|
1185 |
+
else:
|
1186 |
+
enable_static_img_suffix_embs = kwargs['enable_static_img_suffix_embs']
|
1187 |
+
if isinstance(enable_static_img_suffix_embs, bool):
|
1188 |
+
enable_static_img_suffix_embs = [enable_static_img_suffix_embs] * self.num_sub_encoders
|
1189 |
+
|
1190 |
BS = -1
|
1191 |
|
1192 |
if face_id_embs is not None:
|
|
|
1194 |
all_face_id_embs = face_id_embs.split(self.face_id_dims, dim=1)
|
1195 |
else:
|
1196 |
all_face_id_embs = [None] * self.num_sub_encoders
|
1197 |
+
|
1198 |
if img_prompt_embs is not None:
|
1199 |
BS = img_prompt_embs.shape[0] if BS == -1 else BS
|
1200 |
+
if img_prompt_embs.shape[1] != self.num_id_vecs0:
|
1201 |
breakpoint()
|
1202 |
+
all_img_prompt_embs = img_prompt_embs.split(self.encoders_num_id_vecs0, dim=1)
|
1203 |
+
img_prompt_embs_provided = True
|
1204 |
else:
|
1205 |
all_img_prompt_embs = [None] * self.num_sub_encoders
|
1206 |
+
img_prompt_embs_provided = False
|
1207 |
+
|
1208 |
if image_paths is not None:
|
1209 |
BS = len(image_paths) if BS == -1 else BS
|
1210 |
if BS == -1:
|
|
|
1230 |
self.curr_are_encoders_enabled = are_encoders_enabled
|
1231 |
all_adaface_subj_embs = []
|
1232 |
num_available_id_vecs = 0
|
1233 |
+
lens_subj_emb_segments = []
|
1234 |
|
1235 |
for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders):
|
1236 |
if not are_encoders_enabled[i]:
|
1237 |
adaface_subj_embs = None
|
1238 |
+
print(f"Encoder {id2ada_prompt_encoder.name} is disabled.")
|
1239 |
+
N_ID = id2ada_prompt_encoder.num_id_vecs + enable_static_img_suffix_embs[i] \
|
1240 |
+
* id2ada_prompt_encoder.num_static_img_suffix_embs
|
1241 |
else:
|
1242 |
+
kwargs['enable_static_img_suffix_embs'] = enable_static_img_suffix_embs[i]
|
1243 |
# ddpm.embedding_manager.train() -> id2ada_prompt_encoder.train() -> each sub-enconder's train().
|
1244 |
# -> each sub-enconder's subj_basis_generator.train().
|
1245 |
# Therefore grad for the following call is enabled.
|
1246 |
+
adaface_subj_embs, img_prompt_embs, encoder_lens_subj_emb_segments = \
|
1247 |
id2ada_prompt_encoder.generate_adaface_embeddings(image_paths,
|
1248 |
all_face_id_embs[i],
|
1249 |
all_img_prompt_embs[i],
|
1250 |
*args, **kwargs)
|
1251 |
|
1252 |
+
# adaface_subj_embs: arc2face [16, 768] or consistentID [4, 768],
|
1253 |
+
# or arc2face [20, 768] or consistentID [8, 768] if enable_static_img_suffix_embs=True.
|
1254 |
+
N_ID = encoder_lens_subj_emb_segments[0]
|
1255 |
+
|
1256 |
if adaface_subj_embs is None:
|
1257 |
if not return_zero_embs_for_dropped_encoders:
|
1258 |
continue
|
|
|
1263 |
all_adaface_subj_embs.append(adaface_subj_embs)
|
1264 |
else:
|
1265 |
all_adaface_subj_embs.append(adaface_subj_embs)
|
1266 |
+
if not img_prompt_embs_provided:
|
1267 |
+
all_img_prompt_embs[i] = img_prompt_embs
|
1268 |
num_available_id_vecs += N_ID
|
1269 |
|
1270 |
+
lens_subj_emb_segments.append(N_ID)
|
1271 |
+
|
1272 |
# No faces are found in the images, so return None embeddings.
|
1273 |
# We don't want to return an all-zero embedding, which is useless.
|
1274 |
if num_available_id_vecs == 0:
|
1275 |
+
return None, [0]
|
1276 |
|
1277 |
# If id2ada_prompt_encoders are ["arc2face", "consistentID"], then
|
1278 |
# during inference, we average across the batch dim.
|
|
|
1282 |
# all_adaface_subj_embs[0]: [BS, 4, 768]. all_adaface_subj_embs[1]: [BS, 16, 768].
|
1283 |
# all_adaface_subj_embs: [BS, 20, 768].
|
1284 |
all_adaface_subj_embs = torch.cat(all_adaface_subj_embs, dim=-2)
|
1285 |
+
# Check if some of the img_prompt_embs are None.
|
1286 |
+
if None in all_img_prompt_embs:
|
1287 |
+
all_img_prompt_embs = None
|
1288 |
+
else:
|
1289 |
+
all_img_prompt_embs = torch.cat(all_img_prompt_embs, dim=-2)
|
1290 |
+
return all_adaface_subj_embs, all_img_prompt_embs, lens_subj_emb_segments
|
1291 |
|
1292 |
|
1293 |
'''
|
adaface/subj_basis_generator.py
CHANGED
@@ -9,7 +9,7 @@ import torch
|
|
9 |
from torch import nn
|
10 |
from einops import rearrange
|
11 |
from einops.layers.torch import Rearrange
|
12 |
-
from transformers import CLIPTokenizer, CLIPTextModel
|
13 |
|
14 |
from torch import einsum
|
15 |
from adaface.util import gen_gradient_scaler
|
@@ -57,7 +57,25 @@ class IP_MLPProjModel(nn.Module):
|
|
57 |
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
58 |
x = self.norm(x)
|
59 |
return x
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
# group_dim: the tensor dimension that corresponds to the multiple groups.
|
62 |
class LearnedSoftAggregate(nn.Module):
|
63 |
def __init__(self, num_feat, group_dim, keepdim=False):
|
@@ -349,23 +367,26 @@ class CrossAttention(nn.Module):
|
|
349 |
else:
|
350 |
return out
|
351 |
|
|
|
352 |
class ImgPrompt2TextPrompt(nn.Module):
|
353 |
-
def __init__(self, placeholder_is_bg, num_id_vecs,
|
|
|
354 |
super().__init__()
|
355 |
self.N_ID = num_id_vecs
|
356 |
# If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components().
|
357 |
self.N_SFX = 0
|
|
|
358 |
|
359 |
if not placeholder_is_bg:
|
360 |
-
self.
|
|
|
361 |
|
362 |
# prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
|
363 |
# prompt2token_proj is with the same architecture as the original arc2face text encoder,
|
364 |
# but retrained to do inverse mapping.
|
365 |
# To be initialized in the subclass.
|
366 |
self.prompt2token_proj = None
|
367 |
-
|
368 |
-
|
369 |
def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768):
|
370 |
self.N_SFX = num_static_img_suffix_embs
|
371 |
# We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs.
|
@@ -376,11 +397,11 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
376 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.")
|
377 |
elif self.static_img_suffix_embs.shape[1] < self.N_SFX:
|
378 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.")
|
379 |
-
self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim))
|
380 |
elif self.N_SFX > 0:
|
381 |
# self.static_img_suffix_embs.shape[1] > self.N_SFX > 0.
|
382 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.")
|
383 |
-
self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX])
|
384 |
else:
|
385 |
# self.static_img_suffix_embs.shape[1] > self.N_SFX == 0.
|
386 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.")
|
@@ -391,7 +412,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
391 |
# or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare,
|
392 |
# so we don't consider to reuse and extend a shorter static_img_suffix_embs).
|
393 |
# So we reinitialize it.
|
394 |
-
self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim))
|
395 |
else:
|
396 |
# If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance.
|
397 |
self.static_img_suffix_embs = None
|
@@ -399,9 +420,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
399 |
# Implement a separate initialization function, so that it can be called from SubjBasisGenerator
|
400 |
# after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator
|
401 |
# ckpts which were not subclassed from ImgPrompt2TextPrompt.
|
402 |
-
def initialize_text_components(self, max_prompt_length=77
|
403 |
-
num_static_img_suffix_embs=0, img_prompt_dim=768):
|
404 |
-
self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim)
|
405 |
self.max_prompt_length = max_prompt_length
|
406 |
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
407 |
# clip_text_embeddings: CLIPTextEmbeddings instance.
|
@@ -416,7 +435,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
416 |
# pad_embeddings is still on CPU. But should be moved to GPU automatically.
|
417 |
# Note: detach pad_embeddings from the computation graph, otherwise
|
418 |
# deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail.
|
419 |
-
self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach()
|
420 |
|
421 |
# image prompt space -> text prompt space.
|
422 |
# return_emb_types: a list of strings, each string is among
|
@@ -439,7 +458,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
439 |
else:
|
440 |
breakpoint()
|
441 |
else:
|
442 |
-
# len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in
|
443 |
# But list_extra_words always corresponds to the actual batch size. So we only take the first element.
|
444 |
list_extra_words = list_extra_words[:1]
|
445 |
|
@@ -466,7 +485,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
466 |
face_prompt_embs_orig_dtype = face_prompt_embs.dtype
|
467 |
face_prompt_embs = face_prompt_embs.to(self.dtype)
|
468 |
|
469 |
-
ID_END = 4
|
470 |
PAD_BEGIN = ID_END + self.N_SFX + 2
|
471 |
|
472 |
# token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
|
@@ -545,6 +564,7 @@ class ImgPrompt2TextPrompt(nn.Module):
|
|
545 |
class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
546 |
def __init__(
|
547 |
self,
|
|
|
548 |
# number of cross-attention heads of the bg prompt translator.
|
549 |
# Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14:
|
550 |
# https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
@@ -553,22 +573,25 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
553 |
# or number of background input identity vectors (no matter the subject is face or not).
|
554 |
# 257: 257 CLIP tokens.
|
555 |
num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 },
|
|
|
556 |
num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4.
|
557 |
num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings.
|
558 |
bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above.
|
559 |
obj_embedding_dim=384, # DINO object feature dimension for objects.
|
560 |
output_dim=768, # CLIP text embedding input dimension.
|
|
|
561 |
placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens.
|
562 |
-
prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj.
|
563 |
learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
|
564 |
-
bg_prompt_translator_has_to_out_proj:
|
565 |
):
|
566 |
|
567 |
# If not placeholder_is_bg, then it calls initialize_text_components() in the superclass.
|
568 |
-
super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs,
|
569 |
-
num_static_img_suffix_embs=num_static_img_suffix_embs,
|
|
|
570 |
|
571 |
self.placeholder_is_bg = placeholder_is_bg
|
|
|
572 |
self.num_out_embs = self.N_ID + self.N_SFX
|
573 |
self.output_dim = output_dim
|
574 |
# num_nonface_in_id_vecs should be the number of core ID embs, 16.
|
@@ -586,14 +609,18 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
586 |
# self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings).
|
587 |
# If self.placeholder_is_bg: prompt2token_proj is set to None.
|
588 |
# Use an attention dropout of 0.2 to increase robustness.
|
589 |
-
|
590 |
-
self.prompt2token_proj
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
|
|
|
|
|
|
|
|
597 |
self.freeze_prompt2token_proj()
|
598 |
|
599 |
# These multipliers are relative to the original CLIPTextModel.
|
@@ -631,6 +658,9 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
631 |
identity_to_out=identity_to_out,
|
632 |
out_has_skip=out_has_skip)
|
633 |
|
|
|
|
|
|
|
634 |
self.output_scale = output_dim ** -0.5
|
635 |
|
636 |
'''
|
@@ -686,21 +716,20 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
686 |
hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
|
687 |
|
688 |
# faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space.
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
ada_id_embs = self.prompt2token_proj_grad_scaler(ada_id_embs)
|
704 |
elif raw_id_embs is not None:
|
705 |
# id_embs: [BS, 384] -> [BS, 18, 768].
|
706 |
# obj_proj_in is expected to project the DINO object features to
|
@@ -726,14 +755,15 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
726 |
|
727 |
adaface_out_embs = id_embs_out * self.output_scale # * 0.036
|
728 |
else:
|
729 |
-
|
|
|
730 |
# If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings.
|
731 |
if out_id_embs_cfg_scale != 1:
|
732 |
-
# pad_embeddings: [77, 768] -> [16, 768] -> [1, 16, 768].
|
733 |
# NOTE: Never do cfg on static image suffix embeddings.
|
734 |
# So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX,
|
735 |
# even if enable_static_img_suffix_embs=True.
|
736 |
-
pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).to(ada_id_embs.device)
|
737 |
adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \
|
738 |
+ pad_embeddings * (1 - out_id_embs_cfg_scale)
|
739 |
|
@@ -812,37 +842,37 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
812 |
# Only applicable to fg basis generator.
|
813 |
if self.placeholder_is_bg:
|
814 |
return
|
815 |
-
|
816 |
-
# Then we don't have to check whether it's for subj or bg.
|
817 |
-
if self.prompt2token_proj_grad_scale == 0:
|
818 |
-
frozen_components_name = 'all'
|
819 |
-
frozen_param_set = self.prompt2token_proj.named_parameters()
|
820 |
-
else:
|
821 |
-
frozen_components_name = 'token_pos_embeddings'
|
822 |
-
frozen_param_set = self.prompt2token_proj.text_model.embeddings.named_parameters()
|
823 |
-
|
824 |
if self.prompt2token_proj is not None:
|
825 |
frozen_param_names = []
|
826 |
-
for param_name, param in
|
827 |
if param.requires_grad:
|
828 |
param.requires_grad = False
|
829 |
frozen_param_names.append(param_name)
|
830 |
# If param is already frozen, then no need to freeze it again.
|
831 |
-
print(f"{
|
832 |
#print(f"Frozen parameters:\n{frozen_param_names}")
|
833 |
|
834 |
def patch_old_subj_basis_generator_ckpt(self):
|
835 |
# Fix compatability with the previous version.
|
836 |
if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
|
837 |
self.bg_prompt_translator_has_to_out_proj = False
|
838 |
-
if not hasattr(self, 'num_out_embs'):
|
839 |
-
self.num_out_embs = -1
|
840 |
if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'):
|
841 |
self.N_ID = self.num_id_vecs
|
|
|
|
|
|
|
842 |
if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'):
|
843 |
self.num_nonface_in_id_vecs = self.N_ID
|
844 |
if not hasattr(self, 'dtype'):
|
845 |
-
self.dtype = torch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
846 |
|
847 |
if self.placeholder_is_bg:
|
848 |
if not hasattr(self, 'pos_embs') or self.pos_embs is None:
|
@@ -860,6 +890,14 @@ class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
|
860 |
num_static_img_suffix_embs=self.N_SFX,
|
861 |
img_prompt_dim=self.output_dim)
|
862 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
863 |
def __repr__(self):
|
864 |
type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
|
865 |
|
|
|
9 |
from torch import nn
|
10 |
from einops import rearrange
|
11 |
from einops.layers.torch import Rearrange
|
12 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
13 |
|
14 |
from torch import einsum
|
15 |
from adaface.util import gen_gradient_scaler
|
|
|
57 |
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
58 |
x = self.norm(x)
|
59 |
return x
|
60 |
+
|
61 |
+
class LayerwiseMLPProjWithSkip(nn.Module):
|
62 |
+
def __init__(self, id_embeddings_dim=768, num_layers=16, dim_mult=2):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.proj = nn.Sequential(
|
66 |
+
nn.Linear(id_embeddings_dim, id_embeddings_dim*dim_mult*num_layers),
|
67 |
+
Rearrange('b n (l d) -> b n l d', l=num_layers, d=id_embeddings_dim*dim_mult),
|
68 |
+
nn.GELU(),
|
69 |
+
nn.Linear(id_embeddings_dim*dim_mult, id_embeddings_dim),
|
70 |
+
)
|
71 |
+
self.norm = nn.LayerNorm(id_embeddings_dim)
|
72 |
+
|
73 |
+
def forward(self, id_embeds):
|
74 |
+
# B N D -> B N L D + B N L D -> B N L D
|
75 |
+
x = self.proj(id_embeds) + id_embeds.unsqueeze(1)
|
76 |
+
x = self.norm(x)
|
77 |
+
return x
|
78 |
+
|
79 |
# group_dim: the tensor dimension that corresponds to the multiple groups.
|
80 |
class LearnedSoftAggregate(nn.Module):
|
81 |
def __init__(self, num_feat, group_dim, keepdim=False):
|
|
|
367 |
else:
|
368 |
return out
|
369 |
|
370 |
+
|
371 |
class ImgPrompt2TextPrompt(nn.Module):
|
372 |
+
def __init__(self, placeholder_is_bg, num_id_vecs, num_static_img_suffix_embs,
|
373 |
+
max_prompt_length=77, img_prompt_dim=768, dtype=torch.float16):
|
374 |
super().__init__()
|
375 |
self.N_ID = num_id_vecs
|
376 |
# If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components().
|
377 |
self.N_SFX = 0
|
378 |
+
self.dtype = dtype
|
379 |
|
380 |
if not placeholder_is_bg:
|
381 |
+
self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim)
|
382 |
+
self.initialize_text_components(max_prompt_length)
|
383 |
|
384 |
# prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
|
385 |
# prompt2token_proj is with the same architecture as the original arc2face text encoder,
|
386 |
# but retrained to do inverse mapping.
|
387 |
# To be initialized in the subclass.
|
388 |
self.prompt2token_proj = None
|
389 |
+
|
|
|
390 |
def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768):
|
391 |
self.N_SFX = num_static_img_suffix_embs
|
392 |
# We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs.
|
|
|
397 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.")
|
398 |
elif self.static_img_suffix_embs.shape[1] < self.N_SFX:
|
399 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.")
|
400 |
+
self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim, dtype=self.dtype))
|
401 |
elif self.N_SFX > 0:
|
402 |
# self.static_img_suffix_embs.shape[1] > self.N_SFX > 0.
|
403 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.")
|
404 |
+
self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX].to(dtype=self.dtype))
|
405 |
else:
|
406 |
# self.static_img_suffix_embs.shape[1] > self.N_SFX == 0.
|
407 |
print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.")
|
|
|
412 |
# or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare,
|
413 |
# so we don't consider to reuse and extend a shorter static_img_suffix_embs).
|
414 |
# So we reinitialize it.
|
415 |
+
self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim, dtype=self.dtype))
|
416 |
else:
|
417 |
# If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance.
|
418 |
self.static_img_suffix_embs = None
|
|
|
420 |
# Implement a separate initialization function, so that it can be called from SubjBasisGenerator
|
421 |
# after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator
|
422 |
# ckpts which were not subclassed from ImgPrompt2TextPrompt.
|
423 |
+
def initialize_text_components(self, max_prompt_length=77):
|
|
|
|
|
424 |
self.max_prompt_length = max_prompt_length
|
425 |
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
426 |
# clip_text_embeddings: CLIPTextEmbeddings instance.
|
|
|
435 |
# pad_embeddings is still on CPU. But should be moved to GPU automatically.
|
436 |
# Note: detach pad_embeddings from the computation graph, otherwise
|
437 |
# deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail.
|
438 |
+
self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach().to(self.dtype)
|
439 |
|
440 |
# image prompt space -> text prompt space.
|
441 |
# return_emb_types: a list of strings, each string is among
|
|
|
458 |
else:
|
459 |
breakpoint()
|
460 |
else:
|
461 |
+
# len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_feat_distill_on_comp_prompt.
|
462 |
# But list_extra_words always corresponds to the actual batch size. So we only take the first element.
|
463 |
list_extra_words = list_extra_words[:1]
|
464 |
|
|
|
485 |
face_prompt_embs_orig_dtype = face_prompt_embs.dtype
|
486 |
face_prompt_embs = face_prompt_embs.to(self.dtype)
|
487 |
|
488 |
+
ID_END = 4 + self.N_ID
|
489 |
PAD_BEGIN = ID_END + self.N_SFX + 2
|
490 |
|
491 |
# token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
|
|
|
564 |
class SubjBasisGenerator(ImgPrompt2TextPrompt):
|
565 |
def __init__(
|
566 |
self,
|
567 |
+
dtype=torch.float16,
|
568 |
# number of cross-attention heads of the bg prompt translator.
|
569 |
# Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14:
|
570 |
# https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
|
|
573 |
# or number of background input identity vectors (no matter the subject is face or not).
|
574 |
# 257: 257 CLIP tokens.
|
575 |
num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 },
|
576 |
+
num_ca_layers=16,
|
577 |
num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4.
|
578 |
num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings.
|
579 |
bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above.
|
580 |
obj_embedding_dim=384, # DINO object feature dimension for objects.
|
581 |
output_dim=768, # CLIP text embedding input dimension.
|
582 |
+
use_layerwise_proj: bool = False, # Whether to use layerwise projection.
|
583 |
placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens.
|
|
|
584 |
learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
|
585 |
+
bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
|
586 |
):
|
587 |
|
588 |
# If not placeholder_is_bg, then it calls initialize_text_components() in the superclass.
|
589 |
+
super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs,
|
590 |
+
num_static_img_suffix_embs=num_static_img_suffix_embs,
|
591 |
+
max_prompt_length=77, img_prompt_dim=output_dim, dtype=dtype)
|
592 |
|
593 |
self.placeholder_is_bg = placeholder_is_bg
|
594 |
+
self.num_ca_layers = num_ca_layers
|
595 |
self.num_out_embs = self.N_ID + self.N_SFX
|
596 |
self.output_dim = output_dim
|
597 |
# num_nonface_in_id_vecs should be the number of core ID embs, 16.
|
|
|
609 |
# self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings).
|
610 |
# If self.placeholder_is_bg: prompt2token_proj is set to None.
|
611 |
# Use an attention dropout of 0.2 to increase robustness.
|
612 |
+
self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
|
613 |
+
self.prompt2token_proj.to(dtype=self.dtype)
|
614 |
+
|
615 |
+
if use_layerwise_proj:
|
616 |
+
# MLPProjWithSkip: MLP with skip connection.
|
617 |
+
# [BS, 4, 768] -> [BS, 16, 4, 768]. Extra 16: 16 layers.
|
618 |
+
self.layerwise_proj = LayerwiseMLPProjWithSkip(output_dim, dim_mult=2)
|
619 |
+
else:
|
620 |
+
self.layerwise_proj = nn.Identity() #Rearrange('b n d -> b l n d', l=16)
|
621 |
+
|
622 |
+
print(f"Subj prompt2token_proj initialized.")
|
623 |
+
# Only freeze token and positional embeddings of the original CLIPTextModel.
|
624 |
self.freeze_prompt2token_proj()
|
625 |
|
626 |
# These multipliers are relative to the original CLIPTextModel.
|
|
|
658 |
identity_to_out=identity_to_out,
|
659 |
out_has_skip=out_has_skip)
|
660 |
|
661 |
+
if self.dtype == torch.float16:
|
662 |
+
self.prompt_translator = self.prompt_translator.half()
|
663 |
+
|
664 |
self.output_scale = output_dim ** -0.5
|
665 |
|
666 |
'''
|
|
|
716 |
hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
|
717 |
|
718 |
# faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space.
|
719 |
+
# inverse_img_prompt_embs() applies self.prompt2token_proj to faceid2img_prompt_embs.
|
720 |
+
# If list_extra_words is not None, then ada_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
|
721 |
+
# and (at most) two extra words in adaface_prompt_embs, without BOS and EOS.
|
722 |
+
# If list_extra_words is None, then ada_id_embs: [BS, 16, 768], the 16 identity tokens in adaface_prompt_embs.
|
723 |
+
# hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
|
724 |
+
# ada_id_embs: [BS, 16, 768].
|
725 |
+
# return_emb_types: a list of strings, each string is among
|
726 |
+
# ['full', 'core', 'full_pad', 'full_half_pad'].
|
727 |
+
ada_id_embs, = \
|
728 |
+
self.inverse_img_prompt_embs(faceid2img_prompt_embs,
|
729 |
+
list_extra_words=None,
|
730 |
+
return_emb_types=['core'],
|
731 |
+
hidden_state_layer_weights=hidden_state_layer_weights,
|
732 |
+
enable_static_img_suffix_embs=enable_static_img_suffix_embs)
|
|
|
733 |
elif raw_id_embs is not None:
|
734 |
# id_embs: [BS, 384] -> [BS, 18, 768].
|
735 |
# obj_proj_in is expected to project the DINO object features to
|
|
|
755 |
|
756 |
adaface_out_embs = id_embs_out * self.output_scale # * 0.036
|
757 |
else:
|
758 |
+
# [BS, 16, 768] -> [BS, layers=16, tokens=16, 768]
|
759 |
+
adaface_out_embs = self.layerwise_proj(ada_id_embs)
|
760 |
# If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings.
|
761 |
if out_id_embs_cfg_scale != 1:
|
762 |
+
# pad_embeddings: [77, 768] -> [16, 768] -> [1, 1, 16, 768].
|
763 |
# NOTE: Never do cfg on static image suffix embeddings.
|
764 |
# So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX,
|
765 |
# even if enable_static_img_suffix_embs=True.
|
766 |
+
pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).unsqueeze(1).to(ada_id_embs.device)
|
767 |
adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \
|
768 |
+ pad_embeddings * (1 - out_id_embs_cfg_scale)
|
769 |
|
|
|
842 |
# Only applicable to fg basis generator.
|
843 |
if self.placeholder_is_bg:
|
844 |
return
|
845 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
846 |
if self.prompt2token_proj is not None:
|
847 |
frozen_param_names = []
|
848 |
+
for param_name, param in self.prompt2token_proj.text_model.embeddings.named_parameters():
|
849 |
if param.requires_grad:
|
850 |
param.requires_grad = False
|
851 |
frozen_param_names.append(param_name)
|
852 |
# If param is already frozen, then no need to freeze it again.
|
853 |
+
print(f"{len(frozen_param_names)} params of token_pos_embeddings in Subj prompt2token_proj is frozen.")
|
854 |
#print(f"Frozen parameters:\n{frozen_param_names}")
|
855 |
|
856 |
def patch_old_subj_basis_generator_ckpt(self):
|
857 |
# Fix compatability with the previous version.
|
858 |
if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
|
859 |
self.bg_prompt_translator_has_to_out_proj = False
|
|
|
|
|
860 |
if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'):
|
861 |
self.N_ID = self.num_id_vecs
|
862 |
+
# Update the number of output embeddings.
|
863 |
+
self.num_out_embs = self.N_ID + self.N_SFX
|
864 |
+
|
865 |
if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'):
|
866 |
self.num_nonface_in_id_vecs = self.N_ID
|
867 |
if not hasattr(self, 'dtype'):
|
868 |
+
self.dtype = torch.float16
|
869 |
+
if not self.placeholder_is_bg:
|
870 |
+
self.prompt2token_proj.to(dtype=self.dtype)
|
871 |
+
else:
|
872 |
+
self.prompt_translator.half()
|
873 |
+
|
874 |
+
if not hasattr(self, 'num_ca_layers'):
|
875 |
+
self.num_ca_layers = 16
|
876 |
|
877 |
if self.placeholder_is_bg:
|
878 |
if not hasattr(self, 'pos_embs') or self.pos_embs is None:
|
|
|
890 |
num_static_img_suffix_embs=self.N_SFX,
|
891 |
img_prompt_dim=self.output_dim)
|
892 |
|
893 |
+
if not hasattr(self, 'use_layerwise_proj'):
|
894 |
+
self.use_layerwise_proj = False
|
895 |
+
if not hasattr(self, 'layerwise_proj'):
|
896 |
+
if self.use_layerwise_proj:
|
897 |
+
self.layerwise_proj = LayerwiseMLPProjWithSkip(self.output_dim, dim_mult=2)
|
898 |
+
else:
|
899 |
+
self.layerwise_proj = nn.Identity()
|
900 |
+
|
901 |
def __repr__(self):
|
902 |
type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
|
903 |
|
adaface/unet_teachers.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import torch
|
|
|
2 |
import numpy as np
|
3 |
-
import pytorch_lightning as pl
|
4 |
from diffusers import UNet2DConditionModel
|
5 |
from adaface.util import UNetEnsemble, create_consistentid_pipeline
|
6 |
from diffusers import UNet2DConditionModel
|
@@ -12,9 +12,9 @@ def create_unet_teacher(teacher_type, device='cpu', **kwargs):
|
|
12 |
teacher_type = teacher_type[0]
|
13 |
|
14 |
if teacher_type == "arc2face":
|
15 |
-
|
16 |
elif teacher_type == "unet_ensemble":
|
17 |
-
# unet, extra_unet_dirpaths and
|
18 |
# Even if we distill from unet_ensemble, we still need to load arc2face for generating
|
19 |
# arc2face embeddings.
|
20 |
# The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet,
|
@@ -22,20 +22,24 @@ def create_unet_teacher(teacher_type, device='cpu', **kwargs):
|
|
22 |
# However, since the __call__ method of the ddpm unet takes different formats of params,
|
23 |
# for simplicity, we still use the diffusers unet.
|
24 |
# unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU.
|
25 |
-
|
26 |
elif teacher_type == "consistentID":
|
27 |
-
|
28 |
elif teacher_type == "simple_unet":
|
29 |
-
|
30 |
# Since we've dereferenced the list if it has only one element,
|
31 |
# this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher.
|
32 |
elif isinstance(teacher_type, (tuple, list, ListConfig)):
|
33 |
# teacher_type is a list of teacher types. So it's UNetEnsembleTeacher.
|
34 |
-
|
35 |
else:
|
36 |
raise NotImplementedError(f"Teacher type {teacher_type} not implemented.")
|
37 |
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
def __init__(self, **kwargs):
|
40 |
super().__init__()
|
41 |
self.name = None
|
@@ -56,9 +60,10 @@ class UNetTeacher(pl.LightningModule):
|
|
56 |
# to be initialized, which will unnecessarily complicate the code.
|
57 |
# noise: the initial noise for the first iteration.
|
58 |
# t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
|
59 |
-
#
|
60 |
-
def forward(self, ddpm_model, x_start, noise, t, teacher_context,
|
61 |
-
num_denoising_steps=1,
|
|
|
62 |
assert num_denoising_steps <= 10
|
63 |
|
64 |
if self.p_uses_cfg > 0:
|
@@ -71,27 +76,22 @@ class UNetTeacher(pl.LightningModule):
|
|
71 |
|
72 |
if self.uses_cfg:
|
73 |
print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
else:
|
75 |
self.cfg_scale = 1
|
76 |
print("Teacher does not use CFG.")
|
77 |
|
78 |
-
# If
|
79 |
-
#
|
80 |
-
#
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
for teacher_context_i in teacher_context:
|
85 |
-
pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
|
86 |
-
if pos_context.shape[0] != x_start.shape[0]:
|
87 |
-
breakpoint()
|
88 |
-
teacher_pos_contexts.append(pos_context)
|
89 |
-
teacher_context = teacher_pos_contexts
|
90 |
-
else:
|
91 |
-
pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
|
92 |
-
if pos_context.shape[0] != x_start.shape[0]:
|
93 |
-
breakpoint()
|
94 |
-
teacher_context = pos_context
|
95 |
else:
|
96 |
# p_uses_cfg = 0. Never use CFG.
|
97 |
self.uses_cfg = False
|
@@ -102,15 +102,21 @@ class UNetTeacher(pl.LightningModule):
|
|
102 |
# in case someday we want to switch from CFG to non-CFG during runtime.
|
103 |
self.cfg_scale = 1
|
104 |
|
|
|
105 |
if self.name == 'unet_ensemble':
|
106 |
# teacher_context is a list of teacher contexts.
|
107 |
for teacher_context_i in teacher_context:
|
108 |
-
if teacher_context_i.shape[0] != x_start.shape[0] *
|
109 |
breakpoint()
|
110 |
else:
|
111 |
-
if teacher_context.shape[0] != x_start.shape[0] *
|
112 |
breakpoint()
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
114 |
# Initially, x_starts only contains the original x_start.
|
115 |
x_starts = [ x_start ]
|
116 |
noises = [ noise ]
|
@@ -125,24 +131,35 @@ class UNetTeacher(pl.LightningModule):
|
|
125 |
# sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise
|
126 |
x_noisy = ddpm_model.q_sample(x_start, t, noise)
|
127 |
|
128 |
-
if self.uses_cfg:
|
129 |
x_noisy2 = x_noisy.repeat(2, 1, 1, 1)
|
130 |
t2 = t.repeat(2)
|
131 |
else:
|
132 |
x_noisy2 = x_noisy
|
133 |
-
t2
|
134 |
|
135 |
# If do_arc2face_distill, then pos_context is [BS=6, 21, 768].
|
136 |
noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context,
|
137 |
return_dict=False)[0]
|
138 |
if self.uses_cfg and self.cfg_scale > 1:
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1)
|
141 |
|
142 |
-
# sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise
|
143 |
-
pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred)
|
144 |
noise_preds.append(noise_pred)
|
145 |
-
|
|
|
146 |
# The predicted x0 is used as the x_start for the next denoising step.
|
147 |
x_starts.append(pred_x0)
|
148 |
|
@@ -157,20 +174,43 @@ class UNetTeacher(pl.LightningModule):
|
|
157 |
# of the current timestep.
|
158 |
t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3))
|
159 |
t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3))
|
|
|
|
|
160 |
earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb
|
161 |
earlier_timesteps = earlier_timesteps.long()
|
|
|
162 |
|
163 |
-
if
|
164 |
-
# If
|
165 |
earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0])
|
|
|
166 |
|
167 |
# earlier_timesteps = ts[i+1] < ts[i].
|
168 |
ts.append(earlier_timesteps)
|
169 |
-
|
170 |
-
noise = torch.randn_like(pred_x0)
|
171 |
noises.append(noise)
|
172 |
|
173 |
return noise_preds, x_starts, noises, ts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
class Arc2FaceTeacher(UNetTeacher):
|
176 |
def __init__(self, **kwargs):
|
@@ -185,11 +225,11 @@ class Arc2FaceTeacher(UNetTeacher):
|
|
185 |
self.cfg_scale_range = [1, 1]
|
186 |
|
187 |
class UNetEnsembleTeacher(UNetTeacher):
|
188 |
-
#
|
189 |
-
def __init__(self, unets, unet_types, extra_unet_dirpaths,
|
190 |
super().__init__(**kwargs)
|
191 |
self.name = "unet_ensemble"
|
192 |
-
self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths,
|
193 |
|
194 |
class ConsistentIDTeacher(UNetTeacher):
|
195 |
def __init__(self, base_model_path="models/sd15-dste8-vae.safetensors", **kwargs):
|
@@ -199,12 +239,9 @@ class ConsistentIDTeacher(UNetTeacher):
|
|
199 |
# In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module.
|
200 |
# We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU.
|
201 |
# Instead, we have to initialize it to GPU directly.
|
202 |
-
|
203 |
-
# Compatible with the UNetTeacher interface.
|
204 |
-
self.unet = pipe.unet
|
205 |
-
# Release VAE and text_encoder to save memory. UNet is still needed for denoising
|
206 |
# (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet).
|
207 |
-
|
208 |
|
209 |
# We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher.
|
210 |
# Note p_uses_cfg=0.5 will also be passed in in kwargs.
|
|
|
1 |
import torch
|
2 |
+
from torch import nn
|
3 |
import numpy as np
|
|
|
4 |
from diffusers import UNet2DConditionModel
|
5 |
from adaface.util import UNetEnsemble, create_consistentid_pipeline
|
6 |
from diffusers import UNet2DConditionModel
|
|
|
12 |
teacher_type = teacher_type[0]
|
13 |
|
14 |
if teacher_type == "arc2face":
|
15 |
+
teacher = Arc2FaceTeacher(**kwargs)
|
16 |
elif teacher_type == "unet_ensemble":
|
17 |
+
# unet, extra_unet_dirpaths and unet_weights_in_ensemble are passed in kwargs.
|
18 |
# Even if we distill from unet_ensemble, we still need to load arc2face for generating
|
19 |
# arc2face embeddings.
|
20 |
# The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet,
|
|
|
22 |
# However, since the __call__ method of the ddpm unet takes different formats of params,
|
23 |
# for simplicity, we still use the diffusers unet.
|
24 |
# unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU.
|
25 |
+
teacher = UNetEnsembleTeacher(device=device, **kwargs)
|
26 |
elif teacher_type == "consistentID":
|
27 |
+
teacher = ConsistentIDTeacher(**kwargs)
|
28 |
elif teacher_type == "simple_unet":
|
29 |
+
teacher = SimpleUNetTeacher(**kwargs)
|
30 |
# Since we've dereferenced the list if it has only one element,
|
31 |
# this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher.
|
32 |
elif isinstance(teacher_type, (tuple, list, ListConfig)):
|
33 |
# teacher_type is a list of teacher types. So it's UNetEnsembleTeacher.
|
34 |
+
teacher = UNetEnsembleTeacher(unet_types=teacher_type, device=device, **kwargs)
|
35 |
else:
|
36 |
raise NotImplementedError(f"Teacher type {teacher_type} not implemented.")
|
37 |
|
38 |
+
for param in teacher.parameters():
|
39 |
+
param.requires_grad = False
|
40 |
+
return teacher
|
41 |
+
|
42 |
+
class UNetTeacher(nn.Module):
|
43 |
def __init__(self, **kwargs):
|
44 |
super().__init__()
|
45 |
self.name = None
|
|
|
60 |
# to be initialized, which will unnecessarily complicate the code.
|
61 |
# noise: the initial noise for the first iteration.
|
62 |
# t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
|
63 |
+
# same_t_noise_across_instances: when sampling t and noise, use the same t and noise for all instances.
|
64 |
+
def forward(self, ddpm_model, x_start, noise, t, teacher_context, negative_context=None,
|
65 |
+
num_denoising_steps=1, same_t_noise_across_instances=False,
|
66 |
+
global_t_lb=0, global_t_ub=1000):
|
67 |
assert num_denoising_steps <= 10
|
68 |
|
69 |
if self.p_uses_cfg > 0:
|
|
|
76 |
|
77 |
if self.uses_cfg:
|
78 |
print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
|
79 |
+
if negative_context is not None:
|
80 |
+
negative_context = negative_context[:1].repeat(x_start.shape[0], 1, 1)
|
81 |
+
|
82 |
+
# if negative_context is None, then teacher_context is a combination of
|
83 |
+
# (one or multiple if unet_ensemble) pos_context and neg_context.
|
84 |
+
# If negative_context is not None, then teacher_context is only pos_context.
|
85 |
else:
|
86 |
self.cfg_scale = 1
|
87 |
print("Teacher does not use CFG.")
|
88 |
|
89 |
+
# If negative_context is None, then teacher_context is a combination of
|
90 |
+
# (one or multiple if unet_ensemble) pos_context and neg_context.
|
91 |
+
# Since not uses_cfg, we only need pos_context.
|
92 |
+
# If negative_context is not None, then teacher_context is only pos_context.
|
93 |
+
if negative_context is None:
|
94 |
+
teacher_context = self.extract_pos_context(teacher_context, x_start.shape[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
else:
|
96 |
# p_uses_cfg = 0. Never use CFG.
|
97 |
self.uses_cfg = False
|
|
|
102 |
# in case someday we want to switch from CFG to non-CFG during runtime.
|
103 |
self.cfg_scale = 1
|
104 |
|
105 |
+
is_context_doubled = 2 if (self.uses_cfg and negative_context is None) else 1
|
106 |
if self.name == 'unet_ensemble':
|
107 |
# teacher_context is a list of teacher contexts.
|
108 |
for teacher_context_i in teacher_context:
|
109 |
+
if teacher_context_i.shape[0] != x_start.shape[0] * is_context_doubled:
|
110 |
breakpoint()
|
111 |
else:
|
112 |
+
if teacher_context.shape[0] != x_start.shape[0] * is_context_doubled:
|
113 |
breakpoint()
|
114 |
+
|
115 |
+
if same_t_noise_across_instances:
|
116 |
+
# If same_t_noise_across_instances, we use the same t and noise for all instances.
|
117 |
+
t = t[0].repeat(x_start.shape[0])
|
118 |
+
noise = noise[:1].repeat(x_start.shape[0], 1, 1, 1)
|
119 |
+
|
120 |
# Initially, x_starts only contains the original x_start.
|
121 |
x_starts = [ x_start ]
|
122 |
noises = [ noise ]
|
|
|
131 |
# sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise
|
132 |
x_noisy = ddpm_model.q_sample(x_start, t, noise)
|
133 |
|
134 |
+
if self.uses_cfg and self.cfg_scale > 1 and negative_context is None:
|
135 |
x_noisy2 = x_noisy.repeat(2, 1, 1, 1)
|
136 |
t2 = t.repeat(2)
|
137 |
else:
|
138 |
x_noisy2 = x_noisy
|
139 |
+
t2 = t
|
140 |
|
141 |
# If do_arc2face_distill, then pos_context is [BS=6, 21, 768].
|
142 |
noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context,
|
143 |
return_dict=False)[0]
|
144 |
if self.uses_cfg and self.cfg_scale > 1:
|
145 |
+
if negative_context is None:
|
146 |
+
pos_noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0)
|
147 |
+
else:
|
148 |
+
# If negative_context is not None, then teacher_context is only pos_context.
|
149 |
+
pos_noise_pred = noise_pred
|
150 |
+
with torch.no_grad():
|
151 |
+
if self.name == 'unet_ensemble':
|
152 |
+
neg_noise_pred = self.unet.unets[0](sample=x_noisy, timestep=t,
|
153 |
+
encoder_hidden_states=negative_context, return_dict=False)[0]
|
154 |
+
else:
|
155 |
+
neg_noise_pred = self.unet(sample=x_noisy, timestep=t,
|
156 |
+
encoder_hidden_states=negative_context, return_dict=False)[0]
|
157 |
+
|
158 |
noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1)
|
159 |
|
|
|
|
|
160 |
noise_preds.append(noise_pred)
|
161 |
+
# sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise
|
162 |
+
pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred)
|
163 |
# The predicted x0 is used as the x_start for the next denoising step.
|
164 |
x_starts.append(pred_x0)
|
165 |
|
|
|
174 |
# of the current timestep.
|
175 |
t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3))
|
176 |
t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3))
|
177 |
+
t_lb = torch.clamp(t_lb, min=global_t_lb)
|
178 |
+
t_ub = torch.clamp(t_ub, max=global_t_ub)
|
179 |
earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb
|
180 |
earlier_timesteps = earlier_timesteps.long()
|
181 |
+
noise = torch.randn_like(pred_x0)
|
182 |
|
183 |
+
if same_t_noise_across_instances:
|
184 |
+
# If same_t_noise_across_instances, we use the same earlier_timesteps and noise for all instances.
|
185 |
earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0])
|
186 |
+
noise = noise[:1].repeat(x_start.shape[0], 1, 1, 1)
|
187 |
|
188 |
# earlier_timesteps = ts[i+1] < ts[i].
|
189 |
ts.append(earlier_timesteps)
|
|
|
|
|
190 |
noises.append(noise)
|
191 |
|
192 |
return noise_preds, x_starts, noises, ts
|
193 |
+
|
194 |
+
def extract_pos_context(self, teacher_context, BS):
|
195 |
+
# If p_uses_cfg > 0, we always pass both pos_context and neg_context to the teacher.
|
196 |
+
# But the neg_context is only used when self.uses_cfg is True and cfg_scale > 1.
|
197 |
+
# So we manually split the teacher_context into pos_context and neg_context, and only keep pos_context.
|
198 |
+
if self.name == 'unet_ensemble':
|
199 |
+
teacher_pos_contexts = []
|
200 |
+
# teacher_context is a list of teacher contexts.
|
201 |
+
for teacher_context_i in teacher_context:
|
202 |
+
pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
|
203 |
+
if pos_context.shape[0] != BS:
|
204 |
+
breakpoint()
|
205 |
+
teacher_pos_contexts.append(pos_context)
|
206 |
+
teacher_context = teacher_pos_contexts
|
207 |
+
else:
|
208 |
+
pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
|
209 |
+
if pos_context.shape[0] != BS:
|
210 |
+
breakpoint()
|
211 |
+
teacher_context = pos_context
|
212 |
+
|
213 |
+
return teacher_context
|
214 |
|
215 |
class Arc2FaceTeacher(UNetTeacher):
|
216 |
def __init__(self, **kwargs):
|
|
|
225 |
self.cfg_scale_range = [1, 1]
|
226 |
|
227 |
class UNetEnsembleTeacher(UNetTeacher):
|
228 |
+
# unet_weights_in_ensemble are not model weights, but scalar weights for individual unets.
|
229 |
+
def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble=None, device='cuda', **kwargs):
|
230 |
super().__init__(**kwargs)
|
231 |
self.name = "unet_ensemble"
|
232 |
+
self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble, device)
|
233 |
|
234 |
class ConsistentIDTeacher(UNetTeacher):
|
235 |
def __init__(self, base_model_path="models/sd15-dste8-vae.safetensors", **kwargs):
|
|
|
239 |
# In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module.
|
240 |
# We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU.
|
241 |
# Instead, we have to initialize it to GPU directly.
|
242 |
+
# Release VAE and text_encoder to save memory. UNet is needed for denoising
|
|
|
|
|
|
|
243 |
# (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet).
|
244 |
+
self.unet = create_consistentid_pipeline(base_model_path, unet_only=True)
|
245 |
|
246 |
# We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher.
|
247 |
# Note p_uses_cfg=0.5 will also be passed in in kwargs.
|
adaface/util.py
CHANGED
@@ -57,7 +57,7 @@ def perturb_np_array(np_array, perturb_std, perturb_std_is_relative=True, std_di
|
|
57 |
ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim)
|
58 |
return ts.numpy().astype(np_array.dtype)
|
59 |
|
60 |
-
def calc_stats(emb_name, embeddings, mean_dim
|
61 |
print("%s:" %emb_name)
|
62 |
repeat_count = [1] * embeddings.ndim
|
63 |
repeat_count[mean_dim] = embeddings.shape[mean_dim]
|
@@ -153,13 +153,14 @@ def pad_image_obj_to_square(image_obj, new_size=-1):
|
|
153 |
|
154 |
class UNetEnsemble(nn.Module):
|
155 |
# The first unet is the unet already loaded in a pipeline.
|
156 |
-
def __init__(self, unets, unet_types, extra_unet_dirpaths,
|
157 |
super().__init__()
|
158 |
|
159 |
-
self.unets = nn.ModuleList()
|
160 |
if unets is not None:
|
161 |
-
|
162 |
-
|
|
|
|
|
163 |
if unet_types is not None:
|
164 |
for unet_type in unet_types:
|
165 |
if unet_type == "arc2face":
|
@@ -169,25 +170,27 @@ class UNetEnsemble(nn.Module):
|
|
169 |
unet = create_consistentid_pipeline(unet_only=True)
|
170 |
else:
|
171 |
breakpoint()
|
172 |
-
|
173 |
|
174 |
if extra_unet_dirpaths is not None:
|
175 |
for unet_path in extra_unet_dirpaths:
|
176 |
unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype)
|
177 |
-
|
178 |
|
179 |
-
if
|
180 |
-
|
181 |
-
elif len(
|
182 |
-
|
183 |
-
elif len(
|
184 |
breakpoint()
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
self.unet_weights = nn.Parameter(unet_weights, requires_grad=False)
|
189 |
|
190 |
-
|
|
|
|
|
|
|
191 |
# Set these fields to be compatible with diffusers.
|
192 |
self.dtype = self.unets[0].dtype
|
193 |
self.device = self.unets[0].device
|
@@ -215,8 +218,8 @@ class UNetEnsemble(nn.Module):
|
|
215 |
samples.append(sample)
|
216 |
|
217 |
samples = torch.stack(samples, dim=0)
|
218 |
-
|
219 |
-
sample = (samples *
|
220 |
|
221 |
if not return_dict:
|
222 |
return (sample,)
|
|
|
57 |
ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim)
|
58 |
return ts.numpy().astype(np_array.dtype)
|
59 |
|
60 |
+
def calc_stats(emb_name, embeddings, mean_dim=-1):
|
61 |
print("%s:" %emb_name)
|
62 |
repeat_count = [1] * embeddings.ndim
|
63 |
repeat_count[mean_dim] = embeddings.shape[mean_dim]
|
|
|
153 |
|
154 |
class UNetEnsemble(nn.Module):
|
155 |
# The first unet is the unet already loaded in a pipeline.
|
156 |
+
def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights_in_ensemble=None, device='cuda', torch_dtype=torch.float16):
|
157 |
super().__init__()
|
158 |
|
|
|
159 |
if unets is not None:
|
160 |
+
unets = [ unet.to(device) for unet in unets ]
|
161 |
+
else:
|
162 |
+
unets = []
|
163 |
+
|
164 |
if unet_types is not None:
|
165 |
for unet_type in unet_types:
|
166 |
if unet_type == "arc2face":
|
|
|
170 |
unet = create_consistentid_pipeline(unet_only=True)
|
171 |
else:
|
172 |
breakpoint()
|
173 |
+
unets.append(unet.to(device=device))
|
174 |
|
175 |
if extra_unet_dirpaths is not None:
|
176 |
for unet_path in extra_unet_dirpaths:
|
177 |
unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype)
|
178 |
+
unets.append(unet.to(device=device))
|
179 |
|
180 |
+
if unet_weights_in_ensemble is None:
|
181 |
+
unet_weights_in_ensemble = [1.] * len(unets)
|
182 |
+
elif len(unets) < len(unet_weights_in_ensemble):
|
183 |
+
unet_weights_in_ensemble = unet_weights_in_ensemble[:len(unets)]
|
184 |
+
elif len(unets) > len(unet_weights_in_ensemble):
|
185 |
breakpoint()
|
186 |
|
187 |
+
unet_weights_in_ensemble = torch.tensor(unet_weights_in_ensemble, dtype=torch_dtype)
|
188 |
+
unet_weights_in_ensemble = unet_weights_in_ensemble / unet_weights_in_ensemble.sum()
|
|
|
189 |
|
190 |
+
self.unets = nn.ModuleList(unets)
|
191 |
+
# Put the weights in a Parameter so that they will be moved to the same device as the model.
|
192 |
+
self.unet_weights_in_ensemble = nn.Parameter(unet_weights_in_ensemble, requires_grad=False)
|
193 |
+
print(f"UNetEnsemble: {len(self.unets)} UNets loaded with weights: {self.unet_weights_in_ensemble.data.cpu().numpy()}")
|
194 |
# Set these fields to be compatible with diffusers.
|
195 |
self.dtype = self.unets[0].dtype
|
196 |
self.device = self.unets[0].device
|
|
|
218 |
samples.append(sample)
|
219 |
|
220 |
samples = torch.stack(samples, dim=0)
|
221 |
+
unet_weights_in_ensemble = self.unet_weights_in_ensemble.reshape(-1, *([1] * (samples.ndim - 1)))
|
222 |
+
sample = (samples * unet_weights_in_ensemble).sum(dim=0)
|
223 |
|
224 |
if not return_dict:
|
225 |
return (sample,)
|
app.py
CHANGED
@@ -5,40 +5,63 @@ from adaface.adaface_wrapper import AdaFaceWrapper
|
|
5 |
import torch
|
6 |
import numpy as np
|
7 |
import random
|
8 |
-
|
|
|
9 |
import gradio as gr
|
10 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
import argparse
|
12 |
parser = argparse.ArgumentParser()
|
13 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
14 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
15 |
-
parser.add_argument('--adaface_ckpt_path', type=str, default='models/adaface/
|
16 |
-
help="
|
17 |
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
|
18 |
-
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=
|
19 |
help="Scales for the ID2Ada prompt encoders")
|
20 |
parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
|
21 |
choices=["arc2face", "consistentID"],
|
22 |
help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)")
|
23 |
-
parser.add_argument('--model_style_type', type=str, default='
|
24 |
choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
|
25 |
-
parser.add_argument(
|
26 |
-
help="
|
27 |
-
parser.add_argument(
|
28 |
-
help="
|
29 |
-
|
30 |
-
|
31 |
-
parser.add_argument("--
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
parser.add_argument('--gpu', type=int, default=None)
|
35 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
36 |
args = parser.parse_args()
|
37 |
|
|
|
|
|
|
|
|
|
38 |
model_style_type2base_model_path = {
|
39 |
"realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors",
|
40 |
"anime": "models/aingdiffusion/aingdiffusion_v170_ar.safetensors",
|
41 |
-
"photorealistic": "models/sar/sar.safetensors" # LDM format. Needs to be converted.
|
42 |
}
|
43 |
base_model_path = model_style_type2base_model_path[args.model_style_type]
|
44 |
|
@@ -48,13 +71,20 @@ device = "cuda" if args.gpu is None else f"cuda:{args.gpu}"
|
|
48 |
print(f"Device: {device}")
|
49 |
|
50 |
global adaface
|
51 |
-
adaface =
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
60 |
if randomize_seed:
|
@@ -71,12 +101,14 @@ def remove_back_to_files():
|
|
71 |
# Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx
|
72 |
# Or:
|
73 |
# Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx
|
74 |
-
return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True)
|
|
|
75 |
|
76 |
@spaces.GPU
|
77 |
-
def generate_image(image_paths,
|
78 |
-
num_images, prompt, negative_prompt,
|
79 |
-
|
|
|
80 |
|
81 |
global adaface
|
82 |
|
@@ -85,6 +117,9 @@ def generate_image(image_paths, guidance_scale, do_neg_id_prompt_weight, perturb
|
|
85 |
if image_paths is None or len(image_paths) == 0:
|
86 |
raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
|
87 |
|
|
|
|
|
|
|
88 |
if prompt is None:
|
89 |
prompt = ""
|
90 |
|
@@ -100,38 +135,128 @@ def generate_image(image_paths, guidance_scale, do_neg_id_prompt_weight, perturb
|
|
100 |
# Sometimes the pipeline is on CPU, although we've put it on CUDA (due to some offloading mechanism).
|
101 |
# Therefore we set the generator to the correct device.
|
102 |
generator = torch.Generator(device=device).manual_seed(seed)
|
103 |
-
print(f"Manual seed: {seed}.
|
104 |
# Generate two images each time for the user to select from.
|
105 |
noise = torch.randn(num_images, 3, 512, 512, device=device, generator=generator)
|
106 |
#print(noise.abs().sum())
|
107 |
# samples: A list of PIL Image instances.
|
108 |
-
if
|
109 |
-
if "portrait" in prompt:
|
110 |
-
|
|
|
111 |
prompt = prompt.replace("portrait", "face portrait")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
else:
|
113 |
-
prompt = "
|
114 |
|
115 |
generator = torch.Generator(device=adaface.pipeline._execution_device).manual_seed(seed)
|
116 |
-
samples = adaface(noise, prompt, negative_prompt,
|
117 |
-
do_neg_id_prompt_weight=do_neg_id_prompt_weight,
|
118 |
guidance_scale=guidance_scale,
|
119 |
-
out_image_count=num_images, generator=generator,
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
|
|
122 |
|
123 |
-
def check_prompt_and_model_type(prompt, model_style_type):
|
124 |
global adaface
|
125 |
|
126 |
model_style_type = model_style_type.lower()
|
127 |
-
base_model_path = model_style_type2base_model_path[model_style_type]
|
128 |
# If the base model type is changed, reload the model.
|
129 |
-
if model_style_type != args.model_style_type:
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
if not prompt:
|
137 |
raise gr.Error("Prompt cannot be blank")
|
@@ -145,13 +270,12 @@ description = r"""
|
|
145 |
<b>Official demo</b> for our working paper <b>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</b>.<br>
|
146 |
|
147 |
❗️**What's New**❗️
|
148 |
-
- Support switching between
|
149 |
- If you just changed the model style, the first image/video generation will take extra 20~30 seconds for loading new model weight.
|
150 |
|
151 |
❗️**Tips**❗️
|
152 |
1. Upload one or more images of a person. If multiple faces are detected, we use the largest one.
|
153 |
-
2. Check "
|
154 |
-
3. If the face dominates the image, try increasing 'Weight of ID prompt in the negative prompt'.
|
155 |
4. AdaFace Text-to-Video: <a href="https://huggingface.co/spaces/adaface-neurips/adaface-animate" style="display: inline-flex; align-items: center;">
|
156 |
AdaFace-Animate
|
157 |
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;">
|
@@ -162,13 +286,18 @@ description = r"""
|
|
162 |
"""
|
163 |
|
164 |
css = '''
|
165 |
-
.gradio-container {width: 95% !important}
|
166 |
.custom-gallery {
|
167 |
-
height: 800px;
|
168 |
width: 100%;
|
169 |
margin: 10px auto;
|
170 |
-
padding:
|
171 |
-
overflow-y: auto;
|
|
|
|
|
|
|
|
|
|
|
172 |
}
|
173 |
'''
|
174 |
with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
@@ -187,53 +316,108 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
187 |
file_types=["image"],
|
188 |
file_count="multiple"
|
189 |
)
|
190 |
-
|
|
|
191 |
with gr.Column(visible=False) as clear_button_column:
|
192 |
-
remove_and_reupload = gr.ClearButton(value="Remove and upload subject images",
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
submit = gr.Button("Submit", variant="primary")
|
216 |
|
217 |
negative_prompt = gr.Textbox(
|
218 |
label="Negative Prompt",
|
219 |
-
value="
|
|
|
|
|
|
|
|
|
|
|
220 |
)
|
221 |
|
222 |
guidance_scale = gr.Slider(
|
223 |
label="Guidance scale",
|
224 |
minimum=1.0,
|
225 |
-
maximum=
|
226 |
-
step=
|
227 |
value=args.guidance_scale,
|
228 |
)
|
229 |
|
230 |
-
|
231 |
-
label="
|
232 |
-
minimum=
|
233 |
-
maximum=0
|
234 |
-
step=0
|
235 |
-
value=args.
|
236 |
-
visible=
|
237 |
)
|
238 |
|
239 |
model_style_type = gr.Dropdown(
|
@@ -256,7 +440,7 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
256 |
num_images = gr.Slider(
|
257 |
label="Number of output images",
|
258 |
minimum=1,
|
259 |
-
maximum=
|
260 |
step=1,
|
261 |
value=4,
|
262 |
)
|
@@ -267,27 +451,41 @@ with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
267 |
step=1,
|
268 |
value=0,
|
269 |
)
|
270 |
-
randomize_seed
|
|
|
|
|
|
|
|
|
271 |
|
272 |
with gr.Column():
|
273 |
-
out_gallery = gr.Gallery(label="Generated Images", interactive=False, columns=2, rows=
|
274 |
elem_classes="custom-gallery")
|
275 |
|
276 |
-
img_files.upload(fn=swap_to_gallery,
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
demo.launch(share=True, server_name=args.ip, ssl_verify=False)
|
|
|
5 |
import torch
|
6 |
import numpy as np
|
7 |
import random
|
8 |
+
import os, re
|
9 |
+
import time
|
10 |
import gradio as gr
|
11 |
import spaces
|
12 |
+
|
13 |
+
def str2bool(v):
|
14 |
+
if isinstance(v, bool):
|
15 |
+
return v
|
16 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
17 |
+
return True
|
18 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
19 |
+
return False
|
20 |
+
else:
|
21 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
22 |
+
|
23 |
import argparse
|
24 |
parser = argparse.ArgumentParser()
|
25 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
26 |
choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders")
|
27 |
+
parser.add_argument('--adaface_ckpt_path', type=str, default='models/adaface/VGGface2_HQ_masks2025-03-06T03-31-21_zero3-ada-1000.pt',
|
28 |
+
help="Path to the checkpoint of the ID2Ada prompt encoders")
|
29 |
# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face).
|
30 |
+
parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=[6.0, 1.0],
|
31 |
help="Scales for the ID2Ada prompt encoders")
|
32 |
parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None,
|
33 |
choices=["arc2face", "consistentID"],
|
34 |
help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)")
|
35 |
+
parser.add_argument('--model_style_type', type=str, default='photorealistic',
|
36 |
choices=["realistic", "anime", "photorealistic"], help="Type of the base model")
|
37 |
+
parser.add_argument("--guidance_scale", type=float, default=5.0,
|
38 |
+
help="The guidance scale for the diffusion model. Default: 5.0")
|
39 |
+
parser.add_argument("--unet_uses_attn_lora", type=str2bool, nargs="?", const=True, default=False,
|
40 |
+
help="Whether to use LoRA in the Diffusers UNet model")
|
41 |
+
# --attn_lora_layer_names and --q_lora_updates_query are only effective
|
42 |
+
# when --unet_uses_attn_lora is set to True.
|
43 |
+
parser.add_argument("--attn_lora_layer_names", type=str, nargs="*", default=['q', 'k', 'v', 'out'],
|
44 |
+
choices=['q', 'k', 'v', 'out'], help="Names of the cross-attn components to apply LoRA on")
|
45 |
+
parser.add_argument("--q_lora_updates_query", type=str2bool, nargs="?", const=True, default=False,
|
46 |
+
help="Whether the q LoRA updates the query in the Diffusers UNet model. "
|
47 |
+
"If False, the q lora only updates query2.")
|
48 |
+
parser.add_argument("--show_disable_adaface_checkbox", type=str2bool, nargs="?", const=True, default=False,
|
49 |
+
help="Whether to show the checkbox for disabling AdaFace")
|
50 |
+
parser.add_argument('--extra_save_dir', type=str, default=None, help="Directory to save the generated images")
|
51 |
+
parser.add_argument('--test_ui_only', type=str2bool, nargs="?", const=True, default=False,
|
52 |
+
help="Only test the UI layout, and skip loadding the adaface model")
|
53 |
parser.add_argument('--gpu', type=int, default=None)
|
54 |
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
55 |
args = parser.parse_args()
|
56 |
|
57 |
+
from huggingface_hub import snapshot_download
|
58 |
+
large_files = ["models/*", "models/**/*"]
|
59 |
+
snapshot_download(repo_id="adaface-neurips/adaface-models", repo_type="model", allow_patterns=large_files, local_dir=".")
|
60 |
+
|
61 |
model_style_type2base_model_path = {
|
62 |
"realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors",
|
63 |
"anime": "models/aingdiffusion/aingdiffusion_v170_ar.safetensors",
|
64 |
+
"photorealistic": "models/sar/sar.safetensors", # LDM format. Needs to be converted.
|
65 |
}
|
66 |
base_model_path = model_style_type2base_model_path[args.model_style_type]
|
67 |
|
|
|
71 |
print(f"Device: {device}")
|
72 |
|
73 |
global adaface
|
74 |
+
adaface = None
|
75 |
+
|
76 |
+
if not args.test_ui_only:
|
77 |
+
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path,
|
78 |
+
adaface_encoder_types=args.adaface_encoder_types,
|
79 |
+
adaface_ckpt_paths=args.adaface_ckpt_path,
|
80 |
+
adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
|
81 |
+
enabled_encoders=args.enabled_encoders,
|
82 |
+
unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
83 |
+
unet_uses_attn_lora=args.unet_uses_attn_lora,
|
84 |
+
attn_lora_layer_names=args.attn_lora_layer_names,
|
85 |
+
shrink_cross_attn=False,
|
86 |
+
q_lora_updates_query=args.q_lora_updates_query,
|
87 |
+
device='cpu')
|
88 |
|
89 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
90 |
if randomize_seed:
|
|
|
101 |
# Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx
|
102 |
# Or:
|
103 |
# Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx
|
104 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True), \
|
105 |
+
gr.update(value=""), gr.update(value="(none)")
|
106 |
|
107 |
@spaces.GPU
|
108 |
+
def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
|
109 |
+
num_images, prompt, negative_prompt, gender, highlight_face,
|
110 |
+
ablate_prompt_embed_type, nonmix_prompt_emb_weight,
|
111 |
+
composition_level, seed, disable_adaface, subj_name_sig, progress=gr.Progress(track_tqdm=True)):
|
112 |
|
113 |
global adaface
|
114 |
|
|
|
117 |
if image_paths is None or len(image_paths) == 0:
|
118 |
raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
|
119 |
|
120 |
+
if image_paths2 is not None and len(image_paths2) > 0:
|
121 |
+
image_paths = image_paths + image_paths2
|
122 |
+
|
123 |
if prompt is None:
|
124 |
prompt = ""
|
125 |
|
|
|
135 |
# Sometimes the pipeline is on CPU, although we've put it on CUDA (due to some offloading mechanism).
|
136 |
# Therefore we set the generator to the correct device.
|
137 |
generator = torch.Generator(device=device).manual_seed(seed)
|
138 |
+
print(f"Manual seed: {seed}.")
|
139 |
# Generate two images each time for the user to select from.
|
140 |
noise = torch.randn(num_images, 3, 512, 512, device=device, generator=generator)
|
141 |
#print(noise.abs().sum())
|
142 |
# samples: A list of PIL Image instances.
|
143 |
+
if highlight_face:
|
144 |
+
if "portrait" not in prompt:
|
145 |
+
prompt = "face portrait, " + prompt
|
146 |
+
else:
|
147 |
prompt = prompt.replace("portrait", "face portrait")
|
148 |
+
if composition_level >= 2:
|
149 |
+
if "full body" not in prompt:
|
150 |
+
prompt = prompt + ", full body view"
|
151 |
+
|
152 |
+
if gender != "(none)":
|
153 |
+
if "portrait" in prompt:
|
154 |
+
prompt = prompt.replace("portrait, ", f"portrait, {gender} ")
|
155 |
else:
|
156 |
+
prompt = gender + ", " + prompt
|
157 |
|
158 |
generator = torch.Generator(device=adaface.pipeline._execution_device).manual_seed(seed)
|
159 |
+
samples = adaface(noise, prompt, negative_prompt=negative_prompt,
|
|
|
160 |
guidance_scale=guidance_scale,
|
161 |
+
out_image_count=num_images, generator=generator,
|
162 |
+
repeat_prompt_for_each_encoder=(composition_level >= 1),
|
163 |
+
ablate_prompt_no_placeholders=disable_adaface,
|
164 |
+
ablate_prompt_embed_type=ablate_prompt_embed_type,
|
165 |
+
nonmix_prompt_emb_weight=nonmix_prompt_emb_weight,
|
166 |
+
verbose=True)
|
167 |
+
|
168 |
+
session_signature = ",".join(image_paths + [prompt, str(seed)])
|
169 |
+
temp_folder = os.path.join("/tmp/gradio", f"{hash(session_signature)}")
|
170 |
+
os.makedirs(temp_folder, exist_ok=True)
|
171 |
+
|
172 |
+
saved_image_paths = []
|
173 |
+
if "models/adaface/" in args.adaface_ckpt_path:
|
174 |
+
# The model is loaded from within the project.
|
175 |
+
# models/adaface/VGGface2_HQ_masks2024-10-14T16-09-24_zero3-ada-3500.pt
|
176 |
+
matches = re.search(r"models/adaface/\w+\d{4}-(\d{2})-(\d{2})T(\d{2})-\d{2}-\d{2}_zero3-ada-(\d+).pt", args.adaface_ckpt_path)
|
177 |
+
else:
|
178 |
+
# The model is loaded from the adaprompt folder.
|
179 |
+
# adaface_ckpt_path = "VGGface2_HQ_masks2024-11-28T13-13-20_zero3-ada/checkpoints/embeddings_gs-2000.pt"
|
180 |
+
matches = re.search(r"\d{4}-(\d{2})-(\d{2})T(\d{2})-\d{2}-\d{2}_zero3-ada/checkpoints/embeddings_gs-(\d+).pt", args.adaface_ckpt_path)
|
181 |
+
|
182 |
+
# Extract the checkpoint signature as 112813-2000
|
183 |
+
ckpt_sig = f"{matches.group(1)}{matches.group(2)}{matches.group(3)}-{matches.group(4)}"
|
184 |
+
|
185 |
+
prompt_keywords = ['armor', 'beach', 'chef', 'dancing', 'iron man', 'jedi',
|
186 |
+
'street', 'guitar', 'reading', 'running', 'superman', 'new year', 'mars']
|
187 |
+
keywords_reduction = { 'iron man': 'ironman', 'dancing': 'dance',
|
188 |
+
'running': 'run', 'reading': 'read', 'new year': 'newyear' }
|
189 |
+
|
190 |
+
prompt_sig = None
|
191 |
+
for keyword in prompt_keywords:
|
192 |
+
if keyword in prompt.lower():
|
193 |
+
prompt_sig = keywords_reduction.get(keyword, keyword)
|
194 |
+
break
|
195 |
+
|
196 |
+
if prompt_sig is None:
|
197 |
+
prompt_parts = prompt.lower().split(",")
|
198 |
+
# Remove the view/shot parts (full body view, long shot, etc.) from the prompt.
|
199 |
+
prompt_parts = [ part for part in prompt_parts if not re.search(r"\W(view|shot)(\W|$)", part) ]
|
200 |
+
if len(prompt_parts) > 0:
|
201 |
+
# Use the last word of the prompt as the signature.
|
202 |
+
prompt_sig = prompt_parts[-1].split()[-1]
|
203 |
+
else:
|
204 |
+
prompt_sig = "person"
|
205 |
+
|
206 |
+
if len(prompt_sig) > 0:
|
207 |
+
prompt_sig = "-" + prompt_sig
|
208 |
+
|
209 |
+
extra_save_dir = args.extra_save_dir
|
210 |
+
if extra_save_dir is not None:
|
211 |
+
os.makedirs(extra_save_dir, exist_ok=True)
|
212 |
+
|
213 |
+
for i, sample in enumerate(samples):
|
214 |
+
filename = f"adaface{ckpt_sig}{prompt_sig}-{i+1}.png"
|
215 |
+
if len(subj_name_sig) > 0:
|
216 |
+
filename = f"{subj_name_sig.lower()}-{filename}"
|
217 |
+
filepath = os.path.join(temp_folder, filename)
|
218 |
+
# Save the image
|
219 |
+
sample.save(filepath) # Adjust to your image saving method
|
220 |
+
saved_image_paths.append(filepath)
|
221 |
+
|
222 |
+
if extra_save_dir is not None:
|
223 |
+
extra_filepath = os.path.join(extra_save_dir, filename)
|
224 |
+
sample.save(extra_filepath)
|
225 |
+
print(extra_filepath)
|
226 |
+
|
227 |
+
# Solution suggested by o1 to force the client browser to reload images
|
228 |
+
# when we change guidance scales only.
|
229 |
+
saved_image_paths = [f"{url}?t={int(time.time())}" for url in saved_image_paths]
|
230 |
|
231 |
+
return saved_image_paths
|
232 |
|
233 |
+
def check_prompt_and_model_type(prompt, model_style_type, adaface_encoder_cfg_scale1):
|
234 |
global adaface
|
235 |
|
236 |
model_style_type = model_style_type.lower()
|
|
|
237 |
# If the base model type is changed, reload the model.
|
238 |
+
if model_style_type != args.model_style_type or adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]:
|
239 |
+
if model_style_type != args.model_style_type:
|
240 |
+
# Update base model type.
|
241 |
+
args.model_style_type = model_style_type
|
242 |
+
print(f"Switching to the base model type: {model_style_type}.")
|
243 |
+
|
244 |
+
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=model_style_type2base_model_path[model_style_type],
|
245 |
+
adaface_encoder_types=args.adaface_encoder_types,
|
246 |
+
adaface_ckpt_paths=args.adaface_ckpt_path,
|
247 |
+
adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales,
|
248 |
+
enabled_encoders=args.enabled_encoders,
|
249 |
+
unet_types=None, extra_unet_dirpaths=None, unet_weights_in_ensemble=None,
|
250 |
+
unet_uses_attn_lora=args.unet_uses_attn_lora,
|
251 |
+
attn_lora_layer_names=args.attn_lora_layer_names,
|
252 |
+
shrink_cross_attn=False,
|
253 |
+
q_lora_updates_query=args.q_lora_updates_query,
|
254 |
+
device='cpu')
|
255 |
+
|
256 |
+
if adaface_encoder_cfg_scale1 != args.adaface_encoder_cfg_scales[0]:
|
257 |
+
args.adaface_encoder_cfg_scales[0] = adaface_encoder_cfg_scale1
|
258 |
+
adaface.set_adaface_encoder_cfg_scales(args.adaface_encoder_cfg_scales)
|
259 |
+
print(f"Updating the scale for consistentID encoder to {adaface_encoder_cfg_scale1}.")
|
260 |
|
261 |
if not prompt:
|
262 |
raise gr.Error("Prompt cannot be blank")
|
|
|
270 |
<b>Official demo</b> for our working paper <b>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</b>.<br>
|
271 |
|
272 |
❗️**What's New**❗️
|
273 |
+
- Support switching between three model styles: **Photorealistic**, **Realistic** and **Anime**.
|
274 |
- If you just changed the model style, the first image/video generation will take extra 20~30 seconds for loading new model weight.
|
275 |
|
276 |
❗️**Tips**❗️
|
277 |
1. Upload one or more images of a person. If multiple faces are detected, we use the largest one.
|
278 |
+
2. Check "Highlight face" to highlight fine facial features.
|
|
|
279 |
4. AdaFace Text-to-Video: <a href="https://huggingface.co/spaces/adaface-neurips/adaface-animate" style="display: inline-flex; align-items: center;">
|
280 |
AdaFace-Animate
|
281 |
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;">
|
|
|
286 |
"""
|
287 |
|
288 |
css = '''
|
289 |
+
.gradio-container {width: 95% !important}
|
290 |
.custom-gallery {
|
291 |
+
height: 800px !important;
|
292 |
width: 100%;
|
293 |
margin: 10px auto;
|
294 |
+
padding: 0px;
|
295 |
+
overflow-y: auto !important;
|
296 |
+
}
|
297 |
+
.tight-row {
|
298 |
+
gap: 0 !important; /* removes the horizontal gap between columns */
|
299 |
+
margin: 0 !important; /* remove any extra margin if needed */
|
300 |
+
padding: 0 !important; /* remove any extra padding if needed */
|
301 |
}
|
302 |
'''
|
303 |
with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo:
|
|
|
316 |
file_types=["image"],
|
317 |
file_count="multiple"
|
318 |
)
|
319 |
+
# When files are uploaded, show the images in the gallery and hide the file uploader.
|
320 |
+
uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=1, height=300)
|
321 |
with gr.Column(visible=False) as clear_button_column:
|
322 |
+
remove_and_reupload = gr.ClearButton(value="Remove and upload subject images",
|
323 |
+
components=img_files, size="sm")
|
324 |
+
|
325 |
+
with gr.Accordion("Second Subject (Optional)", open=False):
|
326 |
+
img_files2 = gr.File(
|
327 |
+
label="Drag / Select 1 or more photos of second subject's face (optional)",
|
328 |
+
file_types=["image"],
|
329 |
+
file_count="multiple"
|
330 |
+
)
|
331 |
+
|
332 |
+
uploaded_files_gallery2 = gr.Gallery(label="2nd Subject images (optional)", visible=False, columns=3, rows=1, height=300)
|
333 |
+
with gr.Column(visible=False) as clear_button_column2:
|
334 |
+
remove_and_reupload2 = gr.ClearButton(value="Remove and upload 2nd Subject images",
|
335 |
+
components=img_files2, size="sm")
|
336 |
+
|
337 |
+
with gr.Row(elem_classes="tight-row"):
|
338 |
+
with gr.Column(scale=1, min_width=100):
|
339 |
+
gender = gr.Dropdown(label="Gender", value="(none)",
|
340 |
+
info="Gender prefix. Select only when the model errs.",
|
341 |
+
container=False,
|
342 |
+
choices=[ "(none)", "person", "man", "woman", "girl", "boy" ])
|
343 |
+
|
344 |
+
with gr.Column(scale=100):
|
345 |
+
prompt = gr.Dropdown(label="Prompt",
|
346 |
+
info="Try something like 'walking on the beach'. If the face is not in focus, try checking 'Highlight face'.",
|
347 |
+
value="portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
|
348 |
+
allow_custom_value=True,
|
349 |
+
choices=[
|
350 |
+
"portrait, highlighted hair, futuristic silver armor suit, confident stance, living room, smiling, head tilted, perfect smooth skin",
|
351 |
+
"portrait, walking on the beach, sunset, orange sky, front view",
|
352 |
+
"portrait, in a white apron and chef hat, garnishing a gourmet dish",
|
353 |
+
"portrait, waving hands, dancing pose among folks in a park",
|
354 |
+
"portrait, in iron man costume, the sky ablaze with hues of orange and purple",
|
355 |
+
"portrait, jedi wielding a lightsaber, star wars",
|
356 |
+
"portrait, night view of tokyo street, neon light",
|
357 |
+
"portrait, playing guitar on a boat, ocean waves",
|
358 |
+
"portrait, with a passion for reading, curled up with a book in a cozy nook near a window, front view",
|
359 |
+
"portrait, celebrating new year, fireworks",
|
360 |
+
"portrait, running pose in a park",
|
361 |
+
"portrait, in space suit, space helmet, walking on mars",
|
362 |
+
"portrait, in superman costume, the sky ablaze with hues of orange and purple",
|
363 |
+
"in a wheelchair",
|
364 |
+
"on a horse"
|
365 |
+
])
|
366 |
|
367 |
+
highlight_face = gr.Checkbox(label="Highlight face", value=False,
|
368 |
+
info="Enhance the facial features by prepending 'face portrait' to the prompt")
|
369 |
+
composition_level = \
|
370 |
+
gr.Slider(label="Composition Level", visible=True,
|
371 |
+
info="The degree of overall composition, 0~2. Challenging prompts like 'In a wheelchair' and 'on a horse' need level 2",
|
372 |
+
minimum=0, maximum=2, step=1, value=0)
|
373 |
+
|
374 |
+
ablate_prompt_embed_type = gr.Dropdown(label="Ablate prompt embeddings type",
|
375 |
+
choices=["ada", "ada-nonmix", "img"], value="ada", visible=False,
|
376 |
+
info="Use this type of prompt embeddings for ablation study")
|
377 |
+
|
378 |
+
nonmix_prompt_emb_weight = gr.Slider(label="Weight of ada-nonmix ID embeddings",
|
379 |
+
minimum=0.0, maximum=0.5, step=0.1, value=0,
|
380 |
+
info="Weight of ada-nonmix ID embeddings in the prompt embeddings",
|
381 |
+
visible=False)
|
382 |
+
|
383 |
+
|
384 |
+
subj_name_sig = gr.Textbox(
|
385 |
+
label="Nickname of Subject (optional; used to name saved images)",
|
386 |
+
value="",
|
387 |
+
)
|
388 |
+
subj_name_sig2 = gr.Textbox(
|
389 |
+
label="Nickname of 2nd Subject (optional; used to name saved images)",
|
390 |
+
value="",
|
391 |
+
visible=False,
|
392 |
+
)
|
393 |
|
394 |
submit = gr.Button("Submit", variant="primary")
|
395 |
|
396 |
negative_prompt = gr.Textbox(
|
397 |
label="Negative Prompt",
|
398 |
+
value="sagging face, sagging cheeks, wrinkles, flaws in the eyes, flaws in the face, lowres, "
|
399 |
+
"non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, "
|
400 |
+
"mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, "
|
401 |
+
"deformed eyeballs, cross-eyed, extra legs, extra arms, blurry, mutation, duplicate, "
|
402 |
+
"out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, "
|
403 |
+
"nude, naked, nsfw, topless, bare breasts",
|
404 |
)
|
405 |
|
406 |
guidance_scale = gr.Slider(
|
407 |
label="Guidance scale",
|
408 |
minimum=1.0,
|
409 |
+
maximum=8.0,
|
410 |
+
step=0.5,
|
411 |
value=args.guidance_scale,
|
412 |
)
|
413 |
|
414 |
+
adaface_encoder_cfg_scale1 = gr.Slider(
|
415 |
+
label="Scale for consistentID encoder",
|
416 |
+
minimum=1.0,
|
417 |
+
maximum=12.0,
|
418 |
+
step=1.0,
|
419 |
+
value=args.adaface_encoder_cfg_scales[0],
|
420 |
+
visible=False,
|
421 |
)
|
422 |
|
423 |
model_style_type = gr.Dropdown(
|
|
|
440 |
num_images = gr.Slider(
|
441 |
label="Number of output images",
|
442 |
minimum=1,
|
443 |
+
maximum=8,
|
444 |
step=1,
|
445 |
value=4,
|
446 |
)
|
|
|
451 |
step=1,
|
452 |
value=0,
|
453 |
)
|
454 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True,
|
455 |
+
info="Uncheck for reproducible results")
|
456 |
+
disable_adaface = gr.Checkbox(label="Disable AdaFace", value=False,
|
457 |
+
info="Disable AdaFace for ablation. If checked, the results are no longer personalized.",
|
458 |
+
visible=args.show_disable_adaface_checkbox)
|
459 |
|
460 |
with gr.Column():
|
461 |
+
out_gallery = gr.Gallery(label="Generated Images", interactive=False, columns=2, rows=4, height=800,
|
462 |
elem_classes="custom-gallery")
|
463 |
|
464 |
+
img_files.upload(fn=swap_to_gallery, inputs=img_files, outputs=[uploaded_files_gallery, clear_button_column, img_files])
|
465 |
+
img_files2.upload(fn=swap_to_gallery, inputs=img_files2, outputs=[uploaded_files_gallery2, clear_button_column2, img_files2])
|
466 |
+
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery, clear_button_column,
|
467 |
+
img_files, subj_name_sig, gender])
|
468 |
+
remove_and_reupload2.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery2, clear_button_column2,
|
469 |
+
img_files2, subj_name_sig2, gender])
|
470 |
+
|
471 |
+
check_prompt_and_model_type_call_dict = {
|
472 |
+
'fn': check_prompt_and_model_type,
|
473 |
+
'inputs': [prompt, model_style_type, adaface_encoder_cfg_scale1],
|
474 |
+
'outputs': None
|
475 |
+
}
|
476 |
+
randomize_seed_fn_call_dict = {
|
477 |
+
'fn': randomize_seed_fn,
|
478 |
+
'inputs': [seed, randomize_seed],
|
479 |
+
'outputs': seed
|
480 |
+
}
|
481 |
+
generate_image_call_dict = {
|
482 |
+
'fn': generate_image,
|
483 |
+
'inputs': [img_files, img_files2, guidance_scale, perturb_std, num_images, prompt,
|
484 |
+
negative_prompt, gender, highlight_face, ablate_prompt_embed_type,
|
485 |
+
nonmix_prompt_emb_weight, composition_level, seed, disable_adaface, subj_name_sig],
|
486 |
+
'outputs': [out_gallery]
|
487 |
+
}
|
488 |
+
submit.click(**check_prompt_and_model_type_call_dict).success(**randomize_seed_fn_call_dict).then(**generate_image_call_dict)
|
489 |
+
subj_name_sig.submit(**check_prompt_and_model_type_call_dict).success(**randomize_seed_fn_call_dict).then(**generate_image_call_dict)
|
490 |
|
491 |
demo.launch(share=True, server_name=args.ip, ssl_verify=False)
|