diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..b8e3cb0e5a6e27cc49bfd77fd78768942fe4d7d2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +models/ensemble/ar18-unet/diffusion_pytorch_model.safetensors filter=lfs diff=lfs merge=lfs -text +ConsistentID/images/templates/3f8d901770014c1b8f7f261971f0e92.png filter=lfs diff=lfs merge=lfs -text +ConsistentID/images/templates/6577b962b6346df03fea83211daaf48.png filter=lfs diff=lfs merge=lfs -text +ConsistentID/images/templates/75583964a834abe33b72f52b1a98e84.png filter=lfs diff=lfs merge=lfs -text +ConsistentID/lib/BiSeNet/6.jpg filter=lfs diff=lfs merge=lfs -text +ConsistentID/lib/BiSeNet/makeup/116_1.png filter=lfs diff=lfs merge=lfs -text +ConsistentID/lib/BiSeNet/makeup/116_3.png filter=lfs diff=lfs merge=lfs -text +ConsistentID/lib/BiSeNet/makeup/116_lip_ori.png filter=lfs diff=lfs merge=lfs -text +ConsistentID/lib/BiSeNet/makeup/116_ori.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..98ff67a41c96d9528c238d649f517c34d6652501 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +models/awportrait/* +models/awportrait +__pycache__/* +__pycache__ +samples-ada/* +samples-ada +models/ensemble/awp14-unet/* +models/ensemble/awp14-unet +.gradio/certificate.pem + diff --git a/ConsistentID/.gitattributes b/ConsistentID/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..58f1377d7c533a58a9c9de1a4a43f3cdda09fca4 --- /dev/null +++ b/ConsistentID/.gitattributes @@ -0,0 +1,38 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +images/templates/3f8d901770014c1b8f7f261971f0e92.png filter=lfs diff=lfs merge=lfs -text +images/templates/75583964a834abe33b72f52b1a98e84.png filter=lfs diff=lfs merge=lfs -text +models/LLaVA/images/demo_cli.gif filter=lfs diff=lfs merge=lfs -text diff --git a/ConsistentID/.gitignore b/ConsistentID/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..03902be4a3bb461a24426f4af15a055de6c7d553 --- /dev/null +++ b/ConsistentID/.gitignore @@ -0,0 +1,5 @@ +__pycache__/* +__pycache__ +/*.png +models/insightface +models/Realistic_Vision* diff --git a/ConsistentID/LICENSE b/ConsistentID/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4de2ad6baf433ec9f6fc16246814237acd15c38f --- /dev/null +++ b/ConsistentID/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Jiehui Huang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/ConsistentID/README.md b/ConsistentID/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6edc705fd6ef175c3398fa5abebcc6f182880b51 --- /dev/null +++ b/ConsistentID/README.md @@ -0,0 +1,13 @@ +--- +title: ConsistentID +emoji: 🔥 +colorFrom: yellow +colorTo: yellow +sdk: gradio +sdk_version: 4.37.2 +app_file: app.py +pinned: false +license: apache-2.0 +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/ConsistentID/__init__.py b/ConsistentID/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ConsistentID/app.py b/ConsistentID/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ad923a26857e37af1d9ea93933e3193166dd990a --- /dev/null +++ b/ConsistentID/app.py @@ -0,0 +1,168 @@ +import gradio as gr +import torch +import os +import glob +import spaces +import numpy as np + +from PIL import Image +from diffusers.utils import load_image +from diffusers import EulerDiscreteScheduler +from ConsistentID.lib.pipeline_ConsistentID import ConsistentIDPipeline +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--base_model_path', type=str, + default="models/Realistic_Vision_V4.0_noVAE") +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +device = f"cuda:{args.gpu}" + +### Load base model +pipe = ConsistentIDPipeline.from_pretrained( + args.base_model_path, + torch_dtype=torch.float16, +) + +### Load consistentID_model checkpoint +pipe.load_ConsistentID_model( + consistentID_weight_path="./models/ConsistentID-v1.bin", + bise_net_weight_path="./models/BiSeNet_pretrained_for_ConsistentID.pth", +) +pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to(device, torch.float16) + +@spaces.GPU +def process(selected_template_images, custom_image, prompt, + negative_prompt, prompt_selected, model_selected_tab, + prompt_selected_tab, guidance_scale, width, height, merge_steps, seed_set): + + # The gradio UI only supports one image at a time. + if model_selected_tab==0: + subj_images = load_image(Image.open(selected_template_images)) + else: + subj_images = load_image(Image.fromarray(custom_image)) + + if prompt_selected_tab==0: + prompt = prompt_selected + negative_prompt = "" + + # hyper-parameter + num_steps = 50 + seed_set = torch.randint(0, 1000, (1,)).item() + # merge_steps = 30 + + if prompt == "": + prompt = "A man, in a forest" + prompt = "A man, with backpack, in a raining tropical forest, adventuring, holding a flashlight, in mist, seeking animals" + prompt = "A person, in a sowm, wearing santa hat and a scarf, with a cottage behind" + else: + #prompt=Enhance_prompt(prompt, Image.new('RGB', (200, 200), color = 'white')) + print(prompt) + + if negative_prompt == "": + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry" + + #Extend Prompt + #prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed" + #print(prompt) + + negtive_prompt_group="((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))" + negative_prompt = negative_prompt + negtive_prompt_group + + # seed = torch.randint(0, 1000, (1,)).item() + generator = torch.Generator(device=device).manual_seed(seed_set) + + images = pipe( + prompt=prompt, + width=width, + height=height, + input_subj_image_objs=subj_images, + negative_prompt=negative_prompt, + num_images_per_prompt=1, + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + start_merge_step=merge_steps, + generator=generator, + ).images[0] + + return np.array(images) + +# Gets the templates +preset_template = glob.glob("./images/templates/*.png") +preset_template = preset_template + glob.glob("./images/templates/*.jpg") + +with gr.Blocks(title="ConsistentID Demo") as demo: + gr.Markdown("# ConsistentID Demo") + gr.Markdown("\ + Put the reference figure to be redrawn into the box below (There is a small probability of referensing failure. You can submit it repeatedly)") + gr.Markdown("\ + If you find our work interesting, please leave a star in GitHub for us!
\ + https://github.com/JackAILab/ConsistentID") + with gr.Row(): + with gr.Column(): + model_selected_tab = gr.State(0) + with gr.TabItem("template images") as template_images_tab: + template_gallery_list = [(i, i) for i in preset_template] + gallery = gr.Gallery(template_gallery_list,columns=[4], rows=[2], object_fit="contain", height="auto",show_label=False) + + def select_function(evt: gr.SelectData): + return preset_template[evt.index] + + selected_template_images = gr.Text(show_label=False, visible=False, placeholder="Selected") + gallery.select(select_function, None, selected_template_images) + with gr.TabItem("Upload Image") as upload_image_tab: + custom_image = gr.Image(label="Upload Image") + + model_selected_tabs = [template_images_tab, upload_image_tab] + for i, tab in enumerate(model_selected_tabs): + tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[model_selected_tab]) + + with gr.Column(): + prompt_selected_tab = gr.State(0) + with gr.TabItem("template prompts") as template_prompts_tab: + prompt_selected = gr.Dropdown(value="A person, police officer, half body shot", elem_id='dropdown', choices=[ + "A woman in a wedding dress", + "A woman, queen, in a gorgeous palace", + "A man sitting at the beach with sunset", + "A person, police officer, half body shot", + "A man, sailor, in a boat above ocean", + "A women wearing headphone, listening music", + "A man, firefighter, half body shot"], label=f"prepared prompts") + + with gr.TabItem("custom prompt") as custom_prompt_tab: + prompt = gr.Textbox(label="prompt",placeholder="A man/woman wearing a santa hat") + nagetive_prompt = gr.Textbox(label="negative prompt",placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry") + + prompt_selected_tabs = [template_prompts_tab, custom_prompt_tab] + for i, tab in enumerate(prompt_selected_tabs): + tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab]) + + guidance_scale = gr.Slider( + label="Guidance scale", + minimum=1.0, + maximum=10.0, + step=1.0, + value=5.0, + ) + + width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8) + height = gr.Slider(label="image height",minimum=256,maximum=768,value=512,step=8) + width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height]) + height.release(lambda x,y: min(1280-y,x), inputs=[width,height], outputs=[width]) + merge_steps = gr.Slider(label="step starting to merge facial details(30 is recommended)",minimum=10,maximum=50,value=30,step=1) + seed_set = gr.Slider(label="set the random seed for different results",minimum=1,maximum=2147483647,value=2024,step=1) + + btn = gr.Button("Run") + with gr.Column(): + out = gr.Image(label="Output") + gr.Markdown(''' + N.B.:
+ - If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.) + - At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female. + - Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible

+ ''') + btn.click(fn=process, inputs=[selected_template_images, custom_image,prompt, nagetive_prompt, prompt_selected, + model_selected_tab, prompt_selected_tab, guidance_scale, width, height, merge_steps, seed_set], outputs=out) + +demo.launch(server_name='0.0.0.0', ssl_verify=False) diff --git a/ConsistentID/images/templates/3f8d901770014c1b8f7f261971f0e92.png b/ConsistentID/images/templates/3f8d901770014c1b8f7f261971f0e92.png new file mode 100644 index 0000000000000000000000000000000000000000..96322e923d0fe21a29aaf0c0ff81593c4dd8e45b --- /dev/null +++ b/ConsistentID/images/templates/3f8d901770014c1b8f7f261971f0e92.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fa9319750b9927075934c40a180766e75ff539711293581dae6bac5963b9d05 +size 2061666 diff --git a/ConsistentID/images/templates/6577b962b6346df03fea83211daaf48.png b/ConsistentID/images/templates/6577b962b6346df03fea83211daaf48.png new file mode 100644 index 0000000000000000000000000000000000000000..3a3f6c6c6f6379ddb685b2e7e43ecf7290d40e4b --- /dev/null +++ b/ConsistentID/images/templates/6577b962b6346df03fea83211daaf48.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5fa55a7291ea76f96144b5caba73b47ffd31b941ee6fcefd17e72976b446439d +size 218420 diff --git a/ConsistentID/images/templates/75583964a834abe33b72f52b1a98e84.png b/ConsistentID/images/templates/75583964a834abe33b72f52b1a98e84.png new file mode 100644 index 0000000000000000000000000000000000000000..cb4c91b4840d6e4e62d0242083d66d7f071d66a7 --- /dev/null +++ b/ConsistentID/images/templates/75583964a834abe33b72f52b1a98e84.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:318c942eb3cc8a1f9320b2ea84a88cd95067785c07f8ae1dd18fe6c4cf8e8282 +size 7543309 diff --git a/ConsistentID/images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg b/ConsistentID/images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c0c4d207ca292c88f846aea31b218ef75a39bcd5 Binary files /dev/null and b/ConsistentID/images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg differ diff --git a/ConsistentID/lib/BiSeNet/6.jpg b/ConsistentID/lib/BiSeNet/6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1ce64ac5be50184dee6d30857b3c57d05a2edd5a --- /dev/null +++ b/ConsistentID/lib/BiSeNet/6.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d79fc9d9b5b04dc3f82782c11f15aa6fd2ba6654f51b73237f088e3805b655a5 +size 134031 diff --git a/ConsistentID/lib/BiSeNet/__init__.py b/ConsistentID/lib/BiSeNet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34c103210a7fa7fda0b895e183e4f3cbc831f92b --- /dev/null +++ b/ConsistentID/lib/BiSeNet/__init__.py @@ -0,0 +1,2 @@ +#__init__.py +# from BiSeNet.model import * \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/evaluate.py b/ConsistentID/lib/BiSeNet/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..578d75c7e8b4dceeb20cc599ad9062b67311724e --- /dev/null +++ b/ConsistentID/lib/BiSeNet/evaluate.py @@ -0,0 +1,95 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +from logger import setup_logger +import BiSeNet +from face_dataset import FaceMask + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import torch.nn.functional as F +import torch.distributed as dist + +import os +import os.path as osp +import logging +import time +import numpy as np +from tqdm import tqdm +import math +from PIL import Image +import torchvision.transforms as transforms +import cv2 + +def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'): + # Colors for all 20 parts + part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], + [255, 0, 85], [255, 0, 170], + [0, 255, 0], [85, 255, 0], [170, 255, 0], + [0, 255, 85], [0, 255, 170], + [0, 0, 255], [85, 0, 255], [170, 0, 255], + [0, 85, 255], [0, 170, 255], + [255, 255, 0], [255, 255, 85], [255, 255, 170], + [255, 0, 255], [255, 85, 255], [255, 170, 255], + [0, 255, 255], [85, 255, 255], [170, 255, 255]] + + im = np.array(im) + vis_im = im.copy().astype(np.uint8) + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 + + num_of_class = np.max(vis_parsing_anno) + + for pi in range(1, num_of_class + 1): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + # print(vis_parsing_anno_color.shape, vis_im.shape) + vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) + + # Save result or not + if save_im: + cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + + # return vis_im + +def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): + + if not os.path.exists(respth): + os.makedirs(respth) + + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + save_pth = osp.join('res/cp', cp) + net.load_state_dict(torch.load(save_pth)) + net.eval() + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + with torch.no_grad(): + for image_path in os.listdir(dspth): + img = Image.open(osp.join(dspth, image_path)) + image = img.resize((512, 512), Image.BILINEAR) + img = to_tensor(image) + img = torch.unsqueeze(img, 0) + img = img.cuda() + out = net(img)[0] + parsing = out.squeeze(0).cpu().numpy().argmax(0) + + vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path)) + + + + + + + +if __name__ == "__main__": + setup_logger('./res') + evaluate() diff --git a/ConsistentID/lib/BiSeNet/face_dataset.py b/ConsistentID/lib/BiSeNet/face_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ece7fb0afd127c7bf085c769540145838e270e --- /dev/null +++ b/ConsistentID/lib/BiSeNet/face_dataset.py @@ -0,0 +1,106 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +from torch.utils.data import Dataset +import torchvision.transforms as transforms + +import os.path as osp +import os +from PIL import Image +import numpy as np +import json +import cv2 + +from transform import * + + + +class FaceMask(Dataset): + def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs): + super(FaceMask, self).__init__(*args, **kwargs) + assert mode in ('train', 'val', 'test') + self.mode = mode + self.ignore_lb = 255 + self.rootpth = rootpth + + self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img')) + + # pre-processing + self.to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + self.trans_train = Compose([ + ColorJitter( + brightness=0.5, + contrast=0.5, + saturation=0.5), + HorizontalFlip(), + RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)), + RandomCrop(cropsize) + ]) + + def __getitem__(self, idx): + impth = self.imgs[idx] + img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth)) + img = img.resize((512, 512), Image.BILINEAR) + label = Image.open(osp.join(self.rootpth, 'mask', impth[:-3]+'png')).convert('P') + # print(np.unique(np.array(label))) + if self.mode == 'train': + im_lb = dict(im=img, lb=label) + im_lb = self.trans_train(im_lb) + img, label = im_lb['im'], im_lb['lb'] + img = self.to_tensor(img) + label = np.array(label).astype(np.int64)[np.newaxis, :] + return img, label + + def __len__(self): + return len(self.imgs) + + +if __name__ == "__main__": + face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img' + face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno' + mask_path = '/home/zll/data/CelebAMask-HQ/mask' + counter = 0 + total = 0 + for i in range(15): + # files = os.listdir(osp.join(face_sep_mask, str(i))) + + atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', + 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] + + for j in range(i*2000, (i+1)*2000): + + mask = np.zeros((512, 512)) + + for l, att in enumerate(atts, 1): + total += 1 + file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png']) + path = osp.join(face_sep_mask, str(i), file_name) + + if os.path.exists(path): + counter += 1 + sep_mask = np.array(Image.open(path).convert('P')) + # print(np.unique(sep_mask)) + + mask[sep_mask == 225] = l + cv2.imwrite('{}/{}.png'.format(mask_path, j), mask) + print(j) + + print(counter, total) + + + + + + + + + + + + + + diff --git a/ConsistentID/lib/BiSeNet/hair.png b/ConsistentID/lib/BiSeNet/hair.png new file mode 100644 index 0000000000000000000000000000000000000000..07d194f77af5ccbde364500dafc43b96ebfb5c8b Binary files /dev/null and b/ConsistentID/lib/BiSeNet/hair.png differ diff --git a/ConsistentID/lib/BiSeNet/logger.py b/ConsistentID/lib/BiSeNet/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f9ddcc2cae221b4dd881d02404e848b5396f7e --- /dev/null +++ b/ConsistentID/lib/BiSeNet/logger.py @@ -0,0 +1,23 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import os.path as osp +import time +import sys +import logging + +import torch.distributed as dist + + +def setup_logger(logpth): + logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) + logfile = osp.join(logpth, logfile) + FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' + log_level = logging.INFO + if dist.is_initialized() and not dist.get_rank()==0: + log_level = logging.ERROR + logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) + logging.root.addHandler(logging.StreamHandler()) + + diff --git a/ConsistentID/lib/BiSeNet/loss.py b/ConsistentID/lib/BiSeNet/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..62657de66de513995c87acb81108a35d941fe37f --- /dev/null +++ b/ConsistentID/lib/BiSeNet/loss.py @@ -0,0 +1,72 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class OhemCELoss(nn.Module): + def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs): + super(OhemCELoss, self).__init__() + self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() + self.n_min = n_min + self.ignore_lb = ignore_lb + self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') + + def forward(self, logits, labels): + N, C, H, W = logits.size() + loss = self.criteria(logits, labels).view(-1) + loss, _ = torch.sort(loss, descending=True) + if loss[self.n_min] > self.thresh: + loss = loss[loss>self.thresh] + else: + loss = loss[:self.n_min] + return torch.mean(loss) + + +class SoftmaxFocalLoss(nn.Module): + def __init__(self, gamma, ignore_lb=255, *args, **kwargs): + super(SoftmaxFocalLoss, self).__init__() + self.gamma = gamma + self.nll = nn.NLLLoss(ignore_index=ignore_lb) + + def forward(self, logits, labels): + scores = F.softmax(logits, dim=1) + factor = torch.pow(1.-scores, self.gamma) + log_score = F.log_softmax(logits, dim=1) + log_score = factor * log_score + loss = self.nll(log_score, labels) + return loss + + +if __name__ == '__main__': + torch.manual_seed(15) + criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() + criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() + net1 = nn.Sequential( + nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1), + ) + net1.cuda() + net1.train() + net2 = nn.Sequential( + nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1), + ) + net2.cuda() + net2.train() + + with torch.no_grad(): + inten = torch.randn(16, 3, 20, 20).cuda() + lbs = torch.randint(0, 19, [16, 20, 20]).cuda() + lbs[1, :, :] = 255 + + logits1 = net1(inten) + logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear') + logits2 = net2(inten) + logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear') + + loss1 = criteria1(logits1, lbs) + loss2 = criteria2(logits2, lbs) + loss = loss1 + loss2 + print(loss.detach().cpu()) + loss.backward() diff --git a/ConsistentID/lib/BiSeNet/makeup.py b/ConsistentID/lib/BiSeNet/makeup.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8ceee9944f4f41e97027b2c1f57bbbad912036 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/makeup.py @@ -0,0 +1,129 @@ +import cv2 +import numpy as np +from skimage.filters import gaussian + + +def sharpen(img): + img = img * 1.0 + gauss_out = gaussian(img, sigma=5, multichannel=True) + + alpha = 1.5 + img_out = (img - gauss_out) * alpha + img + + img_out = img_out / 255.0 + + mask_1 = img_out < 0 + mask_2 = img_out > 1 + + img_out = img_out * (1 - mask_1) + img_out = img_out * (1 - mask_2) + mask_2 + img_out = np.clip(img_out, 0, 1) + img_out = img_out * 255 + return np.array(img_out, dtype=np.uint8) + + +def hair(image, parsing, part=17, color=[230, 50, 20]): + b, g, r = color #[10, 50, 250] # [10, 250, 10] + tar_color = np.zeros_like(image) + tar_color[:, :, 0] = b + tar_color[:, :, 1] = g + tar_color[:, :, 2] = r + + image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV) + + if part == 12 or part == 13: + image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2] + else: + image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1] + + changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR) + + if part == 17: + changed = sharpen(changed) + + changed[parsing != part] = image[parsing != part] + # changed = cv2.resize(changed, (512, 512)) + return changed + +# +# def lip(image, parsing, part=17, color=[230, 50, 20]): +# b, g, r = color #[10, 50, 250] # [10, 250, 10] +# tar_color = np.zeros_like(image) +# tar_color[:, :, 0] = b +# tar_color[:, :, 1] = g +# tar_color[:, :, 2] = r +# +# image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab) +# il, ia, ib = cv2.split(image_lab) +# +# tar_lab = cv2.cvtColor(tar_color, cv2.COLOR_BGR2Lab) +# tl, ta, tb = cv2.split(tar_lab) +# +# image_lab[:, :, 0] = np.clip(il - np.mean(il) + tl, 0, 100) +# image_lab[:, :, 1] = np.clip(ia - np.mean(ia) + ta, -127, 128) +# image_lab[:, :, 2] = np.clip(ib - np.mean(ib) + tb, -127, 128) +# +# +# changed = cv2.cvtColor(image_lab, cv2.COLOR_Lab2BGR) +# +# if part == 17: +# changed = sharpen(changed) +# +# changed[parsing != part] = image[parsing != part] +# # changed = cv2.resize(changed, (512, 512)) +# return changed + + +if __name__ == '__main__': + # 1 face + # 10 nose + # 11 teeth + # 12 upper lip + # 13 lower lip + # 17 hair + num = 116 + table = { + 'hair': 17, + 'upper_lip': 12, + 'lower_lip': 13 + } + image_path = '/home/zll/data/CelebAMask-HQ/test-img/{}.jpg'.format(num) + parsing_path = 'res/test_res/{}.png'.format(num) + + image = cv2.imread(image_path) + ori = image.copy() + parsing = np.array(cv2.imread(parsing_path, 0)) + parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST) + + parts = [table['hair'], table['upper_lip'], table['lower_lip']] + # colors = [[20, 20, 200], [100, 100, 230], [100, 100, 230]] + colors = [[100, 200, 100]] + for part, color in zip(parts, colors): + image = hair(image, parsing, part, color) + cv2.imwrite('res/makeup/116_ori.png', cv2.resize(ori, (512, 512))) + cv2.imwrite('res/makeup/116_2.png', cv2.resize(image, (512, 512))) + + cv2.imshow('image', cv2.resize(ori, (512, 512))) + cv2.imshow('color', cv2.resize(image, (512, 512))) + + # cv2.imshow('image', ori) + # cv2.imshow('color', image) + + cv2.waitKey(0) + cv2.destroyAllWindows() + + + + + + + + + + + + + + + diff --git a/ConsistentID/lib/BiSeNet/makeup/116_1.png b/ConsistentID/lib/BiSeNet/makeup/116_1.png new file mode 100644 index 0000000000000000000000000000000000000000..97f550ea0ec878406bb4117e3f9d3d63e9b34f1a --- /dev/null +++ b/ConsistentID/lib/BiSeNet/makeup/116_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee885f179c23b9c3438ef306a90ab69811ccf73c21aaaafd2209e934f5022f23 +size 532494 diff --git a/ConsistentID/lib/BiSeNet/makeup/116_3.png b/ConsistentID/lib/BiSeNet/makeup/116_3.png new file mode 100644 index 0000000000000000000000000000000000000000..8f5ad7033b5d0211ca171f20186f11e644f69c00 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/makeup/116_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b303bf115e171a5a80212979ab9f2aa052424cd1e76f700b2bf8838802fd6ca +size 531600 diff --git a/ConsistentID/lib/BiSeNet/makeup/116_lip_ori.png b/ConsistentID/lib/BiSeNet/makeup/116_lip_ori.png new file mode 100644 index 0000000000000000000000000000000000000000..0c697a25aba0d904c3d95a8fefd575f05c6aafe6 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/makeup/116_lip_ori.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:633a200b71fa9763f01fc8f53736e252ad52cd0337276f9165a81ae3bd74aa43 +size 532221 diff --git a/ConsistentID/lib/BiSeNet/makeup/116_ori.png b/ConsistentID/lib/BiSeNet/makeup/116_ori.png new file mode 100644 index 0000000000000000000000000000000000000000..65cbcf5d856c10d30689ad22b3c6bb61da98d419 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/makeup/116_ori.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2b820ddad837598fd7108489a839201172ca8c9609b887c5b7fbf15c207e519 +size 471886 diff --git a/ConsistentID/lib/BiSeNet/model.py b/ConsistentID/lib/BiSeNet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..54ecdcf2a50e553e259eb17883c8f2148960b4cc --- /dev/null +++ b/ConsistentID/lib/BiSeNet/model.py @@ -0,0 +1,282 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .resnet import Resnet18 +# from modules.bn import InPlaceABNSync as BatchNorm2d + + +class ConvBNReLU(nn.Module): + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, + out_chan, + kernel_size = ks, + stride = stride, + padding = padding, + bias = False) + self.bn = nn.BatchNorm2d(out_chan) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + +class BiSeNetOutput(nn.Module): + def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = self.conv_out(x) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class AttentionRefinementModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + self.init_weight() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + +class ContextPath(nn.Module): + def __init__(self, *args, **kwargs): + super(ContextPath, self).__init__() + self.resnet = Resnet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + self.init_weight() + + def forward(self, x): + H0, W0 = x.size()[2:] + feat8, feat16, feat32 = self.resnet(x) + H8, W8 = feat8.size()[2:] + H16, W16 = feat16.size()[2:] + H32, W32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (H32, W32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +### This is not used, since I replace this with the resnet feature with the same size +class SpatialPath(nn.Module): + def __init__(self, *args, **kwargs): + super(SpatialPath, self).__init__() + self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) + self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) + self.init_weight() + + def forward(self, x): + feat = self.conv1(x) + feat = self.conv2(feat) + feat = self.conv3(feat) + feat = self.conv_out(feat) + return feat + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class FeatureFusionModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, + out_chan//4, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.conv2 = nn.Conv2d(out_chan//4, + out_chan, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + self.init_weight() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class BiSeNet(nn.Module): + def __init__(self, n_classes, *args, **kwargs): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + ## here self.sp is deleted + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, n_classes) + self.conv_out16 = BiSeNetOutput(128, 64, n_classes) + self.conv_out32 = BiSeNetOutput(128, 64, n_classes) + self.init_weight() + + def forward(self, x): + H, W = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature + feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + feat_out = self.conv_out(feat_fuse) + feat_out16 = self.conv_out16(feat_cp8) + feat_out32 = self.conv_out32(feat_cp16) + + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) + return feat_out, feat_out16, feat_out32 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] + for name, child in self.named_children(): + child_wd_params, child_nowd_params = child.get_params() + if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): + lr_mul_wd_params += child_wd_params + lr_mul_nowd_params += child_nowd_params + else: + wd_params += child_wd_params + nowd_params += child_nowd_params + return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params + + +if __name__ == "__main__": + net = BiSeNet(19) + net.cuda() + net.eval() + in_ten = torch.randn(16, 3, 640, 480).cuda() + out, out16, out32 = net(in_ten) + print(out.shape) + + net.get_params() diff --git a/ConsistentID/lib/BiSeNet/modules/__init__.py b/ConsistentID/lib/BiSeNet/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a098dee5911f3613d320d23db37bc401cf57fa4 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/__init__.py @@ -0,0 +1,5 @@ +from .bn import ABN, InPlaceABN, InPlaceABNSync +from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE +from .misc import GlobalAvgPool2d, SingleGPU +from .residual import IdentityResidualBlock +from .dense import DenseModule diff --git a/ConsistentID/lib/BiSeNet/modules/bn.py b/ConsistentID/lib/BiSeNet/modules/bn.py new file mode 100644 index 0000000000000000000000000000000000000000..cd3928bccfd3f70233414d837876b323217864c8 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/bn.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +import torch.nn.functional as functional + +try: + from queue import Queue +except ImportError: + from Queue import Queue + +from .functions import * + + +class ABN(nn.Module): + """Activated Batch Normalization + + This gathers a `BatchNorm2d` and an activation function in a single module + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): + """Creates an Activated Batch Normalization module + + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics as. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + activation : str + Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. + slope : float + Negative slope for the `leaky_relu` activation. + """ + super(ABN, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.momentum = momentum + self.activation = activation + self.slope = slope + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.running_mean, 0) + nn.init.constant_(self.running_var, 1) + if self.affine: + nn.init.constant_(self.weight, 1) + nn.init.constant_(self.bias, 0) + + def forward(self, x): + x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + if self.activation == ACT_RELU: + return functional.relu(x, inplace=True) + elif self.activation == ACT_LEAKY_RELU: + return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) + elif self.activation == ACT_ELU: + return functional.elu(x, inplace=True) + else: + return x + + def __repr__(self): + rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ + ' affine={affine}, activation={activation}' + if self.activation == "leaky_relu": + rep += ', slope={slope})' + else: + rep += ')' + return rep.format(name=self.__class__.__name__, **self.__dict__) + + +class InPlaceABN(ABN): + """InPlace Activated Batch Normalization""" + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): + """Creates an InPlace Activated Batch Normalization module + + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics as. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + activation : str + Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. + slope : float + Negative slope for the `leaky_relu` activation. + """ + super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope) + + def forward(self, x): + return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var, + self.training, self.momentum, self.eps, self.activation, self.slope) + + +class InPlaceABNSync(ABN): + """InPlace Activated Batch Normalization with cross-GPU synchronization + This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`. + """ + + def forward(self, x): + return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var, + self.training, self.momentum, self.eps, self.activation, self.slope) + + def __repr__(self): + rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ + ' affine={affine}, activation={activation}' + if self.activation == "leaky_relu": + rep += ', slope={slope})' + else: + rep += ')' + return rep.format(name=self.__class__.__name__, **self.__dict__) + + diff --git a/ConsistentID/lib/BiSeNet/modules/deeplab.py b/ConsistentID/lib/BiSeNet/modules/deeplab.py new file mode 100644 index 0000000000000000000000000000000000000000..fd25b78369b27ef02c183a0b17b9bf8354c5f7c3 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/deeplab.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +import torch.nn.functional as functional + +from models._util import try_index +from .bn import ABN + + +class DeeplabV3(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels=256, + dilations=(12, 24, 36), + norm_act=ABN, + pooling_size=None): + super(DeeplabV3, self).__init__() + self.pooling_size = pooling_size + + self.map_convs = nn.ModuleList([ + nn.Conv2d(in_channels, hidden_channels, 1, bias=False), + nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]), + nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]), + nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2]) + ]) + self.map_bn = norm_act(hidden_channels * 4) + + self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False) + self.global_pooling_bn = norm_act(hidden_channels) + + self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False) + self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False) + self.red_bn = norm_act(out_channels) + + self.reset_parameters(self.map_bn.activation, self.map_bn.slope) + + def reset_parameters(self, activation, slope): + gain = nn.init.calculate_gain(activation, slope) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain) + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, ABN): + if hasattr(m, "weight") and m.weight is not None: + nn.init.constant_(m.weight, 1) + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + # Map convolutions + out = torch.cat([m(x) for m in self.map_convs], dim=1) + out = self.map_bn(out) + out = self.red_conv(out) + + # Global pooling + pool = self._global_pooling(x) + pool = self.global_pooling_conv(pool) + pool = self.global_pooling_bn(pool) + pool = self.pool_red_conv(pool) + if self.training or self.pooling_size is None: + pool = pool.repeat(1, 1, x.size(2), x.size(3)) + + out += pool + out = self.red_bn(out) + return out + + def _global_pooling(self, x): + if self.training or self.pooling_size is None: + pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1) + pool = pool.view(x.size(0), x.size(1), 1, 1) + else: + pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]), + min(try_index(self.pooling_size, 1), x.shape[3])) + padding = ( + (pooling_size[1] - 1) // 2, + (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1, + (pooling_size[0] - 1) // 2, + (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1 + ) + + pool = functional.avg_pool2d(x, pooling_size, stride=1) + pool = functional.pad(pool, pad=padding, mode="replicate") + return pool diff --git a/ConsistentID/lib/BiSeNet/modules/dense.py b/ConsistentID/lib/BiSeNet/modules/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..9638d6e86d2ae838550fefa9002a984af52e6cc8 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/dense.py @@ -0,0 +1,42 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn + +from .bn import ABN + + +class DenseModule(nn.Module): + def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1): + super(DenseModule, self).__init__() + self.in_channels = in_channels + self.growth = growth + self.layers = layers + + self.convs1 = nn.ModuleList() + self.convs3 = nn.ModuleList() + for i in range(self.layers): + self.convs1.append(nn.Sequential(OrderedDict([ + ("bn", norm_act(in_channels)), + ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False)) + ]))) + self.convs3.append(nn.Sequential(OrderedDict([ + ("bn", norm_act(self.growth * bottleneck_factor)), + ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False, + dilation=dilation)) + ]))) + in_channels += self.growth + + @property + def out_channels(self): + return self.in_channels + self.growth * self.layers + + def forward(self, x): + inputs = [x] + for i in range(self.layers): + x = torch.cat(inputs, dim=1) + x = self.convs1[i](x) + x = self.convs3[i](x) + inputs += [x] + + return torch.cat(inputs, dim=1) diff --git a/ConsistentID/lib/BiSeNet/modules/functions.py b/ConsistentID/lib/BiSeNet/modules/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..093615ff4f383e95712c96b57286338ec3b28f3b --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/functions.py @@ -0,0 +1,234 @@ +from os import path +import torch +import torch.distributed as dist +import torch.autograd as autograd +import torch.cuda.comm as comm +from torch.autograd.function import once_differentiable +from torch.utils.cpp_extension import load + +_src_path = path.join(path.dirname(path.abspath(__file__)), "src") +_backend = load(name="inplace_abn", + extra_cflags=["-O3"], + sources=[path.join(_src_path, f) for f in [ + "inplace_abn.cpp", + "inplace_abn_cpu.cpp", + "inplace_abn_cuda.cu", + "inplace_abn_cuda_half.cu" + ]], + extra_cuda_cflags=["--expt-extended-lambda"]) + +# Activation names +ACT_RELU = "relu" +ACT_LEAKY_RELU = "leaky_relu" +ACT_ELU = "elu" +ACT_NONE = "none" + + +def _check(fn, *args, **kwargs): + success = fn(*args, **kwargs) + if not success: + raise RuntimeError("CUDA Error encountered in {}".format(fn)) + + +def _broadcast_shape(x): + out_size = [] + for i, s in enumerate(x.size()): + if i != 1: + out_size.append(1) + else: + out_size.append(s) + return out_size + + +def _reduce(x): + if len(x.size()) == 2: + return x.sum(dim=0) + else: + n, c = x.size()[0:2] + return x.contiguous().view((n, c, -1)).sum(2).sum(0) + + +def _count_samples(x): + count = 1 + for i, s in enumerate(x.size()): + if i != 1: + count *= s + return count + + +def _act_forward(ctx, x): + if ctx.activation == ACT_LEAKY_RELU: + _backend.leaky_relu_forward(x, ctx.slope) + elif ctx.activation == ACT_ELU: + _backend.elu_forward(x) + elif ctx.activation == ACT_NONE: + pass + + +def _act_backward(ctx, x, dx): + if ctx.activation == ACT_LEAKY_RELU: + _backend.leaky_relu_backward(x, dx, ctx.slope) + elif ctx.activation == ACT_ELU: + _backend.elu_backward(x, dx) + elif ctx.activation == ACT_NONE: + pass + + +class InPlaceABN(autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, running_mean, running_var, + training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): + # Save context + ctx.training = training + ctx.momentum = momentum + ctx.eps = eps + ctx.activation = activation + ctx.slope = slope + ctx.affine = weight is not None and bias is not None + + # Prepare inputs + count = _count_samples(x) + x = x.contiguous() + weight = weight.contiguous() if ctx.affine else x.new_empty(0) + bias = bias.contiguous() if ctx.affine else x.new_empty(0) + + if ctx.training: + mean, var = _backend.mean_var(x) + + # Update running stats + running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) + running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) + + # Mark in-place modified tensors + ctx.mark_dirty(x, running_mean, running_var) + else: + mean, var = running_mean.contiguous(), running_var.contiguous() + ctx.mark_dirty(x) + + # BN forward + activation + _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) + _act_forward(ctx, x) + + # Output + ctx.var = var + ctx.save_for_backward(x, var, weight, bias) + return x + + @staticmethod + @once_differentiable + def backward(ctx, dz): + z, var, weight, bias = ctx.saved_tensors + dz = dz.contiguous() + + # Undo activation + _act_backward(ctx, z, dz) + + if ctx.training: + edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) + else: + # TODO: implement simplified CUDA backward for inference mode + edz = dz.new_zeros(dz.size(1)) + eydz = dz.new_zeros(dz.size(1)) + + dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) + dweight = eydz * weight.sign() if ctx.affine else None + dbias = edz if ctx.affine else None + + return dx, dweight, dbias, None, None, None, None, None, None, None + +class InPlaceABNSync(autograd.Function): + @classmethod + def forward(cls, ctx, x, weight, bias, running_mean, running_var, + training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True): + # Save context + ctx.training = training + ctx.momentum = momentum + ctx.eps = eps + ctx.activation = activation + ctx.slope = slope + ctx.affine = weight is not None and bias is not None + + # Prepare inputs + ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + #count = _count_samples(x) + batch_size = x.new_tensor([x.shape[0]],dtype=torch.long) + + x = x.contiguous() + weight = weight.contiguous() if ctx.affine else x.new_empty(0) + bias = bias.contiguous() if ctx.affine else x.new_empty(0) + + if ctx.training: + mean, var = _backend.mean_var(x) + if ctx.world_size>1: + # get global batch size + if equal_batches: + batch_size *= ctx.world_size + else: + dist.all_reduce(batch_size, dist.ReduceOp.SUM) + + ctx.factor = x.shape[0]/float(batch_size.item()) + + mean_all = mean.clone() * ctx.factor + dist.all_reduce(mean_all, dist.ReduceOp.SUM) + + var_all = (var + (mean - mean_all) ** 2) * ctx.factor + dist.all_reduce(var_all, dist.ReduceOp.SUM) + + mean = mean_all + var = var_all + + # Update running stats + running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) + count = batch_size.item() * x.view(x.shape[0],x.shape[1],-1).shape[-1] + running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1))) + + # Mark in-place modified tensors + ctx.mark_dirty(x, running_mean, running_var) + else: + mean, var = running_mean.contiguous(), running_var.contiguous() + ctx.mark_dirty(x) + + # BN forward + activation + _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) + _act_forward(ctx, x) + + # Output + ctx.var = var + ctx.save_for_backward(x, var, weight, bias) + return x + + @staticmethod + @once_differentiable + def backward(ctx, dz): + z, var, weight, bias = ctx.saved_tensors + dz = dz.contiguous() + + # Undo activation + _act_backward(ctx, z, dz) + + if ctx.training: + edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) + edz_local = edz.clone() + eydz_local = eydz.clone() + + if ctx.world_size>1: + edz *= ctx.factor + dist.all_reduce(edz, dist.ReduceOp.SUM) + + eydz *= ctx.factor + dist.all_reduce(eydz, dist.ReduceOp.SUM) + else: + edz_local = edz = dz.new_zeros(dz.size(1)) + eydz_local = eydz = dz.new_zeros(dz.size(1)) + + dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) + dweight = eydz_local * weight.sign() if ctx.affine else None + dbias = edz_local if ctx.affine else None + + return dx, dweight, dbias, None, None, None, None, None, None, None + +inplace_abn = InPlaceABN.apply +inplace_abn_sync = InPlaceABNSync.apply + +__all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"] diff --git a/ConsistentID/lib/BiSeNet/modules/misc.py b/ConsistentID/lib/BiSeNet/modules/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..3c50b69b38c950801baacba8b3684ffd23aef08b --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/misc.py @@ -0,0 +1,21 @@ +import torch.nn as nn +import torch +import torch.distributed as dist + +class GlobalAvgPool2d(nn.Module): + def __init__(self): + """Global average pooling over the input's spatial dimensions""" + super(GlobalAvgPool2d, self).__init__() + + def forward(self, inputs): + in_size = inputs.size() + return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) + +class SingleGPU(nn.Module): + def __init__(self, module): + super(SingleGPU, self).__init__() + self.module=module + + def forward(self, input): + return self.module(input.cuda(non_blocking=True)) + diff --git a/ConsistentID/lib/BiSeNet/modules/residual.py b/ConsistentID/lib/BiSeNet/modules/residual.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d51ad274f3841813c1584a0ceb60ce58979d94 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/residual.py @@ -0,0 +1,88 @@ +from collections import OrderedDict + +import torch.nn as nn + +from .bn import ABN + + +class IdentityResidualBlock(nn.Module): + def __init__(self, + in_channels, + channels, + stride=1, + dilation=1, + groups=1, + norm_act=ABN, + dropout=None): + """Configurable identity-mapping residual block + + Parameters + ---------- + in_channels : int + Number of input channels. + channels : list of int + Number of channels in the internal feature maps. Can either have two or three elements: if three construct + a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then + `3 x 3` then `1 x 1` convolutions. + stride : int + Stride of the first `3 x 3` convolution + dilation : int + Dilation to apply to the `3 x 3` convolutions. + groups : int + Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with + bottleneck blocks. + norm_act : callable + Function to create normalization / activation Module. + dropout: callable + Function to create Dropout Module. + """ + super(IdentityResidualBlock, self).__init__() + + # Check parameters for inconsistencies + if len(channels) != 2 and len(channels) != 3: + raise ValueError("channels must contain either two or three values") + if len(channels) == 2 and groups != 1: + raise ValueError("groups > 1 are only valid if len(channels) == 3") + + is_bottleneck = len(channels) == 3 + need_proj_conv = stride != 1 or in_channels != channels[-1] + + self.bn1 = norm_act(in_channels) + if not is_bottleneck: + layers = [ + ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False, + dilation=dilation)), + ("bn2", norm_act(channels[0])), + ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, + dilation=dilation)) + ] + if dropout is not None: + layers = layers[0:2] + [("dropout", dropout())] + layers[2:] + else: + layers = [ + ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)), + ("bn2", norm_act(channels[0])), + ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, + groups=groups, dilation=dilation)), + ("bn3", norm_act(channels[1])), + ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)) + ] + if dropout is not None: + layers = layers[0:4] + [("dropout", dropout())] + layers[4:] + self.convs = nn.Sequential(OrderedDict(layers)) + + if need_proj_conv: + self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) + + def forward(self, x): + if hasattr(self, "proj_conv"): + bn1 = self.bn1(x) + shortcut = self.proj_conv(bn1) + else: + shortcut = x.clone() + bn1 = self.bn1(x) + + out = self.convs(bn1) + out.add_(shortcut) + + return out diff --git a/ConsistentID/lib/BiSeNet/modules/src/checks.h b/ConsistentID/lib/BiSeNet/modules/src/checks.h new file mode 100644 index 0000000000000000000000000000000000000000..e761a6fe34d0789815b588eba7e3726026e0e868 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/checks.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +// Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT +#ifndef AT_CHECK +#define AT_CHECK AT_ASSERT +#endif + +#define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor") +#define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous") + +#define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.cpp b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a6b1128cc20cbfc476134154e23e5869a92b856 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.cpp @@ -0,0 +1,95 @@ +#include + +#include + +#include "inplace_abn.h" + +std::vector mean_var(at::Tensor x) { + if (x.is_cuda()) { + if (x.type().scalarType() == at::ScalarType::Half) { + return mean_var_cuda_h(x); + } else { + return mean_var_cuda(x); + } + } else { + return mean_var_cpu(x); + } +} + +at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + if (x.is_cuda()) { + if (x.type().scalarType() == at::ScalarType::Half) { + return forward_cuda_h(x, mean, var, weight, bias, affine, eps); + } else { + return forward_cuda(x, mean, var, weight, bias, affine, eps); + } + } else { + return forward_cpu(x, mean, var, weight, bias, affine, eps); + } +} + +std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + if (z.is_cuda()) { + if (z.type().scalarType() == at::ScalarType::Half) { + return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps); + } else { + return edz_eydz_cuda(z, dz, weight, bias, affine, eps); + } + } else { + return edz_eydz_cpu(z, dz, weight, bias, affine, eps); + } +} + +at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps) { + if (z.is_cuda()) { + if (z.type().scalarType() == at::ScalarType::Half) { + return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps); + } else { + return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps); + } + } else { + return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps); + } +} + +void leaky_relu_forward(at::Tensor z, float slope) { + at::leaky_relu_(z, slope); +} + +void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) { + if (z.is_cuda()) { + if (z.type().scalarType() == at::ScalarType::Half) { + return leaky_relu_backward_cuda_h(z, dz, slope); + } else { + return leaky_relu_backward_cuda(z, dz, slope); + } + } else { + return leaky_relu_backward_cpu(z, dz, slope); + } +} + +void elu_forward(at::Tensor z) { + at::elu_(z); +} + +void elu_backward(at::Tensor z, at::Tensor dz) { + if (z.is_cuda()) { + return elu_backward_cuda(z, dz); + } else { + return elu_backward_cpu(z, dz); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("mean_var", &mean_var, "Mean and variance computation"); + m.def("forward", &forward, "In-place forward computation"); + m.def("edz_eydz", &edz_eydz, "First part of backward computation"); + m.def("backward", &backward, "Second part of backward computation"); + m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation"); + m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion"); + m.def("elu_forward", &elu_forward, "Elu forward computation"); + m.def("elu_backward", &elu_backward, "Elu backward computation and inversion"); +} diff --git a/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.h b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.h new file mode 100644 index 0000000000000000000000000000000000000000..17afd1196449ecb6376f28961e54b55e1537492f --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.h @@ -0,0 +1,88 @@ +#pragma once + +#include + +#include + +std::vector mean_var_cpu(at::Tensor x); +std::vector mean_var_cuda(at::Tensor x); +std::vector mean_var_cuda_h(at::Tensor x); + +at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps); +at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps); +at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps); + +std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps); +std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps); +std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps); + +at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps); +at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps); +at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps); + +void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope); +void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope); +void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope); + +void elu_backward_cpu(at::Tensor z, at::Tensor dz); +void elu_backward_cuda(at::Tensor z, at::Tensor dz); + +static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) { + num = x.size(0); + chn = x.size(1); + sp = 1; + for (int64_t i = 2; i < x.ndimension(); ++i) + sp *= x.size(i); +} + +/* + * Specialized CUDA reduction functions for BN + */ +#ifdef __CUDACC__ + +#include "utils/cuda.cuh" + +template +__device__ T reduce(Op op, int plane, int N, int S) { + T sum = (T)0; + for (int batch = 0; batch < N; ++batch) { + for (int x = threadIdx.x; x < S; x += blockDim.x) { + sum += op(batch, plane, x); + } + } + + // sum over NumThreads within a warp + sum = warpSum(sum); + + // 'transpose', and reduce within warp again + __shared__ T shared[32]; + __syncthreads(); + if (threadIdx.x % WARP_SIZE == 0) { + shared[threadIdx.x / WARP_SIZE] = sum; + } + if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { + // zero out the other entries in shared + shared[threadIdx.x] = (T)0; + } + __syncthreads(); + if (threadIdx.x / WARP_SIZE == 0) { + sum = warpSum(shared[threadIdx.x]); + if (threadIdx.x == 0) { + shared[0] = sum; + } + } + __syncthreads(); + + // Everyone picks it up, should be broadcast into the whole gradInput + return shared[0]; +} +#endif diff --git a/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cpu.cpp b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffc6d38c52ea31661b8dd438dc3fe1958f50b61e --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cpu.cpp @@ -0,0 +1,119 @@ +#include + +#include + +#include "utils/checks.h" +#include "inplace_abn.h" + +at::Tensor reduce_sum(at::Tensor x) { + if (x.ndimension() == 2) { + return x.sum(0); + } else { + auto x_view = x.view({x.size(0), x.size(1), -1}); + return x_view.sum(-1).sum(0); + } +} + +at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { + if (x.ndimension() == 2) { + return v; + } else { + std::vector broadcast_size = {1, -1}; + for (int64_t i = 2; i < x.ndimension(); ++i) + broadcast_size.push_back(1); + + return v.view(broadcast_size); + } +} + +int64_t count(at::Tensor x) { + int64_t count = x.size(0); + for (int64_t i = 2; i < x.ndimension(); ++i) + count *= x.size(i); + + return count; +} + +at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) { + if (affine) { + return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z); + } else { + return z; + } +} + +std::vector mean_var_cpu(at::Tensor x) { + auto num = count(x); + auto mean = reduce_sum(x) / num; + auto diff = x - broadcast_to(mean, x); + auto var = reduce_sum(diff.pow(2)) / num; + + return {mean, var}; +} + +at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var); + auto mul = at::rsqrt(var + eps) * gamma; + + x.sub_(broadcast_to(mean, x)); + x.mul_(broadcast_to(mul, x)); + if (affine) x.add_(broadcast_to(bias, x)); + + return x; +} + +std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + auto edz = reduce_sum(dz); + auto y = invert_affine(z, weight, bias, affine, eps); + auto eydz = reduce_sum(y * dz); + + return {edz, eydz}; +} + +at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps) { + auto y = invert_affine(z, weight, bias, affine, eps); + auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps); + + auto num = count(z); + auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz); + return dx; +} + +void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) { + CHECK_CPU_INPUT(z); + CHECK_CPU_INPUT(dz); + + AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] { + int64_t count = z.numel(); + auto *_z = z.data(); + auto *_dz = dz.data(); + + for (int64_t i = 0; i < count; ++i) { + if (_z[i] < 0) { + _z[i] *= 1 / slope; + _dz[i] *= slope; + } + } + })); +} + +void elu_backward_cpu(at::Tensor z, at::Tensor dz) { + CHECK_CPU_INPUT(z); + CHECK_CPU_INPUT(dz); + + AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] { + int64_t count = z.numel(); + auto *_z = z.data(); + auto *_dz = dz.data(); + + for (int64_t i = 0; i < count; ++i) { + if (_z[i] < 0) { + _z[i] = log1p(_z[i]); + _dz[i] *= (_z[i] + 1.f); + } + } + })); +} diff --git a/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda.cu b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..b157b06d47173d1645c6a40c89f564b737e84d43 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda.cu @@ -0,0 +1,333 @@ +#include + +#include +#include + +#include + +#include "utils/checks.h" +#include "utils/cuda.cuh" +#include "inplace_abn.h" + +#include + +// Operations for reduce +template +struct SumOp { + __device__ SumOp(const T *t, int c, int s) + : tensor(t), chn(c), sp(s) {} + __device__ __forceinline__ T operator()(int batch, int plane, int n) { + return tensor[(batch * chn + plane) * sp + n]; + } + const T *tensor; + const int chn; + const int sp; +}; + +template +struct VarOp { + __device__ VarOp(T m, const T *t, int c, int s) + : mean(m), tensor(t), chn(c), sp(s) {} + __device__ __forceinline__ T operator()(int batch, int plane, int n) { + T val = tensor[(batch * chn + plane) * sp + n]; + return (val - mean) * (val - mean); + } + const T mean; + const T *tensor; + const int chn; + const int sp; +}; + +template +struct GradOp { + __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s) + : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} + __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { + T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight; + T _dz = dz[(batch * chn + plane) * sp + n]; + return Pair(_dz, _y * _dz); + } + const T weight; + const T bias; + const T *z; + const T *dz; + const int chn; + const int sp; +}; + +/*********** + * mean_var + ***********/ + +template +__global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) { + int plane = blockIdx.x; + T norm = T(1) / T(num * sp); + + T _mean = reduce>(SumOp(x, chn, sp), plane, num, sp) * norm; + __syncthreads(); + T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, sp) * norm; + + if (threadIdx.x == 0) { + mean[plane] = _mean; + var[plane] = _var; + } +} + +std::vector mean_var_cuda(at::Tensor x) { + CHECK_CUDA_INPUT(x); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + + // Prepare output tensors + auto mean = at::empty({chn}, x.options()); + auto var = at::empty({chn}, x.options()); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] { + mean_var_kernel<<>>( + x.data(), + mean.data(), + var.data(), + num, chn, sp); + })); + + return {mean, var}; +} + +/********** + * forward + **********/ + +template +__global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias, + bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + T _mean = mean[plane]; + T _var = var[plane]; + T _weight = affine ? abs(weight[plane]) + eps : T(1); + T _bias = affine ? bias[plane] : T(0); + + T mul = rsqrt(_var + eps) * _weight; + + for (int batch = 0; batch < num; ++batch) { + for (int n = threadIdx.x; n < sp; n += blockDim.x) { + T _x = x[(batch * chn + plane) * sp + n]; + T _y = (_x - _mean) * mul + _bias; + + x[(batch * chn + plane) * sp + n] = _y; + } + } +} + +at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + CHECK_CUDA_INPUT(x); + CHECK_CUDA_INPUT(mean); + CHECK_CUDA_INPUT(var); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] { + forward_kernel<<>>( + x.data(), + mean.data(), + var.data(), + weight.data(), + bias.data(), + affine, eps, num, chn, sp); + })); + + return x; +} + +/*********** + * edz_eydz + ***********/ + +template +__global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias, + T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + T _weight = affine ? abs(weight[plane]) + eps : 1.f; + T _bias = affine ? bias[plane] : 0.f; + + Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, sp); + __syncthreads(); + + if (threadIdx.x == 0) { + edz[plane] = res.v1; + eydz[plane] = res.v2; + } +} + +std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(z, num, chn, sp); + + auto edz = at::empty({chn}, z.options()); + auto eydz = at::empty({chn}, z.options()); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] { + edz_eydz_kernel<<>>( + z.data(), + dz.data(), + weight.data(), + bias.data(), + edz.data(), + eydz.data(), + affine, eps, num, chn, sp); + })); + + return {edz, eydz}; +} + +/*********** + * backward + ***********/ + +template +__global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz, + const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + T _weight = affine ? abs(weight[plane]) + eps : 1.f; + T _bias = affine ? bias[plane] : 0.f; + T _var = var[plane]; + T _edz = edz[plane]; + T _eydz = eydz[plane]; + + T _mul = _weight * rsqrt(_var + eps); + T count = T(num * sp); + + for (int batch = 0; batch < num; ++batch) { + for (int n = threadIdx.x; n < sp; n += blockDim.x) { + T _dz = dz[(batch * chn + plane) * sp + n]; + T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight; + + dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul; + } + } +} + +at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + CHECK_CUDA_INPUT(var); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + CHECK_CUDA_INPUT(edz); + CHECK_CUDA_INPUT(eydz); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(z, num, chn, sp); + + auto dx = at::zeros_like(z); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] { + backward_kernel<<>>( + z.data(), + dz.data(), + var.data(), + weight.data(), + bias.data(), + edz.data(), + eydz.data(), + dx.data(), + affine, eps, num, chn, sp); + })); + + return dx; +} + +/************** + * activations + **************/ + +template +inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) { + // Create thrust pointers + thrust::device_ptr th_z = thrust::device_pointer_cast(z); + thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); + + auto stream = at::cuda::getCurrentCUDAStream(); + thrust::transform_if(thrust::cuda::par.on(stream), + th_dz, th_dz + count, th_z, th_dz, + [slope] __device__ (const T& dz) { return dz * slope; }, + [] __device__ (const T& z) { return z < 0; }); + thrust::transform_if(thrust::cuda::par.on(stream), + th_z, th_z + count, th_z, + [slope] __device__ (const T& z) { return z / slope; }, + [] __device__ (const T& z) { return z < 0; }); +} + +void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + + int64_t count = z.numel(); + + AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { + leaky_relu_backward_impl(z.data(), dz.data(), slope, count); + })); +} + +template +inline void elu_backward_impl(T *z, T *dz, int64_t count) { + // Create thrust pointers + thrust::device_ptr th_z = thrust::device_pointer_cast(z); + thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); + + auto stream = at::cuda::getCurrentCUDAStream(); + thrust::transform_if(thrust::cuda::par.on(stream), + th_dz, th_dz + count, th_z, th_z, th_dz, + [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); }, + [] __device__ (const T& z) { return z < 0; }); + thrust::transform_if(thrust::cuda::par.on(stream), + th_z, th_z + count, th_z, + [] __device__ (const T& z) { return log1p(z); }, + [] __device__ (const T& z) { return z < 0; }); +} + +void elu_backward_cuda(at::Tensor z, at::Tensor dz) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + + int64_t count = z.numel(); + + AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { + elu_backward_impl(z.data(), dz.data(), count); + })); +} diff --git a/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda_half.cu b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda_half.cu new file mode 100644 index 0000000000000000000000000000000000000000..bb63e73f9d90179e5bd5dae5579c4844da9c25e2 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda_half.cu @@ -0,0 +1,275 @@ +#include + +#include + +#include + +#include "utils/checks.h" +#include "utils/cuda.cuh" +#include "inplace_abn.h" + +#include + +// Operations for reduce +struct SumOpH { + __device__ SumOpH(const half *t, int c, int s) + : tensor(t), chn(c), sp(s) {} + __device__ __forceinline__ float operator()(int batch, int plane, int n) { + return __half2float(tensor[(batch * chn + plane) * sp + n]); + } + const half *tensor; + const int chn; + const int sp; +}; + +struct VarOpH { + __device__ VarOpH(float m, const half *t, int c, int s) + : mean(m), tensor(t), chn(c), sp(s) {} + __device__ __forceinline__ float operator()(int batch, int plane, int n) { + const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]); + return (t - mean) * (t - mean); + } + const float mean; + const half *tensor; + const int chn; + const int sp; +}; + +struct GradOpH { + __device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s) + : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} + __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { + float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight; + float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); + return Pair(_dz, _y * _dz); + } + const float weight; + const float bias; + const half *z; + const half *dz; + const int chn; + const int sp; +}; + +/*********** + * mean_var + ***********/ + +__global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) { + int plane = blockIdx.x; + float norm = 1.f / static_cast(num * sp); + + float _mean = reduce(SumOpH(x, chn, sp), plane, num, sp) * norm; + __syncthreads(); + float _var = reduce(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm; + + if (threadIdx.x == 0) { + mean[plane] = _mean; + var[plane] = _var; + } +} + +std::vector mean_var_cuda_h(at::Tensor x) { + CHECK_CUDA_INPUT(x); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + + // Prepare output tensors + auto mean = at::empty({chn},x.options().dtype(at::kFloat)); + auto var = at::empty({chn},x.options().dtype(at::kFloat)); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + mean_var_kernel_h<<>>( + reinterpret_cast(x.data()), + mean.data(), + var.data(), + num, chn, sp); + + return {mean, var}; +} + +/********** + * forward + **********/ + +__global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias, + bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + const float _mean = mean[plane]; + const float _var = var[plane]; + const float _weight = affine ? abs(weight[plane]) + eps : 1.f; + const float _bias = affine ? bias[plane] : 0.f; + + const float mul = rsqrt(_var + eps) * _weight; + + for (int batch = 0; batch < num; ++batch) { + for (int n = threadIdx.x; n < sp; n += blockDim.x) { + half *x_ptr = x + (batch * chn + plane) * sp + n; + float _x = __half2float(*x_ptr); + float _y = (_x - _mean) * mul + _bias; + + *x_ptr = __float2half(_y); + } + } +} + +at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + CHECK_CUDA_INPUT(x); + CHECK_CUDA_INPUT(mean); + CHECK_CUDA_INPUT(var); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + forward_kernel_h<<>>( + reinterpret_cast(x.data()), + mean.data(), + var.data(), + weight.data(), + bias.data(), + affine, eps, num, chn, sp); + + return x; +} + +__global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias, + float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + float _weight = affine ? abs(weight[plane]) + eps : 1.f; + float _bias = affine ? bias[plane] : 0.f; + + Pair res = reduce, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp); + __syncthreads(); + + if (threadIdx.x == 0) { + edz[plane] = res.v1; + eydz[plane] = res.v2; + } +} + +std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(z, num, chn, sp); + + auto edz = at::empty({chn},z.options().dtype(at::kFloat)); + auto eydz = at::empty({chn},z.options().dtype(at::kFloat)); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + edz_eydz_kernel_h<<>>( + reinterpret_cast(z.data()), + reinterpret_cast(dz.data()), + weight.data(), + bias.data(), + edz.data(), + eydz.data(), + affine, eps, num, chn, sp); + + return {edz, eydz}; +} + +__global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz, + const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + float _weight = affine ? abs(weight[plane]) + eps : 1.f; + float _bias = affine ? bias[plane] : 0.f; + float _var = var[plane]; + float _edz = edz[plane]; + float _eydz = eydz[plane]; + + float _mul = _weight * rsqrt(_var + eps); + float count = float(num * sp); + + for (int batch = 0; batch < num; ++batch) { + for (int n = threadIdx.x; n < sp; n += blockDim.x) { + float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); + float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight; + + dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul); + } + } +} + +at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + CHECK_CUDA_INPUT(var); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + CHECK_CUDA_INPUT(edz); + CHECK_CUDA_INPUT(eydz); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(z, num, chn, sp); + + auto dx = at::zeros_like(z); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + backward_kernel_h<<>>( + reinterpret_cast(z.data()), + reinterpret_cast(dz.data()), + var.data(), + weight.data(), + bias.data(), + edz.data(), + eydz.data(), + reinterpret_cast(dx.data()), + affine, eps, num, chn, sp); + + return dx; +} + +__global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){ + float _z = __half2float(z[i]); + if (_z < 0) { + dz[i] = __float2half(__half2float(dz[i]) * slope); + z[i] = __float2half(_z / slope); + } + } +} + +void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + + int64_t count = z.numel(); + dim3 threads(getNumThreads(count)); + dim3 blocks = (count + threads.x - 1) / threads.x; + auto stream = at::cuda::getCurrentCUDAStream(); + leaky_relu_backward_impl_h<<>>( + reinterpret_cast(z.data()), + reinterpret_cast(dz.data()), + slope, count); +} + diff --git a/ConsistentID/lib/BiSeNet/modules/src/utils/checks.h b/ConsistentID/lib/BiSeNet/modules/src/utils/checks.h new file mode 100644 index 0000000000000000000000000000000000000000..e761a6fe34d0789815b588eba7e3726026e0e868 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/utils/checks.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +// Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT +#ifndef AT_CHECK +#define AT_CHECK AT_ASSERT +#endif + +#define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor") +#define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous") + +#define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/modules/src/utils/common.h b/ConsistentID/lib/BiSeNet/modules/src/utils/common.h new file mode 100644 index 0000000000000000000000000000000000000000..e8403eef8a233b75dd4bb353c16486fe1be2039a --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/utils/common.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +/* + * Functions to share code between CPU and GPU + */ + +#ifdef __CUDACC__ +// CUDA versions + +#define HOST_DEVICE __host__ __device__ +#define INLINE_HOST_DEVICE __host__ __device__ inline +#define FLOOR(x) floor(x) + +#if __CUDA_ARCH__ >= 600 +// Recent compute capabilities have block-level atomicAdd for all data types, so we use that +#define ACCUM(x,y) atomicAdd_block(&(x),(y)) +#else +// Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float +// and use the known atomicCAS-based implementation for double +template +__device__ inline data_t atomic_add(data_t *address, data_t val) { + return atomicAdd(address, val); +} + +template<> +__device__ inline double atomic_add(double *address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); +} + +#define ACCUM(x,y) atomic_add(&(x),(y)) +#endif // #if __CUDA_ARCH__ >= 600 + +#else +// CPU versions + +#define HOST_DEVICE +#define INLINE_HOST_DEVICE inline +#define FLOOR(x) std::floor(x) +#define ACCUM(x,y) (x) += (y) + +#endif // #ifdef __CUDACC__ \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/modules/src/utils/cuda.cuh b/ConsistentID/lib/BiSeNet/modules/src/utils/cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..60c0023835e02c5f7c539c28ac07b75b72df394b --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/utils/cuda.cuh @@ -0,0 +1,71 @@ +#pragma once + +/* + * General settings and functions + */ +const int WARP_SIZE = 32; +const int MAX_BLOCK_SIZE = 1024; + +static int getNumThreads(int nElem) { + int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE}; + for (int i = 0; i < 6; ++i) { + if (nElem <= threadSizes[i]) { + return threadSizes[i]; + } + } + return MAX_BLOCK_SIZE; +} + +/* + * Reduction utilities + */ +template +__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { +#if CUDART_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +__device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } + +template +struct Pair { + T v1, v2; + __device__ Pair() {} + __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} + __device__ Pair(T v) : v1(v), v2(v) {} + __device__ Pair(int v) : v1(v), v2(v) {} + __device__ Pair &operator+=(const Pair &a) { + v1 += a.v1; + v2 += a.v2; + return *this; + } +}; + +template +static __device__ __forceinline__ T warpSum(T val) { +#if __CUDA_ARCH__ >= 300 + for (int i = 0; i < getMSB(WARP_SIZE); ++i) { + val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); + } +#else + __shared__ T values[MAX_BLOCK_SIZE]; + values[threadIdx.x] = val; + __threadfence_block(); + const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; + for (int i = 1; i < WARP_SIZE; i++) { + val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; + } +#endif + return val; +} + +template +static __device__ __forceinline__ Pair warpSum(Pair value) { + value.v1 = warpSum(value.v1); + value.v2 = warpSum(value.v2); + return value; +} \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/optimizer.py b/ConsistentID/lib/BiSeNet/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c99e0645164b22f1e743ee99daadadd26a1cd80 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/optimizer.py @@ -0,0 +1,69 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import logging + +logger = logging.getLogger() + +class Optimizer(object): + def __init__(self, + model, + lr0, + momentum, + wd, + warmup_steps, + warmup_start_lr, + max_iter, + power, + *args, **kwargs): + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr + self.lr0 = lr0 + self.lr = self.lr0 + self.max_iter = float(max_iter) + self.power = power + self.it = 0 + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params() + param_list = [ + {'params': wd_params}, + {'params': nowd_params, 'weight_decay': 0}, + {'params': lr_mul_wd_params, 'lr_mul': True}, + {'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True}] + self.optim = torch.optim.SGD( + param_list, + lr = lr0, + momentum = momentum, + weight_decay = wd) + self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps) + + + def get_lr(self): + if self.it <= self.warmup_steps: + lr = self.warmup_start_lr*(self.warmup_factor**self.it) + else: + factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power + lr = self.lr0 * factor + return lr + + + def step(self): + self.lr = self.get_lr() + for pg in self.optim.param_groups: + if pg.get('lr_mul', False): + pg['lr'] = self.lr * 10 + else: + pg['lr'] = self.lr + if self.optim.defaults.get('lr_mul', False): + self.optim.defaults['lr'] = self.lr * 10 + else: + self.optim.defaults['lr'] = self.lr + self.it += 1 + self.optim.step() + if self.it == self.warmup_steps+2: + logger.info('==> warmup done, start to implement poly lr strategy') + + def zero_grad(self): + self.optim.zero_grad() + diff --git a/ConsistentID/lib/BiSeNet/prepropess_data.py b/ConsistentID/lib/BiSeNet/prepropess_data.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7ed56dd8c0372d482e6a53f323da17043bd521 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/prepropess_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import os.path as osp +import os +import cv2 +from transform import * +from PIL import Image + +face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img' +face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno' +mask_path = '/home/zll/data/CelebAMask-HQ/mask' +counter = 0 +total = 0 +for i in range(15): + + atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', + 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] + + for j in range(i * 2000, (i + 1) * 2000): + + mask = np.zeros((512, 512)) + + for l, att in enumerate(atts, 1): + total += 1 + file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png']) + path = osp.join(face_sep_mask, str(i), file_name) + + if os.path.exists(path): + counter += 1 + sep_mask = np.array(Image.open(path).convert('P')) + # print(np.unique(sep_mask)) + + mask[sep_mask == 225] = l + cv2.imwrite('{}/{}.png'.format(mask_path, j), mask) + print(j) + +print(counter, total) \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/resnet.py b/ConsistentID/lib/BiSeNet/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2bf95130e9815ba378cb6f73207068b81a04b9 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/resnet.py @@ -0,0 +1,109 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as modelzoo + +# from modules.bn import InPlaceABNSync as BatchNorm2d + +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum-1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class Resnet18(nn.Module): + def __init__(self): + super(Resnet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + self.init_weight() + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 + + def init_weight(self): + state_dict = modelzoo.load_url(resnet18_url) + self_state_dict = self.state_dict() + for k, v in state_dict.items(): + if 'fc' in k: continue + self_state_dict.update({k: v}) + self.load_state_dict(self_state_dict) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +if __name__ == "__main__": + net = Resnet18() + x = torch.randn(16, 3, 224, 224) + out = net(x) + print(out[0].size()) + print(out[1].size()) + print(out[2].size()) + net.get_params() diff --git a/ConsistentID/lib/BiSeNet/test.py b/ConsistentID/lib/BiSeNet/test.py new file mode 100644 index 0000000000000000000000000000000000000000..604a89f6e86a6a18581022620c413a43abece91b --- /dev/null +++ b/ConsistentID/lib/BiSeNet/test.py @@ -0,0 +1,90 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +from logger import setup_logger +import BiSeNet + +import torch + +import os +import os.path as osp +import numpy as np +from PIL import Image +import torchvision.transforms as transforms +import cv2 + +def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'): + # Colors for all 20 parts + part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], + [255, 0, 85], [255, 0, 170], + [0, 255, 0], [85, 255, 0], [170, 255, 0], + [0, 255, 85], [0, 255, 170], + [0, 0, 255], [85, 0, 255], [170, 0, 255], + [0, 85, 255], [0, 170, 255], + [255, 255, 0], [255, 255, 85], [255, 255, 170], + [255, 0, 255], [255, 85, 255], [255, 170, 255], + [0, 255, 255], [85, 255, 255], [170, 255, 255]] + + im = np.array(im) + vis_im = im.copy().astype(np.uint8) + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 + + num_of_class = np.max(vis_parsing_anno) + + for pi in range(1, num_of_class + 1): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + # print(vis_parsing_anno_color.shape, vis_im.shape) + vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) + + # Save result or not + if save_im: + cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno) + cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + + # return vis_im + +def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): + + if not os.path.exists(respth): + os.makedirs(respth) + + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + save_pth = osp.join('res/cp', cp) + net.load_state_dict(torch.load(save_pth)) + net.eval() + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + with torch.no_grad(): + for image_path in os.listdir(dspth): + img = Image.open(osp.join(dspth, image_path)) + image = img.resize((512, 512), Image.BILINEAR) + img = to_tensor(image) + img = torch.unsqueeze(img, 0) + img = img.cuda() + out = net(img)[0] + parsing = out.squeeze(0).cpu().numpy().argmax(0) + # print(parsing) + print(np.unique(parsing)) + + vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path)) + + + + + + + +if __name__ == "__main__": + evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='79999_iter.pth') + + diff --git a/ConsistentID/lib/BiSeNet/train.py b/ConsistentID/lib/BiSeNet/train.py new file mode 100644 index 0000000000000000000000000000000000000000..ca481944086fc19f320f01c4f2c0f1ab7aef5a83 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/train.py @@ -0,0 +1,179 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +from logger import setup_logger +import BiSeNet +from face_dataset import FaceMask +from loss import OhemCELoss +from evaluate import evaluate +from optimizer import Optimizer +import cv2 +import numpy as np + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import torch.nn.functional as F +import torch.distributed as dist + +import os +import os.path as osp +import logging +import time +import datetime +import argparse + + +respth = './res' +if not osp.exists(respth): + os.makedirs(respth) +logger = logging.getLogger() + + +def parse_args(): + parse = argparse.ArgumentParser() + parse.add_argument( + '--local_rank', + dest = 'local_rank', + type = int, + default = -1, + ) + return parse.parse_args() + + +def train(): + args = parse_args() + torch.cuda.set_device(args.local_rank) + dist.init_process_group( + backend = 'nccl', + init_method = 'tcp://127.0.0.1:33241', + world_size = torch.cuda.device_count(), + rank=args.local_rank + ) + setup_logger(respth) + + # dataset + n_classes = 19 + n_img_per_gpu = 16 + n_workers = 8 + cropsize = [448, 448] + data_root = '/home/zll/data/CelebAMask-HQ/' + + ds = FaceMask(data_root, cropsize=cropsize, mode='train') + sampler = torch.utils.data.distributed.DistributedSampler(ds) + dl = DataLoader(ds, + batch_size = n_img_per_gpu, + shuffle = False, + sampler = sampler, + num_workers = n_workers, + pin_memory = True, + drop_last = True) + + # model + ignore_idx = -100 + net = BiSeNet(n_classes=n_classes) + net.cuda() + net.train() + net = nn.parallel.DistributedDataParallel(net, + device_ids = [args.local_rank, ], + output_device = args.local_rank + ) + score_thres = 0.7 + n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16 + LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) + Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) + Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) + + ## optimizer + momentum = 0.9 + weight_decay = 5e-4 + lr_start = 1e-2 + max_iter = 80000 + power = 0.9 + warmup_steps = 1000 + warmup_start_lr = 1e-5 + optim = Optimizer( + model = net.module, + lr0 = lr_start, + momentum = momentum, + wd = weight_decay, + warmup_steps = warmup_steps, + warmup_start_lr = warmup_start_lr, + max_iter = max_iter, + power = power) + + ## train loop + msg_iter = 50 + loss_avg = [] + st = glob_st = time.time() + diter = iter(dl) + epoch = 0 + for it in range(max_iter): + try: + im, lb = next(diter) + if not im.size()[0] == n_img_per_gpu: + raise StopIteration + except StopIteration: + epoch += 1 + sampler.set_epoch(epoch) + diter = iter(dl) + im, lb = next(diter) + im = im.cuda() + lb = lb.cuda() + H, W = im.size()[2:] + lb = torch.squeeze(lb, 1) + + optim.zero_grad() + out, out16, out32 = net(im) + lossp = LossP(out, lb) + loss2 = Loss2(out16, lb) + loss3 = Loss3(out32, lb) + loss = lossp + loss2 + loss3 + loss.backward() + optim.step() + + loss_avg.append(loss.item()) + + # print training log message + if (it+1) % msg_iter == 0: + loss_avg = sum(loss_avg) / len(loss_avg) + lr = optim.lr + ed = time.time() + t_intv, glob_t_intv = ed - st, ed - glob_st + eta = int((max_iter - it) * (glob_t_intv / it)) + eta = str(datetime.timedelta(seconds=eta)) + msg = ', '.join([ + 'it: {it}/{max_it}', + 'lr: {lr:4f}', + 'loss: {loss:.4f}', + 'eta: {eta}', + 'time: {time:.4f}', + ]).format( + it = it+1, + max_it = max_iter, + lr = lr, + loss = loss_avg, + time = t_intv, + eta = eta + ) + logger.info(msg) + loss_avg = [] + st = ed + if dist.get_rank() == 0: + if (it+1) % 5000 == 0: + state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() + if dist.get_rank() == 0: + torch.save(state, './res/cp/{}_iter.pth'.format(it)) + evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it)) + + # dump the final model + save_pth = osp.join(respth, 'model_final_diss.pth') + # net.cpu() + state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() + if dist.get_rank() == 0: + torch.save(state, save_pth) + logger.info('training done, model saved to: {}'.format(save_pth)) + + +if __name__ == "__main__": + train() diff --git a/ConsistentID/lib/BiSeNet/transform.py b/ConsistentID/lib/BiSeNet/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..9479ae356a151f5da8eedf288abeae7458739d24 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/transform.py @@ -0,0 +1,129 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +from PIL import Image +import PIL.ImageEnhance as ImageEnhance +import random +import numpy as np + +class RandomCrop(object): + def __init__(self, size, *args, **kwargs): + self.size = size + + def __call__(self, im_lb): + im = im_lb['im'] + lb = im_lb['lb'] + assert im.size == lb.size + W, H = self.size + w, h = im.size + + if (W, H) == (w, h): return dict(im=im, lb=lb) + if w < W or h < H: + scale = float(W) / w if w < h else float(H) / h + w, h = int(scale * w + 1), int(scale * h + 1) + im = im.resize((w, h), Image.BILINEAR) + lb = lb.resize((w, h), Image.NEAREST) + sw, sh = random.random() * (w - W), random.random() * (h - H) + crop = int(sw), int(sh), int(sw) + W, int(sh) + H + return dict( + im = im.crop(crop), + lb = lb.crop(crop) + ) + + +class HorizontalFlip(object): + def __init__(self, p=0.5, *args, **kwargs): + self.p = p + + def __call__(self, im_lb): + if random.random() > self.p: + return im_lb + else: + im = im_lb['im'] + lb = im_lb['lb'] + + # atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', + # 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat'] + + flip_lb = np.array(lb) + flip_lb[lb == 2] = 3 + flip_lb[lb == 3] = 2 + flip_lb[lb == 4] = 5 + flip_lb[lb == 5] = 4 + flip_lb[lb == 7] = 8 + flip_lb[lb == 8] = 7 + flip_lb = Image.fromarray(flip_lb) + return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT), + lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT), + ) + + +class RandomScale(object): + def __init__(self, scales=(1, ), *args, **kwargs): + self.scales = scales + + def __call__(self, im_lb): + im = im_lb['im'] + lb = im_lb['lb'] + W, H = im.size + scale = random.choice(self.scales) + w, h = int(W * scale), int(H * scale) + return dict(im = im.resize((w, h), Image.BILINEAR), + lb = lb.resize((w, h), Image.NEAREST), + ) + + +class ColorJitter(object): + def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs): + if not brightness is None and brightness>0: + self.brightness = [max(1-brightness, 0), 1+brightness] + if not contrast is None and contrast>0: + self.contrast = [max(1-contrast, 0), 1+contrast] + if not saturation is None and saturation>0: + self.saturation = [max(1-saturation, 0), 1+saturation] + + def __call__(self, im_lb): + im = im_lb['im'] + lb = im_lb['lb'] + r_brightness = random.uniform(self.brightness[0], self.brightness[1]) + r_contrast = random.uniform(self.contrast[0], self.contrast[1]) + r_saturation = random.uniform(self.saturation[0], self.saturation[1]) + im = ImageEnhance.Brightness(im).enhance(r_brightness) + im = ImageEnhance.Contrast(im).enhance(r_contrast) + im = ImageEnhance.Color(im).enhance(r_saturation) + return dict(im = im, + lb = lb, + ) + + +class MultiScale(object): + def __init__(self, scales): + self.scales = scales + + def __call__(self, img): + W, H = img.size + sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales] + imgs = [] + [imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes] + return imgs + + +class Compose(object): + def __init__(self, do_list): + self.do_list = do_list + + def __call__(self, im_lb): + for comp in self.do_list: + im_lb = comp(im_lb) + return im_lb + + + + +if __name__ == '__main__': + flip = HorizontalFlip(p = 1) + crop = RandomCrop((321, 321)) + rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0)) + img = Image.open('data/img.jpg') + lb = Image.open('data/label.png') diff --git a/ConsistentID/lib/attention.py b/ConsistentID/lib/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9124b3cabd0bc2cafba5b23cfa09bc6aa6261ca8 --- /dev/null +++ b/ConsistentID/lib/attention.py @@ -0,0 +1,287 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.lora import LoRALinearLayer +from .functions import AttentionMLP + +class FuseModule(nn.Module): + def __init__(self, embed_dim): + super().__init__() + self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False) + self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True) + self.layer_norm = nn.LayerNorm(embed_dim) + + def fuse_fn(self, prompt_embeds, id_embeds): + stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) + stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds + stacked_id_embeds = self.mlp2(stacked_id_embeds) + stacked_id_embeds = self.layer_norm(stacked_id_embeds) + return stacked_id_embeds + + def forward( + self, + prompt_embeds, + id_embeds, + class_tokens_mask, + valid_id_mask, + ) -> torch.Tensor: + id_embeds = id_embeds.to(prompt_embeds.dtype) + batch_size, max_num_inputs = id_embeds.shape[:2] # 1,5 + seq_length = prompt_embeds.shape[1] # 77 + flat_id_embeds = id_embeds.view(-1, id_embeds.shape[-2], id_embeds.shape[-1]) + # flat_id_embeds torch.Size([5, 1, 768]) + valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] + # valid_id_embeds torch.Size([4, 1, 768]) + prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) # torch.Size([77, 768]) + class_tokens_mask = class_tokens_mask.view(-1) # torch.Size([77]) + valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) # torch.Size([4, 768]) + image_token_embeds = prompt_embeds[class_tokens_mask] # torch.Size([4, 768]) + stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) # torch.Size([4, 768]) + assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" + prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) + updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) + + return updated_prompt_embeds + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + +class FacialEncoder(nn.Module): + def __init__(self): + super().__init__() + self.visual_projection = AttentionMLP() + self.fuse_module = FuseModule(768) + + def forward(self, prompt_embeds, multi_image_embeds, class_tokens_mask, valid_id_mask): + bs, num_inputs, token_length, image_dim = multi_image_embeds.shape + multi_image_embeds_view = multi_image_embeds.view(bs * num_inputs, token_length, image_dim) + id_embeds = self.visual_projection(multi_image_embeds_view) # torch.Size([5, 1, 768]) + id_embeds = id_embeds.view(bs, num_inputs, 1, -1) + # fuse_module replaces the class tokens in prompt_embeds with the fused (id_embeds, prompt_embeds[class_tokens_mask]) + # whose indices are specified by class_tokens_mask. + updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask, valid_id_mask) + return updated_prompt_embeds + +class Consistent_AttProcessor(nn.Module): + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class Consistent_IPAttProcessor(nn.Module): + + def __init__( + self, + hidden_size, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + scale=1.0, + num_tokens=4): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + self.num_tokens = num_tokens + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + for module in [self.to_q_lora, self.to_k_lora, self.to_v_lora, self.to_out_lora, self.to_k_ip, self.to_v_ip]: + for param in module.parameters(): + param.requires_grad = False + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + scale=1.0, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file diff --git a/ConsistentID/lib/functions.py b/ConsistentID/lib/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..e343952cc891459c21ba3a773a2cbd8ad3cede64 --- /dev/null +++ b/ConsistentID/lib/functions.py @@ -0,0 +1,606 @@ +import numpy as np +import math +import types +import torch +import torch.nn as nn +import numpy as np +import cv2 +import re +import torch.nn.functional as F +from einops import rearrange +from einops.layers.torch import Rearrange +from PIL import Image + +def extract_first_sentence(text): + end_index = text.find('.') + if end_index != -1: + first_sentence = text[:end_index + 1] + return first_sentence.strip() + else: + return text.strip() + +import re +def remove_duplicate_keywords(text, keywords): + keyword_counts = {} + + words = re.findall(r'\b\w+\b|[.,;!?]', text) + + for keyword in keywords: + keyword_counts[keyword] = 0 + for i, word in enumerate(words): + if word.lower() == keyword.lower(): + keyword_counts[keyword] += 1 + if keyword_counts[keyword] > 1: + words[i] = "" + processed_text = " ".join(words) + + return processed_text + +# text: 'The person has one nose , two eyes , two ears , and a mouth .' +def insert_markers_to_prompt(text, parsing_mask_dict): + keywords = ["face", "ears", "eyes", "nose", "mouth"] + text = remove_duplicate_keywords(text, keywords) + key_parsing_mask_markers = ["Nose", "Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Upper_Lip", "Lower_Lip"] + mapping = { + "Face": "face", + "Left_Ear": "ears", + "Right_Ear": "ears", + "Left_Eye": "eyes", + "Right_Eye": "eyes", + "Nose": "nose", + "Upper_Lip": "mouth", + "Lower_Lip": "mouth", + } + facial_features_align = [] + markers_align = [] + for key in key_parsing_mask_markers: + if key in parsing_mask_dict: + mapped_key = mapping.get(key, key.lower()) + if mapped_key not in facial_features_align: + facial_features_align.append(mapped_key) + markers_align.append("<|" + mapped_key + "|>") + + text_marked = text + align_parsing_mask_dict = parsing_mask_dict + for feature, marker in zip(facial_features_align[::-1], markers_align[::-1]): + pattern = rf'\b{feature}\b' + text_marked_new = re.sub(pattern, f'{feature} {marker}', text_marked, count=1) + if text_marked == text_marked_new: + for key, value in mapping.items(): + if value == feature: + if key in align_parsing_mask_dict: + del align_parsing_mask_dict[key] + + text_marked = text_marked_new + + text_marked = text_marked.replace('\n', '') + + ordered_text = [] + text_none_makers = [] + facial_marked_count = 0 + skip_count = 0 + for marker in markers_align: + start_idx = text_marked.find(marker) + end_idx = start_idx + len(marker) + + while start_idx > 0 and text_marked[start_idx - 1] not in [",", ".", ";"]: + start_idx -= 1 + + while end_idx < len(text_marked) and text_marked[end_idx] not in [",", ".", ";"]: + end_idx += 1 + + context = text_marked[start_idx:end_idx].strip() + if context == "": + text_none_makers.append(text_marked[:end_idx]) + else: + if skip_count!=0: + skip_count -= 1 + continue + else: + ordered_text.append(context + ", ") + text_delete_makers = text_marked[:start_idx] + text_marked[end_idx:] + text_marked = text_delete_makers + facial_marked_count += 1 + + # ordered_text: ['The person has one nose <|nose|>, ', 'two ears <|ears|>, ', + # 'two eyes <|eyes|>, ', 'and a mouth <|mouth|>, '] + # align_parsing_mask_dict.keys(): ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip'] + align_marked_text = "".join(ordered_text) + replace_list = ["<|face|>", "<|ears|>", "<|nose|>", "<|eyes|>", "<|mouth|>"] + for item in replace_list: + align_marked_text = align_marked_text.replace(item, "<|facial|>") + + # align_marked_text: 'The person has one nose <|facial|>, two ears <|facial|>, two eyes <|facial|>, and a mouth <|facial|>, ' + return align_marked_text, align_parsing_mask_dict + +def tokenize_and_mask_noun_phrases_ends(text, image_token_id, facial_token_id, tokenizer): + input_ids = tokenizer.encode(text) + image_noun_phrase_end_mask = [False for _ in input_ids] + facial_noun_phrase_end_mask = [False for _ in input_ids] + clean_input_ids = [] + clean_index = 0 + image_num = 0 + + for i, id in enumerate(input_ids): + if id == image_token_id: + image_noun_phrase_end_mask[clean_index + image_num - 1] = True + image_num += 1 + elif id == facial_token_id: + facial_noun_phrase_end_mask[clean_index - 1] = True + else: + clean_input_ids.append(id) + clean_index += 1 + + max_len = tokenizer.model_max_length + + if len(clean_input_ids) > max_len: + clean_input_ids = clean_input_ids[:max_len] + else: + clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( + max_len - len(clean_input_ids) + ) + + if len(image_noun_phrase_end_mask) > max_len: + image_noun_phrase_end_mask = image_noun_phrase_end_mask[:max_len] + else: + image_noun_phrase_end_mask = image_noun_phrase_end_mask + [False] * ( + max_len - len(image_noun_phrase_end_mask) + ) + + if len(facial_noun_phrase_end_mask) > max_len: + facial_noun_phrase_end_mask = facial_noun_phrase_end_mask[:max_len] + else: + facial_noun_phrase_end_mask = facial_noun_phrase_end_mask + [False] * ( + max_len - len(facial_noun_phrase_end_mask) + ) + clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long) + image_noun_phrase_end_mask = torch.tensor(image_noun_phrase_end_mask, dtype=torch.bool) + facial_noun_phrase_end_mask = torch.tensor(facial_noun_phrase_end_mask, dtype=torch.bool) + + return clean_input_ids.unsqueeze(0), image_noun_phrase_end_mask.unsqueeze(0), facial_noun_phrase_end_mask.unsqueeze(0) + +def prepare_image_token_idx(image_token_mask, facial_token_mask, max_num_objects=2, max_num_facials=5): + image_token_idx = torch.nonzero(image_token_mask, as_tuple=True)[1] + image_token_idx_mask = torch.ones_like(image_token_idx, dtype=torch.bool) + if len(image_token_idx) < max_num_objects: + image_token_idx = torch.cat( + [ + image_token_idx, + torch.zeros(max_num_objects - len(image_token_idx), dtype=torch.long), + ] + ) + image_token_idx_mask = torch.cat( + [ + image_token_idx_mask, + torch.zeros( + max_num_objects - len(image_token_idx_mask), + dtype=torch.bool, + ), + ] + ) + facial_token_idx = torch.nonzero(facial_token_mask, as_tuple=True)[1] + facial_token_idx_mask = torch.ones_like(facial_token_idx, dtype=torch.bool) + if len(facial_token_idx) < max_num_facials: + facial_token_idx = torch.cat( + [ + facial_token_idx, + torch.zeros(max_num_facials - len(facial_token_idx), dtype=torch.long), + ] + ) + facial_token_idx_mask = torch.cat( + [ + facial_token_idx_mask, + torch.zeros( + max_num_facials - len(facial_token_idx_mask), + dtype=torch.bool, + ), + ] + ) + image_token_idx = image_token_idx.unsqueeze(0) + image_token_idx_mask = image_token_idx_mask.unsqueeze(0) + + facial_token_idx = facial_token_idx.unsqueeze(0) + facial_token_idx_mask = facial_token_idx_mask.unsqueeze(0) + + return image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask + +def get_object_localization_loss_for_one_layer( + cross_attention_scores, + object_segmaps, + object_token_idx, + object_token_idx_mask, + loss_fn, +): + bxh, num_noise_latents, num_text_tokens = cross_attention_scores.shape + b, max_num_objects, _, _ = object_segmaps.shape + size = int(num_noise_latents**0.5) + + object_segmaps = F.interpolate(object_segmaps, size=(size, size), mode="bilinear", antialias=True) + + object_segmaps = object_segmaps.view( + b, max_num_objects, -1 + ) + + num_heads = bxh // b + cross_attention_scores = cross_attention_scores.view(b, num_heads, num_noise_latents, num_text_tokens) + + + object_token_attn_prob = torch.gather( + cross_attention_scores, + dim=3, + index=object_token_idx.view(b, 1, 1, max_num_objects).expand( + b, num_heads, num_noise_latents, max_num_objects + ), + ) + object_segmaps = ( + object_segmaps.permute(0, 2, 1) + .unsqueeze(1) + .expand(b, num_heads, num_noise_latents, max_num_objects) + ) + loss = loss_fn(object_token_attn_prob, object_segmaps) + + loss = loss * object_token_idx_mask.view(b, 1, max_num_objects) + object_token_cnt = object_token_idx_mask.sum(dim=1).view(b, 1) + 1e-5 + loss = (loss.sum(dim=2) / object_token_cnt).mean() + + return loss + + +def get_object_localization_loss( + cross_attention_scores, + object_segmaps, + image_token_idx, + image_token_idx_mask, + loss_fn, +): + num_layers = len(cross_attention_scores) + loss = 0 + for k, v in cross_attention_scores.items(): + layer_loss = get_object_localization_loss_for_one_layer( + v, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn + ) + loss += layer_loss + return loss / num_layers + +def unet_store_cross_attention_scores(unet, attention_scores, layers=5): + from diffusers.models.attention_processor import Attention + + UNET_LAYER_NAMES = [ + "down_blocks.0", + "down_blocks.1", + "down_blocks.2", + "mid_block", + "up_blocks.1", + "up_blocks.2", + "up_blocks.3", + ] + + start_layer = (len(UNET_LAYER_NAMES) - layers) // 2 + end_layer = start_layer + layers + applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer] + + def make_new_get_attention_scores_fn(name): + def new_get_attention_scores(module, query, key, attention_mask=None): + attention_probs = module.old_get_attention_scores( + query, key, attention_mask + ) + attention_scores[name] = attention_probs + return attention_probs + + return new_get_attention_scores + + for name, module in unet.named_modules(): + if isinstance(module, Attention) and "attn1" in name: + if not any(layer in name for layer in applicable_layers): + continue + + module.old_get_attention_scores = module.get_attention_scores + module.get_attention_scores = types.MethodType( + make_new_get_attention_scores_fn(name), module + ) + return unet + +class BalancedL1Loss(nn.Module): + def __init__(self, threshold=1.0, normalize=False): + super().__init__() + self.threshold = threshold + self.normalize = normalize + + def forward(self, object_token_attn_prob, object_segmaps): + if self.normalize: + object_token_attn_prob = object_token_attn_prob / ( + object_token_attn_prob.max(dim=2, keepdim=True)[0] + 1e-5 + ) + background_segmaps = 1 - object_segmaps + background_segmaps_sum = background_segmaps.sum(dim=2) + 1e-5 + object_segmaps_sum = object_segmaps.sum(dim=2) + 1e-5 + + background_loss = (object_token_attn_prob * background_segmaps).sum( + dim=2 + ) / background_segmaps_sum + + object_loss = (object_token_attn_prob * object_segmaps).sum( + dim=2 + ) / object_segmaps_sum + + return background_loss - object_loss + +def apply_mask_to_raw_image(raw_image, mask_image): + mask_image = mask_image.resize(raw_image.size) + mask_raw_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (0, 0, 0)), mask_image) + return mask_raw_image + +mapping_table = [ + {"Mask Value": 0, "Body Part": "Background", "RGB Color": [0, 0, 0]}, + {"Mask Value": 1, "Body Part": "Face", "RGB Color": [255, 0, 0]}, + {"Mask Value": 2, "Body Part": "Left_Eyebrow", "RGB Color": [255, 85, 0]}, + {"Mask Value": 3, "Body Part": "Right_Eyebrow", "RGB Color": [255, 170, 0]}, + {"Mask Value": 4, "Body Part": "Left_Eye", "RGB Color": [255, 0, 85]}, + {"Mask Value": 5, "Body Part": "Right_Eye", "RGB Color": [255, 0, 170]}, + {"Mask Value": 6, "Body Part": "Hair", "RGB Color": [0, 0, 255]}, + {"Mask Value": 7, "Body Part": "Left_Ear", "RGB Color": [85, 0, 255]}, + {"Mask Value": 8, "Body Part": "Right_Ear", "RGB Color": [170, 0, 255]}, + {"Mask Value": 9, "Body Part": "Mouth_External Contour", "RGB Color": [0, 255, 85]}, + {"Mask Value": 10, "Body Part": "Nose", "RGB Color": [0, 255, 0]}, + {"Mask Value": 11, "Body Part": "Mouth_Inner_Contour", "RGB Color": [0, 255, 170]}, + {"Mask Value": 12, "Body Part": "Upper_Lip", "RGB Color": [85, 255, 0]}, + {"Mask Value": 13, "Body Part": "Lower_Lip", "RGB Color": [170, 255, 0]}, + {"Mask Value": 14, "Body Part": "Neck", "RGB Color": [0, 85, 255]}, + {"Mask Value": 15, "Body Part": "Neck_Inner Contour", "RGB Color": [0, 170, 255]}, + {"Mask Value": 16, "Body Part": "Cloth", "RGB Color": [255, 255, 0]}, + {"Mask Value": 17, "Body Part": "Hat", "RGB Color": [255, 0, 255]}, + {"Mask Value": 18, "Body Part": "Earring", "RGB Color": [255, 85, 255]}, + {"Mask Value": 19, "Body Part": "Necklace", "RGB Color": [255, 255, 85]}, + {"Mask Value": 20, "Body Part": "Glasses", "RGB Color": [255, 170, 255]}, + {"Mask Value": 21, "Body Part": "Hand", "RGB Color": [255, 0, 255]}, + {"Mask Value": 22, "Body Part": "Wristband", "RGB Color": [0, 255, 255]}, + {"Mask Value": 23, "Body Part": "Clothes_Upper", "RGB Color": [85, 255, 255]}, + {"Mask Value": 24, "Body Part": "Clothes_Lower", "RGB Color": [170, 255, 255]} +] + + +def masks_for_unique_values(image_raw_mask): + + image_array = np.array(image_raw_mask) + unique_values, counts = np.unique(image_array, return_counts=True) + masks_dict = {} + for value in unique_values: + binary_image = np.uint8(image_array == value) * 255 + contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + mask = np.zeros_like(image_array) + for contour in contours: + cv2.drawContours(mask, [contour], -1, (255), thickness=cv2.FILLED) + + if value == 0: + body_part="WithoutBackground" + mask2 = np.where(mask == 255, 0, 255).astype(mask.dtype) + masks_dict[body_part] = Image.fromarray(mask2) + + body_part = next((entry["Body Part"] for entry in mapping_table if entry["Mask Value"] == value), f"Unknown_{value}") + if body_part.startswith("Unknown_"): + continue + + masks_dict[body_part] = Image.fromarray(mask) + + return masks_dict +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + x = x.view(bs, length, heads, -1) + x = x.transpose(1, 2) + x = x.reshape(bs, heads, length, -1) + return x + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # x -> kv, latents -> q + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + +class FacePerceiverResampler(torch.nn.Module): + def __init__( + self, + *, + dim=768, + depth=4, + dim_head=64, + heads=16, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + ): + super().__init__() + + self.proj_in = torch.nn.Linear(embedding_dim, dim) + self.proj_out = torch.nn.Linear(dim, output_dim) + self.norm_out = torch.nn.LayerNorm(output_dim) + self.layers = torch.nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + # x -> kv, latents -> q + def forward(self, latents, x): # latents.torch.Size([2, 4, 768]) x.torch.Size([2, 257, 1280]) + x = self.proj_in(x) # x.torch.Size([2, 257, 768]) + for attn, ff in self.layers: + # x -> kv, latents -> q + latents = attn(x, latents) + latents # latents.torch.Size([2, 4, 768]) + latents = ff(latents) + latents # latents.torch.Size([2, 4, 768]) + latents = self.proj_out(latents) + return self.norm_out(latents) + +class ProjPlusModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + self.perceiver_resampler = FacePerceiverResampler( + dim=cross_attention_dim, + depth=4, + dim_head=64, + heads=cross_attention_dim // 64, + embedding_dim=clip_embeddings_dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + + def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): + + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + # id_embeds -> x -> kv, clip_embeds -> q + out = self.perceiver_resampler(x, clip_embeds) + if shortcut: + out = scale * x + out + return out + +class AttentionMLP(nn.Module): + def __init__( + self, + dtype=torch.float16, + dim=1024, + depth=8, + dim_head=64, + heads=16, + single_num_tokens=1, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + max_seq_len: int = 257*2, + apply_pos_emb: bool = False, + num_latents_mean_pooled: int = 0, + ): + super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None + + self.single_num_tokens = single_num_tokens + self.latents = nn.Parameter(torch.randn(1, self.single_num_tokens, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb + # x torch.Size([5, 257, 1280]) + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) # torch.Size([5, 257, 1024]) + + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +def masked_mean(t, *, dim, mask=None): + if mask is None: + return t.mean(dim=dim) + + denom = mask.sum(dim=dim, keepdim=True) + mask = rearrange(mask, "b n -> b n 1") + masked_t = t.masked_fill(~mask, 0.0) + + return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) + + diff --git a/ConsistentID/lib/pipeline_ConsistentID.py b/ConsistentID/lib/pipeline_ConsistentID.py new file mode 100644 index 0000000000000000000000000000000000000000..129c7bb21a5367d673299e5bb7a5501333a6d4e7 --- /dev/null +++ b/ConsistentID/lib/pipeline_ConsistentID.py @@ -0,0 +1,605 @@ +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +import cv2 +import PIL +import numpy as np +from PIL import Image +import torch +from torchvision import transforms +from insightface.app import FaceAnalysis +### insight-face installation can be found at https://github.com/deepinsight/insightface +from safetensors import safe_open +from huggingface_hub.utils import validate_hf_hub_args +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline +from .functions import insert_markers_to_prompt, masks_for_unique_values, apply_mask_to_raw_image, tokenize_and_mask_noun_phrases_ends, prepare_image_token_idx +from .functions import ProjPlusModel, masks_for_unique_values +from .attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder +from easydict import EasyDict as edict +from huggingface_hub import hf_hub_download +### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file +### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812 +### Thanks for the open source of face-parsing model. +from .BiSeNet.model import BiSeNet +import os + +PipelineImageInput = Union[ + PIL.Image.Image, + torch.FloatTensor, + List[PIL.Image.Image], + List[torch.FloatTensor], +] + +### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location. +class ConsistentIDPipeline(StableDiffusionPipeline): + # to() should be only called after all modules are loaded. + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(torch_device, dtype=dtype) + self.bise_net.to(torch_device, dtype=dtype) + self.clip_encoder.to(torch_device, dtype=dtype) + self.image_proj_model.to(torch_device, dtype=dtype) + self.FacialEncoder.to(torch_device, dtype=dtype) + # If the unet is not released, the ip_layers should be moved to the specified device and dtype. + if not isinstance(self.unet, edict): + self.ip_layers.to(torch_device, dtype=dtype) + return self + + @validate_hf_hub_args + def load_ConsistentID_model( + self, + consistentID_weight_path: str, + bise_net_weight_path: str, + trigger_word_facial: str = '<|facial|>', + # A CLIP ViT-H/14 model trained with the LAION-2B English subset of LAION-5B using OpenCLIP. + # output dim: 1280. + image_encoder_path: str = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K', + torch_dtype = torch.float16, + num_tokens = 4, + lora_rank= 128, + **kwargs, + ): + self.lora_rank = lora_rank + self.torch_dtype = torch_dtype + self.num_tokens = num_tokens + self.set_ip_adapter() + self.image_encoder_path = image_encoder_path + self.clip_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path) + self.clip_preprocessor = CLIPImageProcessor() + self.id_image_processor = CLIPImageProcessor() + self.crop_size = 512 + + # face_app: FaceAnalysis object + self.face_app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['CPUExecutionProvider']) + # The original det_size=(640, 640) is too large and face_app often fails to detect faces. + self.face_app.prepare(ctx_id=0, det_size=(512, 512)) + + if not os.path.exists(consistentID_weight_path): + ### Download pretrained models + hf_hub_download(repo_id="JackAILab/ConsistentID", repo_type="model", + filename=os.path.basename(consistentID_weight_path), + local_dir=os.path.dirname(consistentID_weight_path)) + if not os.path.exists(bise_net_weight_path): + hf_hub_download(repo_id="JackAILab/ConsistentID", + filename=os.path.basename(bise_net_weight_path), + local_dir=os.path.dirname(bise_net_weight_path)) + + bise_net = BiSeNet(n_classes = 19) + bise_net.load_state_dict(torch.load(bise_net_weight_path, map_location="cpu")) + bise_net.eval() + self.bise_net = bise_net + + # Colors for all 20 parts + self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], + [255, 0, 85], [255, 0, 170], + [0, 255, 0], [85, 255, 0], [170, 255, 0], + [0, 255, 85], [0, 255, 170], + [0, 0, 255], [85, 0, 255], [170, 0, 255], + [0, 85, 255], [0, 170, 255], + [255, 255, 0], [255, 255, 85], [255, 255, 170], + [255, 0, 255], [255, 85, 255], [255, 170, 255], + [0, 255, 255], [85, 255, 255], [170, 255, 255]] + + # image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings. + self.image_proj_model = ProjPlusModel( + cross_attention_dim=self.unet.config.cross_attention_dim, + id_embeddings_dim=512, + clip_embeddings_dim=self.clip_encoder.config.hidden_size, + num_tokens=self.num_tokens, # 4 - inspirsed by IPAdapter and Midjourney + ) + self.FacialEncoder = FacialEncoder() + + if consistentID_weight_path.endswith(".safetensors"): + state_dict = {"id_encoder": {}, "lora_weights": {}} + with safe_open(consistentID_weight_path, framework="pt", device="cpu") as f: + ### TODO safetensors add + for key in f.keys(): + if key.startswith("FacialEncoder."): + state_dict["FacialEncoder"][key.replace("FacialEncoder.", "")] = f.get_tensor(key) + elif key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(consistentID_weight_path, map_location="cpu") + + self.trigger_word_facial = trigger_word_facial + + self.FacialEncoder.load_state_dict(state_dict["FacialEncoder"], strict=True) + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) + self.ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) + self.ip_layers.load_state_dict(state_dict["adapter_modules"], strict=True) + print(f"Successfully loaded weights from checkpoint") + + # Add trigger word token + if self.tokenizer is not None: + self.tokenizer.add_tokens([self.trigger_word_facial], special_tokens=True) + + def set_ip_adapter(self): + unet = self.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = Consistent_AttProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, + ) + else: + attn_procs[name] = Consistent_IPAttProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, + ) + + unet.set_attn_processor(attn_procs) + + @torch.inference_mode() + # parsed_image_parts2 is a batched tensor of parsed_image_parts with bs=1. It only contains the facial areas of one input image. + # clip_encoder maps image parts to image-space diffusion prompts. + # Then the facial class token embeddings are replaced with the fused (multi_facial_embeds, prompt_embeds[class_tokens_mask]). + def extract_local_facial_embeds(self, prompt_embeds, uncond_prompt_embeds, parsed_image_parts2, + facial_token_masks, valid_facial_token_idx_mask, calc_uncond=True): + + hidden_states = [] + uncond_hidden_states = [] + for parsed_image_parts in parsed_image_parts2: + hidden_state = self.clip_encoder(parsed_image_parts.to(self.device, dtype=self.torch_dtype), output_hidden_states=True).hidden_states[-2] + uncond_hidden_state = self.clip_encoder(torch.zeros_like(parsed_image_parts, dtype=self.torch_dtype).to(self.device), output_hidden_states=True).hidden_states[-2] + hidden_states.append(hidden_state) + uncond_hidden_states.append(uncond_hidden_state) + multi_facial_embeds = torch.stack(hidden_states) + uncond_multi_facial_embeds = torch.stack(uncond_hidden_states) + + # conditional prompt. + # FacialEncoder maps multi_facial_embeds to facial ID embeddings, and replaces the class tokens in prompt_embeds + # with the fused (facial ID embeddings, prompt_embeds[class_tokens_mask]). + # multi_facial_embeds: [1, 5, 257, 1280]. + facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask) + + if not calc_uncond: + return facial_prompt_embeds, None + # unconditional prompt. + uncond_facial_prompt_embeds = self.FacialEncoder(uncond_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask) + + return facial_prompt_embeds, uncond_facial_prompt_embeds + + @torch.inference_mode() + # Extrat OpenCLIP embeddings from the input image and map them to face prompt embeddings. + def extract_global_id_embeds(self, face_image_obj, s_scale=1.0, shortcut=False): + clip_image_ts = self.clip_preprocessor(images=face_image_obj, return_tensors="pt").pixel_values + clip_image_ts = clip_image_ts.to(self.device, dtype=self.torch_dtype) + clip_image_embeds = self.clip_encoder(clip_image_ts, output_hidden_states=True).hidden_states[-2] + uncond_clip_image_embeds = self.clip_encoder(torch.zeros_like(clip_image_ts), output_hidden_states=True).hidden_states[-2] + + faceid_embeds = self.extract_faceid(face_image_obj) + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + # image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings. + # clip_image_embeds are used as queries to transform faceid_embeds. + # faceid_embeds -> kv, clip_image_embeds -> q + global_id_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale) + uncond_global_id_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale) + + return global_id_embeds, uncond_global_id_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, Consistent_IPAttProcessor): + attn_processor.scale = scale + + @torch.inference_mode() + def extract_faceid(self, face_image_obj): + faceid_image = np.array(face_image_obj) + faces = self.face_app.get(faceid_image) + if faces==[]: + faceid_embeds = torch.zeros_like(torch.empty((1, 512))) + else: + faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) + + return faceid_embeds + + @torch.inference_mode() + def parse_face_mask(self, raw_image_refer): + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + to_pil = transforms.ToPILImage() + + with torch.no_grad(): + image = raw_image_refer.resize((512, 512), Image.BILINEAR) + image_resize_PIL = image + img = to_tensor(image) + img = torch.unsqueeze(img, 0) + img = img.to(self.device, dtype=self.torch_dtype) + out = self.bise_net(img)[0] + parsing_anno = out.squeeze(0).cpu().numpy().argmax(0) + + im = np.array(image_resize_PIL) + vis_im = im.copy().astype(np.uint8) + stride=1 + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 + + num_of_class = np.max(vis_parsing_anno) + + for pi in range(1, num_of_class + 1): # num_of_class=17 pi=1~16 + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi] + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + vis_parsing_anno_color = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) + + return vis_parsing_anno_color, vis_parsing_anno + + @torch.inference_mode() + def extract_facemask(self, input_image_obj): + vis_parsing_anno_color, vis_parsing_anno = self.parse_face_mask(input_image_obj) + parsing_mask_list = masks_for_unique_values(vis_parsing_anno) + + key_parsing_mask_dict = {} + key_list = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"] + processed_keys = set() + for key, mask_image in parsing_mask_list.items(): + if key in key_list: + if "_" in key: + prefix = key.split("_")[1] + if prefix in processed_keys: + continue + else: + key_parsing_mask_dict[key] = mask_image + processed_keys.add(prefix) + + key_parsing_mask_dict[key] = mask_image + + return key_parsing_mask_dict, vis_parsing_anno_color + + def augment_prompt_with_trigger_word( + self, + prompt: str, + face_caption: str, + key_parsing_mask_dict = None, + facial_token = "<|facial|>", + max_num_facials = 5, + num_id_images: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + # face_caption_align: 'The person has one nose <|facial|>, two ears <|facial|>, two eyes <|facial|>, and a mouth <|facial|>, ' + face_caption_align, key_parsing_mask_dict_align = insert_markers_to_prompt(face_caption, key_parsing_mask_dict) + + prompt_face = prompt + " Detail: " + face_caption_align + + max_text_length=330 + if len(self.tokenizer(prompt_face, max_length=self.tokenizer.model_max_length, + padding="max_length", truncation=False, return_tensors="pt").input_ids[0]) != 77: + # Put face_caption_align at the beginning of the prompt, so that the original prompt is truncated, + # but the face_caption_align is well kept. + prompt_face = "Detail: " + face_caption_align + " Caption:" + prompt + + # Remove "<|facial|>" from prompt_face. + # augmented_prompt: 'A person, police officer, half body shot Detail: + # The person has one nose , two ears , two eyes , and a mouth , ' + augmented_prompt = prompt_face.replace("<|facial|>", "") + tokenizer = self.tokenizer + facial_token_id = tokenizer.convert_tokens_to_ids(facial_token) + image_token_id = None + + # image_token_id: the token id of "<|image|>". Disabled, as it's set to None. + # facial_token_id: the token id of "<|facial|>". + clean_input_id, image_token_mask, facial_token_mask = \ + tokenize_and_mask_noun_phrases_ends(prompt_face, image_token_id, facial_token_id, tokenizer) + + image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = \ + prepare_image_token_idx(image_token_mask, facial_token_mask, num_id_images, max_num_facials) + + return augmented_prompt, clean_input_id, key_parsing_mask_dict_align, facial_token_mask, facial_token_idx, facial_token_idx_mask + + @torch.inference_mode() + def extract_parsed_image_parts(self, input_image_obj, key_parsing_mask_dict, image_size=512, max_num_facials=5): + facial_masks = [] + parsed_image_parts = [] + key_masked_raw_images_dict = {} + transform_mask = transforms.Compose([transforms.CenterCrop(size=image_size), transforms.ToTensor(),]) + clip_preprocessor = CLIPImageProcessor() + + num_facial_part = len(key_parsing_mask_dict) + + for key in key_parsing_mask_dict: + key_mask=key_parsing_mask_dict[key] + facial_masks.append(transform_mask(key_mask)) + key_masked_raw_image = apply_mask_to_raw_image(input_image_obj, key_mask) + key_masked_raw_images_dict[key] = key_masked_raw_image + # clip_preprocessor normalizes key_masked_raw_image, so that (masked) zero pixels become non-zero. + # It also resizes the image to 224x224. + parsed_image_part = clip_preprocessor(images=key_masked_raw_image, return_tensors="pt").pixel_values + parsed_image_parts.append(parsed_image_part) + + padding_ficial_clip_image = torch.zeros_like(torch.zeros([1, 3, 224, 224])) + padding_ficial_mask = torch.zeros_like(torch.zeros([1, image_size, image_size])) + + if num_facial_part < max_num_facials: + parsed_image_parts += [ torch.zeros_like(padding_ficial_clip_image) for _ in range(max_num_facials - num_facial_part) ] + facial_masks += [ torch.zeros_like(padding_ficial_mask) for _ in range(max_num_facials - num_facial_part) ] + + parsed_image_parts = torch.stack(parsed_image_parts, dim=1).squeeze(0) + facial_masks = torch.stack(facial_masks, dim=0).squeeze(dim=1) + + return parsed_image_parts, facial_masks, key_masked_raw_images_dict + + # Release the unet/vae/text_encoder to save memory. + def release_components(self, released_components=["unet", "vae", "text_encoder"]): + if "unet" in released_components: + unet = edict() + # Only keep the config and in_channels attributes that are used in the pipeline. + unet.config = self.unet.config + self.unet = unet + + if "vae" in released_components: + self.vae = None + if "text_encoder" in released_components: + self.text_encoder = None + + # input_subj_image_obj: an Image object. + def extract_double_id_prompt_embeds(self, prompt, negative_prompt, input_subj_image_obj, device, calc_uncond=True): + face_caption = "The person has one nose, two eyes, two ears, and a mouth." + key_parsing_mask_dict, vis_parsing_anno_color = self.extract_facemask(input_subj_image_obj) + + augmented_prompt, clean_input_id, key_parsing_mask_dict_align, \ + facial_token_mask, facial_token_idx, facial_token_idx_mask \ + = self.augment_prompt_with_trigger_word( + prompt = prompt, + face_caption = face_caption, + key_parsing_mask_dict=key_parsing_mask_dict, + device=device, + max_num_facials = 5, + num_id_images = 1 + ) + + text_embeds, uncond_text_embeds = self.encode_prompt( + augmented_prompt, + device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=calc_uncond, + negative_prompt=negative_prompt, + ) + + # 5. Prepare the input ID images + # global_id_embeds: [1, 4, 768] + # extract_global_id_embeds() extrats OpenCLIP embeddings from the input image and map them to global face prompt embeddings. + global_id_embeds, uncond_global_id_embeds = \ + self.extract_global_id_embeds(face_image_obj=input_subj_image_obj, s_scale=1.0, shortcut=False) + + # parsed_image_parts: [5, 3, 224, 224]. 5 parts, each part is a 3-channel 224x224 image (resized by CLIP Preprocessor). + parsed_image_parts, facial_masks, key_masked_raw_images_dict = \ + self.extract_parsed_image_parts(input_subj_image_obj, key_parsing_mask_dict_align, image_size=512, max_num_facials=5) + parsed_image_parts2 = parsed_image_parts.unsqueeze(0).to(device, dtype=self.torch_dtype) + facial_token_mask = facial_token_mask.to(device) + facial_token_idx_mask = facial_token_idx_mask.to(device) + + # key_masked_raw_images_dict: ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip'] + # for key in key_masked_raw_images_dict: + # key_masked_raw_images_dict[key].save(f"{key}.png") + + # 6. Get the update text embedding + # parsed_image_parts2: the facial areas of the input image + # extract_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds + # with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask. + # parsed_image_parts2: [1, 5, 3, 224, 224] + text_local_id_embeds, uncond_text_local_id_embeds = \ + self.extract_local_facial_embeds(text_embeds, uncond_text_embeds, \ + parsed_image_parts2, facial_token_mask, facial_token_idx_mask, + calc_uncond=calc_uncond) + + # text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768] + # text_local_id_embeds: [1, 77, 768], only differs with text_embeds on 4 ID embeddings, and is identical + # to text_embeds on the rest 73 tokens. + text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1) + text_local_global_id_embeds = torch.cat([text_local_id_embeds, global_id_embeds], dim=1) + + if calc_uncond: + uncond_text_global_id_embeds = torch.cat([uncond_text_local_id_embeds, uncond_global_id_embeds], dim=1) + coarse_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_global_id_embeds], dim=0) + fine_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_local_global_id_embeds], dim=0) + else: + coarse_prompt_embeds = text_global_id_embeds + fine_prompt_embeds = text_local_global_id_embeds + + # fine_prompt_embeds: the conditional part is + # (text_global_id_embeds + text_local_global_id_embeds) / 2. + fine_prompt_embeds = (coarse_prompt_embeds + fine_prompt_embeds) / 2 + + return coarse_prompt_embeds, fine_prompt_embeds + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + original_size: Optional[Tuple[int, int]] = None, + target_size: Optional[Tuple[int, int]] = None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + input_subj_image_objs: PipelineImageInput = None, + start_merge_step: int = 0, + ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + do_classifier_free_guidance = guidance_scale >= 1.0 + assert do_classifier_free_guidance + + if input_subj_image_objs is not None: + if not isinstance(input_subj_image_objs, list): + input_subj_image_objs = [input_subj_image_objs] + + # 3. Encode input prompt + coarse_prompt_embeds, fine_prompt_embeds = \ + self.extract_double_id_prompt_embeds(prompt, negative_prompt, input_subj_image_objs[0], device) + else: + # Replace the coarse_prompt_embeds and fine_prompt_embeds with the input prompt_embeds. + # This is used when prompt_embeds are computed in advance. + cfg_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + coarse_prompt_embeds = cfg_prompt_embeds + fine_prompt_embeds = cfg_prompt_embeds + + # 7. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 8. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + self.dtype, + device, + generator, + latents, + ) + + # {'eta': 0.0, 'generator': None}. eta is 0 for DDIM. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + cross_attention_kwargs = {} + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + # DDIM doesn't scale latent_model_input. + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if i <= start_merge_step: + current_prompt_embeds = coarse_prompt_embeds + else: + current_prompt_embeds = fine_prompt_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=current_prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + else: + assert 0, "Not Implemented" + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or \ + ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + elif output_type == "pil": + # 9.1 Post-processing + image = self.decode_latents(latents) + # 9.3 Convert to PIL + image = self.numpy_to_pil(image) + else: + # 9.1 Post-processing + image = self.decode_latents(latents) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=None + ) + + + + + + + + diff --git a/ConsistentID/requirements.txt b/ConsistentID/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5ba980a578d3d2963080eac9b431efd05f6ff674 --- /dev/null +++ b/ConsistentID/requirements.txt @@ -0,0 +1,15 @@ +accelerate +safetensors +einops +onnxruntime-gpu +omegaconf +peft +opencv-python +insightface +diffusers +torch +torchvision +transformers +spaces +huggingface-hub +sentencepiece \ No newline at end of file diff --git a/README.md b/README.md index f643de873a59a82050ebbdcb439835f323a71ebd..5eb8e78888c4a79862ae029db78d9af6d610210b 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,14 @@ --- -title: Adaface2 -emoji: 📈 -colorFrom: green -colorTo: pink +title: Adaface +emoji: 😻 +colorFrom: indigo +colorTo: red sdk: gradio -sdk_version: 5.23.1 +sdk_version: 5.0.2 app_file: app.py pinned: false +license: apache-2.0 +short_description: 'AdaFace: Face Encoder for 0-Shot Diffusion Personalization' --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/adaface/__init__.py b/adaface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/adaface/adaface_infer.py b/adaface/adaface_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..29c0f9821c93fa85e802c2023970c3699077955f --- /dev/null +++ b/adaface/adaface_infer.py @@ -0,0 +1,155 @@ +from adaface.adaface_wrapper import AdaFaceWrapper +import torch +#import torch.nn.functional as F +from PIL import Image +import numpy as np +import os, argparse, glob, re + +def save_images(images, num_images_per_row, subject_name, prompt, perturb_std, save_dir = "samples-ada"): + if num_images_per_row > len(images): + num_images_per_row = len(images) + + os.makedirs(save_dir, exist_ok=True) + + num_columns = int(np.ceil(len(images) / num_images_per_row)) + # Save 4 images as a grid image in save_dir + grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns)) + for i, image in enumerate(images): + image = image.resize((512, 512)) + grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row))) + + prompt_sig = prompt.replace(" ", "_").replace(",", "_") + grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}.png") + if os.path.exists(grid_filepath): + grid_count = 2 + grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}-{grid_count}.png') + while os.path.exists(grid_filepath): + grid_count += 1 + grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}-{grid_count}.png') + + grid_image.save(grid_filepath) + print(f"Saved to {grid_filepath}") + +def seed_everything(seed): + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PL_GLOBAL_SEED"] = str(seed) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--pipeline", type=str, default="text2img", + choices=["text2img", "text2imgxl", "img2img", "text2img3", "flux"], + help="Type of pipeline to use (default: txt2img)") + parser.add_argument("--base_model_path", type=str, default=None, + help="Type of checkpoints to use (default: None, using the official model)") + parser.add_argument('--adaface_ckpt_paths', type=str, nargs="+", + default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt']) + parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"], + choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders") + parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None, + choices=["arc2face", "consistentID"], + help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)") + # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face). + parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None, + help="CFG scales of output embeddings of the ID2Ada prompt encoders") + parser.add_argument("--main_unet_filepath", type=str, default=None, + help="Path to the checkpoint of the main UNet model, if you want to replace the default UNet within --base_model_path") + parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*", + default=[], + help="Extra paths to the checkpoints of the UNet models") + parser.add_argument('--unet_weights', type=float, nargs="+", default=[1], + help="Weights for the UNet models") + parser.add_argument("--subject", type=str) + parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use") + parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate") + parser.add_argument("--prompt", type=str, default="a woman z in superman costume") + parser.add_argument("--noise", dest='perturb_std', type=float, default=0) + parser.add_argument("--randface", action="store_true") + parser.add_argument("--scale", dest='guidance_scale', type=float, default=4, + help="Guidance scale for the diffusion model") + parser.add_argument("--subject_string", + type=str, default="z", + help="Subject placeholder string used in prompts to denote the concept.") + parser.add_argument("--num_images_per_row", type=int, default=4, + help="Number of images to display in a row in the output grid image.") + parser.add_argument("--num_inference_steps", type=int, default=50, + help="Number of inference steps") + parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on") + parser.add_argument("--seed", type=int, default=42, + help="the seed (for reproducible sampling). Set to -1 to disable.") + args = parser.parse_args() + + return args + +if __name__ == "__main__": + args = parse_args() + if args.seed != -1: + seed_everything(args.seed) + + if re.match(r"^\d+$", args.device): + args.device = f"cuda:{args.device}" + print(f"Using device {args.device}") + + if args.pipeline not in ["text2img", "img2img"]: + args.extra_unet_dirpaths = None + args.unet_weights = None + + adaface = AdaFaceWrapper(args.pipeline, args.base_model_path, + args.adaface_encoder_types, args.adaface_ckpt_paths, + args.adaface_encoder_cfg_scales, args.enabled_encoders, + args.subject_string, args.num_inference_steps, + unet_types=None, + main_unet_filepath=args.main_unet_filepath, + extra_unet_dirpaths=args.extra_unet_dirpaths, + unet_weights=args.unet_weights, device=args.device) + + if not args.randface: + image_folder = args.subject + if image_folder.endswith("/"): + image_folder = image_folder[:-1] + + if os.path.isfile(image_folder): + # Get the second to the last part of the path + subject_name = os.path.basename(os.path.dirname(image_folder)) + image_paths = [image_folder] + + else: + subject_name = os.path.basename(image_folder) + image_types = ["*.jpg", "*.png", "*.jpeg"] + alltype_image_paths = [] + for image_type in image_types: + # glob returns the full path. + image_paths = glob.glob(os.path.join(image_folder, image_type)) + if len(image_paths) > 0: + alltype_image_paths.extend(image_paths) + + # Filter out images of "*_mask.png" + alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path] + + # image_paths contain at most args.example_image_count full image paths. + if args.example_image_count > 0: + image_paths = alltype_image_paths[:args.example_image_count] + else: + image_paths = alltype_image_paths + else: + subject_name = None + image_paths = None + image_folder = None + + subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name + rand_init_id_embs = torch.randn(1, 512) + + init_id_embs = rand_init_id_embs if args.randface else None + noise = torch.randn(args.out_image_count, 4, 64, 64).cuda() + # args.perturb_std: the *relative* std of the noise added to the face embeddings. + # A noise level of 0.08 could change gender, but 0.06 is usually safe. + # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call). + adaface_subj_embs = \ + adaface.prepare_adaface_embeddings(image_paths, init_id_embs, + perturb_at_stage='img_prompt_emb', + perturb_std=args.perturb_std, update_text_encoder=True) + images = adaface(noise, args.prompt, None, 'append', args.guidance_scale, args.out_image_count, verbose=True) + save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std) diff --git a/adaface/adaface_translate.py b/adaface/adaface_translate.py new file mode 100644 index 0000000000000000000000000000000000000000..690f2dd81beb95865168a8a5657363fca851e0ed --- /dev/null +++ b/adaface/adaface_translate.py @@ -0,0 +1,226 @@ +from adaface.adaface_wrapper import AdaFaceWrapper +import torch +#import torch.nn.functional as F +from PIL import Image +import numpy as np +import os, argparse, glob, re, shutil + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + +def seed_everything(seed): + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PL_GLOBAL_SEED"] = str(seed) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors', + help="Path to the UNet checkpoint (default: RealisticVision 4.0)") + parser.add_argument('--adaface_ckpt_paths', type=str, nargs="+", + default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt']) + parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"], + choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders") + parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None, + choices=["arc2face", "consistentID"], + help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)") + # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face). + parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None, + help="CFG scales of output embeddings of the ID2Ada prompt encoders") + parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*", + default=[], + help="Extra paths to the checkpoints of the UNet models") + parser.add_argument('--unet_weights', type=float, nargs="+", default=[1], + help="Weights for the UNet models") + parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images") + # If True, the input folder contains images of mixed subjects. + # If False, the input folder contains multiple subfolders, each of which contains images of the same subject. + parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?", + help="Whether the input folder contains images of mixed subjects") + parser.add_argument("--max_images_per_subject", type=int, default=5, help="Number of example images used per subject") + parser.add_argument("--trans_subject_count", type=int, default=-1, help="Number of example images to be translated") + parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images") + parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image") + parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder") + parser.add_argument("--noise", dest='perturb_std', type=float, default=0) + parser.add_argument("--scale", dest='guidance_scale', type=float, default=4, + help="Guidance scale for the diffusion model") + parser.add_argument("--ref_img_strength", type=float, default=0.8, + help="Strength of the reference image in the output image.") + parser.add_argument("--subject_string", + type=str, default="z", + help="Subject placeholder string used in prompts to denote the concept.") + parser.add_argument("--prompt", type=str, default="a person z") + parser.add_argument("--num_images_per_row", type=int, default=4, + help="Number of images to display in a row in the output grid image.") + parser.add_argument("--num_inference_steps", type=int, default=50, + help="Number of DDIM inference steps") + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.") + parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on") + parser.add_argument("--seed", type=int, default=42, + help="the seed (for reproducible sampling). Set to -1 to disable.") + args = parser.parse_args() + + return args + +if __name__ == "__main__": + args = parse_args() + if args.seed != -1: + seed_everything(args.seed) + +# screen -dm -L -Logfile trans_rv4-2.txt accelerate launch --multi_gpu --num_processes=2 scripts/adaface-translate.py +# --adaface_ckpt_paths logs/subjects-celebrity2024-05-16T17-22-46_zero3-ada/checkpoints/embeddings_gs-30000.pt +# --base_model_path models/realisticvision/realisticVisionV40_v40VAE.safetensors --in_folder /path/to/VGGface2_HQ_masks/ +# --is_mix_subj_folder 0 --out_folder /path/to/VGGface2_HQ_masks_rv4a --copy_masks --num_gpus 2 + if args.num_gpus > 1: + from accelerate import PartialState + distributed_state = PartialState() + args.device = distributed_state.device + process_index = distributed_state.process_index + elif re.match(r"^\d+$", args.device): + args.device = f"cuda:{args.device}" + distributed_state = None + process_index = 0 + + adaface = AdaFaceWrapper("img2img", args.base_model_path, + args.adaface_encoder_types, args.adaface_ckpt_paths, + args.adaface_encoder_cfg_scales, args.enabled_encoders, + args.subject_string, args.num_inference_steps, + unet_types=None, + extra_unet_dirpaths=args.extra_unet_dirpaths, unet_weights=args.unet_weights, + device=args.device) + + in_folder = args.in_folder + if os.path.isfile(in_folder): + subject_folders = [ os.path.dirname(in_folder) ] + images_by_subject = [[in_folder]] + else: + if not args.is_mix_subj_folder: + in_folders = [in_folder] + else: + in_folders = [ os.path.join(in_folder, subfolder) for subfolder in sorted(os.listdir(in_folder)) ] + + images_by_subject = [] + subject_folders = [] + for in_folder in in_folders: + image_types = ["*.jpg", "*.png", "*.jpeg"] + alltype_image_paths = [] + for image_type in image_types: + # glob returns the full path. + image_paths = glob.glob(os.path.join(in_folder, image_type)) + if len(image_paths) > 0: + alltype_image_paths.extend(image_paths) + + # Filter out images of "*_mask.png" + alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path] + alltype_image_paths = sorted(alltype_image_paths) + + if not args.is_mix_subj_folder: + # image_paths contain at most args.max_images_per_subject full image paths. + if args.max_images_per_subject > 0: + image_paths = alltype_image_paths[:args.max_images_per_subject] + else: + image_paths = alltype_image_paths + + images_by_subject.append(image_paths) + subject_folders.append(in_folder) + else: + # Each image in the folder is treated as an individual subject. + images_by_subject.extend([[image_path] for image_path in alltype_image_paths]) + subject_folders.extend([in_folder] * len(alltype_image_paths)) + + if args.trans_subject_count > 0 and len(subject_folders) >= args.trans_subject_count: + break + + if args.trans_subject_count > 0: + images_by_subject = images_by_subject[:args.trans_subject_count] + subject_folders = subject_folders[:args.trans_subject_count] + + out_image_count = 0 + out_mask_count = 0 + if not args.out_folder.endswith("/"): + args.out_folder += "/" + + if args.num_gpus > 1: + # Split the subjects across the GPUs. + subject_folders = subject_folders[process_index::args.num_gpus] + images_by_subject = images_by_subject[process_index::args.num_gpus] + #subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject)) + + for (subject_folder, image_paths) in zip(subject_folders, images_by_subject): + # If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image. + # Otherwise, we use the folder name as the signature of the images. + images_sig = subject_folder if not args.is_mix_subj_folder else os.path.basename(image_paths[0]) + + print(f"Translating {images_sig}...") + with torch.no_grad(): + adaface_subj_embs = \ + adaface.prepare_adaface_embeddings(image_paths, None, + perturb_at_stage='img_prompt_emb', + perturb_std=args.perturb_std, + update_text_encoder=True) + + # Replace the first occurrence of "in_folder" with "out_folder" in the path of the subject_folder. + subject_out_folder = subject_folder.replace(args.in_folder, args.out_folder, 1) + if not os.path.exists(subject_out_folder): + os.makedirs(subject_out_folder) + print(f"Output images will be saved to {subject_out_folder}") + + in_images = [] + for image_path in image_paths: + image = Image.open(image_path).convert("RGB").resize((512, 512)) + # [512, 512, 3] -> [3, 512, 512]. + image = np.array(image).transpose(2, 0, 1) + # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU. + image = torch.tensor(image).unsqueeze(0).float().cuda() + in_images.append(image) + + # Put all input images of the subject into a batch. This assumes max_images_per_subject is small. + # NOTE: For simplicity, we do not check overly large batch sizes. + in_images = torch.cat(in_images, dim=0) + # in_images: [5, 3, 512, 512]. + # Normalize the pixel values to [0, 1]. + in_images = in_images / 255.0 + num_out_images = len(in_images) * args.out_count_per_input_image + + with torch.no_grad(): + # args.perturb_std: the *relative* std of the noise added to the face embeddings. + # A noise level of 0.08 could change gender, but 0.06 is usually safe. + # The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly. + # NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images. + out_images = adaface(in_images, args.prompt, None, 'append', args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength) + + for img_i, img in enumerate(out_images): + # out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ... + subj_i = img_i % len(in_images) + copy_i = img_i // len(in_images) + image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i])) + if copy_i == 0: + img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}")) + else: + img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}")) + + if args.copy_masks: + mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png") + if os.path.exists(mask_path): + if copy_i == 0: + shutil.copy(mask_path, subject_out_folder) + else: + mask_filename_stem = image_filename_stem + shutil.copy(mask_path, os.path.join(subject_out_folder, f"{mask_filename_stem}_{copy_i}_mask.png")) + + out_mask_count += 1 + + out_image_count += len(out_images) + + print(f"{out_image_count} output images and {out_mask_count} masks saved to {args.out_folder}") diff --git a/adaface/adaface_wrapper.py b/adaface/adaface_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..05f944145a7026a89c485dbeb6164bee22f58804 --- /dev/null +++ b/adaface/adaface_wrapper.py @@ -0,0 +1,483 @@ +import torch +import torch.nn as nn +from transformers import CLIPTextModel +from diffusers import ( + StableDiffusionPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionXLPipeline, + StableDiffusion3Pipeline, + #FluxPipeline, + DDIMScheduler, + AutoencoderKL, +) +from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint +from adaface.util import UNetEnsemble +from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder +from safetensors.torch import load_file as safetensors_load_file +import re, os +import numpy as np + +class AdaFaceWrapper(nn.Module): + def __init__(self, pipeline_name, base_model_path, adaface_encoder_types, + adaface_ckpt_paths, adaface_encoder_cfg_scales=None, + enabled_encoders=None, + subject_string='z', num_inference_steps=50, negative_prompt=None, + use_840k_vae=False, use_ds_text_encoder=False, + main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights=None, + device='cuda', is_training=False): + ''' + pipeline_name: "text2img", "text2imgxl", "img2img", "text2img3", "flux", or None. + If None, it's used only as a face encoder, and the unet and vae are + removed from the pipeline to release RAM. + ''' + super().__init__() + self.pipeline_name = pipeline_name + self.base_model_path = base_model_path + self.adaface_encoder_types = adaface_encoder_types + + self.adaface_ckpt_paths = adaface_ckpt_paths + self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales + self.enabled_encoders = enabled_encoders + self.subject_string = subject_string + + self.num_inference_steps = num_inference_steps + self.use_840k_vae = use_840k_vae + self.use_ds_text_encoder = use_ds_text_encoder + self.main_unet_filepath = main_unet_filepath + self.unet_types = unet_types + self.extra_unet_dirpaths = extra_unet_dirpaths + self.unet_weights = unet_weights + self.device = device + self.is_training = is_training + + if negative_prompt is None: + self.negative_prompt = \ + "flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \ + "mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \ + "mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \ + "nude, naked, nsfw, topless, bare breasts" + else: + self.negative_prompt = negative_prompt + + self.initialize_pipeline() + # During inference, we never use static image suffix embeddings. + # So num_id_vecs is the length of the returned adaface embeddings for each encoder. + self.encoders_num_id_vecs = self.id2ada_prompt_encoder.encoders_num_id_vecs + self.extend_tokenizer_and_text_encoder() + + def to(self, device): + self.device = device + self.id2ada_prompt_encoder.to(device) + self.pipeline.to(device) + print(f"Moved AdaFaceWrapper to {device}.") + return self + + def initialize_pipeline(self): + self.id2ada_prompt_encoder = create_id2ada_prompt_encoder(self.adaface_encoder_types, + self.adaface_ckpt_paths, + self.adaface_encoder_cfg_scales, + self.enabled_encoders) + + self.id2ada_prompt_encoder.to(self.device) + print(f"adaface_encoder_cfg_scales: {self.adaface_encoder_cfg_scales}") + + if self.use_840k_vae: + # The 840000-step vae model is slightly better in face details than the original vae model. + # https://huggingface.co/stabilityai/sd-vae-ft-mse-original + vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", + torch_dtype=torch.float16) + else: + vae = None + + if self.use_ds_text_encoder: + # The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder. + # https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder + text_encoder = CLIPTextModel.from_pretrained("models/diffusers/ds_text_encoder", + torch_dtype=torch.float16) + else: + text_encoder = None + + remove_unet = False + + if self.pipeline_name == "img2img": + PipelineClass = StableDiffusionImg2ImgPipeline + elif self.pipeline_name == "text2img": + PipelineClass = StableDiffusionPipeline + elif self.pipeline_name == "text2imgxl": + PipelineClass = StableDiffusionXLPipeline + elif self.pipeline_name == "text2img3": + PipelineClass = StableDiffusion3Pipeline + #elif self.pipeline_name == "flux": + # PipelineClass = FluxPipeline + # pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images. + elif self.pipeline_name is None: + PipelineClass = StableDiffusionPipeline + remove_unet = True + else: + raise ValueError(f"Unknown pipeline name: {self.pipeline_name}") + + if self.base_model_path is None: + base_model_path_dict = { + 'text2img': 'models/sd15-dste8-vae.safetensors', + 'text2imgxl': 'stabilityai/stable-diffusion-xl-base-1.0', + 'text2img3': 'stabilityai/stable-diffusion-3-medium-diffusers', + 'flux': 'black-forest-labs/FLUX.1-schnell', + } + self.base_model_path = base_model_path_dict[self.pipeline_name] + + if os.path.isfile(self.base_model_path): + pipeline = PipelineClass.from_single_file( + self.base_model_path, + torch_dtype=torch.float16 + ) + else: + pipeline = PipelineClass.from_pretrained( + self.base_model_path, + torch_dtype=torch.float16, + safety_checker=None + ) + + if self.main_unet_filepath is not None: + print(f"Replacing the UNet with the UNet from {self.main_unet_filepath}.") + ret = pipeline.unet.load_state_dict(self.load_unet_from_file(self.main_unet_filepath, device='cpu')) + if len(ret.missing_keys) > 0: + print(f"Missing keys: {ret.missing_keys}") + if len(ret.unexpected_keys) > 0: + print(f"Unexpected keys: {ret.unexpected_keys}") + + if (self.unet_types is not None and len(self.unet_types) > 0) \ + or (self.extra_unet_dirpaths is not None and len(self.extra_unet_dirpaths) > 0): + unet_ensemble = UNetEnsemble([pipeline.unet], self.unet_types, self.extra_unet_dirpaths, self.unet_weights, + device=self.device, torch_dtype=torch.float16) + pipeline.unet = unet_ensemble + + print(f"Loaded pipeline from {self.base_model_path}.") + + if self.use_840k_vae: + pipeline.vae = vae + print("Replaced the VAE with the 840k-step VAE.") + + if self.use_ds_text_encoder: + pipeline.text_encoder = text_encoder + print("Replaced the text encoder with the DreamShaper text encoder.") + + if remove_unet: + # Remove unet and vae to release RAM. Only keep tokenizer and text_encoder. + pipeline.unet = None + pipeline.vae = None + print("Removed UNet and VAE from the pipeline.") + + if self.pipeline_name not in ["text2imgxl", "text2img3", "flux"]: + noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + pipeline.scheduler = noise_scheduler + # Otherwise, pipeline.scheduler == FlowMatchEulerDiscreteScheduler + self.pipeline = pipeline.to(self.device) + + def load_unet_from_file(self, unet_path, device=None): + if os.path.isfile(unet_path): + if unet_path.endswith(".safetensors"): + unet_state_dict = safetensors_load_file(unet_path, device=device) + else: + unet_state_dict = torch.load(unet_path, map_location=device) + + key0 = list(unet_state_dict.keys())[0] + if key0.startswith("model.diffusion_model"): + key_prefix = "" + is_ldm_unet = True + elif key0.startswith("diffusion_model"): + key_prefix = "model." + is_ldm_unet = True + else: + is_ldm_unet = False + + if is_ldm_unet: + unet_state_dict2 = {} + for key, value in unet_state_dict.items(): + key2 = key_prefix + key + unet_state_dict2[key2] = value + print(f"LDM UNet detected. Convert to diffusers") + ldm_unet_config = { 'layers_per_block': 2 } + unet_state_dict = convert_ldm_unet_checkpoint(unet_state_dict2, ldm_unet_config) + else: + raise ValueError(f"UNet path {unet_path} is not a file.") + return unet_state_dict + + def extend_tokenizer_and_text_encoder(self): + if np.sum(self.encoders_num_id_vecs) < 1: + raise ValueError(f"encoders_num_id_vecs has to be larger or equal to 1, but is {self.encoders_num_id_vecs}") + + tokenizer = self.pipeline.tokenizer + # If adaface_encoder_types is ["arc2face", "consistentID"], then total_num_id_vecs = 20. + # We add z_0_0, z_0_1, z_0_2, ..., z_0_15, z_1_0, z_1_1, z_1_2, z_1_3 to the tokenizer. + self.all_placeholder_tokens = [] + self.placeholder_tokens_strs = [] + for i in range(len(self.adaface_encoder_types)): + placeholder_tokens = [] + for j in range(self.encoders_num_id_vecs[i]): + placeholder_tokens.append(f"{self.subject_string}_{i}_{j}") + placeholder_tokens_str = " ".join(placeholder_tokens) + + self.all_placeholder_tokens.extend(placeholder_tokens) + self.placeholder_tokens_strs.append(placeholder_tokens_str) + + self.all_placeholder_tokens_str = " ".join(self.placeholder_tokens_strs) + # all_null_placeholder_tokens_str: ", , , , ..." (20 times). + # It just contains the commas and spaces with the same length, but no actual tokens. + self.all_null_placeholder_tokens_str = " ".join([", "] * len(self.all_placeholder_tokens)) + + # Add the new tokens to the tokenizer. + num_added_tokens = tokenizer.add_tokens(self.all_placeholder_tokens) + if num_added_tokens != np.sum(self.encoders_num_id_vecs): + raise ValueError( + f"The tokenizer already contains some of the tokens {self.all_placeholder_tokens_str}. Please pass a different" + " `subject_string` that is not already in the tokenizer.") + + print(f"Added {num_added_tokens} tokens ({self.all_placeholder_tokens_str}) to the tokenizer.") + + # placeholder_token_ids: [49408, ..., 49423]. + self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.all_placeholder_tokens) + #print("New tokens:", self.placeholder_token_ids) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + old_weight_shape = self.pipeline.text_encoder.get_input_embeddings().weight.shape + self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer)) + new_weight = self.pipeline.text_encoder.get_input_embeddings().weight + print(f"Resized text encoder token embeddings from {old_weight_shape} to {new_weight.shape} on {new_weight.device}.") + + # Extend pipeline.text_encoder with the adaface subject emeddings. + # subj_embs: [16, 768]. + def update_text_encoder_subj_embeddings(self, subj_embs): + # Initialise the newly added placeholder token with the embeddings of the initializer token + # token_embeds: [49412, 768] + token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data + with torch.no_grad(): + for i, token_id in enumerate(self.placeholder_token_ids): + token_embeds[token_id] = subj_embs[i] + print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.all_placeholder_tokens_str}) in the text encoder.") + + def update_prompt(self, prompt, placeholder_tokens_pos='append', + use_null_placeholders=False): + if prompt is None: + prompt = "" + + if use_null_placeholders: + all_placeholder_tokens_str = self.all_null_placeholder_tokens_str + else: + all_placeholder_tokens_str = self.all_placeholder_tokens_str + + # Delete the subject_string from the prompt. + prompt = re.sub(r'\b(a|an|the)\s+' + self.subject_string + r'\b,?', "", prompt) + prompt = re.sub(r'\b' + self.subject_string + r'\b,?', "", prompt) + # Prevously, arc2face ada prompts work better if they are prepended to the prompt, + # and consistentID ada prompts work better if they are appended to the prompt. + # When we do joint training, seems both work better if they are appended to the prompt. + # Therefore we simply appended all placeholder_tokens_str's to the prompt. + # NOTE: Prepending them hurts compositional prompts. + if placeholder_tokens_pos == 'prepend': + prompt = all_placeholder_tokens_str + " " + prompt + elif placeholder_tokens_pos == 'append': + prompt = prompt + " " + all_placeholder_tokens_str + else: + breakpoint() + + return prompt + + # If face_id_embs is None, then it extracts face_id_embs from the images, + # then map them to ada prompt embeddings. + # avg_at_stage: 'id_emb', 'img_prompt_emb', or None. + # avg_at_stage == ada_prompt_emb usually produces the worst results. + # id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better. + def prepare_adaface_embeddings(self, image_paths, face_id_embs=None, + avg_at_stage='id_emb', # id_emb, img_prompt_emb, ada_prompt_emb, or None. + perturb_at_stage=None, # id_emb, img_prompt_emb, or None. + perturb_std=0, update_text_encoder=True): + + all_adaface_subj_embs = \ + self.id2ada_prompt_encoder.generate_adaface_embeddings(\ + image_paths, face_id_embs=face_id_embs, + img_prompt_embs=None, + avg_at_stage=avg_at_stage, + perturb_at_stage=perturb_at_stage, + perturb_std=perturb_std, + enable_static_img_suffix_embs=False) + + if all_adaface_subj_embs is None: + return None + + if all_adaface_subj_embs.ndim == 4: + # [1, 1, 16, 768] -> [16, 768] + all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0).squeeze(0) + elif all_adaface_subj_embs.ndim == 3: + # [1, 16, 768] -> [16, 768] + all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0) + + if update_text_encoder: + self.update_text_encoder_subj_embeddings(all_adaface_subj_embs) + return all_adaface_subj_embs + + def diffusers_encode_prompts(self, prompt, plain_prompt, negative_prompt, device): + # pooled_prompt_embeds_, negative_pooled_prompt_embeds_ are used by text2img3 and flux. + pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = None, None + + # Compatible with older versions of diffusers. + if not hasattr(self.pipeline, "encode_prompt"): + # prompt_embeds_, negative_prompt_embeds_: [77, 768] -> [1, 77, 768]. + prompt_embeds_, negative_prompt_embeds_ = \ + self.pipeline._encode_prompt(prompt, device=device, num_images_per_prompt=1, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt) + prompt_embeds_ = prompt_embeds_.unsqueeze(0) + negative_prompt_embeds_ = negative_prompt_embeds_.unsqueeze(0) + else: + if self.pipeline_name in ["text2imgxl", "text2img3", "flux"]: + prompt_2 = plain_prompt + # CLIP Text Encoder prompt uses a maximum sequence length of 77. + # T5 Text Encoder prompt uses a maximum sequence length of 256. + # 333 = 256 + 77. + prompt_t5 = prompt + "".join([", "] * 256) + + # prompt_embeds_, negative_prompt_embeds_: [1, 333, 4096] + # pooled_prompt_embeds_, negative_pooled_prompt_embeds_: [1, 2048] + if self.pipeline_name == "text2imgxl": + prompt_embeds_, negative_prompt_embeds_, \ + pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \ + self.pipeline.encode_prompt(prompt, prompt_2, device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt) + elif self.pipeline_name == "text2img3": + prompt_embeds_, negative_prompt_embeds_, \ + pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \ + self.pipeline.encode_prompt(prompt, prompt_2, prompt_t5, device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt) + elif self.pipeline_name == "flux": + # prompt_embeds_: [1, 512, 4096] + # pooled_prompt_embeds_: [1, 768] + prompt_embeds_, pooled_prompt_embeds_, text_ids = \ + self.pipeline.encode_prompt(prompt, prompt_t5, device=device, + num_images_per_prompt=1) + negative_prompt_embeds_ = negative_pooled_prompt_embeds_ = None + else: + breakpoint() + else: + # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768] + prompt_embeds_, negative_prompt_embeds_ = \ + self.pipeline.encode_prompt(prompt, device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt) + + return prompt_embeds_, negative_prompt_embeds_, \ + pooled_prompt_embeds_, negative_pooled_prompt_embeds_ + + def encode_prompt(self, prompt, negative_prompt=None, + placeholder_tokens_pos='append', + do_neg_id_prompt_weight=0, + device=None, verbose=False): + if negative_prompt is None: + negative_prompt = self.negative_prompt + + if device is None: + device = self.device + + plain_prompt = prompt + prompt = self.update_prompt(prompt, placeholder_tokens_pos=placeholder_tokens_pos) + if verbose: + print(f"Subject prompt:\n{prompt}") + + if do_neg_id_prompt_weight > 0: + # Use 'prepend' for the negative prompt, since it's long and we want to make sure + # the placeholder tokens are not cut off. + negative_prompt0 = negative_prompt + negative_prompt = self.update_prompt(negative_prompt0, placeholder_tokens_pos='prepend') + null_negative_prompt = self.update_prompt(negative_prompt0, placeholder_tokens_pos='prepend', + use_null_placeholders=True) + ''' if verbose: + print(f"Negative prompt:\n{negative_prompt}") + print(f"Null negative prompt:\n{null_negative_prompt}") + + ''' + else: + null_negative_prompt = None + + # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device). + # So we manually move it to GPU here. + self.pipeline.text_encoder.to(device) + + prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \ + self.diffusers_encode_prompts(prompt, plain_prompt, negative_prompt, device) + + if 0 < do_neg_id_prompt_weight < 1: + _, negative_prompt_embeds_null, _, _ = \ + self.diffusers_encode_prompts(prompt, plain_prompt, null_negative_prompt, device) + negative_prompt_embeds_ = negative_prompt_embeds_ * do_neg_id_prompt_weight + \ + negative_prompt_embeds_null * (1 - do_neg_id_prompt_weight) + + return prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ + + # ref_img_strength is used only in the img2img pipeline. + def forward(self, noise, prompt, negative_prompt=None, + placeholder_tokens_pos='append', + do_neg_id_prompt_weight=0, + guidance_scale=6.0, out_image_count=4, + ref_img_strength=0.8, generator=None, verbose=False): + noise = noise.to(device=self.device, dtype=torch.float16) + + if negative_prompt is None: + negative_prompt = self.negative_prompt + # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768] + prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \ + negative_pooled_prompt_embeds_ = \ + self.encode_prompt(prompt, negative_prompt, + placeholder_tokens_pos=placeholder_tokens_pos, + do_neg_id_prompt_weight=do_neg_id_prompt_weight, + device=self.device, verbose=verbose) + # Repeat the prompt embeddings for all images in the batch. + prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1) + if negative_prompt_embeds_ is not None: + negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1) + + if self.pipeline_name in ["text2imgxl", "text2img3"]: + pooled_prompt_embeds_ = pooled_prompt_embeds_.repeat(out_image_count, 1) + negative_pooled_prompt_embeds_ = negative_pooled_prompt_embeds_.repeat(out_image_count, 1) + + # noise: [BS, 4, 64, 64] + # When the pipeline is text2img, strength is ignored. + images = self.pipeline(prompt_embeds=prompt_embeds_, + negative_prompt_embeds=negative_prompt_embeds_, + pooled_prompt_embeds=pooled_prompt_embeds_, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds_, + num_inference_steps=self.num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=1, + generator=generator).images + elif self.pipeline_name == "flux": + images = self.pipeline(prompt_embeds=prompt_embeds_, + pooled_prompt_embeds=pooled_prompt_embeds_, + num_inference_steps=4, + guidance_scale=guidance_scale, + num_images_per_prompt=1, + generator=generator).images + else: + # When the pipeline is text2img, noise: [BS, 4, 64, 64], and strength is ignored. + # When the pipeline is img2img, noise is an initiali image of [BS, 3, 512, 512], + # whose pixels are normalized to [0, 1]. + images = self.pipeline(image=noise, + prompt_embeds=prompt_embeds_, + negative_prompt_embeds=negative_prompt_embeds_, + num_inference_steps=self.num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=1, + strength=ref_img_strength, + generator=generator).images + # images: [BS, 3, 512, 512] + return images + \ No newline at end of file diff --git a/adaface/arc2face_models.py b/adaface/arc2face_models.py new file mode 100644 index 0000000000000000000000000000000000000000..73e1c5d8383fb148ead9afc201c91816b96830fd --- /dev/null +++ b/adaface/arc2face_models.py @@ -0,0 +1,382 @@ +import torch +import torch.nn as nn +from transformers import CLIPTextModel +from transformers.models.clip.modeling_clip import CLIPAttention +from typing import Optional, Tuple, Union +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from diffusers import ( + StableDiffusionPipeline, + UNet2DConditionModel, + DDIMScheduler, +) +# from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask +_make_causal_mask = AttentionMaskConverter._make_causal_mask +_expand_mask = AttentionMaskConverter._expand_mask + +from .util import perturb_tensor + +def create_arc2face_pipeline(base_model_path="models/sd15-dste8-vae.safetensors", + dtype=torch.float16, unet_only=False): + unet = UNet2DConditionModel.from_pretrained( + 'models/arc2face', subfolder="arc2face", torch_dtype=dtype + ) + if unet_only: + return unet + + text_encoder = CLIPTextModelWrapper.from_pretrained( + 'models/arc2face', subfolder="encoder", torch_dtype=dtype + ) + noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + pipeline = StableDiffusionPipeline.from_single_file( + base_model_path, + text_encoder=text_encoder, + unet=unet, + torch_dtype=dtype, + safety_checker=None + ) + pipeline.scheduler = noise_scheduler + return pipeline + +# Extend CLIPAttention by using multiple k_proj and v_proj in each head. +# To avoid too much increase of computation, we don't extend q_proj. +class CLIPAttentionMKV(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config, multiplier=2): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.multiplier = multiplier + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + # The (approximately) repeated token features are repeated along the last dim in tensor + # (multiplier * num_heads * head_dim), and then reshaped to (bsz, -1, num_heads, head_dim). + # Therefore, the "multiplier" dim is tucked into the seq_len dim, which looks like + # [token1_emb, token1_emb, token2_emb, token2_emb, ..., tokenN_emb, tokenN_emb]. + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + # clip_attn_layer is usually self. + def extend_weights(self, clip_attn_layer, layer_idx, multiplier, perturb_std=0.2, + perturb_std_is_relative=True, perturb_keep_norm=False, verbose=False): + ORIG_V_SHAPE = list(clip_attn_layer.v_proj.weight.shape) + ORIG_V_SHAPE_D0 = ORIG_V_SHAPE[0] + ORIG_K_SHAPE = list(clip_attn_layer.k_proj.weight.shape) + ORIG_K_SHAPE_D0 = ORIG_K_SHAPE[0] + + self.multiplier *= multiplier + + # q_proj and out_proj are the same as the original CLIPAttention. + self.q_proj.weight.data = clip_attn_layer.q_proj.weight.data.clone() + self.q_proj.bias.data = clip_attn_layer.q_proj.bias.data.clone() + self.out_proj.weight.data = clip_attn_layer.out_proj.weight.data.clone() + self.out_proj.bias.data = clip_attn_layer.out_proj.bias.data.clone() + + # bias doesn't need noise perturbation, as after the weights are noised, + # different copies of the weight/bias will receive different gradients, + # making the bias terms diverge and identifiable after training. + self.k_proj.bias.data = clip_attn_layer.k_proj.bias.data.repeat(multiplier) + self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.repeat(multiplier) + + self.k_proj.weight.data = clip_attn_layer.k_proj.weight.data.repeat(multiplier, 1) + self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.repeat(multiplier, 1) + + # Correct the out_features attribute of k_proj and v_proj. + self.k_proj.out_features = self.k_proj.weight.shape[0] + self.v_proj.out_features = self.v_proj.weight.shape[0] + + if perturb_std > 0: + # Adding noise to the extra copies of the weights (keep the first copy unchanged). + self.v_proj.weight.data[ORIG_V_SHAPE_D0:] = \ + perturb_tensor(self.v_proj.weight.data[ORIG_V_SHAPE_D0:], + perturb_std, perturb_std_is_relative, perturb_keep_norm, verbose=verbose) + if verbose: + NEW_V_SHAPE = list(self.v_proj.weight.shape) + NOISED_V_SHAPE = list(self.v_proj.weight.data[ORIG_V_SHAPE_D0:].shape) + print(f"Layer {layer_idx}: {NOISED_V_SHAPE} in {NEW_V_SHAPE} of v_proj is added with {perturb_std} noise") + + # Adding noise to the extra copies of the weights. + self.k_proj.weight.data[ORIG_K_SHAPE_D0:] = \ + perturb_tensor(self.k_proj.weight.data[ORIG_K_SHAPE_D0:], + perturb_std, perturb_std_is_relative, perturb_keep_norm, verbose=verbose) + if verbose: + NEW_K_SHAPE = list(self.k_proj.weight.shape) + NOISED_K_SHAPE = list(self.k_proj.weight.data[ORIG_K_SHAPE_D0:].shape) + print(f"Layer {layer_idx}: {NOISED_K_SHAPE} in {NEW_K_SHAPE} of k_proj is added with {perturb_std} noise") + + def squeeze_weights(self, clip_attn_layer, divisor): + if self.multiplier % divisor != 0: + breakpoint() + self.multiplier //= divisor + + self.k_proj.bias.data = clip_attn_layer.k_proj.bias.data.reshape(divisor, -1).mean(dim=0) + self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.reshape(divisor, -1).mean(dim=0) + + self.k_proj.weight.data = clip_attn_layer.k_proj.weight.data.reshape(divisor, -1, self.k_proj.weight.shape[1]).mean(dim=0) + self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.reshape(divisor, -1, self.v_proj.weight.shape[1]).mean(dim=0) + + # Correct the out_features attribute of k_proj and v_proj. + self.k_proj.out_features = self.k_proj.weight.shape[0] + self.v_proj.out_features = self.v_proj.weight.shape[0] + + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + query_states = self.q_proj(hidden_states) * self.scale + # For key_states and value_states, the multiplier is absorbed into the seq_len (dim 1, shape specified as -1). + # [token0_head_emb, token0_head_emb, token1_head_emb, token1_head_emb, ..., tokenN-1_head_emb, tokenN-1_head_emb]. + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + # src_len0 is the original src_len without the multiplier. + src_len0 = src_len // self.multiplier + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2).contiguous()) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len0): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is" + f" {causal_attention_mask.size()}" + ) + # The last dim of attn_weights corresponds to [token0, token0, token1, token1, ..., tokenN-1, tokenN-1]. + # If reshaping it as (self.multiplier, src_len0), it will become + # [[token0, token0, token1, token1, ..., tokenN//2], [tokenN//2+1, tokenN//2+1, ..., tokenN-1, tokenN-1]], + # and the mask will be applied to wrong elements. + # If reshaping it as (src_len0, self.multiplier), it will become + # [[token0, token1, ..., tokenN-1], [token0, token1, ..., tokenN-1]], and then + # the mask at element i will mask all the multiplier elements at i, which is desired. + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + causal_attention_mask.unsqueeze(4) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len0): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + attention_mask.unsqueeze(4) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + +class CLIPTextModelWrapper(CLIPTextModel): + # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812 + # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them. + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + input_token_embs: Optional[torch.Tensor] = None, + hidden_state_layer_weights: Optional[torch.Tensor] = None, + return_token_embs: Optional[bool] = False, + ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]: + + if return_token_embs: + return self.text_model.embeddings.token_embedding(input_ids) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states + ) + if hidden_state_layer_weights is not None: + output_hidden_states = True + return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.text_model.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + # output_hidden_states is False by default, and only True if hidden_state_layer_weights is provided. + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # If output_hidden_states is True, then encoder_outputs[0] is last_hidden_state [1, 22, 768]. + # encoder_outputs[1] is hidden_states, which is a tuple of 13 hidden states, each being [1, 22, 768]. + # encoder_outputs[0] == encoder_outputs[1][12]. + if hidden_state_layer_weights is None: + last_hidden_state = encoder_outputs[0] + else: + num_hidden_state_layers = len(hidden_state_layer_weights) + last_hidden_states = encoder_outputs[1][-num_hidden_state_layers:] + hidden_state_layer_weights = hidden_state_layer_weights.to(last_hidden_states[0].dtype) + # Normalize the weights of to sum to 1 across layers. + # hidden_state_layer_weights: [3, 1] or [3, 768]. + hidden_state_layer_weights = hidden_state_layer_weights / hidden_state_layer_weights.sum(dim=0, keepdim=True) + # [3, 1/768] -> [3, 1, 1, 1/768] + hidden_state_layer_weights = hidden_state_layer_weights.unsqueeze(1).unsqueeze(1) + # A weighted sum of last_hidden_states. + # [3, 1, 22, 768] * [3, 1, 1, 1/768] -> [3, 1, 22, 768] -> [1, 22, 768] + last_hidden_state = (torch.stack(last_hidden_states, dim=0) * hidden_state_layer_weights).sum(dim=0) + + last_hidden_state = self.text_model.final_layer_norm(last_hidden_state) + + # self.text_model.eos_token_id == 2 is True. + if self.text_model.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id) + .int() + .argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + # Applied to all attention layers in the encoder, if the corresponding multiplier is not 1. + # The layer indexed by end_layer_idx is not included. + # If both layer indices are -1, then apply to all layers (0-11). + def extend_clip_attention_MKV_multiplier(self, prompt2token_proj_attention_multipliers, perturb_std=0.1): + num_extended_layers = 0 + + for layer_idx, layer in enumerate(self.text_model.encoder.layers): + multiplier = prompt2token_proj_attention_multipliers[layer_idx] + if multiplier == 1: + continue + # This shouldn't happen, unless self_attn has already been extended as CLIPAttentionMKV. + if not isinstance(layer.self_attn, (CLIPAttention, CLIPAttentionMKV)): + breakpoint() + old_attn_layer = layer.self_attn + if not isinstance(old_attn_layer, CLIPAttentionMKV): + layer.self_attn = CLIPAttentionMKV(old_attn_layer.config, 1) + # Extends the v_proj and k_proj weights in the self_attn layer. + layer.self_attn.extend_weights(old_attn_layer, layer_idx, multiplier, perturb_std, verbose=True) + num_extended_layers += 1 + + return num_extended_layers + + # Applied to layers [begin_layer_idx, end_layer_idx) in the encoder. + # The layer indexed by end_layer_idx is not included. + # If both layer indices are -1, then apply to all layers (0-11). + def squeeze_clip_attention_MKV_divisor(self, prompt2token_proj_attention_divisors): + num_squeezed_layers = 0 + + for layer_idx, layer in enumerate(self.text_model.encoder.layers): + divisor = prompt2token_proj_attention_divisors[layer_idx] + if divisor == 1: + continue + # This shouldn't happen, unless self_attn has already been extended as CLIPAttentionMKV. + if not isinstance(layer.self_attn, (CLIPAttention, CLIPAttentionMKV)): + breakpoint() + old_attn_layer = layer.self_attn + if not isinstance(old_attn_layer, CLIPAttentionMKV): + layer.self_attn = CLIPAttentionMKV(old_attn_layer.config, 1) + # Squeeze the k_proj and v_proj weights in the self_attn layer. + layer.self_attn.squeeze_weights(old_attn_layer, divisor) + num_squeezed_layers += 1 + + return num_squeezed_layers diff --git a/adaface/face_id_to_ada_prompt.py b/adaface/face_id_to_ada_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0f1d4b48fb46c95d31e62eeb27862a689cab84 --- /dev/null +++ b/adaface/face_id_to_ada_prompt.py @@ -0,0 +1,1175 @@ +import torch +import torch.nn as nn +from transformers import CLIPTokenizer, CLIPImageProcessor +from .arc2face_models import CLIPTextModelWrapper +from ConsistentID.lib.pipeline_ConsistentID import ConsistentIDPipeline +from .util import perturb_tensor, pad_image_obj_to_square, \ + calc_stats, patch_clip_image_encoder_with_mask, CLIPVisionModelWithMask +from adaface.subj_basis_generator import SubjBasisGenerator +import torch.nn.functional as F +import numpy as np +import cv2 +from PIL import Image +from insightface.app import FaceAnalysis +import os +from omegaconf.listconfig import ListConfig + +# adaface_encoder_types can be a list of one or more encoder types. +# adaface_ckpt_paths can be one or a list of ckpt paths. +# adaface_encoder_cfg_scales is None, or a list of scales for the adaface encoder types. +def create_id2ada_prompt_encoder(adaface_encoder_types, adaface_ckpt_paths=None, + adaface_encoder_cfg_scales=None, enabled_encoders=None, + *args, **kwargs): + if len(adaface_encoder_types) == 1: + adaface_encoder_type = adaface_encoder_types[0] + adaface_ckpt_path = adaface_ckpt_paths[0] if adaface_ckpt_paths is not None else None + if adaface_encoder_type == 'arc2face': + id2ada_prompt_encoder = \ + Arc2Face_ID2AdaPrompt(adaface_ckpt_path=adaface_ckpt_path, + *args, **kwargs) + elif adaface_encoder_type == 'consistentID': + id2ada_prompt_encoder = \ + ConsistentID_ID2AdaPrompt(pipe=None, + adaface_ckpt_path=adaface_ckpt_path, + *args, **kwargs) + else: + id2ada_prompt_encoder = Joint_FaceID2AdaPrompt(adaface_encoder_types, adaface_ckpt_paths, + adaface_encoder_cfg_scales, enabled_encoders, + *args, **kwargs) + + return id2ada_prompt_encoder + +class FaceID2AdaPrompt(nn.Module): + # To be initialized in derived classes. + def __init__(self, *args, **kwargs): + super().__init__() + # Initialize model components. + # These components of ConsistentID_ID2AdaPrompt will be shared with the teacher model. + # So we don't initialize them in the ctor(), but borrow them from the teacher model. + # These components of Arc2Face_ID2AdaPrompt will be initialized in its ctor(). + self.clip_image_encoder = None + self.clip_preprocessor = None + self.face_app = None + self.text_to_image_prompt_encoder = None + self.tokenizer = None + self.dtype = kwargs.get('dtype', torch.float16) + + # Load Img2Ada SubjectBasisGenerator. + self.subject_string = kwargs.get('subject_string', 'z') + self.adaface_ckpt_path = kwargs.get('adaface_ckpt_path', None) + self.subj_basis_generator = None + # -1: use the default scale for the adaface encoder type. + # i.e., 6 for arc2face and 1 for consistentID. + self.out_id_embs_cfg_scale = kwargs.get('out_id_embs_cfg_scale', -1) + self.is_training = kwargs.get('is_training', False) + # extend_prompt2token_proj_attention_multiplier is an integer >= 1. + # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers. + self.extend_prompt2token_proj_attention_multiplier = kwargs.get('extend_prompt2token_proj_attention_multiplier', 1) + self.prompt2token_proj_ext_attention_perturb_ratio = kwargs.get('prompt2token_proj_ext_attention_perturb_ratio', 0.1) + + # Set model behavior configurations. + self.gen_neg_img_prompt = False + self.clip_neg_features = None + + self.use_clip_embs = False + self.do_contrast_clip_embs_on_bg_features = False + # num_id_vecs is the output embeddings of the ID2ImgPrompt module. + # If there's no static image suffix embeddings, then num_id_vecs is also + # the number of ada embeddings returned by the subject basis generator. + # num_id_vecs will be set in each derived class. + self.num_static_img_suffix_embs = kwargs.get('num_static_img_suffix_embs', 0) + print(f'{self.name} Adaface uses {self.num_id_vecs} ID image embeddings and {self.num_static_img_suffix_embs} fixed image embeddings as input.') + + self.id_img_prompt_max_length = 77 + self.face_id_dim = 512 + # clip_embedding_dim: by default it's the OpenAI CLIP embedding dim. + # Could be overridden by derived classes. + self.clip_embedding_dim = 1024 + self.output_dim = 768 + + def get_id2img_learnable_modules(self): + raise NotImplementedError + + def load_id2img_learnable_modules(self, id2img_learnable_modules_state_dict_list): + id2img_prompt_encoder_learnable_modules = self.get_id2img_learnable_modules() + for module, state_dict in zip(id2img_prompt_encoder_learnable_modules, id2img_learnable_modules_state_dict_list): + module.load_state_dict(state_dict) + print(f'{len(id2img_prompt_encoder_learnable_modules)} ID2ImgPrompt encoder modules loaded.') + + # init_subj_basis_generator() can only be called after the derived class is initialized, + # when self.num_id_vecs, self.num_static_img_suffix_embs and self.clip_embedding_dim have been set. + def init_subj_basis_generator(self): + self.subj_basis_generator = \ + SubjBasisGenerator(num_id_vecs = self.num_id_vecs, + num_static_img_suffix_embs = self.num_static_img_suffix_embs, + bg_image_embedding_dim = self.clip_embedding_dim, + output_dim = self.output_dim, + placeholder_is_bg = False, + prompt2token_proj_grad_scale = 1, + bg_prompt_translator_has_to_out_proj=False) + + def load_adaface_ckpt(self, adaface_ckpt_path): + ckpt = torch.load(adaface_ckpt_path, map_location='cpu') + string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"] + if self.subject_string not in string_to_subj_basis_generator_dict: + print(f"Subject '{self.subject_string}' not found in the embedding manager.") + breakpoint() + + ckpt_subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string] + ckpt_subj_basis_generator.N_ID = self.num_id_vecs + # Since we directly use the subject basis generator object from the ckpt, + # fixing the number of static image suffix embeddings is much simpler. + # Otherwise if we want to load the subject basis generator from its state_dict, + # things are more complicated, see embedding manager's load(). + ckpt_subj_basis_generator.N_SFX = self.num_static_img_suffix_embs + # obj_proj_in and pos_embs are for non-faces. So they are useless for human faces. + ckpt_subj_basis_generator.obj_proj_in = None + ckpt_subj_basis_generator.pos_embs = None + # Handle differences in num_static_img_suffix_embs between the current model and the ckpt. + ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.num_static_img_suffix_embs, img_prompt_dim=self.output_dim) + # Fix missing variables in old ckpt. + ckpt_subj_basis_generator.patch_old_subj_basis_generator_ckpt() + + self.subj_basis_generator.extend_prompt2token_proj_attention(\ + ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0) + ret = self.subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict(), strict=False) + print(f"{adaface_ckpt_path}: subject basis generator loaded for '{self.name}'.") + print(repr(ckpt_subj_basis_generator)) + + if ret is not None and len(ret.missing_keys) > 0: + print(f"Missing keys: {ret.missing_keys}") + if ret is not None and len(ret.unexpected_keys) > 0: + print(f"Unexpected keys: {ret.unexpected_keys}") + + # extend_prompt2token_proj_attention_multiplier is an integer >= 1. + # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers. + # If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict, + # extend subj_basis_generator again. + if self.extend_prompt2token_proj_attention_multiplier > 1: + # During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt. + # During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1. + # During inference, prompt2token_proj_ext_attention_perturb_ratio == 0. + self.subj_basis_generator.extend_prompt2token_proj_attention(\ + None, -1, -1, self.extend_prompt2token_proj_attention_multiplier, + perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio) + + self.subj_basis_generator.freeze_prompt2token_proj() + + @torch.no_grad() + def get_clip_neg_features(self, BS): + if self.clip_neg_features is None: + # neg_pixel_values: [1, 3, 224, 224]. clip_neg_features is invariant to the actual image. + neg_pixel_values = torch.zeros([1, 3, 224, 224], device=self.clip_image_encoder.device, dtype=self.dtype) + # Precompute CLIP negative features for the negative image prompt. + self.clip_neg_features = self.clip_image_encoder(neg_pixel_values, attn_mask=None, output_hidden_states=True).hidden_states[-2] + + clip_neg_features = self.clip_neg_features.repeat(BS, 1, 1) + return clip_neg_features + + # image_objs: a list of np array / tensor / Image objects of different sizes [Hi, Wi]. + # If image_objs is a list of tensors, then each tensor should be [3, Hi, Wi]. + # If image_objs is None, then image_paths should be provided, + # and image_objs will be loaded from image_paths. + # fg_masks: None, or a list of [Hi, Wi]. + def extract_init_id_embeds_from_images(self, image_objs, image_paths, fg_masks=None, + size=(512, 512), calc_avg=False, + skip_non_faces=True, return_clip_embs=None, + do_contrast_clip_embs_on_bg_features=None, + verbose=False): + # If return_clip_embs or do_contrast_clip_embs_on_bg_features is not provided, + # then use their default values. + if return_clip_embs is None: + return_clip_embs = self.use_clip_embs + if do_contrast_clip_embs_on_bg_features is None: + do_contrast_clip_embs_on_bg_features = self.do_contrast_clip_embs_on_bg_features + + # clip_image_encoder should be already put on GPU. + # So its .device is the device of its parameters. + device = self.clip_image_encoder.device + + image_pixel_values = [] + all_id_embs = [] + faceless_img_count = 0 + + if image_objs is None and image_paths is not None: + image_objs = [] + for image_path in image_paths: + image_obj = Image.open(image_path) + image_objs.append(image_obj) + print(f'Loaded {len(image_objs)} images from {image_paths[0]}...') + + # image_objs could be a batch of images that have been collated into a tensor or np array. + # image_objs can also be a list of images. + # The code below that processes them one by one can be applied in both cases. + # If image_objs are a collated batch, processing them one by one will not add much overhead. + for idx, image_obj in enumerate(image_objs): + if return_clip_embs: + # input to clip_preprocessor: an image or a batch of images, each being PIL.Image.Image, numpy.ndarray, + # torch.Tensor, tf.Tensor or jax.ndarray. + # Different sizes of images are standardized to the same size 224*224. + clip_image_pixel_values = self.clip_preprocessor(images=image_obj, return_tensors="pt").pixel_values + image_pixel_values.append(clip_image_pixel_values) + + # Convert tensor to numpy array. + if isinstance(image_obj, torch.Tensor): + image_obj = image_obj.cpu().numpy().transpose(1, 2, 0) + if isinstance(image_obj, np.ndarray): + image_obj = Image.fromarray(image_obj) + # Resize image_obj to (512, 512). The scheme is Image.NEAREST, to be consistent with + # PersonalizedBase dataset class. + image_obj, _, _ = pad_image_obj_to_square(image_obj) + image_np = np.array(image_obj.resize(size, Image.NEAREST)) + face_info = self.face_app.get(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) + if len(face_info) > 0: + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face + # id_emb: [512,] + id_emb = torch.from_numpy(face_info.normed_embedding) + else: + faceless_img_count += 1 + print(f'No face detected in {image_paths[idx]}.', end=' ') + if not skip_non_faces: + print('Replace with random face embedding.') + # During training, use a random tensor as the face embedding. + id_emb = torch.randn(512) + else: + print(f'Skip.') + continue + + all_id_embs.append(id_emb) + + if verbose: + print(f'{len(all_id_embs)} face images identified, {faceless_img_count} faceless images.') + + # No face is detected in the input images. + if len(all_id_embs) == 0: + return faceless_img_count, None, None + + # all_id_embs: [BS, 512]. + all_id_embs = torch.stack(all_id_embs, dim=0).to(device=device, dtype=torch.float16) + + if return_clip_embs: + # image_pixel_values: [BS, 3, 224, 224] + image_pixel_values = torch.cat(image_pixel_values, dim=0) + image_pixel_values = image_pixel_values.to(device=device, dtype=torch.float16) + + if fg_masks is not None: + assert len(fg_masks) == len(image_objs) + # fg_masks is a list of masks. + if isinstance(fg_masks, (list, tuple)): + fg_masks2 = [] + for fg_mask in fg_masks: + # fg_mask: [Hi, Wi] + # BUG: clip_preprocessor will do central crop on images. But fg_mask is not central cropped. + # If the ref image is not square, then the fg_mask will not match the image. + # TODO: crop fg_mask and images to square before calling extract_init_id_embeds_from_images(). + # fg_mask2: [Hi, Wi] -> [1, 1, 224, 224] + fg_mask2 = torch.tensor(fg_mask, device=device, dtype=torch.float16).unsqueeze(0).unsqueeze(0) + fg_mask2 = F.interpolate(fg_mask2, size=image_pixel_values.shape[-2:], mode='bilinear', align_corners=False) + fg_masks2.append(fg_mask2) + # fg_masks2: [BS, 224, 224] + fg_masks2 = torch.cat(fg_masks2, dim=0).squeeze(1) + else: + # fg_masks is a collated batch of masks. + # The actual size doesn't matter, + # as fg_mask2 will be resized to the same size as image features + # (much smaller than image_pixel_values). + fg_masks2 = fg_masks.to(device=device, dtype=torch.float16).unsqueeze(1) + # F.interpolate() always return a copy, even if scale_factor=1. So we don't need to clone fg_masks2. + fg_masks2 = F.interpolate(fg_masks2, size=image_pixel_values.shape[-2:], mode='bilinear', align_corners=False) + fg_masks2 = fg_masks2.squeeze(1) + else: + # fg_mask2: [BS, 224, 224]. + fg_masks2 = torch.ones_like(image_pixel_values[:, 0, :, :], device=device, dtype=torch.float16) + + clip_neg_features = self.get_clip_neg_features(BS=image_pixel_values.shape[0]) + + with torch.no_grad(): + # image_fg_features: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds). + image_fg_dict = self.clip_image_encoder(image_pixel_values, attn_mask=fg_masks2, output_hidden_states=True) + # attn_mask: [BS, 1, 257] + image_fg_features = image_fg_dict.hidden_states[-2] + if image_fg_dict.attn_mask is not None: + image_fg_features = image_fg_features * image_fg_dict.attn_mask + + # A negative mask is used to extract the background features. + # If fg_masks is None, then fg_masks2 is all ones, and bg masks is all zeros. + # Therefore, all pixels are masked. The extracted image_bg_features will be + # meaningless in this case. + image_bg_dict = self.clip_image_encoder(image_pixel_values, attn_mask=1-fg_masks2, output_hidden_states=True) + image_bg_features = image_bg_dict.hidden_states[-2] + # Subtract the feature bias (null features) from the bg features, to highlight the useful bg features. + if do_contrast_clip_embs_on_bg_features: + image_bg_features = image_bg_features - clip_neg_features + if image_bg_dict.attn_mask is not None: + image_bg_features = image_bg_features * image_bg_dict.attn_mask + + # clip_fgbg_features: [BS, 514, 1280]. 514 = 257*2. + # all_id_embs: [BS, 512]. + clip_fgbg_features = torch.cat([image_fg_features, image_bg_features], dim=1) + else: + clip_fgbg_features = None + clip_neg_features = None + + if calc_avg: + if return_clip_embs: + # clip_fgbg_features: [BS, 514, 1280] -> [1, 514, 1280]. + # all_id_embs: [BS, 512] -> [1, 512]. + clip_fgbg_features = clip_fgbg_features.mean(dim=0, keepdim=True) + clip_neg_features = clip_neg_features.mean(dim=0, keepdim=True) + + debug = False + if debug and all_id_embs is not None: + print(image_paths) + calc_stats('all_id_embs', all_id_embs) + # Compute pairwise similarities of the embeddings. + all_id_embs = F.normalize(all_id_embs, p=2, dim=1) + pairwise_sim = torch.matmul(all_id_embs, all_id_embs.t()) + print('pairwise_sim:', pairwise_sim) + top_dir = os.path.dirname(image_paths[0]) + mean_emb_path = os.path.join(top_dir, "mean_emb.pt") + if os.path.exists(mean_emb_path): + mean_emb = torch.load(mean_emb_path) + sim_to_mean = torch.matmul(all_id_embs, mean_emb.t()) + print('sim_to_mean:', sim_to_mean) + + if all_id_embs is not None: + id_embs = all_id_embs.mean(dim=0, keepdim=True) + # Without normalization, id_embs.norm(dim=1) is ~0.9. So normalization doesn't have much effect. + id_embs = F.normalize(id_embs, p=2, dim=-1) + # id_embs is None only if insightface_app is None, i.e., disabled by the user. + else: + # Don't do average of all_id_embs. + id_embs = all_id_embs + + return faceless_img_count, id_embs, clip_fgbg_features + + # This function should be implemented in derived classes. + # We don't plan to fine-tune the ID2ImgPrompt module. So disable the gradient computation. + def map_init_id_to_img_prompt_embs(self, init_id_embs, + clip_features=None, + called_for_neg_img_prompt=False): + raise NotImplementedError + + # If init_id_embs/pre_clip_features is provided, then use the provided face embeddings. + # Otherwise, if image_paths/image_objs are provided, extract face embeddings from the images. + # Otherwise, we generate random face embeddings [id_batch_size, 512]. + def get_img_prompt_embs(self, init_id_embs, pre_clip_features, image_paths, image_objs, + id_batch_size, + skip_non_faces=True, + avg_at_stage=None, # id_emb, img_prompt_emb, or None. + perturb_at_stage=None, # id_emb, img_prompt_emb, or None. + perturb_std=0.0, + verbose=False): + face_image_count = 0 + device = self.clip_image_encoder.device + clip_neg_features = self.get_clip_neg_features(BS=id_batch_size) + + if init_id_embs is None: + # Input images are not provided. Generate random face embeddings. + if image_paths is None and image_objs is None: + faceid_embeds_from_images = False + # Use random face embeddings as faceid_embeds. [BS, 512]. + faceid_embeds = torch.randn(id_batch_size, 512).to(device=device, dtype=torch.float16) + # Since it's a batch of random IDs, the CLIP features are all zeros as a placeholder. + # Only ConsistentID_ID2AdaPrompt will use clip_fgbg_features and clip_neg_features. + # Experiments show that using random clip features yields much better images than using zeros. + clip_fgbg_features = torch.randn(id_batch_size, 514, 1280).to(device=device, dtype=torch.float16) \ + if self.use_clip_embs else None + else: + # Extract face ID embeddings and CLIP features from the images. + faceid_embeds_from_images = True + faceless_img_count, faceid_embeds, clip_fgbg_features \ + = self.extract_init_id_embeds_from_images( \ + image_objs, image_paths=image_paths, size=(512, 512), + calc_avg=(avg_at_stage == 'id_emb'), + skip_non_faces=skip_non_faces, + verbose=verbose) + + if image_paths is not None: + face_image_count = len(image_paths) - faceless_img_count + else: + face_image_count = len(image_objs) - faceless_img_count + else: + faceid_embeds_from_images = False + # Use the provided init_id_embs as faceid_embeds. + faceid_embeds = init_id_embs + if pre_clip_features is not None: + clip_fgbg_features = pre_clip_features + else: + clip_fgbg_features = None + + if faceid_embeds.shape[0] == 1: + faceid_embeds = faceid_embeds.repeat(id_batch_size, 1) + if clip_fgbg_features is not None: + clip_fgbg_features = clip_fgbg_features.repeat(id_batch_size, 1, 1) + + # If skip_non_faces, then faceid_embeds won't be None. + # Otherwise, if faceid_embeds_from_images, and no face images are detected, + # then we return Nones. + if faceid_embeds is None: + return face_image_count, None, None, None + + if perturb_at_stage == 'id_emb' and perturb_std > 0: + # If id_batch_size > 1, after adding noises, the id_batch_size embeddings will be different. + faceid_embeds = perturb_tensor(faceid_embeds, perturb_std, perturb_std_is_relative=True, keep_norm=True) + if self.name == 'consistentID' or self.name == 'jointIDs': + clip_fgbg_features = perturb_tensor(clip_fgbg_features, perturb_std, perturb_std_is_relative=True, keep_norm=True) + + faceid_embeds = F.normalize(faceid_embeds, p=2, dim=-1) + + # pos_prompt_embs, neg_prompt_embs: [BS, 77, 768] or [BS, 22, 768]. + with torch.no_grad(): + pos_prompt_embs = \ + self.map_init_id_to_img_prompt_embs(faceid_embeds, clip_fgbg_features, + called_for_neg_img_prompt=False) + + if avg_at_stage == 'img_prompt_emb': + pos_prompt_embs = pos_prompt_embs.mean(dim=0, keepdim=True) + faceid_embeds = faceid_embeds.mean(dim=0, keepdim=True) + if clip_fgbg_features is not None: + clip_fgbg_features = clip_fgbg_features.mean(dim=0, keepdim=True) + + if perturb_at_stage == 'img_prompt_emb' and perturb_std > 0: + # NOTE: for simplicity, pos_prompt_embs and pos_core_prompt_emb are perturbed independently. + # This could cause inconsistency between pos_prompt_embs and pos_core_prompt_emb. + # But in practice, unless we use both pos_prompt_embs and pos_core_prompt_emb + # this is not an issue. But we rarely use pos_prompt_embs and pos_core_prompt_emb together. + pos_prompt_embs = perturb_tensor(pos_prompt_embs, perturb_std, perturb_std_is_relative=True, keep_norm=True) + + # If faceid_embeds_from_images, and the prompt embeddings are already averaged, then + # we assume all images are from the same subject, and the batch dim of faceid_embeds is 1. + # So we need to repeat faceid_embeds. + if faceid_embeds_from_images and avg_at_stage is not None: + faceid_embeds = faceid_embeds.repeat(id_batch_size, 1) + pos_prompt_embs = pos_prompt_embs.repeat(id_batch_size, 1, 1) + if clip_fgbg_features is not None: + clip_fgbg_features = clip_fgbg_features.repeat(id_batch_size, 1, 1) + + if self.gen_neg_img_prompt: + # Never perturb the negative prompt embeddings. + with torch.no_grad(): + neg_prompt_embs = \ + self.map_init_id_to_img_prompt_embs(torch.zeros_like(faceid_embeds), + clip_neg_features, + called_for_neg_img_prompt=True) + + return face_image_count, faceid_embeds, pos_prompt_embs, neg_prompt_embs + else: + return face_image_count, faceid_embeds, pos_prompt_embs, None + + # get_batched_img_prompt_embs() is a wrapper of get_img_prompt_embs() + # which is convenient for batched training. + # NOTE: get_batched_img_prompt_embs() should only be called during training. + # It is a wrapper of get_img_prompt_embs() which is convenient for batched training. + # If init_id_embs is None, generate random face embeddings [BS, 512]. + # Returns faceid_embeds, id2img_prompt_emb. + def get_batched_img_prompt_embs(self, batch_size, init_id_embs, pre_clip_features): + # pos_prompt_embs, neg_prompt_embs are generated without gradient computation. + # So we don't need to worry that the teacher model weights are updated. + return self.get_img_prompt_embs(init_id_embs=init_id_embs, + pre_clip_features=pre_clip_features, + image_paths=None, + image_objs=None, + id_batch_size=batch_size, + # During training, don't skip non-face images. Instead, + # setting skip_non_faces=False will replace them by random face embeddings. + skip_non_faces=False, + # We always assume the instances belong to different subjects. + # So never average the embeddings across instances. + avg_at_stage=None, + verbose=False) + + # If img_prompt_embs is provided, we use it directly. + # Otherwise, if face_id_embs is provided, we use it to generate img_prompt_embs. + # Otherwise, if image_paths is provided, we extract face_id_embs from the images. + # image_paths: a list of image paths. image_folder: the parent folder name. + # avg_at_stage: 'id_emb', 'img_prompt_emb', or None. + # avg_at_stage == ada_prompt_emb usually produces the worst results. + # avg_at_stage == id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better. + # p_dropout and return_zero_embs_for_dropped_encoders are only used by Joint_FaceID2AdaPrompt. + def generate_adaface_embeddings(self, image_paths, face_id_embs=None, img_prompt_embs=None, + p_dropout=0, + return_zero_embs_for_dropped_encoders=True, + avg_at_stage='id_emb', # id_emb, img_prompt_emb, or None. + perturb_at_stage=None, # id_emb, img_prompt_emb, or None. + perturb_std=0, enable_static_img_suffix_embs=False): + if (avg_at_stage is None) or avg_at_stage.lower() == 'none': + img_prompt_avg_at_stage = None + else: + img_prompt_avg_at_stage = avg_at_stage + + if img_prompt_embs is None: + # Do averaging. So id_batch_size becomes 1 after averaging. + if img_prompt_avg_at_stage is not None: + id_batch_size = 1 + else: + if face_id_embs is not None: + id_batch_size = face_id_embs.shape[0] + elif image_paths is not None: + id_batch_size = len(image_paths) + else: + id_batch_size = 1 + + # faceid_embeds: [BS, 512] is a batch of extracted face analysis embeddings. NOT used later. + # NOTE: If face_id_embs, image_paths and image_objs are all None, + # then get_img_prompt_embs() generates random faceid_embeds/img_prompt_embs, + # and each instance is different. + # Otherwise, if face_id_embs is provided, it's used. + # If not, image_paths/image_objs are used to extract face embeddings. + # img_prompt_embs is in the image prompt space. + # img_prompt_embs: [BS, 16/4, 768]. + face_image_count, faceid_embeds, img_prompt_embs, neg_img_prompt_embs \ + = self.get_img_prompt_embs(\ + init_id_embs=face_id_embs, + pre_clip_features=None, + # image_folder is passed only for logging purpose. + # image_paths contains the paths of the images. + image_paths=image_paths, image_objs=None, + id_batch_size=id_batch_size, + perturb_at_stage=perturb_at_stage, + perturb_std=perturb_std, + avg_at_stage=img_prompt_avg_at_stage, + verbose=True) + + if face_image_count == 0: + return None + + # No matter whether avg_at_stage is id_emb or img_prompt_emb, we average img_prompt_embs. + elif avg_at_stage is not None and avg_at_stage.lower() != 'none': + # img_prompt_embs: [BS, 16/4, 768] -> [1, 16/4, 768]. + img_prompt_embs = img_prompt_embs.mean(dim=0, keepdim=True) + + # adaface_subj_embs: [BS, 16/4, 768]. + adaface_subj_embs = \ + self.subj_basis_generator(img_prompt_embs, clip_features=None, raw_id_embs=None, + out_id_embs_cfg_scale=self.out_id_embs_cfg_scale, + is_face=True, + enable_static_img_suffix_embs=enable_static_img_suffix_embs) + # During training, img_prompt_avg_at_stage is None, and BS >= 1. + # During inference, img_prompt_avg_at_stage is 'id_emb' or 'img_prompt_emb', and BS == 1. + if img_prompt_avg_at_stage is not None: + # adaface_subj_embs: [1, 16, 768] -> [16, 768] + adaface_subj_embs = adaface_subj_embs.squeeze(0) + + return adaface_subj_embs + +class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt): + def __init__(self, *args, **kwargs): + self.name = 'arc2face' + self.num_id_vecs = 16 + + super().__init__(*args, **kwargs) + + self.clip_image_encoder = CLIPVisionModelWithMask.from_pretrained('openai/clip-vit-large-patch14') + self.clip_preprocessor = CLIPImageProcessor.from_pretrained('openai/clip-vit-large-patch14') + self.clip_image_encoder.eval() + if self.dtype == torch.float16: + self.clip_image_encoder.half() + print(f'CLIP image encoder loaded.') + + ''' + {'landmark_3d_68': , + 'landmark_2d_106': , + 'detection': , + 'genderage': , + 'recognition': } + ''' + # Use the same model as ID2AdaPrompt does. + # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2. + # Note there's a second "model" in the path. + # Note DON'T use CUDAExecutionProvider, as it will hang DDP training. + # Seems when loading insightface onto the GPU, it will only reside on the first GPU. + # Then the process on the second GPU has issue to communicate with insightface on the first GPU, causing hanging. + self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', + providers=['CPUExecutionProvider']) + self.face_app.prepare(ctx_id=0, det_size=(512, 512)) + print(f'Face encoder loaded on CPU.') + + self.text_to_image_prompt_encoder = CLIPTextModelWrapper.from_pretrained( + 'models/arc2face', subfolder="encoder", + torch_dtype=self.dtype + ) + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + if self.out_id_embs_cfg_scale == -1: + self.out_id_embs_cfg_scale = 1 + #### Arc2Face pipeline specific configs #### + self.gen_neg_img_prompt = False + # bg CLIP features are used by the bg subject basis generator. + self.use_clip_embs = True + self.do_contrast_clip_embs_on_bg_features = True + # self.num_static_img_suffix_embs is initialized in the parent class. + self.id_img_prompt_max_length = 22 + self.clip_embedding_dim = 1024 + + self.init_subj_basis_generator() + if self.adaface_ckpt_path is not None: + self.load_adaface_ckpt(self.adaface_ckpt_path) + + print(f"{self.name} ada prompt encoder initialized, " + f"ID vecs: {self.num_id_vecs}, static suffix: {self.num_static_img_suffix_embs}.") + + # Arc2Face_ID2AdaPrompt never uses clip_features or called_for_neg_img_prompt. + def map_init_id_to_img_prompt_embs(self, init_id_embs, + clip_features=None, + called_for_neg_img_prompt=False): + + ''' + self.text_to_image_prompt_encoder: arc2face_models.py:CLIPTextModelWrapper instance. + init_id_embs: (N, 512) normalized Face ID embeddings. + ''' + + # arcface_token_id: 1014 + arcface_token_id = self.tokenizer.encode("id", add_special_tokens=False)[0] + + # This step should be quite fast, and there's no need to cache the input_ids. + input_ids = self.tokenizer( + "photo of a id person", + truncation=True, + padding="max_length", + # In Arc2Face_ID2AdaPrompt, id_img_prompt_max_length is 22. + # Arc2Face's image prompt is meanlingless in tokens other than ID tokens. + max_length=self.id_img_prompt_max_length, + return_tensors="pt", + ).input_ids.to(init_id_embs.device) + # input_ids: [1, 22] or [3, 22] (during training). + input_ids = input_ids.repeat(len(init_id_embs), 1) + init_id_embs = init_id_embs.to(self.dtype) + # face_embs_padded: [1, 512] -> [1, 768]. + face_embs_padded = F.pad(init_id_embs, (0, self.text_to_image_prompt_encoder.config.hidden_size - init_id_embs.shape[-1]), "constant", 0) + # self.text_to_image_prompt_encoder(input_ids=input_ids, ...) is called twice. The first is only to get the token embeddings (the shallowest mapping). + # The second call does the ordinary CLIP text encoding pass. + token_embs = self.text_to_image_prompt_encoder(input_ids=input_ids, return_token_embs=True) + token_embs[input_ids==arcface_token_id] = face_embs_padded + + prompt_embeds = self.text_to_image_prompt_encoder( + input_ids=input_ids, + input_token_embs=token_embs, + return_token_embs=False + )[0] + + # Restore the original dtype of prompt_embeds: float16 -> float32. + prompt_embeds = prompt_embeds.to(self.dtype) + + # token 4: 'id' in "photo of a id person". + # 4:20 are the most important 16 embeddings that contain the subject's identity. + # [N, 22, 768] -> [N, 16, 768] + return prompt_embeds[:, 4:20] + + def get_id2img_learnable_modules(self): + return [ self.text_to_image_prompt_encoder ] + +# ConsistentID_ID2AdaPrompt is just a wrapper of ConsistentIDPipeline, so it's not an nn.Module. +class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt): + def __init__(self, pipe=None, base_model_path="models/sd15-dste8-vae.safetensors", + *args, **kwargs): + self.name = 'consistentID' + self.num_id_vecs = 4 + + super().__init__(*args, **kwargs) + if pipe is None: + # The base_model_path is kind of arbitrary, as the UNet and VAE in the model + # are not used and will be released soon. + # Only the consistentID modules and bise_net are used. + assert base_model_path is not None, "base_model_path should be provided." + # Avoid passing dtype to ConsistentIDPipeline.from_single_file(), + # because we've overloaded .to() to convert consistentID specific modules as well, + # but diffusers will call .to(dtype) in .from_single_file(), + # and at that moment, the consistentID specific modules are not loaded yet. + pipe = ConsistentIDPipeline.from_single_file(base_model_path) + pipe.load_ConsistentID_model(consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin", + bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth") + pipe.to(dtype=self.dtype) + # Since the passed-in pipe is None, this should be called during inference, + # when the teacher ConsistentIDPipeline is not initialized. + # Therefore, we release VAE, UNet and text_encoder to save memory. + pipe.release_components(["unet", "vae"]) + + # Otherwise, we share the pipeline with the teacher. + # So we don't release the components. + self.pipe = pipe + self.face_app = pipe.face_app + # ConsistentID uses 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K'. + self.clip_image_encoder = patch_clip_image_encoder_with_mask(pipe.clip_encoder) + self.clip_preprocessor = pipe.clip_preprocessor + self.text_to_image_prompt_encoder = pipe.text_encoder + self.tokenizer = pipe.tokenizer + self.image_proj_model = pipe.image_proj_model + + self.clip_image_encoder.eval() + self.image_proj_model.eval() + if self.dtype == torch.float16: + self.clip_image_encoder.half() + self.image_proj_model.half() + + if self.out_id_embs_cfg_scale == -1: + self.out_id_embs_cfg_scale = 6 + #### ConsistentID pipeline specific configs #### + # self.num_static_img_suffix_embs is initialized in the parent class. + self.gen_neg_img_prompt = True + self.use_clip_embs = True + self.do_contrast_clip_embs_on_bg_features = True + self.clip_embedding_dim = 1280 + self.s_scale = 1.0 + self.shortcut = False + + self.init_subj_basis_generator() + if self.adaface_ckpt_path is not None: + self.load_adaface_ckpt(self.adaface_ckpt_path) + + print(f"{self.name} ada prompt encoder initialized, " + f"ID vecs: {self.num_id_vecs}, static suffix: {self.num_static_img_suffix_embs}.") + + def map_init_id_to_img_prompt_embs(self, init_id_embs, + clip_features=None, + called_for_neg_img_prompt=False): + assert init_id_embs is not None, "init_id_embs should be provided." + + init_id_embs = init_id_embs.to(self.dtype) + clip_features = clip_features.to(self.dtype) + + if not called_for_neg_img_prompt: + # clip_features: [BS, 514, 1280]. + # clip_features is provided when the function is called within + # ConsistentID_ID2AdaPrompt:extract_init_id_embeds_from_images(), which is + # image_fg_features and image_bg_features concatenated at dim=1. + # Therefore, we split clip_image_double_embeds into image_fg_features and image_bg_features. + # image_bg_features is not used in ConsistentID_ID2AdaPrompt. + image_fg_features, image_bg_features = clip_features.chunk(2, dim=1) + # clip_image_embeds: [BS, 257, 1280]. + clip_image_embeds = image_fg_features + else: + # clip_features is the negative image features. So we don't need to split it. + clip_image_embeds = clip_features + init_id_embs = torch.zeros_like(init_id_embs) + + faceid_embeds = init_id_embs + # image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings. + # clip_image_embeds are used as queries to transform faceid_embeds. + # faceid_embeds -> kv, clip_image_embeds -> q + if faceid_embeds.shape[0] != clip_image_embeds.shape[0]: + breakpoint() + + try: + global_id_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=self.shortcut, scale=self.s_scale) + except: + breakpoint() + + return global_id_embeds + + def get_id2img_learnable_modules(self): + return [ self.image_proj_model ] + +# A wrapper for combining multiple FaceID2AdaPrompt instances. +class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt): + def __init__(self, adaface_encoder_types, adaface_ckpt_paths, + out_id_embs_cfg_scales=None, enabled_encoders=None, + *args, **kwargs): + self.name = 'jointIDs' + assert len(adaface_encoder_types) > 0, "adaface_encoder_types should not be empty." + adaface_encoder_types2num_id_vecs = { 'arc2face': 16, 'consistentID': 4 } + self.encoders_num_id_vecs = [ adaface_encoder_types2num_id_vecs[encoder_type] \ + for encoder_type in adaface_encoder_types ] + self.num_id_vecs = sum(self.encoders_num_id_vecs) + # super() sets self.is_training. + super().__init__(*args, **kwargs) + + self.num_sub_encoders = len(adaface_encoder_types) + self.id2ada_prompt_encoders = nn.ModuleList() + self.encoders_num_static_img_suffix_embs = [] + + # TODO: apply adaface_encoder_cfg_scales to influence the final prompt embeddings. + # Now they are just placeholders. + if out_id_embs_cfg_scales is None: + # -1: use the default scale for the adaface encoder type. + # i.e., 6 for arc2face and 1 for consistentID. + self.out_id_embs_cfg_scales = [-1] * self.num_sub_encoders + else: + # Do not normalize the weights, and just use them as is. + self.out_id_embs_cfg_scales = out_id_embs_cfg_scales + + # Note we don't pass the adaface_ckpt_paths to the base class, but instead, + # we load them once and for all in self.load_adaface_ckpt(). + for i, encoder_type in enumerate(adaface_encoder_types): + kwargs['out_id_embs_cfg_scale'] = self.out_id_embs_cfg_scales[i] + if encoder_type == 'arc2face': + encoder = Arc2Face_ID2AdaPrompt(*args, **kwargs) + elif encoder_type == 'consistentID': + encoder = ConsistentID_ID2AdaPrompt(*args, **kwargs) + else: + breakpoint() + self.id2ada_prompt_encoders.append(encoder) + self.encoders_num_static_img_suffix_embs.append(encoder.num_static_img_suffix_embs) + + self.num_static_img_suffix_embs = sum(self.encoders_num_static_img_suffix_embs) + # No need to set gen_neg_img_prompt, as we don't access it in this class, but rather + # in the derived classes. + # self.gen_neg_img_prompt = True + # self.use_clip_embs = True + # self.do_contrast_clip_embs_on_bg_features = True + self.face_id_dims = [encoder.face_id_dim for encoder in self.id2ada_prompt_encoders] + self.face_id_dim = sum(self.face_id_dims) + # Different adaface encoders may have different clip_embedding_dim. + # clip_embedding_dim is only used for bg subject basis generator. + # Here we use the joint clip embeddings of both OpenAI CLIP and laion CLIP. + # Therefore, the clip_embedding_dim is the sum of the clip_embedding_dims of all adaface encoders. + self.clip_embedding_dims = [encoder.clip_embedding_dim for encoder in self.id2ada_prompt_encoders] + self.clip_embedding_dim = sum(self.clip_embedding_dims) + + # The ctors of the derived classes have already initialized encoder.subj_basis_generator. + # If subj_basis_generator expansion params are specified, they are equally applied to all adaface encoders. + # This self.subj_basis_generator is not meant to be called as self.subj_basis_generator(), but instead, + # it's used as a unified interface to save/load the subj_basis_generator of all adaface encoders. + self.subj_basis_generator = \ + nn.ModuleList( [encoder.subj_basis_generator for encoder \ + in self.id2ada_prompt_encoders] ) + + # load_adaface_ckpt() loads into self.subj_basis_generator. So we need to initialize self.subj_basis_generator first. + if adaface_ckpt_paths is not None: + self.load_adaface_ckpt(adaface_ckpt_paths) + + print(f"{self.name} ada prompt encoder initialized with {self.num_sub_encoders} sub-encoders. " + f"ID vecs: {self.num_id_vecs}, static suffix embs: {self.num_static_img_suffix_embs}.") + + if enabled_encoders is not None: + self.are_encoders_enabled = \ + torch.tensor([True if encoder_type in enabled_encoders else False \ + for encoder_type in adaface_encoder_types]) + if not self.are_encoders_enabled.any(): + print(f"All encoders are disabled, which shoudn't happen.") + breakpoint() + if self.are_encoders_enabled.sum() < self.num_sub_encoders: + disabled_encoders = [ encoder_type for i, encoder_type in enumerate(adaface_encoder_types) \ + if not self.are_encoders_enabled[i] ] + print(f"{len(disabled_encoders)} encoders are disabled: {disabled_encoders}.") + else: + self.are_encoders_enabled = \ + torch.tensor([True] * self.num_sub_encoders) + + for i, encoder in enumerate(self.id2ada_prompt_encoders): + if not (self.is_training and self.are_encoders_enabled[i]): + for param in encoder.parameters(): + param.requires_grad = False + else: + for param in encoder.parameters(): + param.requires_grad = True + + def load_adaface_ckpt(self, adaface_ckpt_paths): + # If only one adaface ckpt path is provided, then we assume it's the ckpt of the Joint_FaceID2AdaPrompt, + # so we dereference the list to get the actual path and load the subj_basis_generators of all adaface encoders. + if isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)): + if len(adaface_ckpt_paths) == 1 and self.num_sub_encoders > 1: + adaface_ckpt_paths = adaface_ckpt_paths[0] + + if isinstance(adaface_ckpt_paths, str): + # This is only applicable to newest ckpts of Joint_FaceID2AdaPrompt, where + # the ckpt_subj_basis_generator is an nn.ModuleList of multiple subj_basis_generators. + # Therefore, no need to patch missing variables. + ckpt = torch.load(adaface_ckpt_paths, map_location='cpu') + string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"] + if self.subject_string not in string_to_subj_basis_generator_dict: + print(f"Subject '{self.subject_string}' not found in the embedding manager.") + breakpoint() + + ckpt_subj_basis_generators = string_to_subj_basis_generator_dict[self.subject_string] + if len(ckpt_subj_basis_generators) != self.num_sub_encoders: + print(f"Number of subj_basis_generators in the ckpt ({len(ckpt_subj_basis_generators)}) " + f"doesn't match the number of adaface encoders ({self.num_sub_encoders}).") + breakpoint() + + for i, subj_basis_generator in enumerate(self.subj_basis_generator): + ckpt_subj_basis_generator = ckpt_subj_basis_generators[i] + # Handle differences in num_static_img_suffix_embs between the current model and the ckpt. + ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.encoders_num_static_img_suffix_embs[i], + img_prompt_dim=self.output_dim) + + if subj_basis_generator.prompt2token_proj_attention_multipliers \ + == [1] * 12: + subj_basis_generator.extend_prompt2token_proj_attention(\ + ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0) + elif subj_basis_generator.prompt2token_proj_attention_multipliers \ + != ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers: + raise ValueError("Inconsistent prompt2token_proj_attention_multipliers.") + + assert subj_basis_generator.prompt2token_proj_attention_multipliers \ + == ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, \ + "Inconsistent prompt2token_proj_attention_multipliers." + subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict()) + + # extend_prompt2token_proj_attention_multiplier is an integer >= 1. + # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers. + # If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict, + # extend subj_basis_generator again. + if self.extend_prompt2token_proj_attention_multiplier > 1: + # During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt. + # During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1. + # During inference, prompt2token_proj_ext_attention_perturb_ratio == 0. + subj_basis_generator.extend_prompt2token_proj_attention(\ + None, -1, -1, self.extend_prompt2token_proj_attention_multiplier, + perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio) + + subj_basis_generator.freeze_prompt2token_proj() + + print(f"{adaface_ckpt_paths}: {len(self.subj_basis_generator)} subj_basis_generators loaded for {self.name}.") + + elif isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)): + for i, ckpt_path in enumerate(adaface_ckpt_paths): + self.id2ada_prompt_encoders[i].load_adaface_ckpt(ckpt_path) + else: + breakpoint() + + def extract_init_id_embeds_from_images(self, *args, **kwargs): + total_faceless_img_count = 0 + all_id_embs = [] + all_clip_fgbg_features = [] + id_embs_shape = None + clip_fgbg_features_shape = None + # clip_image_encoder should be already put on GPU. + # So its .device is the device of its parameters. + device = self.id2ada_prompt_encoders[0].clip_image_encoder.device + + for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders): + faceless_img_count, id_embs, clip_fgbg_features = \ + id2ada_prompt_encoder.extract_init_id_embeds_from_images(*args, **kwargs) + total_faceless_img_count += faceless_img_count + # id_embs: [BS, 512] or [1, 512] (if calc_avg == True), or None. + # id_embs has the same shape across all id2ada_prompt_encoders. + all_id_embs.append(id_embs) + # clip_fgbg_features: [BS, 514, 1280/1024] or [1, 514, 1280/1024] (if calc_avg == True), or None. + # clip_fgbg_features has the same shape except for the last dimension across all id2ada_prompt_encoders. + all_clip_fgbg_features.append(clip_fgbg_features) + if id_embs is not None: + id_embs_shape = id_embs.shape + if clip_fgbg_features is not None: + clip_fgbg_features_shape = clip_fgbg_features.shape + + num_extracted_id_embs = 0 + for i in range(len(all_id_embs)): + if all_id_embs[i] is not None: + # As calc_avg is the same for all id2ada_prompt_encoders, + # each id_embs and clip_fgbg_features should have the same shape, if they are not None. + if all_id_embs[i].shape != id_embs_shape: + print("Inconsistent ID embedding shapes.") + breakpoint() + else: + num_extracted_id_embs += 1 + else: + all_id_embs[i] = torch.zeros(id_embs_shape, dtype=torch.float16, device=device) + + clip_fgbg_features_shape2 = torch.Size(clip_fgbg_features_shape[:-1] + (self.clip_embedding_dims[i],)) + if all_clip_fgbg_features[i] is not None: + if all_clip_fgbg_features[i].shape != clip_fgbg_features_shape2: + print("Inconsistent clip features shapes.") + breakpoint() + else: + all_clip_fgbg_features[i] = torch.zeros(clip_fgbg_features_shape2, + dtype=torch.float16, device=device) + + # If at least one face encoder detects faces, then return the embeddings. + # Otherwise return None embeddings. + # It's possible that some face encoders detect faces, while others don't, + # since different face encoders use different face detection models. + if num_extracted_id_embs == 0: + return 0, None, None + + all_id_embs = torch.cat(all_id_embs, dim=1) + # clip_fgbg_features: [BS, 514, 1280] or [BS, 514, 1024]. So we concatenate them along dim=2. + all_clip_fgbg_features = torch.cat(all_clip_fgbg_features, dim=2) + return total_faceless_img_count, all_id_embs, all_clip_fgbg_features + + # init_id_embs, clip_features are never None. + def map_init_id_to_img_prompt_embs(self, init_id_embs, + clip_features=None, + called_for_neg_img_prompt=False): + if init_id_embs is None or clip_features is None: + breakpoint() + + # each id_embs and clip_fgbg_features should have the same shape. + # If some of them were None, they have been replaced by zero embeddings. + all_init_id_embs = init_id_embs.split(self.face_id_dims, dim=1) + all_clip_features = clip_features.split(self.clip_embedding_dims, dim=2) + all_img_prompt_embs = [] + + for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders): + img_prompt_embs = id2ada_prompt_encoder.map_init_id_to_img_prompt_embs( + all_init_id_embs[i], clip_features=all_clip_features[i], + called_for_neg_img_prompt=called_for_neg_img_prompt, + ) + all_img_prompt_embs.append(img_prompt_embs) + + all_img_prompt_embs = torch.cat(all_img_prompt_embs, dim=1) + return all_img_prompt_embs + + # If init_id_embs/pre_clip_features is provided, then use the provided face embeddings. + # Otherwise, if image_paths/image_objs are provided, extract face embeddings from the images. + # Otherwise, we generate random face embeddings [id_batch_size, 512]. + def get_img_prompt_embs(self, init_id_embs, pre_clip_features, *args, **kwargs): + face_image_counts = [] + all_faceid_embeds = [] + all_pos_prompt_embs = [] + all_neg_prompt_embs = [] + faceid_embeds_shape = None + # clip_image_encoder should be already put on GPU. + # So its .device is the device of its parameters. + device = self.id2ada_prompt_encoders[0].clip_image_encoder.device + + # init_id_embs, pre_clip_features could be None. If they are None, + # we split them into individual vectors for each id2ada_prompt_encoder. + if init_id_embs is not None: + all_init_id_embs = init_id_embs.split(self.face_id_dims, dim=1) + else: + all_init_id_embs = [None] * self.num_sub_encoders + if pre_clip_features is not None: + all_pre_clip_features = pre_clip_features.split(self.clip_embedding_dims, dim=2) + else: + all_pre_clip_features = [None] * self.num_sub_encoders + + faceid_embeds_shape = None + for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders): + face_image_count, faceid_embeds, pos_prompt_embs, neg_prompt_embs = \ + id2ada_prompt_encoder.get_img_prompt_embs(all_init_id_embs[i], all_pre_clip_features[i], + *args, **kwargs) + face_image_counts.append(face_image_count) + all_faceid_embeds.append(faceid_embeds) + all_pos_prompt_embs.append(pos_prompt_embs) + all_neg_prompt_embs.append(neg_prompt_embs) + # all faceid_embeds have the same shape across all id2ada_prompt_encoders. + # But pos_prompt_embs and neg_prompt_embs may have different number of ID embeddings. + if faceid_embeds is not None: + faceid_embeds_shape = faceid_embeds.shape + + if faceid_embeds_shape is None: + return 0, None, None, None + + # We take the maximum face_image_count among all adaface encoders. + face_image_count = max(face_image_counts) + BS = faceid_embeds.shape[0] + + for i in range(len(all_faceid_embeds)): + if all_faceid_embeds[i] is not None: + if all_faceid_embeds[i].shape != faceid_embeds_shape: + print("Inconsistent face embedding shapes.") + breakpoint() + else: + all_faceid_embeds[i] = torch.zeros(faceid_embeds_shape, dtype=torch.float16, device=device) + + N_ID = self.encoders_num_id_vecs[i] + if all_pos_prompt_embs[i] is None: + # Both pos_prompt_embs and neg_prompt_embs have N_ID == num_id_vecs embeddings. + all_pos_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device) + if all_neg_prompt_embs[i] is None: + all_neg_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device) + + all_faceid_embeds = torch.cat(all_faceid_embeds, dim=1) + all_pos_prompt_embs = torch.cat(all_pos_prompt_embs, dim=1) + all_neg_prompt_embs = torch.cat(all_neg_prompt_embs, dim=1) + + return face_image_count, all_faceid_embeds, all_pos_prompt_embs, all_neg_prompt_embs + + # We don't need to implement get_batched_img_prompt_embs() since the interface + # is fully compatible with FaceID2AdaPrompt.get_batched_img_prompt_embs(). + + def generate_adaface_embeddings(self, image_paths, face_id_embs=None, + img_prompt_embs=None, p_dropout=0, + return_zero_embs_for_dropped_encoders=True, + *args, **kwargs): + # clip_image_encoder should be already put on GPU. + # So its .device is the device of its parameters. + device = self.id2ada_prompt_encoders[0].clip_image_encoder.device + is_emb_averaged = kwargs.get('avg_at_stage', None) is not None + BS = -1 + + if face_id_embs is not None: + BS = face_id_embs.shape[0] + all_face_id_embs = face_id_embs.split(self.face_id_dims, dim=1) + else: + all_face_id_embs = [None] * self.num_sub_encoders + if img_prompt_embs is not None: + BS = img_prompt_embs.shape[0] if BS == -1 else BS + if img_prompt_embs.shape[1] != self.num_id_vecs: + breakpoint() + all_img_prompt_embs = img_prompt_embs.split(self.encoders_num_id_vecs, dim=1) + else: + all_img_prompt_embs = [None] * self.num_sub_encoders + if image_paths is not None: + BS = len(image_paths) if BS == -1 else BS + if BS == -1: + breakpoint() + + # During training, p_dropout is 0.1. During inference, p_dropout is 0. + # When there are two sub-encoders, the prob of one encoder being dropped is + # p_dropout * 2 - p_dropout^2 = 0.18. + if p_dropout > 0: + # self.are_encoders_enabled is a global mask. + # are_encoders_enabled is a local mask for each batch. + are_encoders_enabled = torch.rand(self.num_sub_encoders) < p_dropout + are_encoders_enabled = are_encoders_enabled & self.are_encoders_enabled + # We should at least enable one encoder. + if not are_encoders_enabled.any(): + # Randomly enable an encoder with self.are_encoders_enabled[i] == True. + enabled_indices = torch.nonzero(self.are_encoders_enabled).squeeze(1) + sel_idx = torch.randint(0, len(enabled_indices), (1,)).item() + are_encoders_enabled[enabled_indices[sel_idx]] = True + else: + are_encoders_enabled = self.are_encoders_enabled + + self.curr_are_encoders_enabled = are_encoders_enabled + all_adaface_subj_embs = [] + num_available_id_vecs = 0 + + for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders): + if not are_encoders_enabled[i]: + adaface_subj_embs = None + print(f"Encoder {id2ada_prompt_encoder.name} is dropped.") + else: + # ddpm.embedding_manager.train() -> id2ada_prompt_encoder.train() -> each sub-enconder's train(). + # -> each sub-enconder's subj_basis_generator.train(). + # Therefore grad for the following call is enabled. + adaface_subj_embs = \ + id2ada_prompt_encoder.generate_adaface_embeddings(image_paths, + all_face_id_embs[i], + all_img_prompt_embs[i], + *args, **kwargs) + + # adaface_subj_embs: [16, 768] or [4, 768]. + N_ID = self.encoders_num_id_vecs[i] + if adaface_subj_embs is None: + if not return_zero_embs_for_dropped_encoders: + continue + else: + subj_emb_shape = (N_ID, 768) if is_emb_averaged else (BS, N_ID, 768) + # adaface_subj_embs is zero-filled. So N_ID is not counted as available subject embeddings. + adaface_subj_embs = torch.zeros(subj_emb_shape, dtype=torch.float16, device=device) + all_adaface_subj_embs.append(adaface_subj_embs) + else: + all_adaface_subj_embs.append(adaface_subj_embs) + num_available_id_vecs += N_ID + + # No faces are found in the images, so return None embeddings. + # We don't want to return an all-zero embedding, which is useless. + if num_available_id_vecs == 0: + return None + + # If id2ada_prompt_encoders are ["arc2face", "consistentID"], then + # during inference, we average across the batch dim. + # all_adaface_subj_embs[0]: [4, 768]. all_adaface_subj_embs[1]: [16, 768]. + # all_adaface_subj_embs: [20, 768]. + # during training, we don't average across the batch dim. + # all_adaface_subj_embs[0]: [BS, 4, 768]. all_adaface_subj_embs[1]: [BS, 16, 768]. + # all_adaface_subj_embs: [BS, 20, 768]. + all_adaface_subj_embs = torch.cat(all_adaface_subj_embs, dim=-2) + return all_adaface_subj_embs + + +''' +# For ip-adapter distillation on objects. Strictly speaking, it's not face-to-image prompts, but +# CLIP/DINO visual features to image prompts. +class Objects_Vis2ImgPrompt(nn.Module): + def __init__(self): + self.dino_encoder = ViTModel.from_pretrained('facebook/dino-vits16') + self.dino_encoder.eval() + self.dino_encoder.half() + self.dino_preprocess = ViTFeatureExtractor.from_pretrained('facebook/dino-vits16') + print(f'DINO encoder loaded.') + +''' diff --git a/adaface/subj_basis_generator.py b/adaface/subj_basis_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..d8944461435e46d03a953078b85430fdaeafb87e --- /dev/null +++ b/adaface/subj_basis_generator.py @@ -0,0 +1,868 @@ +# Borrowed from ip-adapter resampler.py. +# https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py + +import math + +import torch +from torch import nn +from einops import rearrange +from einops.layers.torch import Rearrange +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig + +from torch import einsum +from adaface.util import gen_gradient_scaler +from adaface.arc2face_models import CLIPTextModelWrapper + +def reshape_tensor(x, num_heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, num_heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2).contiguous() + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, num_heads, length, -1) + return x + +# FFN. Added a Dropout layer at the end, so that it can still load the old ckpt. +def FeedForward(dim, mult=4, p_dropout=0.1): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + nn.Dropout(p_dropout), + ) + +# IP-Adapter FaceID class. Only used in knn-faces.py. +# From: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid_separate.py +class IP_MLPProjModel(nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = nn.Sequential( + nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + nn.GELU(), + nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x + +# group_dim: the tensor dimension that corresponds to the multiple groups. +class LearnedSoftAggregate(nn.Module): + def __init__(self, num_feat, group_dim, keepdim=False): + super(LearnedSoftAggregate, self).__init__() + self.group_dim = group_dim + # num_feat = 1: element-wise score function & softmax. + # num_feat > 1: the linear score function is applied to the last dim (features) of the input tensor. + self.num_feat = num_feat + self.feat2score = nn.Linear(num_feat, 1, bias=False) + self.keepdim = keepdim + + def forward(self, x, score_basis=None): + # If there's only one mode, do nothing. + if x.shape[self.group_dim] == 1: + if self.keepdim: + return x + else: + return x.squeeze(self.group_dim) + + # Assume the last dim of x is the feature dim. + if score_basis is None: + score_basis = x + + if self.num_feat == 1: + mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1) + else: + mode_scores = self.feat2score(score_basis) + attn_probs = mode_scores.softmax(dim=self.group_dim) + x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim) + return x_aggr + +def LoRA_ExpandEmbs(input_dim, lora_rank, output_dim, num_modes, + num_output_vecs, elementwise_affine=True, p_dropout=0.1): + return nn.Sequential( + # Project to [BS, lora_rank * output_dim * num_modes]. + # It takes a huge param size. 512 * 32 * 768 * 4 = 6,291,456. + nn.Linear(input_dim, lora_rank * output_dim * num_modes, bias=False), + # Reshape to [BS, lora_rank, output_dim]. + Rearrange('b (m q d) -> b m q d', q=lora_rank, m=num_modes, d=output_dim), + nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), + # Aggregate [BS, num_modes, loar_rank, output_dim] -> [BS, lora_rank, output_dim]. + LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \ + else Rearrange('b () q d -> b q d'), + nn.Dropout(p_dropout), + # Permute to [BS, output_dim, lora_rank]. + Rearrange('b q d -> b d q'), + # Project to [BS, output_dim, num_output_vecs]. + nn.Linear(lora_rank, num_output_vecs, bias=False), + # Permute to [BS, num_output_vecs, output_dim]. + Rearrange('b d q -> b q d'), + nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), + nn.Dropout(p_dropout), + ) + +def ExpandEmbs(input_dim, output_dim, expansion_ratio, elementwise_affine=True, p_dropout=0.1): + return nn.Sequential( + # Project to [BS, num_output_vecs * output_dim]. + nn.Linear(input_dim, expansion_ratio * output_dim, bias=False), + # Reshape to [BS, num_output_vecs, output_dim]. + Rearrange('b (e d) -> b e d', e=expansion_ratio, d=output_dim), + nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), + nn.Dropout(p_dropout), + ) + +# Input: [BS, N, D]. +def MultimodeProjection(input_dim, output_dim=-1, num_modes=4, elementwise_affine=True, p_dropout=0.1): + if output_dim == -1: + output_dim = input_dim + + return nn.Sequential( + nn.Linear(input_dim, output_dim * num_modes, bias=False), + # Reshape to [BS, num_output_vecs, output_dim]. + Rearrange('b n (m d) -> b n m d', m=num_modes, d=output_dim), + nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), + # If num_modes == 1, then simply remove the mode dim. Otherwise, aggregate the modes. + LearnedSoftAggregate(num_feat=output_dim, group_dim=2, keepdim=False) if num_modes > 1 \ + else Rearrange('b n () d -> b n d'), + nn.Dropout(p_dropout), + ) + +# Low-rank to high-rank transformation. +def Lora2Hira(lora_rank, hira_rank, output_dim, num_modes, elementwise_affine=True, p_dropout=0.1): + return nn.Sequential( + # Permute to [BS, output_dim, lora_rank]. + Rearrange('b q d -> b d q'), + # Project to [BS, output_dim, hira_rank]. + nn.Linear(lora_rank, hira_rank * num_modes, bias=False), + # Reshape and permute to [BS, num_modes, num_output_vecs, output_dim]. + Rearrange('b d (m q) -> b m q d', m=num_modes, q=hira_rank), + nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine), + # Aggregate [BS, num_modes, hira_rank, output_dim] -> [BS, hira_rank, output_dim]. + LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \ + else Rearrange('b () q d -> b q d'), + nn.Dropout(p_dropout), + ) + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, num_heads=8, elementwise_affine=True): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.num_heads = num_heads + inner_dim = dim_head * num_heads + + self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latent_queries): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latent_queries = self.norm2(latent_queries) + + b, l, _ = latent_queries.shape + + q = self.to_q(latent_queries) + kv_input = torch.cat((x, latent_queries), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.num_heads) + k = reshape_tensor(k, self.num_heads) + v = reshape_tensor(v, self.num_heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + attn = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = attn @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class CrossAttention(nn.Module): + # output_dim is always the same as input_dim. + # num_q only matters when q_aware_to_v is True. + # If q_aware_to_v is False, query x in forward() is still usable. + def __init__(self, input_dim, num_heads=6, p_dropout=0.05, + identity_to_q=False, identity_to_k=False, identity_to_v=False, v_has_skip=True, + q_aware_to_v=True, num_q=416, v_repeat=4, q_aware_to_v_lora_rank=64, + identity_to_out=False, out_has_skip=False): + super().__init__() + dim_head = input_dim // num_heads + inner_dim = dim_head * num_heads + + self.num_heads = num_heads + self.q_aware_to_v = q_aware_to_v + self.v_has_skip = v_has_skip + self.to_q = nn.Sequential( + nn.Linear(input_dim, inner_dim, bias=False), + nn.LayerNorm(inner_dim, elementwise_affine=True) + ) if not identity_to_q else nn.Identity() + self.to_k = nn.Sequential( + nn.Linear(input_dim, inner_dim, bias=False), + nn.LayerNorm(inner_dim, elementwise_affine=True) + ) if not identity_to_k else nn.Identity() + + self.v_repeat = v_repeat + self.num_q_group = num_q_group = num_q // v_repeat # 416 / 4 = 104. + + # If q_aware_to_v is True, then self.to_v consists of num_q projections of input_dim to inner_dim. + # Otherwise, self.to_v consists of a single projection of input_dim to inner_dim. + if q_aware_to_v: + # all_q_mid: 104 * 64 = 6656. + all_q_mid = num_q_group * q_aware_to_v_lora_rank + self.to_v = nn.Sequential( + # number of params: 768 * 6656 = 5,111,808. + # Input: [BS, 16, 768]. Output: [BS, 16, 104*64] = [BS, 16, 6656]. + # Each 768-dim vec is dispersed into 104 64-dim vecs. + nn.Linear(input_dim, all_q_mid, bias=False), + nn.LayerNorm(all_q_mid, elementwise_affine=True), + # Change the dim of the tensor to [BS, 6656, 16], as Conv1d transforms dim 1. + Rearrange('b n q -> b q n', q=all_q_mid), + # Each q_aware_to_v projection has its own linear layer. + # The total number of parameters will be 6656*768 = 5,111,808. + # Output: [BS, 104*768, 16]. Each 64 dim feature is expanded to 768 dim. + nn.Conv1d( + in_channels=all_q_mid, + out_channels=num_q_group * input_dim, + kernel_size=1, + groups=num_q_group, + bias=False, + ), + # Output: [BS, 104, 16, 768]. + Rearrange('b (q d) n -> b q n d', q=num_q_group, d=input_dim), + nn.LayerNorm(input_dim, elementwise_affine=True), + ) + else: + self.to_v = nn.Sequential( + nn.Linear(input_dim, inner_dim, bias=False), + nn.LayerNorm(inner_dim, elementwise_affine=True) + ) if not identity_to_v else nn.Identity() + + if identity_to_out: + assert not out_has_skip, "identity_to_out=True, then out_has_skip has to be False." + + if identity_to_out: + self.to_out = nn.Identity() + else: + self.to_out = nn.Sequential( + nn.Linear(input_dim, input_dim, bias=False), + nn.Dropout(p_dropout), + nn.LayerNorm(inner_dim, elementwise_affine=True) + ) + + self.out_has_skip = out_has_skip + self.attn_drop = nn.Dropout(p_dropout) + + def forward(self, x, context=None, attn_mat=None, return_attn=False): + h = self.num_heads + + if context is None: + context = x + + if attn_mat is None: + # q: [BS, Q, D] -> [BS, Q, D]. + q = self.to_q(x) + # k: [BS, L, D] -> [BS, L, D]. + k = self.to_k(context) + # q: [6, 512, 128], k: [6, 17, 128]. + q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k)) + + if self.q_aware_to_v: + # context: [BS, L, D]. v: [BS, Q, L, D]. + # There are effectively Q to_v projections. + v = self.to_v(context) + if self.v_has_skip: + v = v + context.unsqueeze(1) + else: + # v: [BS, L, D]. + v = self.to_v(context) + if self.v_has_skip: + v = v + context + + #print(v.shape) + + if self.q_aware_to_v: + # v: [6, 64, 17, 128]. + # v is query-specific, so there's an extra dim for the query. + v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h).contiguous() + # Each v is for a query group with 512/64 = 8 queries. + # So each v is repeated 8 times to match the number of queries. + # v: [6, 64, 17, 128] -> [6, 512, 17, 128]. + v = v.repeat(1, self.v_repeat, 1, 1) + else: + v = rearrange(v, 'b n (h d) -> (b h) n d', h=h).contiguous() + + if attn_mat is None: + scale = q.size(-1) ** -0.25 + sim = einsum('b i d, b j d -> b i j', q * scale, k * scale) + # sim: [6, 64, 17]. 6: bs 1 * h 6. + # attention, what we cannot get enough of + # NOTE: the normalization is done across tokens, not across pixels. + # So for each pixel, the sum of attention scores across tokens is 1. + attn = sim.softmax(dim=-1) + attn = self.attn_drop(attn) + #print(attn.std()) + else: + attn = attn_mat + + if self.q_aware_to_v: + # attn: [6, 32, 17]. v: [6, 32, 17, 128]. 128: dim of each head. out: [6, 32, 128]. + # out is combined with different attn weights and v for different queries. + out = einsum('b i j, b i j d -> b i d', attn, v) + else: + # v: [6, 17, 128]. out: [6, 32, 128]. + out = einsum('b i j, b j d -> b i d', attn, v) + + # [6, 32, 128] -> [1, 32, 768]. + out = rearrange(out, '(b h) n d -> b n (h d)', h=h).contiguous() + + if self.out_has_skip: + out = self.to_out(out) + out + else: + out = self.to_out(out) + + if return_attn: + return out, attn + else: + return out + +class ImgPrompt2TextPrompt(nn.Module): + def __init__(self, placeholder_is_bg, num_id_vecs, dtype=torch.float32, *args, **kwargs): + super().__init__() + self.N_ID = num_id_vecs + # If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components(). + self.N_SFX = 0 + + if not placeholder_is_bg: + self.initialize_text_components(*args, **kwargs) + + # prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**. + # prompt2token_proj is with the same architecture as the original arc2face text encoder, + # but retrained to do inverse mapping. + # To be initialized in the subclass. + self.prompt2token_proj = None + self.dtype = dtype + + def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768): + self.N_SFX = num_static_img_suffix_embs + # We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs. + # So it's OK that static_img_suffix_embs is larger than required number num_static_img_suffix_embs. + # This holds even if num_static_img_suffix_embs is 0. + if hasattr(self, 'static_img_suffix_embs') and self.static_img_suffix_embs is not None: + if self.static_img_suffix_embs.shape[1] == self.N_SFX: + print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.") + elif self.static_img_suffix_embs.shape[1] < self.N_SFX: + print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.") + self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim)) + elif self.N_SFX > 0: + # self.static_img_suffix_embs.shape[1] > self.N_SFX > 0. + print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.") + self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX]) + else: + # self.static_img_suffix_embs.shape[1] > self.N_SFX == 0. + print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.") + self.static_img_suffix_embs = None + else: + if self.N_SFX > 0: + # Either static_img_suffix_embs does not exist or is None, + # or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare, + # so we don't consider to reuse and extend a shorter static_img_suffix_embs). + # So we reinitialize it. + self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim)) + else: + # If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance. + self.static_img_suffix_embs = None + + # Implement a separate initialization function, so that it can be called from SubjBasisGenerator + # after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator + # ckpts which were not subclassed from ImgPrompt2TextPrompt. + def initialize_text_components(self, max_prompt_length=77, num_id_vecs=16, + num_static_img_suffix_embs=0, img_prompt_dim=768): + self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim) + self.max_prompt_length = max_prompt_length + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + # clip_text_embeddings: CLIPTextEmbeddings instance. + clip_text_embeddings = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").text_model.embeddings + # clip_text_embeddings() and clip_text_embeddings.token_embedding() differ in that + # clip_text_embeddings() adds positional embeddings, while clip_text_embeddings.token_embedding() doesn't. + # Adding positional embeddings seems to help somewhat. + # pad_tokens: pad_token_id 49407 repeated 77 times. + # pad_token_id is the EOS token. But BOS is 49406. + pad_tokens = torch.tensor([self.tokenizer.pad_token_id]).repeat(self.max_prompt_length) + # pad_embeddings: [77, 768]. + # pad_embeddings is still on CPU. But should be moved to GPU automatically. + # Note: detach pad_embeddings from the computation graph, otherwise + # deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail. + self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach() + + # image prompt space -> text prompt space. + # return_emb_types: a list of strings, each string is among + # ['full', 'core', 'full_pad', 'full_half_pad']. + def inverse_img_prompt_embs(self, face_prompt_embs, list_extra_words, + return_emb_types, hidden_state_layer_weights=None, + enable_static_img_suffix_embs=False): + + ''' + face_prompt_embs: (BS, self.N_ID, 768), in the image prompt space. + Only the core embeddings, no paddings. + list_extra_words: None or [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt. + ''' + if list_extra_words is not None: + if len(list_extra_words) != len(face_prompt_embs): + if len(face_prompt_embs) > 1: + print("Warn: list_extra_words has different length as face_prompt_embs.") + if len(list_extra_words) == 1: + list_extra_words = list_extra_words * len(face_prompt_embs) + else: + breakpoint() + else: + # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_comp_prompt_distillation. + # But list_extra_words always corresponds to the actual batch size. So we only take the first element. + list_extra_words = list_extra_words[:1] + + for extra_words in list_extra_words: + assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words." + # 16 or 4 ", " are placeholders for face_prompt_embs. + prompt_templates = [ "photo of a " + ", " * self.N_ID + list_extra_words[i] for i in range(len(list_extra_words)) ] + else: + # 16 or 4 ", " are placeholders for face_prompt_embs. + # No extra words are added to the prompt. So we add 2 more ", " to the template to keep + # the number of tokens roughly the same as when extra words are added. + prompt_templates = [ "photo of a " + ", " * (self.N_ID + 2) for _ in range(len(face_prompt_embs)) ] + + # This step should be quite fast, and there's no need to cache the input_ids. + # input_ids: [BS, 77]. + input_ids = self.tokenizer( + prompt_templates, + truncation=True, + padding="max_length", + max_length=self.max_prompt_length, + return_tensors="pt", + ).input_ids.to(face_prompt_embs.device) + + face_prompt_embs_orig_dtype = face_prompt_embs.dtype + face_prompt_embs = face_prompt_embs.to(self.dtype) + + ID_END = 4 + self.N_ID + PAD_BEGIN = ID_END + self.N_SFX + 2 + + # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping). + token_embs = self.prompt2token_proj(input_ids=input_ids, return_token_embs=True) + # token 4: first ", " in the template prompt. + # Replace embeddings of 16 or 4 placeholder ", " with face_prompt_embs. + token_embs[:, 4:ID_END] = face_prompt_embs + # Only when do_unet_distill == True, we append the static image suffix embeddings. + # Otherwise, static image suffix embeddings are ignored, + # and token_embs[:, ID_END:ID_END+self.N_SFX] are the filler embeddings of the + # extra ", " in the template prompt. + if enable_static_img_suffix_embs and self.N_SFX > 0: + # Put the static image suffix embeddings right after face_prompt_embs. + token_embs[:, ID_END:ID_END+self.N_SFX] = self.static_img_suffix_embs[:, :self.N_SFX] + + # This call does the ordinary CLIP text encoding pass. + prompt_embeds = self.prompt2token_proj( + input_ids=input_ids, + input_token_embs=token_embs, + hidden_state_layer_weights=hidden_state_layer_weights, + return_token_embs=False + )[0] + + # Restore the original dtype of prompt_embeds: float16 -> float32. + prompt_embeds = prompt_embeds.to(face_prompt_embs_orig_dtype) + # token 4: first ", " in the template prompt. + # When N_ID == 16, + # prompt_embeds 4:20 are the most important 16 embeddings that contain the subject's identity. + # 20:22 are embeddings of the (at most) two extra words. + # [N, 77, 768] -> [N, 16, 768] + if enable_static_img_suffix_embs: + core_prompt_embs = prompt_embeds[:, 4:ID_END+self.N_SFX] + else: + core_prompt_embs = prompt_embeds[:, 4:ID_END] + + if list_extra_words is not None: + # [N, 16, 768] -> [N, 18, 768] + extra_words_embs = prompt_embeds[:, ID_END+self.N_SFX:PAD_BEGIN] + core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1) + + returned_prompt_embs = [] + for emb_type in return_emb_types: + if emb_type == 'full': + returned_prompt_embs.append(prompt_embeds) + elif emb_type == 'full_half_pad': + prompt_embeds2 = prompt_embeds.clone() + # PAD_BEGIN is 22 or 10. Also exclude the last EOS token. + # So we subtract max_prompt_length by (PAD_BEGIN + 1). + PADS = self.max_prompt_length - PAD_BEGIN - 1 + if PADS >= 2: + # Fill half of the remaining embeddings with pad embeddings. + prompt_embeds2[:, PAD_BEGIN:PAD_BEGIN+PADS//2] = self.pad_embeddings[PAD_BEGIN:PAD_BEGIN+PADS//2] + returned_prompt_embs.append(prompt_embeds2) + elif emb_type == 'full_pad': + prompt_embeds2 = prompt_embeds.clone() + # Replace the PAD_BEGIN-th to the second last embeddings with pad embeddings. + # Skip replacing the last embedding, which might has special roles. + # (Although all padding tokens are the same EOS, the last token might acquire special semantics + # due to its special position.) + prompt_embeds2[:, PAD_BEGIN:-1] = self.pad_embeddings[PAD_BEGIN:-1] + returned_prompt_embs.append(prompt_embeds2) + elif emb_type == 'full_zeroed_extra': + prompt_embeds2 = prompt_embeds.clone() + # Only add two pad embeddings. The remaining embeddings are set to 0. + # Make the positional embeddings align with the actual positions. + prompt_embeds2[:, 22:24] = self.pad_embeddings[22:24] + prompt_embeds2[:, 24:-1] = 0 + returned_prompt_embs.append(prompt_embeds2) + elif emb_type == 'core': + returned_prompt_embs.append(core_prompt_embs) + else: + breakpoint() + + return returned_prompt_embs + +class SubjBasisGenerator(ImgPrompt2TextPrompt): + def __init__( + self, + # number of cross-attention heads of the bg prompt translator. + # Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14: + # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json + num_bg_encoder_heads=6, + # number of subject input identity vectors (only when the subject is not face), + # or number of background input identity vectors (no matter the subject is face or not). + # 257: 257 CLIP tokens. + num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 }, + num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4. + num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings. + bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above. + obj_embedding_dim=384, # DINO object feature dimension for objects. + output_dim=768, # CLIP text embedding input dimension. + placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens. + prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj. + learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer. + bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection. + ): + + # If not placeholder_is_bg, then it calls initialize_text_components() in the superclass. + super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs, max_prompt_length=77, + num_static_img_suffix_embs=num_static_img_suffix_embs, img_prompt_dim=output_dim) + + self.placeholder_is_bg = placeholder_is_bg + self.num_out_embs = self.N_ID + self.N_SFX + self.output_dim = output_dim + # num_nonface_in_id_vecs should be the number of core ID embs, 16. + # However, in such case, pos_embs is not used. So it doesn't matter if it's wrongly set. + self.num_nonface_in_id_vecs = num_nonface_in_id_vecs['bg'] if placeholder_is_bg else num_nonface_in_id_vecs['subj'] + self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj + + if not self.placeholder_is_bg: + # [1, 384] -> [1, 16, 768]. + # TODO: use CLIPTextModelWrapper as obj_proj_in. + self.obj_proj_in = ExpandEmbs(obj_embedding_dim, output_dim, expansion_ratio=self.num_nonface_in_id_vecs) + + # ** prompt2token_proj does the actual job: ** + # it is the inverse projection that maps from faceid2img_prompt_embs to adaface_prompt_embs. + # self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings). + # If self.placeholder_is_bg: prompt2token_proj is set to None. + # Use an attention dropout of 0.2 to increase robustness. + clip_dropout_config = None #CLIPTextConfig.from_pretrained('openai/clip-vit-large-patch14', attention_dropout=0.05, dropout=0.05) + self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14', + config=clip_dropout_config) + self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale + self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale) + print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.") + # If prompt2token_proj_grad_scale is 0, freeze all params in prompt2token_proj. + # Otherwise, only freeze token and positional embeddings of the original CLIPTextModel. + self.freeze_prompt2token_proj() + + # These multipliers are relative to the original CLIPTextModel. + self.prompt2token_proj_attention_multipliers = [1] * 12 + self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu') + self.bg_proj_in = None + self.pos_embs = self.pos_embs_ln = self.latent_queries = self.latent_queries_ln = None + else: + # For background placeholders, face and object embeddings are not used as they are foreground. + self.obj_proj_in = None + + self.bg_proj_in = nn.Sequential( + nn.Linear(bg_image_embedding_dim, output_dim, bias=False), + nn.LayerNorm(output_dim), + ) + + self.pos_embs = nn.Parameter(torch.zeros(1, self.num_nonface_in_id_vecs, output_dim)) + self.pos_embs_ln = nn.LayerNorm(output_dim) + self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim)) + self.latent_queries_ln = nn.LayerNorm(output_dim) + + identity_to_v = False + v_has_skip = not identity_to_v # True + identity_to_out = not bg_prompt_translator_has_to_out_proj # True + out_has_skip = not identity_to_out # False + # prompt_translator maps the clip image features (of the background) to the prompt embedding space. + # It is only used during training when placeholder_is_bg is True. + # prompt_translator has a to_v projection with skip connection, and doesn't have a to_out projection. + # dim=768, num_bg_encoder_heads=6. + self.prompt_translator = \ + CrossAttention(input_dim=output_dim, num_heads=num_bg_encoder_heads, p_dropout=0.05, + identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v, + q_aware_to_v=False, v_has_skip=v_has_skip, + num_q=0, # When not q_aware_to_v, num_q is not referenced. + identity_to_out=identity_to_out, + out_has_skip=out_has_skip) + + self.output_scale = output_dim ** -0.5 + + ''' + prompt_translator: CLIPEncoder + # https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566 + # CLIPEncoder.layers: 12 layers of CLIPEncoderLayer, each being + (0): CLIPEncoderLayer( + (self_attn): CLIPAttention( + (k_proj): Linear(in_features=768, out_features=768, bias=True) + (v_proj): Linear(in_features=768, out_features=768, bias=True) + (q_proj): Linear(in_features=768, out_features=768, bias=True) + (out_proj): Linear(in_features=768, out_features=768, bias=True) + ) + (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) + (mlp): CLIPMLP( + (activation_fn): QuickGELUActivation() + (fc1): Linear(in_features=768, out_features=3072, bias=True) + (fc2): Linear(in_features=3072, out_features=768, bias=True) + ) + (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True) + ) + ''' + + print(repr(self)) + + # raw_id_embs: only used when the subject is non-faces. In that case it's DINO embeddings. + # Otherwise, raw_id_embs is not used. + # faceid2img_prompt_embs: [BS, 16, 768], the core ID prompt embeddings generated by ID2ImgPrompt. + def forward(self, faceid2img_prompt_embs, clip_features=None, raw_id_embs=None, out_id_embs_cfg_scale=1.0, + is_face=True, enable_static_img_suffix_embs=False): + + if not self.placeholder_is_bg: + BS = faceid2img_prompt_embs.shape[0] + else: + # If bg, then faceid2img_prompt_embs is set to None, but clip_features is not None. + BS = clip_features.shape[0] + clip_features = clip_features.to(self.dtype) + + # No need to use raw_id_embs if placeholder_is_bg. + if not self.placeholder_is_bg: + if is_face: + assert faceid2img_prompt_embs is not None + # id2img_embs has been projected to the (modified) prompt embedding space + # by ID2AdaPrompt::map_init_id_to_img_prompt_embs(). This prompt embedding space is modified because + # the ID2ImgPrompt module (at least when it's arc2face) may have finetuned the + # text encoder and the U-Net. + # in embedding_manager: [BS, 16, 768] -> [BS, 77, 768]. + # faceid2img_prompt_embs is part of id2img_embs: [BS, 77, 768] -> [BS, 16, 768]. + # adaface_prompt_embs is projected to the prompt embedding spaces. This is the + # original U-Net prompt embedding space. + + # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]] + hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights) + + # faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space. + with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0): + # If list_extra_words is not None, then ada_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens + # and (at most) two extra words in adaface_prompt_embs, without BOS and EOS. + # If list_extra_words is None, then ada_id_embs: [BS, 16, 768], the 16 identity tokens in adaface_prompt_embs. + # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]] + # ada_id_embs: [BS, 16, 768]. + # return_emb_types: a list of strings, each string is among + # ['full', 'core', 'full_pad', 'full_half_pad']. + ada_id_embs, = \ + self.inverse_img_prompt_embs(faceid2img_prompt_embs, + list_extra_words=None, + return_emb_types=['core'], + hidden_state_layer_weights=hidden_state_layer_weights, + enable_static_img_suffix_embs=enable_static_img_suffix_embs) + ada_id_embs = self.prompt2token_proj_grad_scaler(ada_id_embs) + elif raw_id_embs is not None: + # id_embs: [BS, 384] -> [BS, 18, 768]. + # obj_proj_in is expected to project the DINO object features to + # the token embedding space. So no need to use prompt2token_proj. + id_embs = self.obj_proj_in(raw_id_embs) + else: + breakpoint() + else: + # Otherwise, context is the ad-hoc CLIP image features. + # id_embs: [BS, 257, 768]. + id_embs = self.bg_proj_in(clip_features) + + if self.placeholder_is_bg: + id_embs = id_embs + self.pos_embs_ln(self.pos_embs) + latent_queries = self.latent_queries_ln(self.latent_queries).repeat(BS, 1, 1) + # If bg, we don't have to use a specific attn layer for each 4-vec set. Instead, one attn layer can generate 257 embs, + # and we take the first 16*4=64. + # Output of prompt_translator is exactly num_out_embs == 64 tokens. id_embs_out: [BS, 64, 768]. + # prompt_translator: better named as bg_prompt_translator. It maps the bg features + # to bg prompt embeddings. + with torch.set_grad_enabled(self.training): + id_embs_out = self.prompt_translator(latent_queries, id_embs) + + adaface_out_embs = id_embs_out * self.output_scale # * 0.036 + else: + adaface_out_embs = ada_id_embs + # If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings. + if out_id_embs_cfg_scale != 1: + # pad_embeddings: [77, 768] -> [16, 768] -> [1, 16, 768]. + # NOTE: Never do cfg on static image suffix embeddings. + # So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX, + # even if enable_static_img_suffix_embs=True. + pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).to(ada_id_embs.device) + adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \ + + pad_embeddings * (1 - out_id_embs_cfg_scale) + + return adaface_out_embs + + def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device): + if learnable_hidden_state_weights_scheme == 'none': + self.hidden_state_layer_weights = None + # A grad scaler with alpha =1 is nn.Identity(), which outputs None given None as input. + self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(1) + print("hidden_state_layer_weights is set to None.") + + elif learnable_hidden_state_weights_scheme == 'per-layer': + # Learnable weights of the last 3 layers, initialized to putting more focus on the last layer. + # 'per-layer': Different weights for different layers, but the same for different channels. + # hidden_state_layer_weights: [3, 1]. + self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device), + requires_grad=True) + # A gradient scaler of 5 makes the gradients on hidden_state_layer_weights 5 times larger. + self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5) + print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.") + else: + breakpoint() + + def extend_prompt2token_proj_attention(self, prompt2token_proj_attention_multipliers=None, + begin_layer_idx=-1, end_layer_idx=-1, multiplier=1, perturb_std=0.1): + if begin_layer_idx == -1: + begin_layer_idx = 0 + if end_layer_idx == -1: + end_layer_idx = 11 + + if prompt2token_proj_attention_multipliers is None and multiplier == 1: + print("prompt2token_proj_attention_multipliers are all 1. No extension is done.") + return + + elif prompt2token_proj_attention_multipliers is None: + # prompt2token_proj_attention_multipliers are relative to the current prompt2token_proj. + prompt2token_proj_attention_multipliers = [1] * 12 + for i in range(begin_layer_idx, end_layer_idx+1): + prompt2token_proj_attention_multipliers[i] = multiplier + # Otherwise, use the given prompt2token_proj_attention_multipliers. + + num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(prompt2token_proj_attention_multipliers, perturb_std) + # Update prompt2token_proj_attention_multipliers (relative to the original CLIPTextModel). + for i in range(begin_layer_idx, end_layer_idx+1): + self.prompt2token_proj_attention_multipliers[i] *= prompt2token_proj_attention_multipliers[i] + + print(f"{num_extended_layers} layers in prompt2token_proj_attention are extended by {prompt2token_proj_attention_multipliers}") + return num_extended_layers + + def squeeze_prompt2token_proj_attention(self, prompt2token_proj_attention_divisors=None, + begin_layer_idx=-1, end_layer_idx=-1, divisor=1): + if begin_layer_idx == -1: + begin_layer_idx = 0 + if end_layer_idx == -1: + end_layer_idx = 11 + + if prompt2token_proj_attention_divisors is None and divisor == 1: + print("prompt2token_proj_attention_divisors are all 1. No squeezing is done.") + return + elif prompt2token_proj_attention_divisors is None: + prompt2token_proj_attention_divisors = [1] * 12 + for i in range(begin_layer_idx, end_layer_idx+1): + prompt2token_proj_attention_divisors[i] = divisor + # Otherwise, use the given prompt2token_proj_attention_divisors. + + num_squeezed_layers = self.prompt2token_proj.squeeze_clip_attention_MKV_divisor(prompt2token_proj_attention_divisors) + # Update prompt2token_proj_attention_multipliers (relative to the original CLIPTextModel). + for i in range(begin_layer_idx, end_layer_idx+1): + self.prompt2token_proj_attention_multipliers[i] //= prompt2token_proj_attention_divisors[i] + + print(f"{num_squeezed_layers} layers in prompt2token_proj_attention are squeezed by {prompt2token_proj_attention_divisors}") + return num_squeezed_layers + + def freeze_prompt2token_proj(self): + # Only applicable to fg basis generator. + if self.placeholder_is_bg: + return + # If bg, then prompt2token_proj is set to None. Therefore no need to freeze it. + # Then we don't have to check whether it's for subj or bg. + if self.prompt2token_proj_grad_scale == 0: + frozen_components_name = 'all' + frozen_param_set = self.prompt2token_proj.named_parameters() + else: + frozen_components_name = 'token_pos_embeddings' + frozen_param_set = self.prompt2token_proj.text_model.embeddings.named_parameters() + + if self.prompt2token_proj is not None: + frozen_param_names = [] + for param_name, param in frozen_param_set: + if param.requires_grad: + param.requires_grad = False + frozen_param_names.append(param_name) + # If param is already frozen, then no need to freeze it again. + print(f"{frozen_components_name} {len(frozen_param_names)} params in Subj prompt2token_proj is frozen.") + #print(f"Frozen parameters:\n{frozen_param_names}") + + def patch_old_subj_basis_generator_ckpt(self): + # Fix compatability with the previous version. + if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'): + self.bg_prompt_translator_has_to_out_proj = False + if not hasattr(self, 'num_out_embs'): + self.num_out_embs = -1 + if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'): + self.N_ID = self.num_id_vecs + if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'): + self.num_nonface_in_id_vecs = self.N_ID + if not hasattr(self, 'dtype'): + self.dtype = torch.float32 + + if self.placeholder_is_bg: + if not hasattr(self, 'pos_embs') or self.pos_embs is None: + self.pos_embs = nn.Parameter(torch.zeros(1, self.num_nonface_in_id_vecs, self.output_dim)) + if not hasattr(self, 'latent_queries') or self.latent_queries is None: + self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, self.output_dim)) + # Background encoder doesn't require initializing text components. + else: + self.initialize_hidden_state_layer_weights('per-layer', 'cpu') + if not hasattr(self, 'prompt2token_proj_attention_multipliers'): + # Please manually set prompt2token_proj_attention_multipliers in the ckpt. + breakpoint() + + self.initialize_text_components(max_prompt_length=77, num_id_vecs=self.N_ID, + num_static_img_suffix_embs=self.N_SFX, + img_prompt_dim=self.output_dim) + + def __repr__(self): + type_sig = 'subj' if not self.placeholder_is_bg else 'bg' + + return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \ + f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}" + diff --git a/adaface/test_img_prompt_model.py b/adaface/test_img_prompt_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b45f936e6b55570e6dc9fcd847f9661ecc0919ce --- /dev/null +++ b/adaface/test_img_prompt_model.py @@ -0,0 +1,199 @@ +import torch +from PIL import Image +import os, argparse, glob +import numpy as np +from .face_id_to_ada_prompt import create_id2ada_prompt_encoder +from .util import create_consistentid_pipeline +from .arc2face_models import create_arc2face_pipeline +from transformers import CLIPTextModel + +def save_images(images, subject_name, id2img_prompt_encoder_type, + prompt, perturb_std, save_dir = "samples-ada"): + os.makedirs(save_dir, exist_ok=True) + # Save 4 images as a grid image in save_dir + grid_image = Image.new('RGB', (512 * 2, 512 * 2)) + for i, image in enumerate(images): + image = image.resize((512, 512)) + grid_image.paste(image, (512 * (i % 2), 512 * (i // 2))) + + prompt_sig = prompt.replace(" ", "_").replace(",", "_") + grid_filepath = os.path.join(save_dir, + "-".join([subject_name, id2img_prompt_encoder_type, + prompt_sig, f"perturb{perturb_std:.02f}.png"])) + + if os.path.exists(grid_filepath): + grid_count = 2 + grid_filepath = os.path.join(save_dir, + "-".join([ subject_name, id2img_prompt_encoder_type, + prompt_sig, f"perturb{perturb_std:.02f}", str(grid_count) ]) + ".png") + while os.path.exists(grid_filepath): + grid_count += 1 + grid_filepath = os.path.join(save_dir, + "-".join([ subject_name, id2img_prompt_encoder_type, + prompt_sig, f"perturb{perturb_std:.02f}", str(grid_count) ]) + ".png") + + grid_image.save(grid_filepath) + print(f"Saved to {grid_filepath}") + +def seed_everything(seed): + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PL_GLOBAL_SEED"] = str(seed) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # --base_model_path models/Realistic_Vision_V4.0_noVAE + parser.add_argument("--base_model_path", type=str, default="models/sar/sar.safetensors") + parser.add_argument("--id2img_prompt_encoder_type", type=str, + choices=["arc2face", "consistentID"], + help="Types of the ID2Img prompt encoder") + parser.add_argument("--subject", type=str, default="subjects-celebrity/taylorswift") + parser.add_argument("--example_image_count", type=int, default=5, help="Number of example images to use") + parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate") + parser.add_argument("--init_img", type=str, default=None) + parser.add_argument("--prompt", type=str, default="portrait photo of a person in superman costume") + parser.add_argument("--use_core_only", action="store_true") + parser.add_argument("--truncate_prompt_at", type=int, default=-1, + help="Truncate the prompt to this length") + parser.add_argument("--randface", action="store_true") + parser.add_argument("--seed", type=int, default=-1) + parser.add_argument("--perturb_std", type=float, default=1) + + args = parser.parse_args() + if args.seed > 0: + seed_everything(args.seed) + + if args.id2img_prompt_encoder_type == "arc2face": + pipeline = create_arc2face_pipeline(args.base_model_path) + use_teacher_neg = False + elif args.id2img_prompt_encoder_type == "consistentID": + pipeline = create_consistentid_pipeline(args.base_model_path) + use_teacher_neg = True + + pipeline = pipeline.to('cuda', torch.float16) + + # When the second argument, adaface_ckpt_path = None, create_id2ada_prompt_encoder() + # returns an id2ada_prompt_encoder object, with .subj_basis_generator uninitialized. + # But it doesn't matter, as we don't use the subj_basis_generator to generate ada embeddings. + id2img_prompt_encoder = create_id2ada_prompt_encoder([args.id2img_prompt_encoder_type], + num_static_img_suffix_embs=0) + id2img_prompt_encoder.to('cuda') + + if not args.randface: + image_folder = args.subject + if image_folder.endswith("/"): + image_folder = image_folder[:-1] + + if os.path.isfile(image_folder): + # Get the second to the last part of the path + subject_name = os.path.basename(os.path.dirname(image_folder)) + image_paths = [image_folder] + + else: + subject_name = os.path.basename(image_folder) + image_types = ["*.jpg", "*.png", "*.jpeg"] + alltype_image_paths = [] + for image_type in image_types: + # glob returns the full path. + image_paths = glob.glob(os.path.join(image_folder, image_type)) + if len(image_paths) > 0: + alltype_image_paths.extend(image_paths) + # image_paths contain at most args.example_image_count full image paths. + image_paths = alltype_image_paths[:args.example_image_count] + else: + subject_name = None + image_paths = None + image_folder = None + + subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name + id_batch_size = args.out_image_count + + text_encoder = pipeline.text_encoder + orig_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to("cuda") + + noise = torch.randn(args.out_image_count, 4, 64, 64, device='cuda', dtype=torch.float16) + if args.randface: + init_id_embs = torch.randn(1, 512, device='cuda', dtype=torch.float16) + if args.id2img_prompt_encoder_type == "arc2face": + pre_clip_features = None + elif args.id2img_prompt_encoder_type == "consistentID": + # For ConsistentID, random clip features are much better than zero clip features. + rand_clip_fgbg_features = torch.randn(1, 514, 1280, device='cuda', dtype=torch.float16) + pre_clip_features = rand_clip_fgbg_features + else: + breakpoint() + else: + init_id_embs = None + pre_clip_features = None + + # perturb_std is the *relative* std of the noise added to the face ID embeddings. + # For Arc2Face, a perturb_std of 0.08 could change gender, but 0.06 is usually safe. + # For ConsistentID, the image prompt embeddings are extremely robust to noise, + # and the perturb_std can be set to 0.5, only leading to a slight change in the result images. + # Seems ConsistentID mainly relies on CLIP features, instead of the face ID embeddings. + for perturb_std in (args.perturb_std, 0): + # id_prompt_emb is in the image prompt space. + # neg_id_prompt_emb is used in ConsistentID only. + face_image_count, faceid_embeds, id_prompt_emb, neg_id_prompt_emb \ + = id2img_prompt_encoder.get_img_prompt_embs( \ + init_id_embs=init_id_embs, + pre_clip_features=pre_clip_features, + image_paths=image_paths, + image_objs=None, + id_batch_size=id_batch_size, + perturb_at_stage='img_prompt_emb', + perturb_std=perturb_std, + avg_at_stage='id_emb', + verbose=True) + + pipeline.text_encoder = orig_text_encoder + + comp_prompt = args.prompt + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + # prompt_embeds_, negative_prompt_embeds_: [4, 77, 768] + prompt_embeds_, negative_prompt_embeds_ = \ + pipeline.encode_prompt(comp_prompt, device='cuda', num_images_per_prompt=args.out_image_count, + do_classifier_free_guidance=True, negative_prompt=negative_prompt) + #pipeline.text_encoder = text_encoder + # Append the id prompt embeddings to the prompt embeddings. + # For arc2face, id_prompt_emb can be either pre- or post-pended. + # But for ConsistentID, id_prompt_emb has to be **post-pended**. Otherwise, the result images are blank. + + full_negative_prompt_embeds_ = negative_prompt_embeds_ + if args.truncate_prompt_at >= 0: + prompt_embeds_ = prompt_embeds_[:, :args.truncate_prompt_at] + negative_prompt_embeds_ = negative_prompt_embeds_[:, :args.truncate_prompt_at] + + prompt_embeds_ = torch.cat([prompt_embeds_, id_prompt_emb], dim=1) + M = id_prompt_emb.shape[1] + + if (not use_teacher_neg) or neg_id_prompt_emb is None: + # For arc2face, neg_id_prompt_emb is None. So we concatenate the last M negative prompt embeddings, + # to make the negative prompt embeddings have the same length as the prompt embeddings. + negative_prompt_embeds_ = torch.cat([negative_prompt_embeds_, full_negative_prompt_embeds_[:, -M:]], dim=1) + else: + # NOTE: For ConsistentID, neg_id_prompt_emb has to be present in the negative prompt embeddings. + # Otherwise, the result images are cartoonish. + negative_prompt_embeds_ = torch.cat([negative_prompt_embeds_, neg_id_prompt_emb], dim=1) + + if args.use_core_only: + prompt_embeds_ = id_prompt_emb + if (not use_teacher_neg) or neg_id_prompt_emb is None: + negative_prompt_embeds_ = full_negative_prompt_embeds_[:, :M] + else: + negative_prompt_embeds_ = neg_id_prompt_emb + + for guidance_scale in [6]: + images = pipeline(latents=noise, + prompt_embeds=prompt_embeds_, + negative_prompt_embeds=negative_prompt_embeds_, + num_inference_steps=50, + guidance_scale=guidance_scale, + num_images_per_prompt=1).images + + save_images(images, subject_name, args.id2img_prompt_encoder_type, + f"guide{guidance_scale}", perturb_std) diff --git a/adaface/unet_teachers.py b/adaface/unet_teachers.py new file mode 100644 index 0000000000000000000000000000000000000000..4030999e4706e4e516079049cbcc0194d2e66d4e --- /dev/null +++ b/adaface/unet_teachers.py @@ -0,0 +1,218 @@ +import torch +import numpy as np +import pytorch_lightning as pl +from diffusers import UNet2DConditionModel +from adaface.util import UNetEnsemble, create_consistentid_pipeline +from diffusers import UNet2DConditionModel +from omegaconf.listconfig import ListConfig + +def create_unet_teacher(teacher_type, device='cpu', **kwargs): + # If teacher_type is a list with only one element, we dereference it. + if isinstance(teacher_type, (tuple, list, ListConfig)) and len(teacher_type) == 1: + teacher_type = teacher_type[0] + + if teacher_type == "arc2face": + return Arc2FaceTeacher(**kwargs) + elif teacher_type == "unet_ensemble": + # unet, extra_unet_dirpaths and unet_weights are passed in kwargs. + # Even if we distill from unet_ensemble, we still need to load arc2face for generating + # arc2face embeddings. + # The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet, + # in our case, the ddpm unet. Ideally we should reuse it to save GPU RAM. + # However, since the __call__ method of the ddpm unet takes different formats of params, + # for simplicity, we still use the diffusers unet. + # unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU. + return UNetEnsembleTeacher(device=device, **kwargs) + elif teacher_type == "consistentID": + return ConsistentIDTeacher(**kwargs) + elif teacher_type == "simple_unet": + return SimpleUNetTeacher(**kwargs) + # Since we've dereferenced the list if it has only one element, + # this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher. + elif isinstance(teacher_type, (tuple, list, ListConfig)): + # teacher_type is a list of teacher types. So it's UNetEnsembleTeacher. + return UNetEnsembleTeacher(unet_types=teacher_type, device=device, **kwargs) + else: + raise NotImplementedError(f"Teacher type {teacher_type} not implemented.") + +class UNetTeacher(pl.LightningModule): + def __init__(self, **kwargs): + super().__init__() + self.name = None + # self.unet will be initialized in the child class. + self.unet = None + self.p_uses_cfg = kwargs.get("p_uses_cfg", 0) + # self.cfg_scale will be randomly sampled from cfg_scale_range. + self.cfg_scale_range = kwargs.get("cfg_scale_range", [1.3, 2]) + # Initialize cfg_scale to 1. It will be randomly sampled during forward pass. + self.cfg_scale = 1 + if self.p_uses_cfg > 0: + print(f"Using CFG with probability {self.p_uses_cfg} and scale range {self.cfg_scale_range}.") + else: + print(f"Never using CFG.") + + # Passing in ddpm_model to use its q_sample and predict_start_from_noise methods. + # We don't implement the two functions here, because they involve a few tensors + # to be initialized, which will unnecessarily complicate the code. + # noise: the initial noise for the first iteration. + # t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t. + # uses_same_t: when sampling t, use the same t for all instances. + def forward(self, ddpm_model, x_start, noise, t, teacher_context, + num_denoising_steps=1, uses_same_t=False): + assert num_denoising_steps <= 10 + + if self.p_uses_cfg > 0: + self.uses_cfg = np.random.rand() < self.p_uses_cfg + if self.uses_cfg: + # Randomly sample a cfg_scale from cfg_scale_range. + self.cfg_scale = np.random.uniform(*self.cfg_scale_range) + if self.cfg_scale == 1: + self.uses_cfg = False + + if self.uses_cfg: + print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.") + else: + self.cfg_scale = 1 + print("Teacher does not use CFG.") + + # If p_uses_cfg > 0, we always pass both pos_context and neg_context to the teacher. + # But the neg_context is only used when self.uses_cfg is True and cfg_scale > 1. + # So we manually split the teacher_context into pos_context and neg_context, and only keep pos_context. + if self.name == 'unet_ensemble': + teacher_pos_contexts = [] + # teacher_context is a list of teacher contexts. + for teacher_context_i in teacher_context: + pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0) + if pos_context.shape[0] != x_start.shape[0]: + breakpoint() + teacher_pos_contexts.append(pos_context) + teacher_context = teacher_pos_contexts + else: + pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0) + if pos_context.shape[0] != x_start.shape[0]: + breakpoint() + teacher_context = pos_context + else: + # p_uses_cfg = 0. Never use CFG. + self.uses_cfg = False + # In this case, the student only passes pos_context to the teacher, + # so no need to split teacher_context into pos_context and neg_context. + # self.cfg_scale will be accessed by the student, + # so we need to make sure it is always set correctly, + # in case someday we want to switch from CFG to non-CFG during runtime. + self.cfg_scale = 1 + + if self.name == 'unet_ensemble': + # teacher_context is a list of teacher contexts. + for teacher_context_i in teacher_context: + if teacher_context_i.shape[0] != x_start.shape[0] * (1 + self.uses_cfg): + breakpoint() + else: + if teacher_context.shape[0] != x_start.shape[0] * (1 + self.uses_cfg): + breakpoint() + + # Initially, x_starts only contains the original x_start. + x_starts = [ x_start ] + noises = [ noise ] + ts = [ t ] + noise_preds = [] + + with torch.autocast(device_type='cuda', dtype=torch.float16): + for i in range(num_denoising_steps): + x_start = x_starts[i] + t = ts[i] + noise = noises[i] + # sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise + x_noisy = ddpm_model.q_sample(x_start, t, noise) + + if self.uses_cfg: + x_noisy2 = x_noisy.repeat(2, 1, 1, 1) + t2 = t.repeat(2) + else: + x_noisy2 = x_noisy + t2 = t + + # If do_arc2face_distill, then pos_context is [BS=6, 21, 768]. + noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context, + return_dict=False)[0] + if self.uses_cfg and self.cfg_scale > 1: + pos_noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0) + noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1) + + # sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise + pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred) + noise_preds.append(noise_pred) + + # The predicted x0 is used as the x_start for the next denoising step. + x_starts.append(pred_x0) + + # Sample an earlier timestep for the next denoising step. + if i < num_denoising_steps - 1: + # NOTE: rand_like() samples from U(0, 1), not like randn_like(). + relative_ts = torch.rand_like(t.float()) + # Make sure at the middle step (i = sqrt(num_denoising_steps - 1), the timestep + # is between 50% and 70% of the current timestep. So if num_denoising_steps = 5, + # we take timesteps within [0.5^0.66, 0.7^0.66] = [0.63, 0.79] of the current timestep. + # If num_denoising_steps = 4, we take timesteps within [0.5^0.72, 0.7^0.72] = [0.61, 0.77] + # of the current timestep. + t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3)) + t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3)) + earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb + earlier_timesteps = earlier_timesteps.long() + + if uses_same_t: + # If uses_same_t, we use the same earlier_timesteps for all instances. + earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0]) + + # earlier_timesteps = ts[i+1] < ts[i]. + ts.append(earlier_timesteps) + + noise = torch.randn_like(pred_x0) + noises.append(noise) + + return noise_preds, x_starts, noises, ts + +class Arc2FaceTeacher(UNetTeacher): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.name = "arc2face" + self.unet = UNet2DConditionModel.from_pretrained( + #"runwayml/stable-diffusion-v1-5", subfolder="unet" + 'models/arc2face', subfolder="arc2face", torch_dtype=torch.float16 + ) + # Disable CFG. Even if p_uses_cfg > 0, the randomly drawn cfg_scale is still 1, + # so the CFG is effectively disabled. + self.cfg_scale_range = [1, 1] + +class UNetEnsembleTeacher(UNetTeacher): + # unet_weights are not model weights, but scalar weights for individual unets. + def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', **kwargs): + super().__init__(**kwargs) + self.name = "unet_ensemble" + self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths, unet_weights, device) + +class ConsistentIDTeacher(UNetTeacher): + def __init__(self, base_model_path="models/sd15-dste8-vae.safetensors", **kwargs): + super().__init__(**kwargs) + self.name = "consistentID" + ### Load base model + # In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module. + # We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU. + # Instead, we have to initialize it to GPU directly. + pipe = create_consistentid_pipeline(base_model_path) + # Compatible with the UNetTeacher interface. + self.unet = pipe.unet + # Release VAE and text_encoder to save memory. UNet is still needed for denoising + # (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet). + pipe.release_components(["vae", "text_encoder"]) + +# We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher. +# Note p_uses_cfg=0.5 will also be passed in in kwargs. +class SimpleUNetTeacher(UNetTeacher): + def __init__(self, unet_dirpath='models/ensemble/sd15-unet', + torch_dtype=torch.float16, **kwargs): + super().__init__(**kwargs) + self.name = "simple_unet" + self.unet = UNet2DConditionModel.from_pretrained( + unet_dirpath, torch_dtype=torch_dtype + ) diff --git a/adaface/util.py b/adaface/util.py new file mode 100644 index 0000000000000000000000000000000000000000..f44159bb96abcc140ed837b060e6689733d2937a --- /dev/null +++ b/adaface/util.py @@ -0,0 +1,391 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from diffusers import UNet2DConditionModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput +from transformers import CLIPVisionModel +from dataclasses import dataclass +from typing import Optional, Tuple +from transformers.utils import ModelOutput +import numpy as np +import argparse +from ConsistentID.lib.pipeline_ConsistentID import ConsistentIDPipeline +from diffusers import ( + UNet2DConditionModel, + DDIMScheduler, +) + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + +# perturb_tensor() adds a fixed amount of noise to the tensor. +def perturb_tensor(ts, perturb_std, perturb_std_is_relative=True, keep_norm=False, + std_dim=-1, norm_dim=-1, verbose=True): + orig_ts = ts + if perturb_std_is_relative: + ts_std_mean = ts.std(dim=std_dim).mean().detach() + + perturb_std *= ts_std_mean + # ts_std_mean: 50~80 for unnormalized images, perturb_std: 2.5-4 for 0.05 noise. + if verbose: + print(f"ts_std_mean: {ts_std_mean:.03f}, perturb_std: {perturb_std:.03f}") + + noise = torch.randn_like(ts) * perturb_std + if keep_norm: + orig_norm = ts.norm(dim=norm_dim, keepdim=True) + ts = ts + noise + new_norm = ts.norm(dim=norm_dim, keepdim=True).detach() + ts = ts * orig_norm / (new_norm + 1e-8) + else: + ts = ts + noise + + if verbose: + print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).item():.03f}") + + return ts + +def perturb_np_array(np_array, perturb_std, perturb_std_is_relative=True, std_dim=-1): + ts = torch.from_numpy(np_array).to(dtype=torch.float32) + ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim) + return ts.numpy().astype(np_array.dtype) + +def calc_stats(emb_name, embeddings, mean_dim=0): + print("%s:" %emb_name) + repeat_count = [1] * embeddings.ndim + repeat_count[mean_dim] = embeddings.shape[mean_dim] + # Average across the mean_dim dim. + # Make emb_mean the same size as embeddings, as required by F.l1_loss. + emb_mean = embeddings.mean(mean_dim, keepdim=True).repeat(repeat_count) + l1_loss = F.l1_loss(embeddings, emb_mean) + # F.l2_loss doesn't take sqrt. So the loss is very small. + # Compute it manually. + l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt() + norms = torch.norm(embeddings, dim=1).detach().cpu().numpy() + print("L1: %.4f, L2: %.4f" %(l1_loss.item(), l2_loss.item())) + print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std())) + + +# Revised from RevGrad, by removing the grad negation. +class ScaleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, alpha_, debug=False): + ctx.save_for_backward(alpha_, debug) + output = input_ + if debug: + print(f"input: {input_.abs().mean().item()}") + return output + + @staticmethod + def backward(ctx, grad_output): # pragma: no cover + # saved_tensors returns a tuple of tensors. + alpha_, debug = ctx.saved_tensors + if ctx.needs_input_grad[0]: + grad_output2 = grad_output * alpha_ + if debug: + print(f"grad_output2: {grad_output2.abs().mean().item()}") + else: + grad_output2 = None + return grad_output2, None, None + +class GradientScaler(nn.Module): + def __init__(self, alpha=1., debug=False, *args, **kwargs): + """ + A gradient scaling layer. + This layer has no parameters, and simply scales the gradient in the backward pass. + """ + super().__init__(*args, **kwargs) + + self._alpha = torch.tensor(alpha, requires_grad=False) + self._debug = torch.tensor(debug, requires_grad=False) + + def forward(self, input_): + _debug = self._debug if hasattr(self, '_debug') else False + return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug) + +def gen_gradient_scaler(alpha, debug=False): + if alpha == 1: + return nn.Identity() + if alpha > 0: + return GradientScaler(alpha, debug=debug) + else: + assert alpha == 0 + # Don't use lambda function here, otherwise the object can't be pickled. + return torch.detach + +def pad_image_obj_to_square(image_obj, new_size=-1): + # Remove alpha channel if it exists. + if image_obj.mode == 'RGBA': + image_obj = image_obj.convert('RGB') + + # Pad input to be width == height + width, height = orig_size = image_obj.size + new_width, new_height = max(width, height), max(width, height) + + if width != height: + if width > height: + pads = (0, (width - height) // 2) + elif height > width: + pads = ((height - width) // 2, 0) + square_image_obj = Image.new("RGB", (new_width, new_height)) + # pads indicates the upper left corner to paste the input. + square_image_obj.paste(image_obj, pads) + #square_image_obj = square_image_obj.resize((512, 512)) + print(f"{width}x{height} -> {new_width}x{new_height} -> {square_image_obj.size}") + long_short_ratio = max(width, height) / min(width, height) + else: + square_image_obj = image_obj + pads = (0, 0) + long_short_ratio = 1 + + if new_size > 0: + # Resize the shorter edge to 512. + square_image_obj = square_image_obj.resize([int(new_size * long_short_ratio), int(new_size * long_short_ratio)]) + + return square_image_obj, pads, orig_size + +class UNetEnsemble(nn.Module): + # The first unet is the unet already loaded in a pipeline. + def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', torch_dtype=torch.float16): + super().__init__() + + self.unets = nn.ModuleList() + if unets is not None: + self.unets += [ unet.to(device) for unet in unets ] + + if unet_types is not None: + for unet_type in unet_types: + if unet_type == "arc2face": + from adaface.arc2face_models import create_arc2face_pipeline + unet = create_arc2face_pipeline(unet_only=True) + elif unet_type == "consistentID": + unet = create_consistentid_pipeline(unet_only=True) + else: + breakpoint() + self.unets.append(unet.to(device=device)) + + if extra_unet_dirpaths is not None: + for unet_path in extra_unet_dirpaths: + unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype) + self.unets.append(unet.to(device=device)) + + if unet_weights is None: + unet_weights = [1.] * len(self.unets) + elif len(self.unets) < len(unet_weights): + unet_weights = unet_weights[:len(self.unets)] + elif len(self.unets) > len(unet_weights): + breakpoint() + + unet_weights = torch.tensor(unet_weights, dtype=torch_dtype) + unet_weights = unet_weights / unet_weights.sum() + self.unet_weights = nn.Parameter(unet_weights, requires_grad=False) + + print(f"UNetEnsemble: {len(self.unets)} UNets loaded with weights: {self.unet_weights.data.cpu().numpy()}") + # Set these fields to be compatible with diffusers. + self.dtype = self.unets[0].dtype + self.device = self.unets[0].device + self.config = self.unets[0].config + + def forward(self, *args, **kwargs): + return_dict = kwargs.get('return_dict', True) + teacher_contexts = kwargs.pop('encoder_hidden_states', None) + # Only one teacher_context is provided. That means all unets will use the same teacher_context. + # We repeat the teacher_contexts to match the number of unets. + if not isinstance(teacher_contexts, (list, tuple)): + teacher_contexts = [teacher_contexts] + if len(teacher_contexts) == 1 and len(self.unets) > 1: + teacher_contexts = teacher_contexts * len(self.unets) + + samples = [] + + for unet, teacher_context in zip(self.unets, teacher_contexts): + sample = unet(encoder_hidden_states=teacher_context, *args, **kwargs) + if not return_dict: + sample = sample[0] + else: + sample = sample.sample + + samples.append(sample) + + samples = torch.stack(samples, dim=0) + unet_weights = self.unet_weights.reshape(-1, *([1] * (samples.ndim - 1))) + sample = (samples * unet_weights).sum(dim=0) + + if not return_dict: + return (sample,) + else: + return UNet2DConditionOutput(sample=sample) + +def create_consistentid_pipeline(base_model_path="models/sd15-dste8-vae.safetensors", + dtype=torch.float16, unet_only=False): + pipe = ConsistentIDPipeline.from_single_file(base_model_path) + # consistentID specific modules are still in fp32. Will be converted to fp16 + # later with .to(device, torch_dtype) by the caller. + pipe.load_ConsistentID_model( + consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin", + bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth", + ) + # Avoid passing dtype to ConsistentIDPipeline.from_single_file(), + # because we've overloaded .to() to convert consistentID specific modules as well, + # but diffusers will call .to(dtype) in .from_single_file(), + # and at that moment, the consistentID specific modules are not loaded yet. + pipe.to(dtype=dtype) + # We load the pipeline first, then use the unet in the pipeline. + # Since the pipeline initialization will load LoRA into the unet, + # now we have the unet with LoRA loaded. + if unet_only: + # We release text_encoder and VAE to save memory. + pipe.release_components(["text_encoder", "vae"]) + return pipe.unet + + noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + pipe.scheduler = noise_scheduler + + return pipe + +@dataclass +class BaseModelOutputWithPooling2(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + attn_mask: Optional[torch.FloatTensor] = None + +# Revised from CLIPVisionTransformer to support attention mask. +# self: a CLIPVisionTransformer instance. +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821 +# pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224] +# attn_mask: B*H*W attention mask. +def CLIPVisionTransformer_forward_with_mask(self, pixel_values = None, attn_mask=None, + output_attentions = None, + output_hidden_states = None, return_dict = None): + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Visual tokens are flattended in embeddings(). + # self.embeddings: CLIPVisionEmbeddings. + # hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds). + # 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False). + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + if attn_mask is not None: + # feat_edge_size: 16. + feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int) + # attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16]. + attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest') + # Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256]. + attn_mask = attn_mask.flatten(2) + # Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257]. + # This 1 corresponds to class_embeds, which is always attended to. + attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1) + attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1) + else: + attn_mask_pairs = None + + # encoder: CLIPEncoder. + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + # New feature: (***The official documentation is wrong***) + # attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*): + # Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`: + # - 1 for pairs that are **not masked**, + # - 0 for pairs that are **masked**. + # attention_mask is eventually used by CLIPEncoderLayer: + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370 + attention_mask=attn_mask_pairs, + output_attentions=output_attentions, # False + output_hidden_states=output_hidden_states, # True + return_dict=return_dict, # True + ) + + # last_hidden_state: [BS, 257, 1280] + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + # return_dict is True. + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling2( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + # Newly added: return resized flattened attention mask. + # [BS, 1, 257] -> [BS, 257, 1] + attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None + ) + +def CLIPVisionModel_forward_with_mask(self, pixel_values = None, attn_mask = None, output_attentions = None, + output_hidden_states = None, return_dict = None): + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + attn_mask=attn_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + +# patch_clip_image_encoder_with_mask() is applicable to both CLIPVisionModel and CLIPVisionModelWithProjection. +def patch_clip_image_encoder_with_mask(clip_image_encoder): + clip_image_encoder.vision_model.forward = CLIPVisionTransformer_forward_with_mask.__get__(clip_image_encoder.vision_model) + clip_image_encoder.forward = CLIPVisionModel_forward_with_mask.__get__(clip_image_encoder) + return clip_image_encoder + +class CLIPVisionModelWithMask(CLIPVisionModel): + def __init__(self, config): + super().__init__(config) + # Replace vision_model.forward() with the new one that supports mask. + patch_clip_image_encoder_with_mask(self) + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..01a330d8b24880c38f3e071994b1bedf5f3240ef --- /dev/null +++ b/app.py @@ -0,0 +1,293 @@ +import sys +sys.path.append('./') + +from adaface.adaface_wrapper import AdaFaceWrapper +import torch +import numpy as np +import random + +import gradio as gr +import spaces +import argparse +parser = argparse.ArgumentParser() +parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"], + choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders") +parser.add_argument('--adaface_ckpt_path', type=str, default='models/adaface/VGGface2_HQ_masks2024-10-14T16-09-24_zero3-ada-3500.pt', + help="Paths to the checkpoints of the ID2Ada prompt encoders") +# If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face). +parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None, + help="Scales for the ID2Ada prompt encoders") +parser.add_argument("--enabled_encoders", type=str, nargs="+", default=None, + choices=["arc2face", "consistentID"], + help="List of enabled encoders (among the list of adaface_encoder_types). Default: None (all enabled)") +parser.add_argument('--model_style_type', type=str, default='realistic', + choices=["realistic", "anime", "photorealistic"], help="Type of the base model") +parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*", default=[], + help="Extra paths to the checkpoints of the UNet models") +parser.add_argument('--unet_weights', type=float, nargs="+", default=[1], + help="Weights for the UNet models") +parser.add_argument("--guidance_scale", type=float, default=8.0, + help="The guidance scale for the diffusion model. Default: 8.0") +parser.add_argument("--do_neg_id_prompt_weight", type=float, default=0.0, + help="The weight of added ID prompt embeddings into the negative prompt. Default: 0, disabled.") + +parser.add_argument('--gpu', type=int, default=None) +parser.add_argument('--ip', type=str, default="0.0.0.0") +args = parser.parse_args() + +model_style_type2base_model_path = { + "realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors", + "anime": "models/aingdiffusion/aingdiffusion_v170_ar.safetensors", + "photorealistic": "models/sar/sar.safetensors" # LDM format. Needs to be converted. +} +base_model_path = model_style_type2base_model_path[args.model_style_type] + +# global variable +MAX_SEED = np.iinfo(np.int32).max +device = "cuda" if args.gpu is None else f"cuda:{args.gpu}" +print(f"Device: {device}") + +global adaface +adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path, + adaface_encoder_types=args.adaface_encoder_types, + adaface_ckpt_paths=args.adaface_ckpt_path, + adaface_encoder_cfg_scales=args.adaface_encoder_cfg_scales, + enabled_encoders=args.enabled_encoders, + unet_types=None, extra_unet_dirpaths=args.extra_unet_dirpaths, + unet_weights=args.unet_weights, device='cpu') + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, MAX_SEED) + return seed + +def swap_to_gallery(images): + # Update uploaded_files_gallery, show files, hide clear_button_column + # Or: + # Update uploaded_init_img_gallery, show init_img_files, hide init_clear_button_column + return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False) + +def remove_back_to_files(): + # Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx + # Or: + # Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx + return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True) + +@spaces.GPU +def generate_image(image_paths, guidance_scale, do_neg_id_prompt_weight, perturb_std, + num_images, prompt, negative_prompt, enhance_face, + seed, progress=gr.Progress(track_tqdm=True)): + + global adaface + + adaface.to(device) + + if image_paths is None or len(image_paths) == 0: + raise gr.Error(f"Cannot find any input face image! Please upload a face image.") + + if prompt is None: + prompt = "" + + adaface_subj_embs = \ + adaface.prepare_adaface_embeddings(image_paths=image_paths, face_id_embs=None, + avg_at_stage='id_emb', + perturb_at_stage='img_prompt_emb', + perturb_std=perturb_std, update_text_encoder=True) + + if adaface_subj_embs is None: + raise gr.Error(f"Failed to detect any faces! Please try with other images") + + # Sometimes the pipeline is on CPU, although we've put it on CUDA (due to some offloading mechanism). + # Therefore we set the generator to the correct device. + generator = torch.Generator(device=device).manual_seed(seed) + print(f"Manual seed: {seed}. do_neg_id_prompt_weight: {do_neg_id_prompt_weight}.") + # Generate two images each time for the user to select from. + noise = torch.randn(num_images, 3, 512, 512, device=device, generator=generator) + #print(noise.abs().sum()) + # samples: A list of PIL Image instances. + if enhance_face and "face portrait" not in prompt: + if "portrait" in prompt: + # Enhance the face features by replacing "portrait" with "face portrait". + prompt = prompt.replace("portrait", "face portrait") + else: + prompt = "face portrait, " + prompt + + generator = torch.Generator(device=adaface.pipeline._execution_device).manual_seed(seed) + samples = adaface(noise, prompt, negative_prompt, + do_neg_id_prompt_weight=do_neg_id_prompt_weight, + guidance_scale=guidance_scale, + out_image_count=num_images, generator=generator, verbose=True) + return samples + + +def check_prompt_and_model_type(prompt, model_style_type): + global adaface + + model_style_type = model_style_type.lower() + base_model_path = model_style_type2base_model_path[model_style_type] + # If the base model type is changed, reload the model. + if model_style_type != args.model_style_type: + adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path, + adaface_encoder_types=args.adaface_encoder_types, + adaface_ckpt_paths=args.adaface_ckpt_path, device='cpu') + # Update base model type. + args.model_style_type = model_style_type + + if not prompt: + raise gr.Error("Prompt cannot be blank") + +### Description +title = r""" +

AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization

+""" + +description = r""" +Official demo for our working paper AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization.
+ +❗️**What's New**❗️ +- Support switching between two model styles: **Realistic** and **Anime**. +- If you just changed the model style, the first image/video generation will take extra 20~30 seconds for loading new model weight. + +❗️**Tips**❗️ +1. Upload one or more images of a person. If multiple faces are detected, we use the largest one. +2. Check "Enhance Face" to highlight fine facial features. +3. If the face dominates the image, try increasing 'Weight of ID prompt in the negative prompt'. +4. AdaFace Text-to-Video: + AdaFace-Animate + Hugging Face Spaces + + +**TODO:** +- ControlNet integration. +""" + +css = ''' +.gradio-container {width: 95% !important}, +.custom-gallery { + height: 800px; + width: 100%; + margin: 10px auto; + padding: 10px; + overflow-y: auto; +} +''' +with gr.Blocks(css=css, theme=gr.themes.Origin()) as demo: + + # description + gr.Markdown(title) + gr.Markdown(description) + + with gr.Row(): + with gr.Column(): + + # upload face image + # img_file = gr.Image(label="Upload a photo with a face", type="filepath") + img_files = gr.File( + label="Drag / Select 1 or more photos of a person's face", + file_types=["image"], + file_count="multiple" + ) + uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=1, height=300) + with gr.Column(visible=False) as clear_button_column: + remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=img_files, size="sm") + + prompt = gr.Dropdown(label="Prompt", + info="Try something like 'walking on the beach'. If the face is not in focus, try checking 'enhance face'.", + value="portrait, ((best quality)), ((masterpiece)), ((realistic)), highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin", + allow_custom_value=True, + filterable=False, + choices=[ + "portrait, ((best quality)), ((masterpiece)), ((realistic)), highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin", + "portrait, walking on the beach, sunset, orange sky", + "portrait, in a white apron and chef hat, garnishing a gourmet dish", + "portrait, dancing pose among folks in a park, waving hands", + "portrait, in iron man costume, the sky ablaze with hues of orange and purple", + "portrait, jedi wielding a lightsaber, star wars, eye level shot", + "portrait, playing guitar on a boat, ocean waves", + "portrait, with a passion for reading, curled up with a book in a cozy nook near a window", + "portrait, running pose in a park, eye level shot", + "portrait, in superman costume, the sky ablaze with hues of orange and purple" + ]) + + enhance_face = gr.Checkbox(label="Enhance face", value=False, + info="Enhance the face features by prepending 'face portrait' to the prompt") + + submit = gr.Button("Submit", variant="primary") + + negative_prompt = gr.Textbox( + label="Negative Prompt", + value="flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, nude, naked, nsfw, topless, bare breasts", + ) + + guidance_scale = gr.Slider( + label="Guidance scale", + minimum=1.0, + maximum=12.0, + step=1.0, + value=args.guidance_scale, + ) + + do_neg_id_prompt_weight = gr.Slider( + label="Weight of ID prompt in the negative prompt", + minimum=0.0, + maximum=0.3, + step=0.1, + value=args.do_neg_id_prompt_weight, + visible=True, + ) + + model_style_type = gr.Dropdown( + label="Base Model Style Type", + info="Switching the base model type will take 10~20 seconds to reload the model", + value=args.model_style_type.capitalize(), + choices=["Realistic", "Anime", "Photorealistic"], + allow_custom_value=False, + filterable=False, + ) + + perturb_std = gr.Slider( + label="Std of noise added to input (may help stablize face embeddings)", + minimum=0.0, + maximum=0.05, + step=0.025, + value=0.0, + visible=False, + ) + num_images = gr.Slider( + label="Number of output images", + minimum=1, + maximum=6, + step=1, + value=4, + ) + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=0, + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True, info="Uncheck for reproducible results") + + with gr.Column(): + out_gallery = gr.Gallery(label="Generated Images", interactive=False, columns=2, rows=2, height=800, + elem_classes="custom-gallery") + + img_files.upload(fn=swap_to_gallery, inputs=img_files, outputs=[uploaded_files_gallery, clear_button_column, img_files]) + remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery, clear_button_column, img_files]) + + submit.click(fn=check_prompt_and_model_type, + inputs=[prompt, model_style_type],outputs=None).success( + fn=randomize_seed_fn, + inputs=[seed, randomize_seed], + outputs=seed, + queue=False, + api_name=False, + ).then( + fn=generate_image, + inputs=[img_files, guidance_scale, do_neg_id_prompt_weight, perturb_std, num_images, + prompt, negative_prompt, enhance_face, seed], + outputs=[out_gallery] + ) + +demo.launch(share=True, server_name=args.ip, ssl_verify=False) \ No newline at end of file diff --git a/models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth b/models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth new file mode 100644 index 0000000000000000000000000000000000000000..ca57f3257ca7715bc340d065764bc249d985c287 --- /dev/null +++ b/models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567 +size 53289463 diff --git a/models/arc2face/arc2face/config.json b/models/arc2face/arc2face/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b98bc620c69b69af13acbe9f56ddc0db18ca8a04 --- /dev/null +++ b/models/arc2face/arc2face/config.json @@ -0,0 +1,67 @@ +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.22.0", + "act_fn": "silu", + "addition_embed_type": null, + "addition_embed_type_num_heads": 64, + "addition_time_embed_dim": null, + "attention_head_dim": 8, + "attention_type": "default", + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "class_embed_type": null, + "class_embeddings_concat": false, + "conv_in_kernel": 3, + "conv_out_kernel": 3, + "cross_attention_dim": 768, + "cross_attention_norm": null, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "dropout": 0.0, + "dual_cross_attention": false, + "encoder_hid_dim": null, + "encoder_hid_dim_type": null, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_only_cross_attention": null, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_attention_heads": null, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 4, + "projection_class_embeddings_input_dim": null, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": false, + "resnet_time_scale_shift": "default", + "reverse_transformer_layers_per_block": null, + "sample_size": 64, + "time_cond_proj_dim": null, + "time_embedding_act_fn": null, + "time_embedding_dim": null, + "time_embedding_type": "positional", + "timestep_post_act": null, + "transformer_layers_per_block": 1, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ], + "upcast_attention": false, + "use_linear_projection": false +} diff --git a/models/arc2face/encoder/config.json b/models/arc2face/encoder/config.json new file mode 100644 index 0000000000000000000000000000000000000000..49c600b99d5738877abb69598b678f2a0005a309 --- /dev/null +++ b/models/arc2face/encoder/config.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "CLIPTextModel" + ], + "attention_dropout": 0.0, + "bos_token_id": 0, + "dropout": 0.0, + "eos_token_id": 2, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 77, + "model_type": "clip_text_model", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 1, + "projection_dim": 768, + "torch_dtype": "float32", + "transformers_version": "4.34.1", + "vocab_size": 49408 +} diff --git a/models/ensemble/ar18-unet/config.json b/models/ensemble/ar18-unet/config.json new file mode 100644 index 0000000000000000000000000000000000000000..1ce31f9032aaf1b832d3f2c95bafae22dbad15aa --- /dev/null +++ b/models/ensemble/ar18-unet/config.json @@ -0,0 +1,61 @@ +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.17.1", + "act_fn": "silu", + "addition_embed_type": null, + "addition_embed_type_num_heads": 64, + "attention_head_dim": 8, + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "class_embed_type": null, + "class_embeddings_concat": false, + "conv_in_kernel": 3, + "conv_out_kernel": 3, + "cross_attention_dim": 768, + "cross_attention_norm": null, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "encoder_hid_dim": null, + "encoder_hid_dim_type": null, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_only_cross_attention": null, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 4, + "projection_class_embeddings_input_dim": null, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": false, + "resnet_time_scale_shift": "default", + "sample_size": 64, + "time_cond_proj_dim": null, + "time_embedding_act_fn": null, + "time_embedding_dim": null, + "time_embedding_type": "positional", + "timestep_post_act": null, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ], + "upcast_attention": null, + "use_linear_projection": false +} diff --git a/models/ensemble/rv4-unet/config.json b/models/ensemble/rv4-unet/config.json new file mode 100644 index 0000000000000000000000000000000000000000..7b922c8ef09d51a58eea8c865feb668ca1fd5451 --- /dev/null +++ b/models/ensemble/rv4-unet/config.json @@ -0,0 +1,60 @@ +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.16.1", + "act_fn": "silu", + "addition_embed_type": null, + "addition_embed_type_num_heads": 64, + "attention_head_dim": 8, + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "class_embed_type": null, + "class_embeddings_concat": false, + "conv_in_kernel": 3, + "conv_out_kernel": 3, + "cross_attention_dim": 768, + "cross_attention_norm": null, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "encoder_hid_dim": null, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_only_cross_attention": null, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 4, + "projection_class_embeddings_input_dim": null, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": false, + "resnet_time_scale_shift": "default", + "sample_size": 64, + "time_cond_proj_dim": null, + "time_embedding_act_fn": null, + "time_embedding_dim": null, + "time_embedding_type": "positional", + "timestep_post_act": null, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ], + "upcast_attention": false, + "use_linear_projection": false +} diff --git a/models/insightface/models/antelopev2/2d106det.onnx b/models/insightface/models/antelopev2/2d106det.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cdb163d88b5f51396855ebc795e0114322c98b6b --- /dev/null +++ b/models/insightface/models/antelopev2/2d106det.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf +size 5030888 diff --git a/models/insightface/models/antelopev2/genderage.onnx b/models/insightface/models/antelopev2/genderage.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fcf638481cea978e99ddabd914ccd3b70c8401cb --- /dev/null +++ b/models/insightface/models/antelopev2/genderage.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb +size 1322532 diff --git a/models/insightface/models/antelopev2/scrfd_10g_bnkps.onnx b/models/insightface/models/antelopev2/scrfd_10g_bnkps.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa586e034379fa5ea5babc8aa73d47afcd0fa6c2 --- /dev/null +++ b/models/insightface/models/antelopev2/scrfd_10g_bnkps.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91 +size 16923827 diff --git a/models/insightface/models/buffalo_l/2d106det.onnx b/models/insightface/models/buffalo_l/2d106det.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cdb163d88b5f51396855ebc795e0114322c98b6b --- /dev/null +++ b/models/insightface/models/buffalo_l/2d106det.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf +size 5030888 diff --git a/models/insightface/models/buffalo_l/det_10g.onnx b/models/insightface/models/buffalo_l/det_10g.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa586e034379fa5ea5babc8aa73d47afcd0fa6c2 --- /dev/null +++ b/models/insightface/models/buffalo_l/det_10g.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91 +size 16923827 diff --git a/models/insightface/models/buffalo_l/genderage.onnx b/models/insightface/models/buffalo_l/genderage.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fcf638481cea978e99ddabd914ccd3b70c8401cb --- /dev/null +++ b/models/insightface/models/buffalo_l/genderage.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb +size 1322532 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ecaad8ce20b5b9038691388e38247c557fa4cd2b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +torch +torchvision +einops +gradio +transformers +insightface +omegaconf +opencv-python +diffusers==0.29.2 +onnx>=1.16.0 +onnxruntime +safetensors +spaces +ftfy