alvanlii commited on
Commit
7e0bf18
·
0 Parent(s):

Duplicate from alvanlii/pix2pix_zero

Browse files
Files changed (45) hide show
  1. .gitattributes +34 -0
  2. .gitignore +160 -0
  3. LICENSE +21 -0
  4. README.md +15 -0
  5. app.py +212 -0
  6. assets/.DS_Store +0 -0
  7. assets/capy.txt +22 -0
  8. assets/dogs_with_glasses.txt +29 -0
  9. assets/embeddings_sd_1.4/capy.pt +3 -0
  10. assets/embeddings_sd_1.4/cat.pt +3 -0
  11. assets/embeddings_sd_1.4/dog.pt +3 -0
  12. assets/embeddings_sd_1.4/dogs_with_glasses.pt +3 -0
  13. assets/embeddings_sd_1.4/horse.pt +3 -0
  14. assets/embeddings_sd_1.4/llama.pt +3 -0
  15. assets/embeddings_sd_1.4/zebra.pt +3 -0
  16. assets/llama.txt +15 -0
  17. assets/test_images/cats/cat_1.png +0 -0
  18. assets/test_images/cats/cat_2.png +0 -0
  19. assets/test_images/cats/cat_3.png +0 -0
  20. assets/test_images/cats/cat_4.png +0 -0
  21. assets/test_images/cats/cat_5.png +0 -0
  22. assets/test_images/cats/cat_6.png +0 -0
  23. assets/test_images/cats/cat_7.png +0 -0
  24. assets/test_images/cats/cat_8.png +0 -0
  25. assets/test_images/cats/cat_9.png +0 -0
  26. assets/test_images/dogs/dog_1.png +0 -0
  27. assets/test_images/dogs/dog_2.png +0 -0
  28. assets/test_images/dogs/dog_3.png +0 -0
  29. assets/test_images/dogs/dog_4.png +0 -0
  30. assets/test_images/dogs/dog_5.png +0 -0
  31. assets/test_images/dogs/dog_6.png +0 -0
  32. assets/test_images/dogs/dog_7.png +0 -0
  33. assets/test_images/dogs/dog_8.png +0 -0
  34. assets/test_images/dogs/dog_9.png +0 -0
  35. requirements.txt +7 -0
  36. src/edit_real.py +65 -0
  37. src/edit_synthetic.py +52 -0
  38. src/inversion.py +66 -0
  39. src/make_edit_direction.py +61 -0
  40. src/utils/base_pipeline.py +322 -0
  41. src/utils/cross_attention.py +57 -0
  42. src/utils/ddim_inv.py +140 -0
  43. src/utils/edit_directions.py +48 -0
  44. src/utils/edit_pipeline.py +174 -0
  45. src/utils/scheduler.py +289 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 pix2pixzero
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Pix2pix Zero
3
+ emoji: 🌍
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.18.0
8
+ app_file: app.py
9
+ pinned: false
10
+ tags:
11
+ - making-demos
12
+ duplicated_from: alvanlii/pix2pix_zero
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Adobe Research. All rights reserved.
2
+ # To view a copy of the license, visit LICENSE.md.
3
+ import os
4
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
5
+
6
+ from PIL import Image
7
+
8
+ import torch
9
+
10
+ import gradio as gr
11
+
12
+ from lavis.models import load_model_and_preprocess
13
+
14
+ from diffusers import DDIMScheduler
15
+ from src.utils.ddim_inv import DDIMInversion
16
+ from src.utils.edit_directions import construct_direction
17
+ from src.utils.scheduler import DDIMInverseScheduler
18
+ from src.utils.edit_pipeline import EditingPipeline
19
+
20
+ def main():
21
+ NUM_DDIM_STEPS = 50
22
+ TORCH_DTYPE = torch.float16
23
+ XA_GUIDANCE = 0.1
24
+ DIR_SCALE = 1.0
25
+ MODEL_NAME = 'CompVis/stable-diffusion-v1-4'
26
+ NEGATIVE_GUIDANCE_SCALE = 5.0
27
+ DEVICE = "cuda"
28
+ # if torch.cuda.is_available():
29
+ # DEVICE = "cuda"
30
+ # else:
31
+ # DEVICE = "cpu"
32
+ # print(f"Using {DEVICE}")
33
+
34
+ model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=DEVICE)
35
+ pipe = EditingPipeline.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to(DEVICE)
36
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
37
+
38
+ inv_pipe = DDIMInversion.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to("cuda")
39
+ inv_pipe.scheduler = DDIMInverseScheduler.from_config(inv_pipe.scheduler.config)
40
+
41
+ TASKS = ["dog2cat","cat2dog","horse2zebra","zebra2horse","horse2llama","dog2capy"]
42
+ TASK_OPTIONS = ["Dog to Cat", "Cat to Dog", "Horse to Zebra", "Zebra to Horse", "Horse to Llama", "Dog to Capy"]
43
+
44
+ def edit_real_image(
45
+ og_img,
46
+ task,
47
+ seed,
48
+ xa_guidance,
49
+ num_ddim_steps,
50
+ dir_scale
51
+ ):
52
+ torch.cuda.manual_seed(seed)
53
+
54
+ # do inversion first, get inversion and generated prompt
55
+ curr_img = og_img.resize((512,512), Image.Resampling.LANCZOS)
56
+ _image = vis_processors["eval"](curr_img).unsqueeze(0).to(DEVICE)
57
+ prompt_str = model_blip.generate({"image": _image})[0]
58
+ x_inv, _, _ = inv_pipe(
59
+ prompt_str,
60
+ guidance_scale=1,
61
+ num_inversion_steps=NUM_DDIM_STEPS,
62
+ img=curr_img,
63
+ torch_dtype=TORCH_DTYPE
64
+ )
65
+
66
+ task_str = TASKS[task]
67
+
68
+ rec_pil, edit_pil = pipe(
69
+ prompt_str,
70
+ num_inference_steps=num_ddim_steps,
71
+ x_in=x_inv[0].unsqueeze(0),
72
+ edit_dir=construct_direction(task_str)*dir_scale,
73
+ guidance_amount=xa_guidance,
74
+ guidance_scale=NEGATIVE_GUIDANCE_SCALE,
75
+ negative_prompt=prompt_str # use the unedited prompt for the negative prompt
76
+ )
77
+
78
+ return prompt_str, edit_pil[0]
79
+
80
+
81
+ def edit_real_image_example():
82
+ test_img = Image.open("./assets/test_images/cats/cat_4.png")
83
+ seed = 42
84
+ task = 1
85
+ prompt_str, edited_img = edit_real_image(test_img, task, seed, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE)
86
+ return test_img, seed, "Cat to Dog", prompt_str, edited_img, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE
87
+
88
+
89
+ def edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps):
90
+ torch.cuda.manual_seed(seed)
91
+ x = torch.randn((1,4,64,64), device="cuda")
92
+
93
+ task_str = TASKS[task]
94
+
95
+ rec_pil, edit_pil = pipe(
96
+ prompt_str,
97
+ num_inference_steps=num_ddim_steps,
98
+ x_in=x,
99
+ edit_dir=construct_direction(task_str),
100
+ guidance_amount=xa_guidance,
101
+ guidance_scale=NEGATIVE_GUIDANCE_SCALE,
102
+ negative_prompt="" # use the empty string for the negative prompt
103
+ )
104
+
105
+ return rec_pil[0], edit_pil[0]
106
+
107
+ def edit_synth_image_example():
108
+ seed = 42
109
+ task = 1
110
+ xa_guidance = XA_GUIDANCE
111
+ num_ddim_steps = NUM_DDIM_STEPS
112
+ prompt_str = "A cute white cat sitting on top of the fridge"
113
+ recon_img, edited_img = edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps)
114
+ return seed, "Cat to Dog", xa_guidance, num_ddim_steps, prompt_str, recon_img, edited_img
115
+
116
+ with gr.Blocks() as demo:
117
+ gr.Markdown("""
118
+ ### Zero-shot Image-to-Image Translation (https://github.com/pix2pixzero/pix2pix-zero)
119
+ Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, Jun-Yan Zhu <br/>
120
+ - For real images:
121
+ - Upload an image of a dog, cat or horse,
122
+ - Choose one of the task options to turn it into another animal!
123
+ - Changing Parameters:
124
+ - Increase direction scale is it is not cat (or another animal) enough.
125
+ - If the quality is not high enough, increase num ddim steps.
126
+ - Increase cross attention guidance to preserve original image structures. <br/>
127
+ - For synthetic images:
128
+ - Enter a prompt about dogs/cats/horses
129
+ - Choose a task option
130
+ """)
131
+ with gr.Tab("Real Image"):
132
+ with gr.Row():
133
+ seed = gr.Number(value=42, precision=1, label="Seed", interactive=True)
134
+ real_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True)
135
+ real_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True)
136
+ real_edit_dir_scale = gr.Number(value=DIR_SCALE, label="Edit Direction Scale", interactive=True)
137
+ real_generate_button = gr.Button("Generate")
138
+ real_load_sample_button = gr.Button("Load Example")
139
+
140
+ with gr.Row():
141
+ task_name = gr.Radio(
142
+ label='Task Name',
143
+ choices=TASK_OPTIONS,
144
+ value=TASK_OPTIONS[0],
145
+ type="index",
146
+ show_label=True,
147
+ interactive=True,
148
+ )
149
+
150
+ with gr.Row():
151
+ recon_text = gr.Textbox(lines=1, label="Reconstructed Text", interactive=False)
152
+ with gr.Row():
153
+ input_image = gr.Image(label="Input Image", type="pil", interactive=True)
154
+ output_image = gr.Image(label="Output Image", type="pil", interactive=False)
155
+
156
+
157
+ with gr.Tab("Synthetic Images"):
158
+ with gr.Row():
159
+ synth_seed = gr.Number(value=42, precision=1, label="Seed", interactive=True)
160
+ synth_prompt = gr.Textbox(lines=1, label="Prompt", interactive=True)
161
+ synth_generate_button = gr.Button("Generate")
162
+ synth_load_sample_button = gr.Button("Load Example")
163
+ with gr.Row():
164
+ synth_task_name = gr.Radio(
165
+ label='Task Name',
166
+ choices=TASK_OPTIONS,
167
+ value=TASK_OPTIONS[0],
168
+ type="index",
169
+ show_label=True,
170
+ interactive=True,
171
+ )
172
+ synth_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True)
173
+ synth_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True)
174
+ with gr.Row():
175
+ synth_input_image = gr.Image(label="Input Image", type="pil", interactive=False)
176
+ synth_output_image = gr.Image(label="Output Image", type="pil", interactive=False)
177
+
178
+
179
+
180
+ real_generate_button.click(
181
+ fn=edit_real_image,
182
+ inputs=[
183
+ input_image, task_name, seed, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale
184
+ ],
185
+ outputs=[recon_text, output_image]
186
+ )
187
+
188
+ real_load_sample_button.click(
189
+ fn=edit_real_image_example,
190
+ inputs=[],
191
+ outputs=[input_image, seed, task_name, recon_text, output_image, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale]
192
+ )
193
+
194
+ synth_generate_button.click(
195
+ fn=edit_synthetic_image,
196
+ inputs=[synth_seed, synth_task_name, synth_prompt, synth_xa_guidance, synth_num_ddim_steps],
197
+ outputs=[synth_input_image, synth_output_image]
198
+ )
199
+
200
+ synth_load_sample_button.click(
201
+ fn=edit_synth_image_example,
202
+ inputs=[],
203
+ outputs=[seed, synth_task_name, synth_xa_guidance, synth_num_ddim_steps, synth_prompt, synth_input_image, synth_output_image]
204
+ )
205
+
206
+
207
+ demo.queue(concurrency_count=1)
208
+ demo.launch(share=False, server_name="0.0.0.0")
209
+
210
+
211
+ if __name__ == "__main__":
212
+ main()
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/capy.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A capybara is a large rodent with a heavy, barrel-shaped body, short head, and reddish-brown coloration on the upper parts of its body.
2
+ A capybara's fore and hind limbs are short, and its feet are webbed, which makes it an excellent swimmer.
3
+ A capybara has large, round ears and small eyes.
4
+ A capybara's buck teeth are sharp and adapted for cutting vegetation, and its cheek teeth do not make contact while the animal is gnawing.
5
+ A capybara has a thick tail and a hairy coat.
6
+ A large, semi-aquatic rodent with a brown coat and a round face is standing in a river.
7
+ A capybara is lazily swimming in a stream, its round face poking out of the water.
8
+ A capybara is wading through a stream, its furry body and short ears visible above the surface.
9
+ A capybara is snuggling up against a fallen tree trunk, its huge size easily visible.
10
+ A capybara is calmly swimming in a pool, its large body moving gracefully through the water.
11
+ A capybara is lounging in the grass, its round face and short ears visible above the blades of grass.
12
+ The capybara has a heavy, barrel-shaped body with short, stocky legs and a blunt snout.
13
+ A capy's coat is thick and composed of reddish-brown, yellowish-brown, or grayish-brown fur.
14
+ Capybaras' eyes and ears are small, and its long, coarse fur is sparse on its muzzle, forehead, and inner ears.
15
+ The capybara's feet are webbed, and its long, stout tail is slightly flattened.
16
+ A capy has four toes on each front foot and three on each hind foot, with the first two front toes partially webbed.
17
+ It has a short, thick neck and a short, broad head with small eyes and ears.
18
+ Its short, thick legs are equipped with long, sharp claws for digging.
19
+ The capybara has a thick, reddish-brown coat and a short, stocky body.
20
+ Capys' eyes and ears are small, and its snout is blunt.
21
+ A capybara has long, coarse fur on its muzzle, forehead, and inner ears, and its long, stout tail is slightly flattened.
22
+ It is an excellent swimmer, and its large size allows it to reach speeds of up to 35 kilometers per hour in short bursts.
assets/dogs_with_glasses.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The dog is wearing small, round, metal glasses.
2
+ This pup has on a pair of glasses that are circular, metal, and small.
3
+ This canine is sporting a pair of round glasses, with a frame made of metal.
4
+ The dog is wearing a pair of spectacles that are small and made of metal.
5
+ The dog is wearing a pair of eyeglasses that is round, with a metal frame.
6
+ A small, round pair of glasses is perched on the dog's face.
7
+ A round, metal pair of eyewear is being worn by the pooch.
8
+ A tiny, round set of glasses adorns the dog's face.
9
+ This pup is wearing a small set of glasses with a metal frame.
10
+ The dog has on a tiny, metal pair of circular glasses.
11
+ A dog wearing glasses has a distinct appearance, with its eye-wear giving it a unique look.
12
+ A dog's eyes will often be visible through the lenses of its glasses and its ears are usually perked up and alert.
13
+ A dog wearing glasses looks smart and stylish.
14
+ The dog wearing glasses was a source of amusement.
15
+ The dog wearing glasses seemed to be aware of its fashion statement.
16
+ The dog wearing glasses was a delightful surprise.
17
+ The cute pup is sporting a pair of stylish glasses.
18
+ The dog seems to be quite the trendsetter with its eyewear.
19
+ The glasses add a unique look to the dog's face.
20
+ The dog looks wise wearing his glasses.
21
+ The glasses add a certain flair to the dog's look.
22
+ The glasses give the dog an air of sophistication.
23
+ The dog looks dapper wearing his glasses.
24
+ The dog looks so smart with its glasses.
25
+ Those glasses make the dog look extra adorable in the picture.
26
+ The glasses give the dog a distinguished look in the image.
27
+ Such a sophisticated pup, wearing glasses and looking relaxed.
28
+ The lenses provide 100% UV protection, keeping your pup's eyes safe from the sun's harmful rays.
29
+ The glasses can also help protect your pup from debris and dust when they're out and about.
assets/embeddings_sd_1.4/capy.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d8085697317ef41bec4119a54d4cfb5829c11c0cf7e95c9f921ed5cd5c3b6d7
3
+ size 118946
assets/embeddings_sd_1.4/cat.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa9441dc014d5e86567c5ef165e10b50d2a7b3a68d90686d0cd1006792adf334
3
+ size 237300
assets/embeddings_sd_1.4/dog.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:becf079d61d7f35727bcc0d8506ddcdcddb61e62d611840ff3d18eca7fb6338c
3
+ size 237300
assets/embeddings_sd_1.4/dogs_with_glasses.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:805014d0f5145c1f094cf3550c8c08919d223b5ce2981adc7f0ef3ee7c6086b2
3
+ size 119049
assets/embeddings_sd_1.4/horse.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5d499299544d11371f84674761292b0512055ef45776c700c0b0da164cbf6c7
3
+ size 118949
assets/embeddings_sd_1.4/llama.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b4d639d08d038b728430acbcfb3956b002141ddfc8cf8c53fe9ac72b7e7b61c
3
+ size 118949
assets/embeddings_sd_1.4/zebra.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a29f6a11d91f3a276e27326b7623fae9d61a3d253ad430bb868bd40fb7e02fec
3
+ size 118949
assets/llama.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Llamas are large mammals with long necks and long woolly fur coats.
2
+ Llamas have four legs and two toes on each foot.
3
+ Llamas have large eyes, long ears, and curved lips.
4
+ Llamas have a short, thick tail and are typically brown and white in color.
5
+ Llamas are social animals and live in herds. They communicate with each other using a variety of vocalizations, including humming and spitting.
6
+ Llamas are also used as pack animals and are often used to carry supplies over long distances.
7
+ The llama stands in the vibrant meadow, its soft wool shining in the sunlight.
8
+ The llama looks content in its lush surroundings, its long neck reaching up towards the sky.
9
+ The majestic llama stands tall in its tranquil habitat, its fluffy coat catching the light of the day.
10
+ The llama grazes peacefully in the grassy meadow, its gentle eyes gazing across the landscape.
11
+ The llama is a picture of grace and beauty, its wool glimmering in the sun.
12
+ The llama strides confidently across the open field, its long legs carrying it through the grass.
13
+ The llama stands proud and stately amongst the rolling hills, its woolly coat glowing in the light.
14
+ The llama moves gracefully through the tall grass, its shaggy coat swaying in the wind.
15
+ The gentle llama stands in the meadow, its long eyelashes batting in the sun.
assets/test_images/cats/cat_1.png ADDED
assets/test_images/cats/cat_2.png ADDED
assets/test_images/cats/cat_3.png ADDED
assets/test_images/cats/cat_4.png ADDED
assets/test_images/cats/cat_5.png ADDED
assets/test_images/cats/cat_6.png ADDED
assets/test_images/cats/cat_7.png ADDED
assets/test_images/cats/cat_8.png ADDED
assets/test_images/cats/cat_9.png ADDED
assets/test_images/dogs/dog_1.png ADDED
assets/test_images/dogs/dog_2.png ADDED
assets/test_images/dogs/dog_3.png ADDED
assets/test_images/dogs/dog_4.png ADDED
assets/test_images/dogs/dog_5.png ADDED
assets/test_images/dogs/dog_6.png ADDED
assets/test_images/dogs/dog_7.png ADDED
assets/test_images/dogs/dog_8.png ADDED
assets/test_images/dogs/dog_9.png ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ gradio
3
+ accelerate
4
+ diffusers==0.12.1
5
+ einops
6
+ salesforce-lavis
7
+ opencv-python-headless
src/edit_real.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.ddim_inv import DDIMInversion
11
+ from utils.edit_directions import construct_direction
12
+ from utils.edit_pipeline import EditingPipeline
13
+
14
+
15
+ if __name__=="__main__":
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--inversion', required=True)
18
+ parser.add_argument('--prompt', type=str, required=True)
19
+ parser.add_argument('--task_name', type=str, default='cat2dog')
20
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
21
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
22
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
23
+ parser.add_argument('--xa_guidance', default=0.1, type=float)
24
+ parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
25
+ parser.add_argument('--use_float_16', action='store_true')
26
+
27
+ args = parser.parse_args()
28
+
29
+ os.makedirs(os.path.join(args.results_folder, "edit"), exist_ok=True)
30
+ os.makedirs(os.path.join(args.results_folder, "reconstruction"), exist_ok=True)
31
+
32
+ if args.use_float_16:
33
+ torch_dtype = torch.float16
34
+ else:
35
+ torch_dtype = torch.float32
36
+
37
+ # if the inversion is a folder, the prompt should also be a folder
38
+ assert (os.path.isdir(args.inversion)==os.path.isdir(args.prompt)), "If the inversion is a folder, the prompt should also be a folder"
39
+ if os.path.isdir(args.inversion):
40
+ l_inv_paths = sorted(glob(os.path.join(args.inversion, "*.pt")))
41
+ l_bnames = [os.path.basename(x) for x in l_inv_paths]
42
+ l_prompt_paths = [os.path.join(args.prompt, x.replace(".pt",".txt")) for x in l_bnames]
43
+ else:
44
+ l_inv_paths = [args.inversion]
45
+ l_prompt_paths = [args.prompt]
46
+
47
+ # Make the editing pipeline
48
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
49
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
50
+
51
+
52
+ for inv_path, prompt_path in zip(l_inv_paths, l_prompt_paths):
53
+ prompt_str = open(prompt_path).read().strip()
54
+ rec_pil, edit_pil = pipe(prompt_str,
55
+ num_inference_steps=args.num_ddim_steps,
56
+ x_in=torch.load(inv_path).unsqueeze(0),
57
+ edit_dir=construct_direction(args.task_name),
58
+ guidance_amount=args.xa_guidance,
59
+ guidance_scale=args.negative_guidance_scale,
60
+ negative_prompt=prompt_str # use the unedited prompt for the negative prompt
61
+ )
62
+
63
+ bname = os.path.basename(args.inversion).split(".")[0]
64
+ edit_pil[0].save(os.path.join(args.results_folder, f"edit/{bname}.png"))
65
+ rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction/{bname}.png"))
src/edit_synthetic.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.edit_directions import construct_direction
11
+ from utils.edit_pipeline import EditingPipeline
12
+
13
+
14
+ if __name__=="__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--prompt_str', type=str, required=True)
17
+ parser.add_argument('--random_seed', default=0)
18
+ parser.add_argument('--task_name', type=str, default='cat2dog')
19
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
20
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
21
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
22
+ parser.add_argument('--xa_guidance', default=0.15, type=float)
23
+ parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
24
+ parser.add_argument('--use_float_16', action='store_true')
25
+ args = parser.parse_args()
26
+
27
+ os.makedirs(args.results_folder, exist_ok=True)
28
+
29
+ if args.use_float_16:
30
+ torch_dtype = torch.float16
31
+ else:
32
+ torch_dtype = torch.float32
33
+
34
+ # make the input noise map
35
+ torch.cuda.manual_seed(args.random_seed)
36
+ x = torch.randn((1,4,64,64), device="cuda")
37
+
38
+ # Make the editing pipeline
39
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
40
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
41
+
42
+ rec_pil, edit_pil = pipe(args.prompt_str,
43
+ num_inference_steps=args.num_ddim_steps,
44
+ x_in=x,
45
+ edit_dir=construct_direction(args.task_name),
46
+ guidance_amount=args.xa_guidance,
47
+ guidance_scale=args.negative_guidance_scale,
48
+ negative_prompt="" # use the empty string for the negative prompt
49
+ )
50
+
51
+ edit_pil[0].save(os.path.join(args.results_folder, f"edit.png"))
52
+ rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction.png"))
src/inversion.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from lavis.models import load_model_and_preprocess
10
+
11
+ from utils.ddim_inv import DDIMInversion
12
+ from utils.scheduler import DDIMInverseScheduler
13
+ from utils.edit_pipeline import EditingPipeline
14
+
15
+
16
+ if __name__=="__main__":
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png')
19
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
20
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
21
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
22
+ parser.add_argument('--use_float_16', action='store_true')
23
+ args = parser.parse_args()
24
+
25
+ # make the output folders
26
+ os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True)
27
+ os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True)
28
+
29
+ if args.use_float_16:
30
+ torch_dtype = torch.float16
31
+ else:
32
+ torch_dtype = torch.float32
33
+
34
+
35
+ # load the BLIP model
36
+ model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
37
+ # make the DDIM inversion pipeline
38
+ pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
39
+ pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
40
+
41
+
42
+ # if the input is a folder, collect all the images as a list
43
+ if os.path.isdir(args.input_image):
44
+ l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png")))
45
+ else:
46
+ l_img_paths = [args.input_image]
47
+
48
+
49
+ for img_path in l_img_paths:
50
+ bname = os.path.basename(args.input_image).split(".")[0]
51
+ img = Image.open(args.input_image).resize((512,512), Image.Resampling.LANCZOS)
52
+ # generate the caption
53
+ _image = vis_processors["eval"](img).unsqueeze(0).cuda()
54
+ prompt_str = model_blip.generate({"image": _image})[0]
55
+ x_inv, x_inv_image, x_dec_img = pipe(
56
+ prompt_str,
57
+ guidance_scale=1,
58
+ num_inversion_steps=args.num_ddim_steps,
59
+ img=img,
60
+ torch_dtype=torch_dtype
61
+ )
62
+ # save the inversion
63
+ torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt"))
64
+ # save the prompt string
65
+ with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f:
66
+ f.write(prompt_str)
src/make_edit_direction.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.edit_pipeline import EditingPipeline
11
+
12
+
13
+ ## convert sentences to sentence embeddings
14
+ def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
15
+ with torch.no_grad():
16
+ l_embeddings = []
17
+ for sent in l_sentences:
18
+ text_inputs = tokenizer(
19
+ sent,
20
+ padding="max_length",
21
+ max_length=tokenizer.model_max_length,
22
+ truncation=True,
23
+ return_tensors="pt",
24
+ )
25
+ text_input_ids = text_inputs.input_ids
26
+ prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
27
+ l_embeddings.append(prompt_embeds)
28
+ return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
29
+
30
+
31
+ if __name__=="__main__":
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument('--file_source_sentences', required=True)
34
+ # parser.add_argument('--file_target_sentences', required=True)
35
+ parser.add_argument('--output_folder', default="./assets/")
36
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
37
+ args = parser.parse_args()
38
+
39
+ # load the model
40
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda")
41
+ bname_src = os.path.basename(args.file_source_sentences).strip(".txt")
42
+ outf_src = os.path.join(args.output_folder, bname_src+".pt")
43
+ if os.path.exists(outf_src):
44
+ print(f"Skipping source file {outf_src} as it already exists")
45
+ else:
46
+ with open(args.file_source_sentences, "r") as f:
47
+ l_sents = [x.strip() for x in f.readlines()]
48
+ mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
49
+ print(mean_emb.shape)
50
+ torch.save(mean_emb, outf_src)
51
+
52
+ # bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt")
53
+ # outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt")
54
+ # if os.path.exists(outf_tgt):
55
+ # print(f"Skipping target file {outf_tgt} as it already exists")
56
+ # else:
57
+ # with open(args.file_target_sentences, "r") as f:
58
+ # l_sents = [x.strip() for x in f.readlines()]
59
+ # mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
60
+ # print(mean_emb.shape)
61
+ # torch.save(mean_emb, outf_tgt)
src/utils/base_pipeline.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import inspect
4
+ from packaging import version
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
8
+ from diffusers import DiffusionPipeline
9
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
10
+ from diffusers.schedulers import KarrasDiffusionSchedulers
11
+ from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
12
+ from diffusers import StableDiffusionPipeline
13
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
+
15
+
16
+
17
+ class BasePipeline(DiffusionPipeline):
18
+ _optional_components = ["safety_checker", "feature_extractor"]
19
+ def __init__(
20
+ self,
21
+ vae: AutoencoderKL,
22
+ text_encoder: CLIPTextModel,
23
+ tokenizer: CLIPTokenizer,
24
+ unet: UNet2DConditionModel,
25
+ scheduler: KarrasDiffusionSchedulers,
26
+ safety_checker: StableDiffusionSafetyChecker,
27
+ feature_extractor: CLIPFeatureExtractor,
28
+ requires_safety_checker: bool = True,
29
+ ):
30
+ super().__init__()
31
+
32
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
33
+ deprecation_message = (
34
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
35
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
36
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
37
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
38
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
39
+ " file"
40
+ )
41
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
42
+ new_config = dict(scheduler.config)
43
+ new_config["steps_offset"] = 1
44
+ scheduler._internal_dict = FrozenDict(new_config)
45
+
46
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
47
+ deprecation_message = (
48
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
49
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
50
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
51
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
52
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
53
+ )
54
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
55
+ new_config = dict(scheduler.config)
56
+ new_config["clip_sample"] = False
57
+ scheduler._internal_dict = FrozenDict(new_config)
58
+
59
+ # if safety_checker is None and requires_safety_checker:
60
+ # logger.warning(
61
+ # f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
62
+ # " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
63
+ # " results in services or applications open to the public. Both the diffusers team and Hugging Face"
64
+ # " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
65
+ # " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
66
+ # " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
67
+ # )
68
+
69
+ if safety_checker is not None and feature_extractor is None:
70
+ raise ValueError(
71
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
72
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
73
+ )
74
+
75
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
76
+ version.parse(unet.config._diffusers_version).base_version
77
+ ) < version.parse("0.9.0.dev0")
78
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
79
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
80
+ deprecation_message = (
81
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
82
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
83
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
84
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
85
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
86
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
87
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
88
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
89
+ " the `unet/config.json` file"
90
+ )
91
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
92
+ new_config = dict(unet.config)
93
+ new_config["sample_size"] = 64
94
+ unet._internal_dict = FrozenDict(new_config)
95
+
96
+ self.register_modules(
97
+ vae=vae,
98
+ text_encoder=text_encoder,
99
+ tokenizer=tokenizer,
100
+ unet=unet,
101
+ scheduler=scheduler,
102
+ safety_checker=safety_checker,
103
+ feature_extractor=feature_extractor,
104
+ )
105
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
106
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
107
+
108
+ @property
109
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
110
+ def _execution_device(self):
111
+ r"""
112
+ Returns the device on which the pipeline's models will be executed. After calling
113
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
114
+ hooks.
115
+ """
116
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
117
+ return self.device
118
+ for module in self.unet.modules():
119
+ if (
120
+ hasattr(module, "_hf_hook")
121
+ and hasattr(module._hf_hook, "execution_device")
122
+ and module._hf_hook.execution_device is not None
123
+ ):
124
+ return torch.device(module._hf_hook.execution_device)
125
+ return self.device
126
+
127
+
128
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
129
+ def _encode_prompt(
130
+ self,
131
+ prompt,
132
+ device,
133
+ num_images_per_prompt,
134
+ do_classifier_free_guidance,
135
+ negative_prompt=None,
136
+ prompt_embeds: Optional[torch.FloatTensor] = None,
137
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
138
+ ):
139
+ r"""
140
+ Encodes the prompt into text encoder hidden states.
141
+
142
+ Args:
143
+ prompt (`str` or `List[str]`, *optional*):
144
+ prompt to be encoded
145
+ device: (`torch.device`):
146
+ torch device
147
+ num_images_per_prompt (`int`):
148
+ number of images that should be generated per prompt
149
+ do_classifier_free_guidance (`bool`):
150
+ whether to use classifier free guidance or not
151
+ negative_ prompt (`str` or `List[str]`, *optional*):
152
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
153
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
154
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
155
+ prompt_embeds (`torch.FloatTensor`, *optional*):
156
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
157
+ provided, text embeddings will be generated from `prompt` input argument.
158
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
159
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
160
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
161
+ argument.
162
+ """
163
+ if prompt is not None and isinstance(prompt, str):
164
+ batch_size = 1
165
+ elif prompt is not None and isinstance(prompt, list):
166
+ batch_size = len(prompt)
167
+ else:
168
+ batch_size = prompt_embeds.shape[0]
169
+
170
+ if prompt_embeds is None:
171
+ text_inputs = self.tokenizer(
172
+ prompt,
173
+ padding="max_length",
174
+ max_length=self.tokenizer.model_max_length,
175
+ truncation=True,
176
+ return_tensors="pt",
177
+ )
178
+ text_input_ids = text_inputs.input_ids
179
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
180
+
181
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
182
+ text_input_ids, untruncated_ids
183
+ ):
184
+ removed_text = self.tokenizer.batch_decode(
185
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
186
+ )
187
+ # logger.warning(
188
+ # "The following part of your input was truncated because CLIP can only handle sequences up to"
189
+ # f" {self.tokenizer.model_max_length} tokens: {removed_text}"
190
+ # )
191
+
192
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
193
+ attention_mask = text_inputs.attention_mask.to(device)
194
+ else:
195
+ attention_mask = None
196
+
197
+ prompt_embeds = self.text_encoder(
198
+ text_input_ids.to(device),
199
+ attention_mask=attention_mask,
200
+ )
201
+ prompt_embeds = prompt_embeds[0]
202
+
203
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
204
+
205
+ bs_embed, seq_len, _ = prompt_embeds.shape
206
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
207
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
208
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
209
+
210
+ # get unconditional embeddings for classifier free guidance
211
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
212
+ uncond_tokens: List[str]
213
+ if negative_prompt is None:
214
+ uncond_tokens = [""] * batch_size
215
+ elif type(prompt) is not type(negative_prompt):
216
+ raise TypeError(
217
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
218
+ f" {type(prompt)}."
219
+ )
220
+ elif isinstance(negative_prompt, str):
221
+ uncond_tokens = [negative_prompt]
222
+ elif batch_size != len(negative_prompt):
223
+ raise ValueError(
224
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
225
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
226
+ " the batch size of `prompt`."
227
+ )
228
+ else:
229
+ uncond_tokens = negative_prompt
230
+
231
+ max_length = prompt_embeds.shape[1]
232
+ uncond_input = self.tokenizer(
233
+ uncond_tokens,
234
+ padding="max_length",
235
+ max_length=max_length,
236
+ truncation=True,
237
+ return_tensors="pt",
238
+ )
239
+
240
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
241
+ attention_mask = uncond_input.attention_mask.to(device)
242
+ else:
243
+ attention_mask = None
244
+
245
+ negative_prompt_embeds = self.text_encoder(
246
+ uncond_input.input_ids.to(device),
247
+ attention_mask=attention_mask,
248
+ )
249
+ negative_prompt_embeds = negative_prompt_embeds[0]
250
+
251
+ if do_classifier_free_guidance:
252
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
253
+ seq_len = negative_prompt_embeds.shape[1]
254
+
255
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
256
+
257
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
258
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
259
+
260
+ # For classifier free guidance, we need to do two forward passes.
261
+ # Here we concatenate the unconditional and text embeddings into a single batch
262
+ # to avoid doing two forward passes
263
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
264
+
265
+ return prompt_embeds
266
+
267
+
268
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
269
+ def decode_latents(self, latents):
270
+ latents = 1 / 0.18215 * latents
271
+ image = self.vae.decode(latents).sample
272
+ image = (image / 2 + 0.5).clamp(0, 1)
273
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
274
+ image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
275
+ return image
276
+
277
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
278
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
279
+ if isinstance(generator, list) and len(generator) != batch_size:
280
+ raise ValueError(
281
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
282
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
283
+ )
284
+
285
+ if latents is None:
286
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
287
+ else:
288
+ latents = latents.to(device)
289
+
290
+ # scale the initial noise by the standard deviation required by the scheduler
291
+ latents = latents * self.scheduler.init_noise_sigma
292
+ return latents
293
+
294
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
295
+ def prepare_extra_step_kwargs(self, generator, eta):
296
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
297
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
298
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
299
+ # and should be between [0, 1]
300
+
301
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
302
+ extra_step_kwargs = {}
303
+ if accepts_eta:
304
+ extra_step_kwargs["eta"] = eta
305
+
306
+ # check if the scheduler accepts generator
307
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
308
+ if accepts_generator:
309
+ extra_step_kwargs["generator"] = generator
310
+ return extra_step_kwargs
311
+
312
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
313
+ def run_safety_checker(self, image, device, dtype):
314
+ if self.safety_checker is not None:
315
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
316
+ image, has_nsfw_concept = self.safety_checker(
317
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
318
+ )
319
+ else:
320
+ has_nsfw_concept = None
321
+ return image, has_nsfw_concept
322
+
src/utils/cross_attention.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.models.attention import CrossAttention
3
+
4
+ class MyCrossAttnProcessor:
5
+ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
6
+ batch_size, sequence_length, _ = hidden_states.shape
7
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
8
+
9
+ query = attn.to_q(hidden_states)
10
+
11
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
12
+ key = attn.to_k(encoder_hidden_states)
13
+ value = attn.to_v(encoder_hidden_states)
14
+
15
+ query = attn.head_to_batch_dim(query)
16
+ key = attn.head_to_batch_dim(key)
17
+ value = attn.head_to_batch_dim(value)
18
+
19
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
20
+ # new bookkeeping to save the attn probs
21
+ attn.attn_probs = attention_probs
22
+
23
+ hidden_states = torch.bmm(attention_probs, value)
24
+ hidden_states = attn.batch_to_head_dim(hidden_states)
25
+
26
+ # linear proj
27
+ hidden_states = attn.to_out[0](hidden_states)
28
+ # dropout
29
+ hidden_states = attn.to_out[1](hidden_states)
30
+
31
+ return hidden_states
32
+
33
+
34
+ """
35
+ A function that prepares a U-Net model for training by enabling gradient computation
36
+ for a specified set of parameters and setting the forward pass to be performed by a
37
+ custom cross attention processor.
38
+
39
+ Parameters:
40
+ unet: A U-Net model.
41
+
42
+ Returns:
43
+ unet: The prepared U-Net model.
44
+ """
45
+ def prep_unet(unet):
46
+ # set the gradients for XA maps to be true
47
+ for name, params in unet.named_parameters():
48
+ if 'attn2' in name:
49
+ params.requires_grad = True
50
+ else:
51
+ params.requires_grad = False
52
+ # replace the fwd function
53
+ for name, module in unet.named_modules():
54
+ module_name = type(module).__name__
55
+ if module_name == "CrossAttention":
56
+ module.set_processor(MyCrossAttnProcessor())
57
+ return unet
src/utils/ddim_inv.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from random import randrange
6
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
7
+ from diffusers import DDIMScheduler
8
+ from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
9
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
10
+ sys.path.insert(0, "src/utils")
11
+ from base_pipeline import BasePipeline
12
+ from cross_attention import prep_unet
13
+
14
+
15
+ class DDIMInversion(BasePipeline):
16
+
17
+ def auto_corr_loss(self, x, random_shift=True):
18
+ B,C,H,W = x.shape
19
+ assert B==1
20
+ x = x.squeeze(0)
21
+ # x must be shape [C,H,W] now
22
+ reg_loss = 0.0
23
+ for ch_idx in range(x.shape[0]):
24
+ noise = x[ch_idx][None, None,:,:]
25
+ while True:
26
+ if random_shift: roll_amount = randrange(noise.shape[2]//2)
27
+ else: roll_amount = 1
28
+ reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2
29
+ reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2
30
+ if noise.shape[2] <= 8:
31
+ break
32
+ noise = F.avg_pool2d(noise, kernel_size=2)
33
+ return reg_loss
34
+
35
+ def kl_divergence(self, x):
36
+ _mu = x.mean()
37
+ _var = x.var()
38
+ return _var + _mu**2 - 1 - torch.log(_var+1e-7)
39
+
40
+
41
+ def __call__(
42
+ self,
43
+ prompt: Union[str, List[str]] = None,
44
+ num_inversion_steps: int = 50,
45
+ guidance_scale: float = 7.5,
46
+ negative_prompt: Optional[Union[str, List[str]]] = None,
47
+ num_images_per_prompt: Optional[int] = 1,
48
+ eta: float = 0.0,
49
+ output_type: Optional[str] = "pil",
50
+ return_dict: bool = True,
51
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
52
+ img=None, # the input image as a PIL image
53
+ torch_dtype=torch.float32,
54
+
55
+ # inversion regularization parameters
56
+ lambda_ac: float = 20.0,
57
+ lambda_kl: float = 20.0,
58
+ num_reg_steps: int = 5,
59
+ num_ac_rolls: int = 5,
60
+ ):
61
+
62
+ # 0. modify the unet to be useful :D
63
+ self.unet = prep_unet(self.unet)
64
+
65
+ # set the scheduler to be the Inverse DDIM scheduler
66
+ # self.scheduler = MyDDIMScheduler.from_config(self.scheduler.config)
67
+
68
+ device = self._execution_device
69
+ do_classifier_free_guidance = guidance_scale > 1.0
70
+ self.scheduler.set_timesteps(num_inversion_steps, device=device)
71
+ timesteps = self.scheduler.timesteps
72
+
73
+ # Encode the input image with the first stage model
74
+ x0 = np.array(img)/255
75
+ x0 = torch.from_numpy(x0).type(torch_dtype).permute(2, 0, 1).unsqueeze(dim=0).repeat(1, 1, 1, 1).cuda()
76
+ x0 = (x0 - 0.5) * 2.
77
+ with torch.no_grad():
78
+ x0_enc = self.vae.encode(x0).latent_dist.sample().to(device, torch_dtype)
79
+ latents = x0_enc = 0.18215 * x0_enc
80
+
81
+ # Decode and return the image
82
+ with torch.no_grad():
83
+ x0_dec = self.decode_latents(x0_enc.detach())
84
+ image_x0_dec = self.numpy_to_pil(x0_dec)
85
+
86
+ with torch.no_grad():
87
+ prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt).to(device)
88
+ extra_step_kwargs = self.prepare_extra_step_kwargs(None, eta)
89
+
90
+ # Do the inversion
91
+ num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order # should be 0?
92
+ with self.progress_bar(total=num_inversion_steps) as progress_bar:
93
+ for i, t in enumerate(timesteps.flip(0)[1:-1]):
94
+ # expand the latents if we are doing classifier free guidance
95
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
96
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
97
+
98
+ # predict the noise residual
99
+ with torch.no_grad():
100
+ noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
101
+
102
+ # perform guidance
103
+ if do_classifier_free_guidance:
104
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
105
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
106
+
107
+ # regularization of the noise prediction
108
+ e_t = noise_pred
109
+ for _outer in range(num_reg_steps):
110
+ if lambda_ac>0:
111
+ for _inner in range(num_ac_rolls):
112
+ _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
113
+ l_ac = self.auto_corr_loss(_var)
114
+ l_ac.backward()
115
+ _grad = _var.grad.detach()/num_ac_rolls
116
+ e_t = e_t - lambda_ac*_grad
117
+ if lambda_kl>0:
118
+ _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
119
+ l_kld = self.kl_divergence(_var)
120
+ l_kld.backward()
121
+ _grad = _var.grad.detach()
122
+ e_t = e_t - lambda_kl*_grad
123
+ e_t = e_t.detach()
124
+ noise_pred = e_t
125
+
126
+ # compute the previous noisy sample x_t -> x_t-1
127
+ latents = self.scheduler.step(noise_pred, t, latents, reverse=True, **extra_step_kwargs).prev_sample
128
+
129
+ # call the callback, if provided
130
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
131
+ progress_bar.update()
132
+
133
+
134
+ x_inv = latents.detach().clone()
135
+ # reconstruct the image
136
+
137
+ # 8. Post-processing
138
+ image = self.decode_latents(latents.detach())
139
+ image = self.numpy_to_pil(image)
140
+ return x_inv, image, image_x0_dec
src/utils/edit_directions.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+
5
+ """
6
+ This function takes in a task name and returns the direction in the embedding space that transforms class A to class B for the given task.
7
+
8
+ Parameters:
9
+ task_name (str): name of the task for which direction is to be constructed.
10
+
11
+ Returns:
12
+ torch.Tensor: A tensor representing the direction in the embedding space that transforms class A to class B.
13
+
14
+ Examples:
15
+ >>> construct_direction("cat2dog")
16
+ """
17
+ def construct_direction(task_name):
18
+ emb_dir = f"assets/embeddings_sd_1.4"
19
+ if task_name=="cat2dog":
20
+ embs_a = torch.load(os.path.join(emb_dir, f"cat.pt"))
21
+ embs_b = torch.load(os.path.join(emb_dir, f"dog.pt"))
22
+ return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
23
+ elif task_name=="dog2cat":
24
+ embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
25
+ embs_b = torch.load(os.path.join(emb_dir, f"cat.pt"))
26
+ return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
27
+ elif task_name=="horse2zebra":
28
+ embs_a = torch.load(os.path.join(emb_dir, f"horse.pt"))
29
+ embs_b = torch.load(os.path.join(emb_dir, f"zebra.pt"))
30
+ return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
31
+ elif task_name=="zebra2horse":
32
+ embs_a = torch.load(os.path.join(emb_dir, f"zebra.pt"))
33
+ embs_b = torch.load(os.path.join(emb_dir, f"horse.pt"))
34
+ return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
35
+ elif task_name=="horse2llama":
36
+ embs_a = torch.load(os.path.join(emb_dir, f"horse.pt"))
37
+ embs_b = torch.load(os.path.join(emb_dir, f"llama.pt"))
38
+ return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
39
+ elif task_name=="dog2capy":
40
+ embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
41
+ embs_b = torch.load(os.path.join(emb_dir, f"capy.pt"))
42
+ return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
43
+ elif task_name=='dogglasses':
44
+ embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
45
+ embs_b = torch.load(os.path.join(emb_dir, f"dogs_with_glasses.pt"))
46
+ return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
47
+ else:
48
+ raise NotImplementedError
src/utils/edit_pipeline.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb, sys
2
+
3
+ import numpy as np
4
+ import torch
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
7
+ sys.path.insert(0, "src/utils")
8
+ from base_pipeline import BasePipeline
9
+ from cross_attention import prep_unet
10
+
11
+
12
+ class EditingPipeline(BasePipeline):
13
+ def __call__(
14
+ self,
15
+ prompt: Union[str, List[str]] = None,
16
+ height: Optional[int] = None,
17
+ width: Optional[int] = None,
18
+ num_inference_steps: int = 50,
19
+ guidance_scale: float = 7.5,
20
+ negative_prompt: Optional[Union[str, List[str]]] = None,
21
+ num_images_per_prompt: Optional[int] = 1,
22
+ eta: float = 0.0,
23
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
24
+ latents: Optional[torch.FloatTensor] = None,
25
+ prompt_embeds: Optional[torch.FloatTensor] = None,
26
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
27
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
28
+
29
+ # pix2pix parameters
30
+ guidance_amount=0.1,
31
+ edit_dir=None,
32
+ x_in=None,
33
+
34
+ ):
35
+
36
+ x_in.to(dtype=self.unet.dtype, device=self._execution_device)
37
+
38
+ # 0. modify the unet to be useful :D
39
+ self.unet = prep_unet(self.unet)
40
+
41
+ # 1. setup all caching objects
42
+ d_ref_t2attn = {} # reference cross attention maps
43
+
44
+ # 2. Default height and width to unet
45
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
46
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
47
+
48
+ # TODO: add the input checker function
49
+ # self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds )
50
+
51
+ # 2. Define call parameters
52
+ if prompt is not None and isinstance(prompt, str):
53
+ batch_size = 1
54
+ elif prompt is not None and isinstance(prompt, list):
55
+ batch_size = len(prompt)
56
+ else:
57
+ batch_size = prompt_embeds.shape[0]
58
+
59
+ device = self._execution_device
60
+ do_classifier_free_guidance = guidance_scale > 1.0
61
+ x_in = x_in.to(dtype=self.unet.dtype, device=self._execution_device)
62
+ # 3. Encode input prompt = 2x77x1024
63
+ prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,)
64
+
65
+ # 4. Prepare timesteps
66
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
67
+ timesteps = self.scheduler.timesteps
68
+
69
+ # 5. Prepare latent variables
70
+ num_channels_latents = self.unet.in_channels
71
+
72
+ # randomly sample a latent code if not provided
73
+ latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,)
74
+
75
+ latents_init = latents.clone()
76
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
77
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
78
+
79
+ # 7. First Denoising loop for getting the reference cross attention maps
80
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
81
+ with torch.no_grad():
82
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
83
+ for i, t in enumerate(timesteps):
84
+ # expand the latents if we are doing classifier free guidance
85
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
86
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
87
+
88
+ # predict the noise residual
89
+ noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
90
+
91
+ # add the cross attention map to the dictionary
92
+ d_ref_t2attn[t.item()] = {}
93
+ for name, module in self.unet.named_modules():
94
+ module_name = type(module).__name__
95
+ if module_name == "CrossAttention" and 'attn2' in name:
96
+ attn_mask = module.attn_probs # size is num_channel,s*s,77
97
+ d_ref_t2attn[t.item()][name] = attn_mask.detach().cpu()
98
+
99
+ # perform guidance
100
+ if do_classifier_free_guidance:
101
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
102
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
103
+
104
+ # compute the previous noisy sample x_t -> x_t-1
105
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
106
+
107
+ # call the callback, if provided
108
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
109
+ progress_bar.update()
110
+
111
+ # make the reference image (reconstruction)
112
+ image_rec = self.numpy_to_pil(self.decode_latents(latents.detach()))
113
+
114
+ prompt_embeds_edit = prompt_embeds.clone()
115
+ #add the edit only to the second prompt, idx 0 is the negative prompt
116
+ prompt_embeds_edit[1:2] += edit_dir
117
+
118
+ latents = latents_init
119
+ # Second denoising loop for editing the text prompt
120
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
121
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
122
+ for i, t in enumerate(timesteps):
123
+ # expand the latents if we are doing classifier free guidance
124
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
125
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
126
+
127
+ x_in = latent_model_input.detach().clone()
128
+ x_in.requires_grad = True
129
+
130
+ opt = torch.optim.SGD([x_in], lr=guidance_amount)
131
+
132
+ # predict the noise residual
133
+ noise_pred = self.unet(x_in,t,encoder_hidden_states=prompt_embeds_edit.detach(),cross_attention_kwargs=cross_attention_kwargs,).sample
134
+
135
+ loss = 0.0
136
+ for name, module in self.unet.named_modules():
137
+ module_name = type(module).__name__
138
+ if module_name == "CrossAttention" and 'attn2' in name:
139
+ curr = module.attn_probs # size is num_channel,s*s,77
140
+ ref = d_ref_t2attn[t.item()][name].detach().cuda()
141
+ loss += ((curr-ref)**2).sum((1,2)).mean(0)
142
+ loss.backward(retain_graph=False)
143
+ opt.step()
144
+
145
+ # recompute the noise
146
+ with torch.no_grad():
147
+ noise_pred = self.unet(x_in.detach(),t,encoder_hidden_states=prompt_embeds_edit,cross_attention_kwargs=cross_attention_kwargs,).sample
148
+
149
+ latents = x_in.detach().chunk(2)[0]
150
+
151
+ # perform guidance
152
+ if do_classifier_free_guidance:
153
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
154
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
155
+
156
+ # compute the previous noisy sample x_t -> x_t-1
157
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
158
+
159
+ # call the callback, if provided
160
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
161
+ progress_bar.update()
162
+
163
+
164
+ # 8. Post-processing
165
+ image = self.decode_latents(latents.detach())
166
+
167
+ # 9. Run safety checker
168
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
169
+
170
+ # 10. Convert to PIL
171
+ image_edit = self.numpy_to_pil(image)
172
+
173
+
174
+ return image_rec, image_edit
src/utils/scheduler.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+ import os, sys, pdb
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import BaseOutput, randn_tensor
27
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
28
+
29
+
30
+ @dataclass
31
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
32
+ class DDIMSchedulerOutput(BaseOutput):
33
+ """
34
+ Output class for the scheduler's step function output.
35
+
36
+ Args:
37
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
39
+ denoising loop.
40
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
42
+ `pred_original_sample` can be used to preview progress or for guidance.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+ pred_original_sample: Optional[torch.FloatTensor] = None
47
+
48
+
49
+ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
50
+ """
51
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
52
+ (1-beta) over time from t = [0,1].
53
+
54
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
55
+ to that part of the diffusion process.
56
+
57
+
58
+ Args:
59
+ num_diffusion_timesteps (`int`): the number of betas to produce.
60
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
61
+ prevent singularities.
62
+
63
+ Returns:
64
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
65
+ """
66
+
67
+ def alpha_bar(time_step):
68
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
69
+
70
+ betas = []
71
+ for i in range(num_diffusion_timesteps):
72
+ t1 = i / num_diffusion_timesteps
73
+ t2 = (i + 1) / num_diffusion_timesteps
74
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
75
+ return torch.tensor(betas)
76
+
77
+
78
+ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
79
+ """
80
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
81
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
82
+
83
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
84
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
85
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
86
+ [`~SchedulerMixin.from_pretrained`] functions.
87
+
88
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
89
+
90
+ Args:
91
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
92
+ beta_start (`float`): the starting `beta` value of inference.
93
+ beta_end (`float`): the final `beta` value.
94
+ beta_schedule (`str`):
95
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
96
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
97
+ trained_betas (`np.ndarray`, optional):
98
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
99
+ clip_sample (`bool`, default `True`):
100
+ option to clip predicted sample between -1 and 1 for numerical stability.
101
+ set_alpha_to_one (`bool`, default `True`):
102
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
103
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
104
+ otherwise it uses the value of alpha at step 0.
105
+ steps_offset (`int`, default `0`):
106
+ an offset added to the inference steps. You can use a combination of `offset=1` and
107
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
108
+ stable diffusion.
109
+ prediction_type (`str`, default `epsilon`, optional):
110
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
111
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
112
+ https://imagen.research.google/video/paper.pdf)
113
+ """
114
+
115
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
116
+ order = 1
117
+
118
+ @register_to_config
119
+ def __init__(
120
+ self,
121
+ num_train_timesteps: int = 1000,
122
+ beta_start: float = 0.0001,
123
+ beta_end: float = 0.02,
124
+ beta_schedule: str = "linear",
125
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
126
+ clip_sample: bool = True,
127
+ set_alpha_to_one: bool = True,
128
+ steps_offset: int = 0,
129
+ prediction_type: str = "epsilon",
130
+ ):
131
+ if trained_betas is not None:
132
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
133
+ elif beta_schedule == "linear":
134
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
135
+ elif beta_schedule == "scaled_linear":
136
+ # this schedule is very specific to the latent diffusion model.
137
+ self.betas = (
138
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
139
+ )
140
+ elif beta_schedule == "squaredcos_cap_v2":
141
+ # Glide cosine schedule
142
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
143
+ else:
144
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
145
+
146
+ self.alphas = 1.0 - self.betas
147
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
148
+
149
+ # At every step in ddim, we are looking into the previous alphas_cumprod
150
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
151
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
152
+ # whether we use the final alpha of the "non-previous" one.
153
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
154
+
155
+ # standard deviation of the initial noise distribution
156
+ self.init_noise_sigma = 1.0
157
+
158
+ # setable values
159
+ self.num_inference_steps = None
160
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
161
+
162
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
163
+ """
164
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
165
+ current timestep.
166
+
167
+ Args:
168
+ sample (`torch.FloatTensor`): input sample
169
+ timestep (`int`, optional): current timestep
170
+
171
+ Returns:
172
+ `torch.FloatTensor`: scaled input sample
173
+ """
174
+ return sample
175
+
176
+ def _get_variance(self, timestep, prev_timestep):
177
+ alpha_prod_t = self.alphas_cumprod[timestep]
178
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
179
+ beta_prod_t = 1 - alpha_prod_t
180
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
181
+
182
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
183
+
184
+ return variance
185
+
186
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
187
+ """
188
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
189
+
190
+ Args:
191
+ num_inference_steps (`int`):
192
+ the number of diffusion steps used when generating samples with a pre-trained model.
193
+ """
194
+
195
+ if num_inference_steps > self.config.num_train_timesteps:
196
+ raise ValueError(
197
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
198
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
199
+ f" maximal {self.config.num_train_timesteps} timesteps."
200
+ )
201
+
202
+ self.num_inference_steps = num_inference_steps
203
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
204
+ # creates integer timesteps by multiplying by ratio
205
+ # casting to int to avoid issues when num_inference_step is power of 3
206
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
207
+ self.timesteps = torch.from_numpy(timesteps).to(device)
208
+ self.timesteps += self.config.steps_offset
209
+
210
+ def step(
211
+ self,
212
+ model_output: torch.FloatTensor,
213
+ timestep: int,
214
+ sample: torch.FloatTensor,
215
+ eta: float = 0.0,
216
+ use_clipped_model_output: bool = False,
217
+ generator=None,
218
+ variance_noise: Optional[torch.FloatTensor] = None,
219
+ return_dict: bool = True,
220
+ reverse=False
221
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
222
+
223
+
224
+ e_t = model_output
225
+
226
+ x = sample
227
+ prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
228
+ # print(timestep, prev_timestep)
229
+ a_t = alpha_prod_t = self.alphas_cumprod[timestep-1]
230
+ a_prev = alpha_t_prev = self.alphas_cumprod[prev_timestep-1] if prev_timestep >= 0 else self.final_alpha_cumprod
231
+ beta_prod_t = 1 - alpha_prod_t
232
+
233
+ pred_x0 = (x - (1-a_t)**0.5 * e_t) / a_t.sqrt()
234
+ # direction pointing to x_t
235
+ dir_xt = (1. - a_prev).sqrt() * e_t
236
+ x = a_prev.sqrt()*pred_x0 + dir_xt
237
+ if not return_dict:
238
+ return (x,)
239
+ return DDIMSchedulerOutput(prev_sample=x, pred_original_sample=pred_x0)
240
+
241
+
242
+
243
+
244
+
245
+ def add_noise(
246
+ self,
247
+ original_samples: torch.FloatTensor,
248
+ noise: torch.FloatTensor,
249
+ timesteps: torch.IntTensor,
250
+ ) -> torch.FloatTensor:
251
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
252
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
253
+ timesteps = timesteps.to(original_samples.device)
254
+
255
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
256
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
257
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
258
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
259
+
260
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
261
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
262
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
263
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
264
+
265
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
266
+ return noisy_samples
267
+
268
+ def get_velocity(
269
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
270
+ ) -> torch.FloatTensor:
271
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
272
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
273
+ timesteps = timesteps.to(sample.device)
274
+
275
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
276
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
277
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
278
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
279
+
280
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
281
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
282
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
283
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
284
+
285
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
286
+ return velocity
287
+
288
+ def __len__(self):
289
+ return self.config.num_train_timesteps