diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..d24153d0515dd26228f2eec44c407f67fb41bdf4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/comparison.jpg filter=lfs diff=lfs merge=lfs -text +assets/grid_cat2dog.jpg filter=lfs diff=lfs merge=lfs -text +assets/grid_dog2cat.jpg filter=lfs diff=lfs merge=lfs -text +assets/grid_horse2zebra.jpg filter=lfs diff=lfs merge=lfs -text +assets/grid_tree2fall.jpg filter=lfs diff=lfs merge=lfs -text +assets/grid_zebra2horse.jpg filter=lfs diff=lfs merge=lfs -text +assets/main.gif filter=lfs diff=lfs merge=lfs -text +assets/method.jpeg filter=lfs diff=lfs merge=lfs -text +assets/results_real.jpg filter=lfs diff=lfs merge=lfs -text +assets/results_syn.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0adcb74bc1fb618ab59832e94537e750f0627690 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +output +scripts +src/folder_*.py +src/ig_*.py +assets/edit_sentences +src/utils/edit_pipeline_spatial.py \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..cf00e660e60cd42dfd0be1c7515005b1ee2f0e9b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 pix2pixzero + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..9e279e3e5eba2d3baf793ec9957c88ee9db67d04 --- /dev/null +++ b/app.py @@ -0,0 +1,19 @@ +import os +import gradio as gr + +def update(name): + os.system('''python src/inversion.py \ + --input_image "assets/test_images/cats/cat_1.png" \ + --results_folder "output/test_cat" + ''') + return f"Inverted!" + +with gr.Blocks() as demo: + gr.Markdown("Start typing below and then click **Run** to see the output.") + with gr.Row(): + inp = gr.Textbox(placeholder="Do you want to invert?") + out = gr.Textbox() + btn = gr.Button("Run") + btn.click(fn=update, inputs=inp, outputs=out) + +demo.launch() \ No newline at end of file diff --git a/assets/.DS_Store b/assets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/assets/.DS_Store differ diff --git a/assets/comparison.jpg b/assets/comparison.jpg new file mode 100644 index 0000000000000000000000000000000000000000..87800e604da369660b15aa139a31b77810ad6e4a --- /dev/null +++ b/assets/comparison.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:acab8ed1680a42dd2f540e8188a43eb0d101895fca8ed36c0e06c8b351d2c276 +size 3389665 diff --git a/assets/embeddings_sd_1.4/cat.pt b/assets/embeddings_sd_1.4/cat.pt new file mode 100644 index 0000000000000000000000000000000000000000..acb972c5e8bf6e14029516ab106cd667c1ff735f --- /dev/null +++ b/assets/embeddings_sd_1.4/cat.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa9441dc014d5e86567c5ef165e10b50d2a7b3a68d90686d0cd1006792adf334 +size 237300 diff --git a/assets/embeddings_sd_1.4/dog.pt b/assets/embeddings_sd_1.4/dog.pt new file mode 100644 index 0000000000000000000000000000000000000000..d7abb6c17268edc1aba7d89df4283d70fca1221e --- /dev/null +++ b/assets/embeddings_sd_1.4/dog.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:becf079d61d7f35727bcc0d8506ddcdcddb61e62d611840ff3d18eca7fb6338c +size 237300 diff --git a/assets/embeddings_sd_1.4/horse.pt b/assets/embeddings_sd_1.4/horse.pt new file mode 100644 index 0000000000000000000000000000000000000000..770452bf6a98237a44756859bd7457e04e8dbdae --- /dev/null +++ b/assets/embeddings_sd_1.4/horse.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5d499299544d11371f84674761292b0512055ef45776c700c0b0da164cbf6c7 +size 118949 diff --git a/assets/embeddings_sd_1.4/zebra.pt b/assets/embeddings_sd_1.4/zebra.pt new file mode 100644 index 0000000000000000000000000000000000000000..b959a6b0fab0204321b7b6b2b3922490b3632814 --- /dev/null +++ b/assets/embeddings_sd_1.4/zebra.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a29f6a11d91f3a276e27326b7623fae9d61a3d253ad430bb868bd40fb7e02fec +size 118949 diff --git a/assets/grid_cat2dog.jpg b/assets/grid_cat2dog.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b1982bf17ed748e90301952bd4c6cf7293637be9 --- /dev/null +++ b/assets/grid_cat2dog.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0080134b70277af723e25c4627494fda8555d43a9f6376e682b67b3341d1f1f3 +size 1212309 diff --git a/assets/grid_dog2cat.jpg b/assets/grid_dog2cat.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6eea5a33ce6f90e4e973c9f9086133014008086f --- /dev/null +++ b/assets/grid_dog2cat.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e5059ec1ad8e4b07fe8b715295e82fcead652b9c366733793674e84d51427d9 +size 1248487 diff --git a/assets/grid_horse2zebra.jpg b/assets/grid_horse2zebra.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e4070b888d1dbc6c725127d83e2efc6c7e2e779a --- /dev/null +++ b/assets/grid_horse2zebra.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a31e0a456e9323697c966e675b02403511ebf0b7c334416a8da91df1c14723df +size 1052534 diff --git a/assets/grid_tree2fall.jpg b/assets/grid_tree2fall.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5119215b29acae75ff0c4afb567edc1d475effc5 --- /dev/null +++ b/assets/grid_tree2fall.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:559ab066e4ef0972748d0a7f004d2ca18fd15062c667ac6665309727f6dc0cc8 +size 1628989 diff --git a/assets/grid_zebra2horse.jpg b/assets/grid_zebra2horse.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4affec39f117cf7f2d37835fc03d5a4699a4eaa6 --- /dev/null +++ b/assets/grid_zebra2horse.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b44b4aa4576be49289515f0aa9023dfd4424b3ba2476c66516b876dd83a06713 +size 1053305 diff --git a/assets/main.gif b/assets/main.gif new file mode 100644 index 0000000000000000000000000000000000000000..8881d22ed205bafe4462b176e44ac340e8be6ec9 --- /dev/null +++ b/assets/main.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1ebc380a461c4847beece13bdc9b5ea88312e8a8013f384eb8809109ff198fc +size 6188602 diff --git a/assets/method.jpeg b/assets/method.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..5f35b3262ccef4c3337fd5c89dba1b82fa74e369 --- /dev/null +++ b/assets/method.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b1b4ea3608b9ad3797c4c7423bf2fd88e5e24f34fecbb00d3d2de22a99fd2ee +size 2351071 diff --git a/assets/results_real.jpg b/assets/results_real.jpg new file mode 100644 index 0000000000000000000000000000000000000000..05024f66d6c2fcbc406aef937986c5e168dfa0ff --- /dev/null +++ b/assets/results_real.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94095526e76b7a000ed56df15f7b5208c0f5a069b20b04fc9bcade14c54d92dc +size 1484789 diff --git a/assets/results_syn.jpg b/assets/results_syn.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a87d6220a5b3428cda272654b9b3f328700b074b --- /dev/null +++ b/assets/results_syn.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5731190e33098406995de563ca12bd6d2f84d9db725618a6d6580b4d1f2f0813 +size 1275841 diff --git a/assets/results_teaser.jpg b/assets/results_teaser.jpg new file mode 100644 index 0000000000000000000000000000000000000000..046c6c9427738c860e88316b27ca588dd30ee4c2 Binary files /dev/null and b/assets/results_teaser.jpg differ diff --git a/assets/test_images/cats/cat_1.png b/assets/test_images/cats/cat_1.png new file mode 100644 index 0000000000000000000000000000000000000000..75ed605314a6f54422a17aed4d7f3f34462b5d45 Binary files /dev/null and b/assets/test_images/cats/cat_1.png differ diff --git a/assets/test_images/cats/cat_2.png b/assets/test_images/cats/cat_2.png new file mode 100644 index 0000000000000000000000000000000000000000..33e72b884822ead8ac757be06aae89efdaa0cf5b Binary files /dev/null and b/assets/test_images/cats/cat_2.png differ diff --git a/assets/test_images/cats/cat_3.png b/assets/test_images/cats/cat_3.png new file mode 100644 index 0000000000000000000000000000000000000000..2c128950a61ffba3a2441390b0382dce6d01eeed Binary files /dev/null and b/assets/test_images/cats/cat_3.png differ diff --git a/assets/test_images/cats/cat_4.png b/assets/test_images/cats/cat_4.png new file mode 100644 index 0000000000000000000000000000000000000000..60f47883a547e6c32be3b10a9e0f45554b839dd1 Binary files /dev/null and b/assets/test_images/cats/cat_4.png differ diff --git a/assets/test_images/cats/cat_5.png b/assets/test_images/cats/cat_5.png new file mode 100644 index 0000000000000000000000000000000000000000..5c0c2ce2c3f5c4ddfda0ce26fbbead4b52683ee6 Binary files /dev/null and b/assets/test_images/cats/cat_5.png differ diff --git a/assets/test_images/cats/cat_6.png b/assets/test_images/cats/cat_6.png new file mode 100644 index 0000000000000000000000000000000000000000..2ad6ca42cd185f2f8fd7fb164958c91525dded8e Binary files /dev/null and b/assets/test_images/cats/cat_6.png differ diff --git a/assets/test_images/cats/cat_7.png b/assets/test_images/cats/cat_7.png new file mode 100644 index 0000000000000000000000000000000000000000..021a8753ae829e9ead908680927af543ef33759a Binary files /dev/null and b/assets/test_images/cats/cat_7.png differ diff --git a/assets/test_images/cats/cat_8.png b/assets/test_images/cats/cat_8.png new file mode 100644 index 0000000000000000000000000000000000000000..4f25b1ac68839dcea11a2108f55bc6aefe36d10b Binary files /dev/null and b/assets/test_images/cats/cat_8.png differ diff --git a/assets/test_images/cats/cat_9.png b/assets/test_images/cats/cat_9.png new file mode 100644 index 0000000000000000000000000000000000000000..97227e0eb990e011546cabba2b040f8a4a171ab5 Binary files /dev/null and b/assets/test_images/cats/cat_9.png differ diff --git a/assets/test_images/dogs/dog_1.png b/assets/test_images/dogs/dog_1.png new file mode 100644 index 0000000000000000000000000000000000000000..d8994f63ad8ece9f2103d0e9397d5898a8634dcb Binary files /dev/null and b/assets/test_images/dogs/dog_1.png differ diff --git a/assets/test_images/dogs/dog_2.png b/assets/test_images/dogs/dog_2.png new file mode 100644 index 0000000000000000000000000000000000000000..ff9f5981120f466615f5c871822d69f3548cbdf6 Binary files /dev/null and b/assets/test_images/dogs/dog_2.png differ diff --git a/assets/test_images/dogs/dog_3.png b/assets/test_images/dogs/dog_3.png new file mode 100644 index 0000000000000000000000000000000000000000..bf7d779ece8c516241ef8af1a084c2fc53a48c48 Binary files /dev/null and b/assets/test_images/dogs/dog_3.png differ diff --git a/assets/test_images/dogs/dog_4.png b/assets/test_images/dogs/dog_4.png new file mode 100644 index 0000000000000000000000000000000000000000..5776103705a90b489e54cd60ce1710dbcca4a078 Binary files /dev/null and b/assets/test_images/dogs/dog_4.png differ diff --git a/assets/test_images/dogs/dog_5.png b/assets/test_images/dogs/dog_5.png new file mode 100644 index 0000000000000000000000000000000000000000..4933dbf31565f44b104b8afdab00b170598fd99c Binary files /dev/null and b/assets/test_images/dogs/dog_5.png differ diff --git a/assets/test_images/dogs/dog_6.png b/assets/test_images/dogs/dog_6.png new file mode 100644 index 0000000000000000000000000000000000000000..9b61080a2fda72ec8495b76a3cdfd1d7fee2d5e0 Binary files /dev/null and b/assets/test_images/dogs/dog_6.png differ diff --git a/assets/test_images/dogs/dog_7.png b/assets/test_images/dogs/dog_7.png new file mode 100644 index 0000000000000000000000000000000000000000..87329a09ff50cfcfe0c2163e9d44275acb117ce3 Binary files /dev/null and b/assets/test_images/dogs/dog_7.png differ diff --git a/assets/test_images/dogs/dog_8.png b/assets/test_images/dogs/dog_8.png new file mode 100644 index 0000000000000000000000000000000000000000..01e97facd3d4ece3d5b46b49784fd4fd4c2ef4d5 Binary files /dev/null and b/assets/test_images/dogs/dog_8.png differ diff --git a/assets/test_images/dogs/dog_9.png b/assets/test_images/dogs/dog_9.png new file mode 100644 index 0000000000000000000000000000000000000000..05357065dd6795ec1a4bdcf2397cafe2cebb0f54 Binary files /dev/null and b/assets/test_images/dogs/dog_9.png differ diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..a153acd176f0c325fe7ab7235c68f114c1d75f27 --- /dev/null +++ b/environment.yml @@ -0,0 +1,23 @@ +name: pix2pix-zero +channels: + - pytorch + - nvidia + - defaults +dependencies: + - pip + - pytorch-cuda=11.6 + - torchvision + - pytorch + - pip: + - accelerate + - diffusers + - einops + - gradio + - ipython + - numpy + - opencv-python-headless + - pillow + - psutil + - tqdm + - transformers + - salesforce-lavis diff --git a/output/test_cat/edit/cat_1.png b/output/test_cat/edit/cat_1.png new file mode 100644 index 0000000000000000000000000000000000000000..33b7b147f751f9c4c280a637e3b16e6439426ecf Binary files /dev/null and b/output/test_cat/edit/cat_1.png differ diff --git a/output/test_cat/inversion/cat_1.pt b/output/test_cat/inversion/cat_1.pt new file mode 100644 index 0000000000000000000000000000000000000000..1becf41016260001df87144e6c0c326a0171a7ac --- /dev/null +++ b/output/test_cat/inversion/cat_1.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1fd7cd2554d2d695841ede6038f7906b50085841706a2f62429ee32c08a0dc85 +size 66283 diff --git a/output/test_cat/prompt/cat_1.txt b/output/test_cat/prompt/cat_1.txt new file mode 100644 index 0000000000000000000000000000000000000000..29a2cd5b1da5ab3605b0a6eeaa194d2395a16828 --- /dev/null +++ b/output/test_cat/prompt/cat_1.txt @@ -0,0 +1 @@ +a dog with his paws on top of a ball, painting \ No newline at end of file diff --git a/output/test_cat/reconstruction/cat_1.png b/output/test_cat/reconstruction/cat_1.png new file mode 100644 index 0000000000000000000000000000000000000000..3163ff0dee957b5661ee0812b195a16e322ea0e1 Binary files /dev/null and b/output/test_cat/reconstruction/cat_1.png differ diff --git a/output/test_cat2/edit/cat_1.png b/output/test_cat2/edit/cat_1.png new file mode 100644 index 0000000000000000000000000000000000000000..919506fe90d6657fca719a533db3b9dd41dfa4f0 Binary files /dev/null and b/output/test_cat2/edit/cat_1.png differ diff --git a/output/test_cat2/reconstruction/cat_1.png b/output/test_cat2/reconstruction/cat_1.png new file mode 100644 index 0000000000000000000000000000000000000000..081c91694cfac90ea36458217b1236578d4afd32 Binary files /dev/null and b/output/test_cat2/reconstruction/cat_1.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b60f9c97fd2ef7c78c0c99be7247d227898cee24 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +accelerate +diffusers +einops +numpy +opencv-python-headless +pillow +psutil +transformers +tqdm +pytorch +salesforce-lavis \ No newline at end of file diff --git a/src/edit_real.py b/src/edit_real.py new file mode 100644 index 0000000000000000000000000000000000000000..5f801165bc299fa72b4e0bdf4a112f6ece7edb70 --- /dev/null +++ b/src/edit_real.py @@ -0,0 +1,65 @@ +import os, pdb + +import argparse +import numpy as np +import torch +import requests +from PIL import Image + +from diffusers import DDIMScheduler +from utils.ddim_inv import DDIMInversion +from utils.edit_directions import construct_direction +from utils.edit_pipeline import EditingPipeline + + +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--inversion', required=True) + parser.add_argument('--prompt', type=str, required=True) + parser.add_argument('--task_name', type=str, default='cat2dog') + parser.add_argument('--results_folder', type=str, default='output/test_cat') + parser.add_argument('--num_ddim_steps', type=int, default=50) + parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') + parser.add_argument('--xa_guidance', default=0.1, type=float) + parser.add_argument('--negative_guidance_scale', default=5.0, type=float) + parser.add_argument('--use_float_16', action='store_true') + + args = parser.parse_args() + + os.makedirs(os.path.join(args.results_folder, "edit"), exist_ok=True) + os.makedirs(os.path.join(args.results_folder, "reconstruction"), exist_ok=True) + + if args.use_float_16: + torch_dtype = torch.float16 + else: + torch_dtype = torch.float32 + + # if the inversion is a folder, the prompt should also be a folder + assert (os.path.isdir(args.inversion)==os.path.isdir(args.prompt)), "If the inversion is a folder, the prompt should also be a folder" + if os.path.isdir(args.inversion): + l_inv_paths = sorted(glob(os.path.join(args.inversion, "*.pt"))) + l_bnames = [os.path.basename(x) for x in l_inv_paths] + l_prompt_paths = [os.path.join(args.prompt, x.replace(".pt",".txt")) for x in l_bnames] + else: + l_inv_paths = [args.inversion] + l_prompt_paths = [args.prompt] + + # Make the editing pipeline + pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda") + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + + + for inv_path, prompt_path in zip(l_inv_paths, l_prompt_paths): + prompt_str = open(prompt_path).read().strip() + rec_pil, edit_pil = pipe(prompt_str, + num_inference_steps=args.num_ddim_steps, + x_in=torch.load(inv_path).unsqueeze(0), + edit_dir=construct_direction(args.task_name), + guidance_amount=args.xa_guidance, + guidance_scale=args.negative_guidance_scale, + negative_prompt=prompt_str # use the unedited prompt for the negative prompt + ) + + bname = os.path.basename(args.inversion).split(".")[0] + edit_pil[0].save(os.path.join(args.results_folder, f"edit/{bname}.png")) + rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction/{bname}.png")) diff --git a/src/edit_synthetic.py b/src/edit_synthetic.py new file mode 100644 index 0000000000000000000000000000000000000000..a1c35a005f28ecff7511e42d10afcdaee7f7c5cc --- /dev/null +++ b/src/edit_synthetic.py @@ -0,0 +1,52 @@ +import os, pdb + +import argparse +import numpy as np +import torch +import requests +from PIL import Image + +from diffusers import DDIMScheduler +from utils.edit_directions import construct_direction +from utils.edit_pipeline import EditingPipeline + + +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--prompt_str', type=str, required=True) + parser.add_argument('--random_seed', default=0) + parser.add_argument('--task_name', type=str, default='cat2dog') + parser.add_argument('--results_folder', type=str, default='output/test_cat') + parser.add_argument('--num_ddim_steps', type=int, default=50) + parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') + parser.add_argument('--xa_guidance', default=0.15, type=float) + parser.add_argument('--negative_guidance_scale', default=5.0, type=float) + parser.add_argument('--use_float_16', action='store_true') + args = parser.parse_args() + + os.makedirs(args.results_folder, exist_ok=True) + + if args.use_float_16: + torch_dtype = torch.float16 + else: + torch_dtype = torch.float32 + + # make the input noise map + torch.cuda.manual_seed(args.random_seed) + x = torch.randn((1,4,64,64), device="cuda") + + # Make the editing pipeline + pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda") + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + + rec_pil, edit_pil = pipe(args.prompt_str, + num_inference_steps=args.num_ddim_steps, + x_in=x, + edit_dir=construct_direction(args.task_name), + guidance_amount=args.xa_guidance, + guidance_scale=args.negative_guidance_scale, + negative_prompt="" # use the empty string for the negative prompt + ) + + edit_pil[0].save(os.path.join(args.results_folder, f"edit.png")) + rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction.png")) diff --git a/src/inversion.py b/src/inversion.py new file mode 100644 index 0000000000000000000000000000000000000000..62c6f9fd587f9154ffd695e0d6fe29dd654ba11d --- /dev/null +++ b/src/inversion.py @@ -0,0 +1,64 @@ +import os, pdb + +import argparse +import numpy as np +import torch +import requests +from PIL import Image + +from lavis.models import load_model_and_preprocess + +from utils.ddim_inv import DDIMInversion +from utils.scheduler import DDIMInverseScheduler + +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png') + parser.add_argument('--results_folder', type=str, default='output/test_cat') + parser.add_argument('--num_ddim_steps', type=int, default=50) + parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') + parser.add_argument('--use_float_16', action='store_true') + args = parser.parse_args() + + # make the output folders + os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True) + os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True) + + if args.use_float_16: + torch_dtype = torch.float16 + else: + torch_dtype = torch.float32 + + + # load the BLIP model + model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda")) + # make the DDIM inversion pipeline + pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda") + pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) + + + # if the input is a folder, collect all the images as a list + if os.path.isdir(args.input_image): + l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png"))) + else: + l_img_paths = [args.input_image] + + + for img_path in l_img_paths: + bname = os.path.basename(args.input_image).split(".")[0] + img = Image.open(args.input_image).resize((512,512), Image.Resampling.LANCZOS) + # generate the caption + _image = vis_processors["eval"](img).unsqueeze(0).cuda() + prompt_str = model_blip.generate({"image": _image})[0] + x_inv, x_inv_image, x_dec_img = pipe( + prompt_str, + guidance_scale=1, + num_inversion_steps=args.num_ddim_steps, + img=img, + torch_dtype=torch_dtype + ) + # save the inversion + torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt")) + # save the prompt string + with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f: + f.write(prompt_str) diff --git a/src/make_edit_direction.py b/src/make_edit_direction.py new file mode 100644 index 0000000000000000000000000000000000000000..d6307694847e6f98749390ccb02c0b5ef6e2b67f --- /dev/null +++ b/src/make_edit_direction.py @@ -0,0 +1,61 @@ +import os, pdb + +import argparse +import numpy as np +import torch +import requests +from PIL import Image + +from diffusers import DDIMScheduler +from utils.edit_pipeline import EditingPipeline + + +## convert sentences to sentence embeddings +def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"): + with torch.no_grad(): + l_embeddings = [] + for sent in l_sentences: + text_inputs = tokenizer( + sent, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] + l_embeddings.append(prompt_embeds) + return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0) + + +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--file_source_sentences', required=True) + parser.add_argument('--file_target_sentences', required=True) + parser.add_argument('--output_folder', required=True) + parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') + args = parser.parse_args() + + # load the model + pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda") + bname_src = os.path.basename(args.file_source_sentences).strip(".txt") + outf_src = os.path.join(args.output_folder, bname_src+".pt") + if os.path.exists(outf_src): + print(f"Skipping source file {outf_src} as it already exists") + else: + with open(args.file_source_sentences, "r") as f: + l_sents = [x.strip() for x in f.readlines()] + mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda") + print(mean_emb.shape) + torch.save(mean_emb, outf_src) + + bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt") + outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt") + if os.path.exists(outf_tgt): + print(f"Skipping target file {outf_tgt} as it already exists") + else: + with open(args.file_target_sentences, "r") as f: + l_sents = [x.strip() for x in f.readlines()] + mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda") + print(mean_emb.shape) + torch.save(mean_emb, outf_tgt) diff --git a/src/utils/base_pipeline.py b/src/utils/base_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..ff084fbfc70d90aa70dcbcf516d35ed2882624ec --- /dev/null +++ b/src/utils/base_pipeline.py @@ -0,0 +1,322 @@ + +import torch +import inspect +from packaging import version +from typing import Any, Callable, Dict, List, Optional, Union + +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from diffusers import DiffusionPipeline +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + + +class BasePipeline(DiffusionPipeline): + _optional_components = ["safety_checker", "feature_extractor"] + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + diff --git a/src/utils/cross_attention.py b/src/utils/cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..01dbbcd0dde9c0453b59be34070d6d5d39b3ceb3 --- /dev/null +++ b/src/utils/cross_attention.py @@ -0,0 +1,57 @@ +import torch +from diffusers.models.attention import CrossAttention + +class MyCrossAttnProcessor: + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + # new bookkeeping to save the attn probs + attn.attn_probs = attention_probs + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +""" +A function that prepares a U-Net model for training by enabling gradient computation +for a specified set of parameters and setting the forward pass to be performed by a +custom cross attention processor. + +Parameters: +unet: A U-Net model. + +Returns: +unet: The prepared U-Net model. +""" +def prep_unet(unet): + # set the gradients for XA maps to be true + for name, params in unet.named_parameters(): + if 'attn2' in name: + params.requires_grad = True + else: + params.requires_grad = False + # replace the fwd function + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.set_processor(MyCrossAttnProcessor()) + return unet diff --git a/src/utils/ddim_inv.py b/src/utils/ddim_inv.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc701e2fb0736960b7799965a4ff050cc464e97 --- /dev/null +++ b/src/utils/ddim_inv.py @@ -0,0 +1,140 @@ +import sys +import numpy as np +import torch +import torch.nn.functional as F +from random import randrange +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from diffusers import DDIMScheduler +from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +sys.path.insert(0, "src/utils") +from base_pipeline import BasePipeline +from cross_attention import prep_unet + + +class DDIMInversion(BasePipeline): + + def auto_corr_loss(self, x, random_shift=True): + B,C,H,W = x.shape + assert B==1 + x = x.squeeze(0) + # x must be shape [C,H,W] now + reg_loss = 0.0 + for ch_idx in range(x.shape[0]): + noise = x[ch_idx][None, None,:,:] + while True: + if random_shift: roll_amount = randrange(noise.shape[2]//2) + else: roll_amount = 1 + reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2 + reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2 + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + return reg_loss + + def kl_divergence(self, x): + _mu = x.mean() + _var = x.var() + return _var + _mu**2 - 1 - torch.log(_var+1e-7) + + + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inversion_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + img=None, # the input image as a PIL image + torch_dtype=torch.float32, + + # inversion regularization parameters + lambda_ac: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 5, + num_ac_rolls: int = 5, + ): + + # 0. modify the unet to be useful :D + self.unet = prep_unet(self.unet) + + # set the scheduler to be the Inverse DDIM scheduler + # self.scheduler = MyDDIMScheduler.from_config(self.scheduler.config) + + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 + self.scheduler.set_timesteps(num_inversion_steps, device=device) + timesteps = self.scheduler.timesteps + + # Encode the input image with the first stage model + x0 = np.array(img)/255 + x0 = torch.from_numpy(x0).type(torch_dtype).permute(2, 0, 1).unsqueeze(dim=0).repeat(1, 1, 1, 1).cuda() + x0 = (x0 - 0.5) * 2. + with torch.no_grad(): + x0_enc = self.vae.encode(x0).latent_dist.sample().to(device, torch_dtype) + latents = x0_enc = 0.18215 * x0_enc + + # Decode and return the image + with torch.no_grad(): + x0_dec = self.decode_latents(x0_enc.detach()) + image_x0_dec = self.numpy_to_pil(x0_dec) + + with torch.no_grad(): + prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt).to(device) + extra_step_kwargs = self.prepare_extra_step_kwargs(None, eta) + + # Do the inversion + num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order # should be 0? + with self.progress_bar(total=num_inversion_steps) as progress_bar: + for i, t in enumerate(timesteps.flip(0)[1:-1]): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + with torch.no_grad(): + noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction + e_t = noise_pred + for _outer in range(num_reg_steps): + if lambda_ac>0: + for _inner in range(num_ac_rolls): + _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True) + l_ac = self.auto_corr_loss(_var) + l_ac.backward() + _grad = _var.grad.detach()/num_ac_rolls + e_t = e_t - lambda_ac*_grad + if lambda_kl>0: + _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True) + l_kld = self.kl_divergence(_var) + l_kld.backward() + _grad = _var.grad.detach() + e_t = e_t - lambda_kl*_grad + e_t = e_t.detach() + noise_pred = e_t + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, reverse=True, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + + x_inv = latents.detach().clone() + # reconstruct the image + + # 8. Post-processing + image = self.decode_latents(latents.detach()) + image = self.numpy_to_pil(image) + return x_inv, image, image_x0_dec \ No newline at end of file diff --git a/src/utils/edit_directions.py b/src/utils/edit_directions.py new file mode 100644 index 0000000000000000000000000000000000000000..7025e20ad8cae7e8d7cb7854c313e4a7da07d089 --- /dev/null +++ b/src/utils/edit_directions.py @@ -0,0 +1,29 @@ +import os +import torch + + +""" +This function takes in a task name and returns the direction in the embedding space that transforms class A to class B for the given task. + +Parameters: +task_name (str): name of the task for which direction is to be constructed. + +Returns: +torch.Tensor: A tensor representing the direction in the embedding space that transforms class A to class B. + +Examples: +>>> construct_direction("cat2dog") +""" +def construct_direction(task_name): + if task_name=="cat2dog": + emb_dir = f"assets/embeddings_sd_1.4" + embs_a = torch.load(os.path.join(emb_dir, f"cat.pt")) + embs_b = torch.load(os.path.join(emb_dir, f"dog.pt")) + return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0) + elif task_name=="dog2cat": + emb_dir = f"assets/embeddings_sd_1.4" + embs_a = torch.load(os.path.join(emb_dir, f"dog.pt")) + embs_b = torch.load(os.path.join(emb_dir, f"cat.pt")) + return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0) + else: + raise NotImplementedError diff --git a/src/utils/edit_pipeline.py b/src/utils/edit_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0ab3b45943ce48148267f9ec4191f8d7d541f3 --- /dev/null +++ b/src/utils/edit_pipeline.py @@ -0,0 +1,174 @@ +import pdb, sys + +import numpy as np +import torch +from typing import Any, Callable, Dict, List, Optional, Union +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +sys.path.insert(0, "src/utils") +from base_pipeline import BasePipeline +from cross_attention import prep_unet + + +class EditingPipeline(BasePipeline): + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + + # pix2pix parameters + guidance_amount=0.1, + edit_dir=None, + x_in=None, + + ): + + x_in.to(dtype=self.unet.dtype, device=self._execution_device) + + # 0. modify the unet to be useful :D + self.unet = prep_unet(self.unet) + + # 1. setup all caching objects + d_ref_t2attn = {} # reference cross attention maps + + # 2. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # TODO: add the input checker function + # self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 + x_in = x_in.to(dtype=self.unet.dtype, device=self._execution_device) + # 3. Encode input prompt = 2x77x1024 + prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + + # randomly sample a latent code if not provided + latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,) + + latents_init = latents.clone() + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. First Denoising loop for getting the reference cross attention maps + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with torch.no_grad(): + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample + + # add the cross attention map to the dictionary + d_ref_t2attn[t.item()] = {} + for name, module in self.unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention" and 'attn2' in name: + attn_mask = module.attn_probs # size is num_channel,s*s,77 + d_ref_t2attn[t.item()][name] = attn_mask.detach().cpu() + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # make the reference image (reconstruction) + image_rec = self.numpy_to_pil(self.decode_latents(latents.detach())) + + prompt_embeds_edit = prompt_embeds.clone() + #add the edit only to the second prompt, idx 0 is the negative prompt + prompt_embeds_edit[1:2] += edit_dir + + latents = latents_init + # Second denoising loop for editing the text prompt + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + x_in = latent_model_input.detach().clone() + x_in.requires_grad = True + + opt = torch.optim.SGD([x_in], lr=guidance_amount) + + # predict the noise residual + noise_pred = self.unet(x_in,t,encoder_hidden_states=prompt_embeds_edit.detach(),cross_attention_kwargs=cross_attention_kwargs,).sample + + loss = 0.0 + for name, module in self.unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention" and 'attn2' in name: + curr = module.attn_probs # size is num_channel,s*s,77 + ref = d_ref_t2attn[t.item()][name].detach().cuda() + loss += ((curr-ref)**2).sum((1,2)).mean(0) + loss.backward(retain_graph=False) + opt.step() + + # recompute the noise + with torch.no_grad(): + noise_pred = self.unet(x_in.detach(),t,encoder_hidden_states=prompt_embeds_edit,cross_attention_kwargs=cross_attention_kwargs,).sample + + latents = x_in.detach().chunk(2)[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + + # 8. Post-processing + image = self.decode_latents(latents.detach()) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image_edit = self.numpy_to_pil(image) + + + return image_rec, image_edit diff --git a/src/utils/scheduler.py b/src/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..282e036db77c48fc12e2756cdfc72b4dcc1cdc30 --- /dev/null +++ b/src/utils/scheduler.py @@ -0,0 +1,289 @@ +# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion +import os, sys, pdb +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, randn_tensor +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas) + + +class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one (`bool`, default `True`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.config.steps_offset + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + reverse=False + ) -> Union[DDIMSchedulerOutput, Tuple]: + + + e_t = model_output + + x = sample + prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps + # print(timestep, prev_timestep) + a_t = alpha_prod_t = self.alphas_cumprod[timestep-1] + a_prev = alpha_t_prev = self.alphas_cumprod[prev_timestep-1] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + + pred_x0 = (x - (1-a_t)**0.5 * e_t) / a_t.sqrt() + # direction pointing to x_t + dir_xt = (1. - a_prev).sqrt() * e_t + x = a_prev.sqrt()*pred_x0 + dir_xt + if not return_dict: + return (x,) + return DDIMSchedulerOutput(prev_sample=x, pred_original_sample=pred_x0) + + + + + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps