Spaces:
Runtime error
Runtime error
Commit
·
7e0bf18
0
Parent(s):
Duplicate from alvanlii/pix2pix_zero
Browse files- .gitattributes +34 -0
- .gitignore +160 -0
- LICENSE +21 -0
- README.md +15 -0
- app.py +212 -0
- assets/.DS_Store +0 -0
- assets/capy.txt +22 -0
- assets/dogs_with_glasses.txt +29 -0
- assets/embeddings_sd_1.4/capy.pt +3 -0
- assets/embeddings_sd_1.4/cat.pt +3 -0
- assets/embeddings_sd_1.4/dog.pt +3 -0
- assets/embeddings_sd_1.4/dogs_with_glasses.pt +3 -0
- assets/embeddings_sd_1.4/horse.pt +3 -0
- assets/embeddings_sd_1.4/llama.pt +3 -0
- assets/embeddings_sd_1.4/zebra.pt +3 -0
- assets/llama.txt +15 -0
- assets/test_images/cats/cat_1.png +0 -0
- assets/test_images/cats/cat_2.png +0 -0
- assets/test_images/cats/cat_3.png +0 -0
- assets/test_images/cats/cat_4.png +0 -0
- assets/test_images/cats/cat_5.png +0 -0
- assets/test_images/cats/cat_6.png +0 -0
- assets/test_images/cats/cat_7.png +0 -0
- assets/test_images/cats/cat_8.png +0 -0
- assets/test_images/cats/cat_9.png +0 -0
- assets/test_images/dogs/dog_1.png +0 -0
- assets/test_images/dogs/dog_2.png +0 -0
- assets/test_images/dogs/dog_3.png +0 -0
- assets/test_images/dogs/dog_4.png +0 -0
- assets/test_images/dogs/dog_5.png +0 -0
- assets/test_images/dogs/dog_6.png +0 -0
- assets/test_images/dogs/dog_7.png +0 -0
- assets/test_images/dogs/dog_8.png +0 -0
- assets/test_images/dogs/dog_9.png +0 -0
- requirements.txt +7 -0
- src/edit_real.py +65 -0
- src/edit_synthetic.py +52 -0
- src/inversion.py +66 -0
- src/make_edit_direction.py +61 -0
- src/utils/base_pipeline.py +322 -0
- src/utils/cross_attention.py +57 -0
- src/utils/ddim_inv.py +140 -0
- src/utils/edit_directions.py +48 -0
- src/utils/edit_pipeline.py +174 -0
- src/utils/scheduler.py +289 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 pix2pixzero
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Pix2pix Zero
|
3 |
+
emoji: 🌍
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.18.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
tags:
|
11 |
+
- making-demos
|
12 |
+
duplicated_from: alvanlii/pix2pix_zero
|
13 |
+
---
|
14 |
+
|
15 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Adobe Research. All rights reserved.
|
2 |
+
# To view a copy of the license, visit LICENSE.md.
|
3 |
+
import os
|
4 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
from lavis.models import load_model_and_preprocess
|
13 |
+
|
14 |
+
from diffusers import DDIMScheduler
|
15 |
+
from src.utils.ddim_inv import DDIMInversion
|
16 |
+
from src.utils.edit_directions import construct_direction
|
17 |
+
from src.utils.scheduler import DDIMInverseScheduler
|
18 |
+
from src.utils.edit_pipeline import EditingPipeline
|
19 |
+
|
20 |
+
def main():
|
21 |
+
NUM_DDIM_STEPS = 50
|
22 |
+
TORCH_DTYPE = torch.float16
|
23 |
+
XA_GUIDANCE = 0.1
|
24 |
+
DIR_SCALE = 1.0
|
25 |
+
MODEL_NAME = 'CompVis/stable-diffusion-v1-4'
|
26 |
+
NEGATIVE_GUIDANCE_SCALE = 5.0
|
27 |
+
DEVICE = "cuda"
|
28 |
+
# if torch.cuda.is_available():
|
29 |
+
# DEVICE = "cuda"
|
30 |
+
# else:
|
31 |
+
# DEVICE = "cpu"
|
32 |
+
# print(f"Using {DEVICE}")
|
33 |
+
|
34 |
+
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=DEVICE)
|
35 |
+
pipe = EditingPipeline.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to(DEVICE)
|
36 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
37 |
+
|
38 |
+
inv_pipe = DDIMInversion.from_pretrained(MODEL_NAME, torch_dtype=TORCH_DTYPE, safety_checker=None).to("cuda")
|
39 |
+
inv_pipe.scheduler = DDIMInverseScheduler.from_config(inv_pipe.scheduler.config)
|
40 |
+
|
41 |
+
TASKS = ["dog2cat","cat2dog","horse2zebra","zebra2horse","horse2llama","dog2capy"]
|
42 |
+
TASK_OPTIONS = ["Dog to Cat", "Cat to Dog", "Horse to Zebra", "Zebra to Horse", "Horse to Llama", "Dog to Capy"]
|
43 |
+
|
44 |
+
def edit_real_image(
|
45 |
+
og_img,
|
46 |
+
task,
|
47 |
+
seed,
|
48 |
+
xa_guidance,
|
49 |
+
num_ddim_steps,
|
50 |
+
dir_scale
|
51 |
+
):
|
52 |
+
torch.cuda.manual_seed(seed)
|
53 |
+
|
54 |
+
# do inversion first, get inversion and generated prompt
|
55 |
+
curr_img = og_img.resize((512,512), Image.Resampling.LANCZOS)
|
56 |
+
_image = vis_processors["eval"](curr_img).unsqueeze(0).to(DEVICE)
|
57 |
+
prompt_str = model_blip.generate({"image": _image})[0]
|
58 |
+
x_inv, _, _ = inv_pipe(
|
59 |
+
prompt_str,
|
60 |
+
guidance_scale=1,
|
61 |
+
num_inversion_steps=NUM_DDIM_STEPS,
|
62 |
+
img=curr_img,
|
63 |
+
torch_dtype=TORCH_DTYPE
|
64 |
+
)
|
65 |
+
|
66 |
+
task_str = TASKS[task]
|
67 |
+
|
68 |
+
rec_pil, edit_pil = pipe(
|
69 |
+
prompt_str,
|
70 |
+
num_inference_steps=num_ddim_steps,
|
71 |
+
x_in=x_inv[0].unsqueeze(0),
|
72 |
+
edit_dir=construct_direction(task_str)*dir_scale,
|
73 |
+
guidance_amount=xa_guidance,
|
74 |
+
guidance_scale=NEGATIVE_GUIDANCE_SCALE,
|
75 |
+
negative_prompt=prompt_str # use the unedited prompt for the negative prompt
|
76 |
+
)
|
77 |
+
|
78 |
+
return prompt_str, edit_pil[0]
|
79 |
+
|
80 |
+
|
81 |
+
def edit_real_image_example():
|
82 |
+
test_img = Image.open("./assets/test_images/cats/cat_4.png")
|
83 |
+
seed = 42
|
84 |
+
task = 1
|
85 |
+
prompt_str, edited_img = edit_real_image(test_img, task, seed, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE)
|
86 |
+
return test_img, seed, "Cat to Dog", prompt_str, edited_img, XA_GUIDANCE, NUM_DDIM_STEPS, DIR_SCALE
|
87 |
+
|
88 |
+
|
89 |
+
def edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps):
|
90 |
+
torch.cuda.manual_seed(seed)
|
91 |
+
x = torch.randn((1,4,64,64), device="cuda")
|
92 |
+
|
93 |
+
task_str = TASKS[task]
|
94 |
+
|
95 |
+
rec_pil, edit_pil = pipe(
|
96 |
+
prompt_str,
|
97 |
+
num_inference_steps=num_ddim_steps,
|
98 |
+
x_in=x,
|
99 |
+
edit_dir=construct_direction(task_str),
|
100 |
+
guidance_amount=xa_guidance,
|
101 |
+
guidance_scale=NEGATIVE_GUIDANCE_SCALE,
|
102 |
+
negative_prompt="" # use the empty string for the negative prompt
|
103 |
+
)
|
104 |
+
|
105 |
+
return rec_pil[0], edit_pil[0]
|
106 |
+
|
107 |
+
def edit_synth_image_example():
|
108 |
+
seed = 42
|
109 |
+
task = 1
|
110 |
+
xa_guidance = XA_GUIDANCE
|
111 |
+
num_ddim_steps = NUM_DDIM_STEPS
|
112 |
+
prompt_str = "A cute white cat sitting on top of the fridge"
|
113 |
+
recon_img, edited_img = edit_synthetic_image(seed, task, prompt_str, xa_guidance, num_ddim_steps)
|
114 |
+
return seed, "Cat to Dog", xa_guidance, num_ddim_steps, prompt_str, recon_img, edited_img
|
115 |
+
|
116 |
+
with gr.Blocks() as demo:
|
117 |
+
gr.Markdown("""
|
118 |
+
### Zero-shot Image-to-Image Translation (https://github.com/pix2pixzero/pix2pix-zero)
|
119 |
+
Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, Jun-Yan Zhu <br/>
|
120 |
+
- For real images:
|
121 |
+
- Upload an image of a dog, cat or horse,
|
122 |
+
- Choose one of the task options to turn it into another animal!
|
123 |
+
- Changing Parameters:
|
124 |
+
- Increase direction scale is it is not cat (or another animal) enough.
|
125 |
+
- If the quality is not high enough, increase num ddim steps.
|
126 |
+
- Increase cross attention guidance to preserve original image structures. <br/>
|
127 |
+
- For synthetic images:
|
128 |
+
- Enter a prompt about dogs/cats/horses
|
129 |
+
- Choose a task option
|
130 |
+
""")
|
131 |
+
with gr.Tab("Real Image"):
|
132 |
+
with gr.Row():
|
133 |
+
seed = gr.Number(value=42, precision=1, label="Seed", interactive=True)
|
134 |
+
real_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True)
|
135 |
+
real_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True)
|
136 |
+
real_edit_dir_scale = gr.Number(value=DIR_SCALE, label="Edit Direction Scale", interactive=True)
|
137 |
+
real_generate_button = gr.Button("Generate")
|
138 |
+
real_load_sample_button = gr.Button("Load Example")
|
139 |
+
|
140 |
+
with gr.Row():
|
141 |
+
task_name = gr.Radio(
|
142 |
+
label='Task Name',
|
143 |
+
choices=TASK_OPTIONS,
|
144 |
+
value=TASK_OPTIONS[0],
|
145 |
+
type="index",
|
146 |
+
show_label=True,
|
147 |
+
interactive=True,
|
148 |
+
)
|
149 |
+
|
150 |
+
with gr.Row():
|
151 |
+
recon_text = gr.Textbox(lines=1, label="Reconstructed Text", interactive=False)
|
152 |
+
with gr.Row():
|
153 |
+
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
|
154 |
+
output_image = gr.Image(label="Output Image", type="pil", interactive=False)
|
155 |
+
|
156 |
+
|
157 |
+
with gr.Tab("Synthetic Images"):
|
158 |
+
with gr.Row():
|
159 |
+
synth_seed = gr.Number(value=42, precision=1, label="Seed", interactive=True)
|
160 |
+
synth_prompt = gr.Textbox(lines=1, label="Prompt", interactive=True)
|
161 |
+
synth_generate_button = gr.Button("Generate")
|
162 |
+
synth_load_sample_button = gr.Button("Load Example")
|
163 |
+
with gr.Row():
|
164 |
+
synth_task_name = gr.Radio(
|
165 |
+
label='Task Name',
|
166 |
+
choices=TASK_OPTIONS,
|
167 |
+
value=TASK_OPTIONS[0],
|
168 |
+
type="index",
|
169 |
+
show_label=True,
|
170 |
+
interactive=True,
|
171 |
+
)
|
172 |
+
synth_xa_guidance = gr.Number(value=XA_GUIDANCE, label="Cross Attention Guidance", interactive=True)
|
173 |
+
synth_num_ddim_steps = gr.Number(value=NUM_DDIM_STEPS, precision=1, label="Num DDIM steps", interactive=True)
|
174 |
+
with gr.Row():
|
175 |
+
synth_input_image = gr.Image(label="Input Image", type="pil", interactive=False)
|
176 |
+
synth_output_image = gr.Image(label="Output Image", type="pil", interactive=False)
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
real_generate_button.click(
|
181 |
+
fn=edit_real_image,
|
182 |
+
inputs=[
|
183 |
+
input_image, task_name, seed, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale
|
184 |
+
],
|
185 |
+
outputs=[recon_text, output_image]
|
186 |
+
)
|
187 |
+
|
188 |
+
real_load_sample_button.click(
|
189 |
+
fn=edit_real_image_example,
|
190 |
+
inputs=[],
|
191 |
+
outputs=[input_image, seed, task_name, recon_text, output_image, real_xa_guidance, real_num_ddim_steps, real_edit_dir_scale]
|
192 |
+
)
|
193 |
+
|
194 |
+
synth_generate_button.click(
|
195 |
+
fn=edit_synthetic_image,
|
196 |
+
inputs=[synth_seed, synth_task_name, synth_prompt, synth_xa_guidance, synth_num_ddim_steps],
|
197 |
+
outputs=[synth_input_image, synth_output_image]
|
198 |
+
)
|
199 |
+
|
200 |
+
synth_load_sample_button.click(
|
201 |
+
fn=edit_synth_image_example,
|
202 |
+
inputs=[],
|
203 |
+
outputs=[seed, synth_task_name, synth_xa_guidance, synth_num_ddim_steps, synth_prompt, synth_input_image, synth_output_image]
|
204 |
+
)
|
205 |
+
|
206 |
+
|
207 |
+
demo.queue(concurrency_count=1)
|
208 |
+
demo.launch(share=False, server_name="0.0.0.0")
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
main()
|
assets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/capy.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
A capybara is a large rodent with a heavy, barrel-shaped body, short head, and reddish-brown coloration on the upper parts of its body.
|
2 |
+
A capybara's fore and hind limbs are short, and its feet are webbed, which makes it an excellent swimmer.
|
3 |
+
A capybara has large, round ears and small eyes.
|
4 |
+
A capybara's buck teeth are sharp and adapted for cutting vegetation, and its cheek teeth do not make contact while the animal is gnawing.
|
5 |
+
A capybara has a thick tail and a hairy coat.
|
6 |
+
A large, semi-aquatic rodent with a brown coat and a round face is standing in a river.
|
7 |
+
A capybara is lazily swimming in a stream, its round face poking out of the water.
|
8 |
+
A capybara is wading through a stream, its furry body and short ears visible above the surface.
|
9 |
+
A capybara is snuggling up against a fallen tree trunk, its huge size easily visible.
|
10 |
+
A capybara is calmly swimming in a pool, its large body moving gracefully through the water.
|
11 |
+
A capybara is lounging in the grass, its round face and short ears visible above the blades of grass.
|
12 |
+
The capybara has a heavy, barrel-shaped body with short, stocky legs and a blunt snout.
|
13 |
+
A capy's coat is thick and composed of reddish-brown, yellowish-brown, or grayish-brown fur.
|
14 |
+
Capybaras' eyes and ears are small, and its long, coarse fur is sparse on its muzzle, forehead, and inner ears.
|
15 |
+
The capybara's feet are webbed, and its long, stout tail is slightly flattened.
|
16 |
+
A capy has four toes on each front foot and three on each hind foot, with the first two front toes partially webbed.
|
17 |
+
It has a short, thick neck and a short, broad head with small eyes and ears.
|
18 |
+
Its short, thick legs are equipped with long, sharp claws for digging.
|
19 |
+
The capybara has a thick, reddish-brown coat and a short, stocky body.
|
20 |
+
Capys' eyes and ears are small, and its snout is blunt.
|
21 |
+
A capybara has long, coarse fur on its muzzle, forehead, and inner ears, and its long, stout tail is slightly flattened.
|
22 |
+
It is an excellent swimmer, and its large size allows it to reach speeds of up to 35 kilometers per hour in short bursts.
|
assets/dogs_with_glasses.txt
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The dog is wearing small, round, metal glasses.
|
2 |
+
This pup has on a pair of glasses that are circular, metal, and small.
|
3 |
+
This canine is sporting a pair of round glasses, with a frame made of metal.
|
4 |
+
The dog is wearing a pair of spectacles that are small and made of metal.
|
5 |
+
The dog is wearing a pair of eyeglasses that is round, with a metal frame.
|
6 |
+
A small, round pair of glasses is perched on the dog's face.
|
7 |
+
A round, metal pair of eyewear is being worn by the pooch.
|
8 |
+
A tiny, round set of glasses adorns the dog's face.
|
9 |
+
This pup is wearing a small set of glasses with a metal frame.
|
10 |
+
The dog has on a tiny, metal pair of circular glasses.
|
11 |
+
A dog wearing glasses has a distinct appearance, with its eye-wear giving it a unique look.
|
12 |
+
A dog's eyes will often be visible through the lenses of its glasses and its ears are usually perked up and alert.
|
13 |
+
A dog wearing glasses looks smart and stylish.
|
14 |
+
The dog wearing glasses was a source of amusement.
|
15 |
+
The dog wearing glasses seemed to be aware of its fashion statement.
|
16 |
+
The dog wearing glasses was a delightful surprise.
|
17 |
+
The cute pup is sporting a pair of stylish glasses.
|
18 |
+
The dog seems to be quite the trendsetter with its eyewear.
|
19 |
+
The glasses add a unique look to the dog's face.
|
20 |
+
The dog looks wise wearing his glasses.
|
21 |
+
The glasses add a certain flair to the dog's look.
|
22 |
+
The glasses give the dog an air of sophistication.
|
23 |
+
The dog looks dapper wearing his glasses.
|
24 |
+
The dog looks so smart with its glasses.
|
25 |
+
Those glasses make the dog look extra adorable in the picture.
|
26 |
+
The glasses give the dog a distinguished look in the image.
|
27 |
+
Such a sophisticated pup, wearing glasses and looking relaxed.
|
28 |
+
The lenses provide 100% UV protection, keeping your pup's eyes safe from the sun's harmful rays.
|
29 |
+
The glasses can also help protect your pup from debris and dust when they're out and about.
|
assets/embeddings_sd_1.4/capy.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2d8085697317ef41bec4119a54d4cfb5829c11c0cf7e95c9f921ed5cd5c3b6d7
|
3 |
+
size 118946
|
assets/embeddings_sd_1.4/cat.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa9441dc014d5e86567c5ef165e10b50d2a7b3a68d90686d0cd1006792adf334
|
3 |
+
size 237300
|
assets/embeddings_sd_1.4/dog.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:becf079d61d7f35727bcc0d8506ddcdcddb61e62d611840ff3d18eca7fb6338c
|
3 |
+
size 237300
|
assets/embeddings_sd_1.4/dogs_with_glasses.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:805014d0f5145c1f094cf3550c8c08919d223b5ce2981adc7f0ef3ee7c6086b2
|
3 |
+
size 119049
|
assets/embeddings_sd_1.4/horse.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5d499299544d11371f84674761292b0512055ef45776c700c0b0da164cbf6c7
|
3 |
+
size 118949
|
assets/embeddings_sd_1.4/llama.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b4d639d08d038b728430acbcfb3956b002141ddfc8cf8c53fe9ac72b7e7b61c
|
3 |
+
size 118949
|
assets/embeddings_sd_1.4/zebra.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a29f6a11d91f3a276e27326b7623fae9d61a3d253ad430bb868bd40fb7e02fec
|
3 |
+
size 118949
|
assets/llama.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Llamas are large mammals with long necks and long woolly fur coats.
|
2 |
+
Llamas have four legs and two toes on each foot.
|
3 |
+
Llamas have large eyes, long ears, and curved lips.
|
4 |
+
Llamas have a short, thick tail and are typically brown and white in color.
|
5 |
+
Llamas are social animals and live in herds. They communicate with each other using a variety of vocalizations, including humming and spitting.
|
6 |
+
Llamas are also used as pack animals and are often used to carry supplies over long distances.
|
7 |
+
The llama stands in the vibrant meadow, its soft wool shining in the sunlight.
|
8 |
+
The llama looks content in its lush surroundings, its long neck reaching up towards the sky.
|
9 |
+
The majestic llama stands tall in its tranquil habitat, its fluffy coat catching the light of the day.
|
10 |
+
The llama grazes peacefully in the grassy meadow, its gentle eyes gazing across the landscape.
|
11 |
+
The llama is a picture of grace and beauty, its wool glimmering in the sun.
|
12 |
+
The llama strides confidently across the open field, its long legs carrying it through the grass.
|
13 |
+
The llama stands proud and stately amongst the rolling hills, its woolly coat glowing in the light.
|
14 |
+
The llama moves gracefully through the tall grass, its shaggy coat swaying in the wind.
|
15 |
+
The gentle llama stands in the meadow, its long eyelashes batting in the sun.
|
assets/test_images/cats/cat_1.png
ADDED
![]() |
assets/test_images/cats/cat_2.png
ADDED
![]() |
assets/test_images/cats/cat_3.png
ADDED
![]() |
assets/test_images/cats/cat_4.png
ADDED
![]() |
assets/test_images/cats/cat_5.png
ADDED
![]() |
assets/test_images/cats/cat_6.png
ADDED
![]() |
assets/test_images/cats/cat_7.png
ADDED
![]() |
assets/test_images/cats/cat_8.png
ADDED
![]() |
assets/test_images/cats/cat_9.png
ADDED
![]() |
assets/test_images/dogs/dog_1.png
ADDED
![]() |
assets/test_images/dogs/dog_2.png
ADDED
![]() |
assets/test_images/dogs/dog_3.png
ADDED
![]() |
assets/test_images/dogs/dog_4.png
ADDED
![]() |
assets/test_images/dogs/dog_5.png
ADDED
![]() |
assets/test_images/dogs/dog_6.png
ADDED
![]() |
assets/test_images/dogs/dog_7.png
ADDED
![]() |
assets/test_images/dogs/dog_8.png
ADDED
![]() |
assets/test_images/dogs/dog_9.png
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
gradio
|
3 |
+
accelerate
|
4 |
+
diffusers==0.12.1
|
5 |
+
einops
|
6 |
+
salesforce-lavis
|
7 |
+
opencv-python-headless
|
src/edit_real.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from diffusers import DDIMScheduler
|
10 |
+
from utils.ddim_inv import DDIMInversion
|
11 |
+
from utils.edit_directions import construct_direction
|
12 |
+
from utils.edit_pipeline import EditingPipeline
|
13 |
+
|
14 |
+
|
15 |
+
if __name__=="__main__":
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument('--inversion', required=True)
|
18 |
+
parser.add_argument('--prompt', type=str, required=True)
|
19 |
+
parser.add_argument('--task_name', type=str, default='cat2dog')
|
20 |
+
parser.add_argument('--results_folder', type=str, default='output/test_cat')
|
21 |
+
parser.add_argument('--num_ddim_steps', type=int, default=50)
|
22 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
23 |
+
parser.add_argument('--xa_guidance', default=0.1, type=float)
|
24 |
+
parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
|
25 |
+
parser.add_argument('--use_float_16', action='store_true')
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
os.makedirs(os.path.join(args.results_folder, "edit"), exist_ok=True)
|
30 |
+
os.makedirs(os.path.join(args.results_folder, "reconstruction"), exist_ok=True)
|
31 |
+
|
32 |
+
if args.use_float_16:
|
33 |
+
torch_dtype = torch.float16
|
34 |
+
else:
|
35 |
+
torch_dtype = torch.float32
|
36 |
+
|
37 |
+
# if the inversion is a folder, the prompt should also be a folder
|
38 |
+
assert (os.path.isdir(args.inversion)==os.path.isdir(args.prompt)), "If the inversion is a folder, the prompt should also be a folder"
|
39 |
+
if os.path.isdir(args.inversion):
|
40 |
+
l_inv_paths = sorted(glob(os.path.join(args.inversion, "*.pt")))
|
41 |
+
l_bnames = [os.path.basename(x) for x in l_inv_paths]
|
42 |
+
l_prompt_paths = [os.path.join(args.prompt, x.replace(".pt",".txt")) for x in l_bnames]
|
43 |
+
else:
|
44 |
+
l_inv_paths = [args.inversion]
|
45 |
+
l_prompt_paths = [args.prompt]
|
46 |
+
|
47 |
+
# Make the editing pipeline
|
48 |
+
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
|
49 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
50 |
+
|
51 |
+
|
52 |
+
for inv_path, prompt_path in zip(l_inv_paths, l_prompt_paths):
|
53 |
+
prompt_str = open(prompt_path).read().strip()
|
54 |
+
rec_pil, edit_pil = pipe(prompt_str,
|
55 |
+
num_inference_steps=args.num_ddim_steps,
|
56 |
+
x_in=torch.load(inv_path).unsqueeze(0),
|
57 |
+
edit_dir=construct_direction(args.task_name),
|
58 |
+
guidance_amount=args.xa_guidance,
|
59 |
+
guidance_scale=args.negative_guidance_scale,
|
60 |
+
negative_prompt=prompt_str # use the unedited prompt for the negative prompt
|
61 |
+
)
|
62 |
+
|
63 |
+
bname = os.path.basename(args.inversion).split(".")[0]
|
64 |
+
edit_pil[0].save(os.path.join(args.results_folder, f"edit/{bname}.png"))
|
65 |
+
rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction/{bname}.png"))
|
src/edit_synthetic.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from diffusers import DDIMScheduler
|
10 |
+
from utils.edit_directions import construct_direction
|
11 |
+
from utils.edit_pipeline import EditingPipeline
|
12 |
+
|
13 |
+
|
14 |
+
if __name__=="__main__":
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--prompt_str', type=str, required=True)
|
17 |
+
parser.add_argument('--random_seed', default=0)
|
18 |
+
parser.add_argument('--task_name', type=str, default='cat2dog')
|
19 |
+
parser.add_argument('--results_folder', type=str, default='output/test_cat')
|
20 |
+
parser.add_argument('--num_ddim_steps', type=int, default=50)
|
21 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
22 |
+
parser.add_argument('--xa_guidance', default=0.15, type=float)
|
23 |
+
parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
|
24 |
+
parser.add_argument('--use_float_16', action='store_true')
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
os.makedirs(args.results_folder, exist_ok=True)
|
28 |
+
|
29 |
+
if args.use_float_16:
|
30 |
+
torch_dtype = torch.float16
|
31 |
+
else:
|
32 |
+
torch_dtype = torch.float32
|
33 |
+
|
34 |
+
# make the input noise map
|
35 |
+
torch.cuda.manual_seed(args.random_seed)
|
36 |
+
x = torch.randn((1,4,64,64), device="cuda")
|
37 |
+
|
38 |
+
# Make the editing pipeline
|
39 |
+
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
|
40 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
41 |
+
|
42 |
+
rec_pil, edit_pil = pipe(args.prompt_str,
|
43 |
+
num_inference_steps=args.num_ddim_steps,
|
44 |
+
x_in=x,
|
45 |
+
edit_dir=construct_direction(args.task_name),
|
46 |
+
guidance_amount=args.xa_guidance,
|
47 |
+
guidance_scale=args.negative_guidance_scale,
|
48 |
+
negative_prompt="" # use the empty string for the negative prompt
|
49 |
+
)
|
50 |
+
|
51 |
+
edit_pil[0].save(os.path.join(args.results_folder, f"edit.png"))
|
52 |
+
rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction.png"))
|
src/inversion.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from lavis.models import load_model_and_preprocess
|
10 |
+
|
11 |
+
from utils.ddim_inv import DDIMInversion
|
12 |
+
from utils.scheduler import DDIMInverseScheduler
|
13 |
+
from utils.edit_pipeline import EditingPipeline
|
14 |
+
|
15 |
+
|
16 |
+
if __name__=="__main__":
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png')
|
19 |
+
parser.add_argument('--results_folder', type=str, default='output/test_cat')
|
20 |
+
parser.add_argument('--num_ddim_steps', type=int, default=50)
|
21 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
22 |
+
parser.add_argument('--use_float_16', action='store_true')
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
# make the output folders
|
26 |
+
os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True)
|
27 |
+
os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True)
|
28 |
+
|
29 |
+
if args.use_float_16:
|
30 |
+
torch_dtype = torch.float16
|
31 |
+
else:
|
32 |
+
torch_dtype = torch.float32
|
33 |
+
|
34 |
+
|
35 |
+
# load the BLIP model
|
36 |
+
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
|
37 |
+
# make the DDIM inversion pipeline
|
38 |
+
pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
|
39 |
+
pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
|
40 |
+
|
41 |
+
|
42 |
+
# if the input is a folder, collect all the images as a list
|
43 |
+
if os.path.isdir(args.input_image):
|
44 |
+
l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png")))
|
45 |
+
else:
|
46 |
+
l_img_paths = [args.input_image]
|
47 |
+
|
48 |
+
|
49 |
+
for img_path in l_img_paths:
|
50 |
+
bname = os.path.basename(args.input_image).split(".")[0]
|
51 |
+
img = Image.open(args.input_image).resize((512,512), Image.Resampling.LANCZOS)
|
52 |
+
# generate the caption
|
53 |
+
_image = vis_processors["eval"](img).unsqueeze(0).cuda()
|
54 |
+
prompt_str = model_blip.generate({"image": _image})[0]
|
55 |
+
x_inv, x_inv_image, x_dec_img = pipe(
|
56 |
+
prompt_str,
|
57 |
+
guidance_scale=1,
|
58 |
+
num_inversion_steps=args.num_ddim_steps,
|
59 |
+
img=img,
|
60 |
+
torch_dtype=torch_dtype
|
61 |
+
)
|
62 |
+
# save the inversion
|
63 |
+
torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt"))
|
64 |
+
# save the prompt string
|
65 |
+
with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f:
|
66 |
+
f.write(prompt_str)
|
src/make_edit_direction.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, pdb
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from diffusers import DDIMScheduler
|
10 |
+
from utils.edit_pipeline import EditingPipeline
|
11 |
+
|
12 |
+
|
13 |
+
## convert sentences to sentence embeddings
|
14 |
+
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
|
15 |
+
with torch.no_grad():
|
16 |
+
l_embeddings = []
|
17 |
+
for sent in l_sentences:
|
18 |
+
text_inputs = tokenizer(
|
19 |
+
sent,
|
20 |
+
padding="max_length",
|
21 |
+
max_length=tokenizer.model_max_length,
|
22 |
+
truncation=True,
|
23 |
+
return_tensors="pt",
|
24 |
+
)
|
25 |
+
text_input_ids = text_inputs.input_ids
|
26 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
|
27 |
+
l_embeddings.append(prompt_embeds)
|
28 |
+
return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
|
29 |
+
|
30 |
+
|
31 |
+
if __name__=="__main__":
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
parser.add_argument('--file_source_sentences', required=True)
|
34 |
+
# parser.add_argument('--file_target_sentences', required=True)
|
35 |
+
parser.add_argument('--output_folder', default="./assets/")
|
36 |
+
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
|
37 |
+
args = parser.parse_args()
|
38 |
+
|
39 |
+
# load the model
|
40 |
+
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda")
|
41 |
+
bname_src = os.path.basename(args.file_source_sentences).strip(".txt")
|
42 |
+
outf_src = os.path.join(args.output_folder, bname_src+".pt")
|
43 |
+
if os.path.exists(outf_src):
|
44 |
+
print(f"Skipping source file {outf_src} as it already exists")
|
45 |
+
else:
|
46 |
+
with open(args.file_source_sentences, "r") as f:
|
47 |
+
l_sents = [x.strip() for x in f.readlines()]
|
48 |
+
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
|
49 |
+
print(mean_emb.shape)
|
50 |
+
torch.save(mean_emb, outf_src)
|
51 |
+
|
52 |
+
# bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt")
|
53 |
+
# outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt")
|
54 |
+
# if os.path.exists(outf_tgt):
|
55 |
+
# print(f"Skipping target file {outf_tgt} as it already exists")
|
56 |
+
# else:
|
57 |
+
# with open(args.file_target_sentences, "r") as f:
|
58 |
+
# l_sents = [x.strip() for x in f.readlines()]
|
59 |
+
# mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
|
60 |
+
# print(mean_emb.shape)
|
61 |
+
# torch.save(mean_emb, outf_tgt)
|
src/utils/base_pipeline.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import inspect
|
4 |
+
from packaging import version
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6 |
+
|
7 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
8 |
+
from diffusers import DiffusionPipeline
|
9 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
10 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
11 |
+
from diffusers.utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
|
12 |
+
from diffusers import StableDiffusionPipeline
|
13 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class BasePipeline(DiffusionPipeline):
|
18 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
vae: AutoencoderKL,
|
22 |
+
text_encoder: CLIPTextModel,
|
23 |
+
tokenizer: CLIPTokenizer,
|
24 |
+
unet: UNet2DConditionModel,
|
25 |
+
scheduler: KarrasDiffusionSchedulers,
|
26 |
+
safety_checker: StableDiffusionSafetyChecker,
|
27 |
+
feature_extractor: CLIPFeatureExtractor,
|
28 |
+
requires_safety_checker: bool = True,
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
33 |
+
deprecation_message = (
|
34 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
35 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
36 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
37 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
38 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
39 |
+
" file"
|
40 |
+
)
|
41 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
42 |
+
new_config = dict(scheduler.config)
|
43 |
+
new_config["steps_offset"] = 1
|
44 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
45 |
+
|
46 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
47 |
+
deprecation_message = (
|
48 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
49 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
50 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
51 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
52 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
53 |
+
)
|
54 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
55 |
+
new_config = dict(scheduler.config)
|
56 |
+
new_config["clip_sample"] = False
|
57 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
58 |
+
|
59 |
+
# if safety_checker is None and requires_safety_checker:
|
60 |
+
# logger.warning(
|
61 |
+
# f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
62 |
+
# " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
63 |
+
# " results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
64 |
+
# " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
65 |
+
# " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
66 |
+
# " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
67 |
+
# )
|
68 |
+
|
69 |
+
if safety_checker is not None and feature_extractor is None:
|
70 |
+
raise ValueError(
|
71 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
72 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
73 |
+
)
|
74 |
+
|
75 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
76 |
+
version.parse(unet.config._diffusers_version).base_version
|
77 |
+
) < version.parse("0.9.0.dev0")
|
78 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
79 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
80 |
+
deprecation_message = (
|
81 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
82 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
83 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
84 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
85 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
86 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
87 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
88 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
89 |
+
" the `unet/config.json` file"
|
90 |
+
)
|
91 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
92 |
+
new_config = dict(unet.config)
|
93 |
+
new_config["sample_size"] = 64
|
94 |
+
unet._internal_dict = FrozenDict(new_config)
|
95 |
+
|
96 |
+
self.register_modules(
|
97 |
+
vae=vae,
|
98 |
+
text_encoder=text_encoder,
|
99 |
+
tokenizer=tokenizer,
|
100 |
+
unet=unet,
|
101 |
+
scheduler=scheduler,
|
102 |
+
safety_checker=safety_checker,
|
103 |
+
feature_extractor=feature_extractor,
|
104 |
+
)
|
105 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
106 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
107 |
+
|
108 |
+
@property
|
109 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
110 |
+
def _execution_device(self):
|
111 |
+
r"""
|
112 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
113 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
114 |
+
hooks.
|
115 |
+
"""
|
116 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
117 |
+
return self.device
|
118 |
+
for module in self.unet.modules():
|
119 |
+
if (
|
120 |
+
hasattr(module, "_hf_hook")
|
121 |
+
and hasattr(module._hf_hook, "execution_device")
|
122 |
+
and module._hf_hook.execution_device is not None
|
123 |
+
):
|
124 |
+
return torch.device(module._hf_hook.execution_device)
|
125 |
+
return self.device
|
126 |
+
|
127 |
+
|
128 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
129 |
+
def _encode_prompt(
|
130 |
+
self,
|
131 |
+
prompt,
|
132 |
+
device,
|
133 |
+
num_images_per_prompt,
|
134 |
+
do_classifier_free_guidance,
|
135 |
+
negative_prompt=None,
|
136 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
137 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
138 |
+
):
|
139 |
+
r"""
|
140 |
+
Encodes the prompt into text encoder hidden states.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
prompt (`str` or `List[str]`, *optional*):
|
144 |
+
prompt to be encoded
|
145 |
+
device: (`torch.device`):
|
146 |
+
torch device
|
147 |
+
num_images_per_prompt (`int`):
|
148 |
+
number of images that should be generated per prompt
|
149 |
+
do_classifier_free_guidance (`bool`):
|
150 |
+
whether to use classifier free guidance or not
|
151 |
+
negative_ prompt (`str` or `List[str]`, *optional*):
|
152 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
153 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
154 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
155 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
156 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
157 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
158 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
159 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
160 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
161 |
+
argument.
|
162 |
+
"""
|
163 |
+
if prompt is not None and isinstance(prompt, str):
|
164 |
+
batch_size = 1
|
165 |
+
elif prompt is not None and isinstance(prompt, list):
|
166 |
+
batch_size = len(prompt)
|
167 |
+
else:
|
168 |
+
batch_size = prompt_embeds.shape[0]
|
169 |
+
|
170 |
+
if prompt_embeds is None:
|
171 |
+
text_inputs = self.tokenizer(
|
172 |
+
prompt,
|
173 |
+
padding="max_length",
|
174 |
+
max_length=self.tokenizer.model_max_length,
|
175 |
+
truncation=True,
|
176 |
+
return_tensors="pt",
|
177 |
+
)
|
178 |
+
text_input_ids = text_inputs.input_ids
|
179 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
180 |
+
|
181 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
182 |
+
text_input_ids, untruncated_ids
|
183 |
+
):
|
184 |
+
removed_text = self.tokenizer.batch_decode(
|
185 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
186 |
+
)
|
187 |
+
# logger.warning(
|
188 |
+
# "The following part of your input was truncated because CLIP can only handle sequences up to"
|
189 |
+
# f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
190 |
+
# )
|
191 |
+
|
192 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
193 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
194 |
+
else:
|
195 |
+
attention_mask = None
|
196 |
+
|
197 |
+
prompt_embeds = self.text_encoder(
|
198 |
+
text_input_ids.to(device),
|
199 |
+
attention_mask=attention_mask,
|
200 |
+
)
|
201 |
+
prompt_embeds = prompt_embeds[0]
|
202 |
+
|
203 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
204 |
+
|
205 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
206 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
207 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
208 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
209 |
+
|
210 |
+
# get unconditional embeddings for classifier free guidance
|
211 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
212 |
+
uncond_tokens: List[str]
|
213 |
+
if negative_prompt is None:
|
214 |
+
uncond_tokens = [""] * batch_size
|
215 |
+
elif type(prompt) is not type(negative_prompt):
|
216 |
+
raise TypeError(
|
217 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
218 |
+
f" {type(prompt)}."
|
219 |
+
)
|
220 |
+
elif isinstance(negative_prompt, str):
|
221 |
+
uncond_tokens = [negative_prompt]
|
222 |
+
elif batch_size != len(negative_prompt):
|
223 |
+
raise ValueError(
|
224 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
225 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
226 |
+
" the batch size of `prompt`."
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
uncond_tokens = negative_prompt
|
230 |
+
|
231 |
+
max_length = prompt_embeds.shape[1]
|
232 |
+
uncond_input = self.tokenizer(
|
233 |
+
uncond_tokens,
|
234 |
+
padding="max_length",
|
235 |
+
max_length=max_length,
|
236 |
+
truncation=True,
|
237 |
+
return_tensors="pt",
|
238 |
+
)
|
239 |
+
|
240 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
241 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
242 |
+
else:
|
243 |
+
attention_mask = None
|
244 |
+
|
245 |
+
negative_prompt_embeds = self.text_encoder(
|
246 |
+
uncond_input.input_ids.to(device),
|
247 |
+
attention_mask=attention_mask,
|
248 |
+
)
|
249 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
250 |
+
|
251 |
+
if do_classifier_free_guidance:
|
252 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
253 |
+
seq_len = negative_prompt_embeds.shape[1]
|
254 |
+
|
255 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
256 |
+
|
257 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
258 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
259 |
+
|
260 |
+
# For classifier free guidance, we need to do two forward passes.
|
261 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
262 |
+
# to avoid doing two forward passes
|
263 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
264 |
+
|
265 |
+
return prompt_embeds
|
266 |
+
|
267 |
+
|
268 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
269 |
+
def decode_latents(self, latents):
|
270 |
+
latents = 1 / 0.18215 * latents
|
271 |
+
image = self.vae.decode(latents).sample
|
272 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
273 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
274 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
275 |
+
return image
|
276 |
+
|
277 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
278 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
279 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
280 |
+
raise ValueError(
|
281 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
282 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
283 |
+
)
|
284 |
+
|
285 |
+
if latents is None:
|
286 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
287 |
+
else:
|
288 |
+
latents = latents.to(device)
|
289 |
+
|
290 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
291 |
+
latents = latents * self.scheduler.init_noise_sigma
|
292 |
+
return latents
|
293 |
+
|
294 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
295 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
296 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
297 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
298 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
299 |
+
# and should be between [0, 1]
|
300 |
+
|
301 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
302 |
+
extra_step_kwargs = {}
|
303 |
+
if accepts_eta:
|
304 |
+
extra_step_kwargs["eta"] = eta
|
305 |
+
|
306 |
+
# check if the scheduler accepts generator
|
307 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
308 |
+
if accepts_generator:
|
309 |
+
extra_step_kwargs["generator"] = generator
|
310 |
+
return extra_step_kwargs
|
311 |
+
|
312 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
313 |
+
def run_safety_checker(self, image, device, dtype):
|
314 |
+
if self.safety_checker is not None:
|
315 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
316 |
+
image, has_nsfw_concept = self.safety_checker(
|
317 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
318 |
+
)
|
319 |
+
else:
|
320 |
+
has_nsfw_concept = None
|
321 |
+
return image, has_nsfw_concept
|
322 |
+
|
src/utils/cross_attention.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.models.attention import CrossAttention
|
3 |
+
|
4 |
+
class MyCrossAttnProcessor:
|
5 |
+
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
6 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
7 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
8 |
+
|
9 |
+
query = attn.to_q(hidden_states)
|
10 |
+
|
11 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
12 |
+
key = attn.to_k(encoder_hidden_states)
|
13 |
+
value = attn.to_v(encoder_hidden_states)
|
14 |
+
|
15 |
+
query = attn.head_to_batch_dim(query)
|
16 |
+
key = attn.head_to_batch_dim(key)
|
17 |
+
value = attn.head_to_batch_dim(value)
|
18 |
+
|
19 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
20 |
+
# new bookkeeping to save the attn probs
|
21 |
+
attn.attn_probs = attention_probs
|
22 |
+
|
23 |
+
hidden_states = torch.bmm(attention_probs, value)
|
24 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
25 |
+
|
26 |
+
# linear proj
|
27 |
+
hidden_states = attn.to_out[0](hidden_states)
|
28 |
+
# dropout
|
29 |
+
hidden_states = attn.to_out[1](hidden_states)
|
30 |
+
|
31 |
+
return hidden_states
|
32 |
+
|
33 |
+
|
34 |
+
"""
|
35 |
+
A function that prepares a U-Net model for training by enabling gradient computation
|
36 |
+
for a specified set of parameters and setting the forward pass to be performed by a
|
37 |
+
custom cross attention processor.
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
unet: A U-Net model.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
unet: The prepared U-Net model.
|
44 |
+
"""
|
45 |
+
def prep_unet(unet):
|
46 |
+
# set the gradients for XA maps to be true
|
47 |
+
for name, params in unet.named_parameters():
|
48 |
+
if 'attn2' in name:
|
49 |
+
params.requires_grad = True
|
50 |
+
else:
|
51 |
+
params.requires_grad = False
|
52 |
+
# replace the fwd function
|
53 |
+
for name, module in unet.named_modules():
|
54 |
+
module_name = type(module).__name__
|
55 |
+
if module_name == "CrossAttention":
|
56 |
+
module.set_processor(MyCrossAttnProcessor())
|
57 |
+
return unet
|
src/utils/ddim_inv.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from random import randrange
|
6 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
7 |
+
from diffusers import DDIMScheduler
|
8 |
+
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
|
9 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
10 |
+
sys.path.insert(0, "src/utils")
|
11 |
+
from base_pipeline import BasePipeline
|
12 |
+
from cross_attention import prep_unet
|
13 |
+
|
14 |
+
|
15 |
+
class DDIMInversion(BasePipeline):
|
16 |
+
|
17 |
+
def auto_corr_loss(self, x, random_shift=True):
|
18 |
+
B,C,H,W = x.shape
|
19 |
+
assert B==1
|
20 |
+
x = x.squeeze(0)
|
21 |
+
# x must be shape [C,H,W] now
|
22 |
+
reg_loss = 0.0
|
23 |
+
for ch_idx in range(x.shape[0]):
|
24 |
+
noise = x[ch_idx][None, None,:,:]
|
25 |
+
while True:
|
26 |
+
if random_shift: roll_amount = randrange(noise.shape[2]//2)
|
27 |
+
else: roll_amount = 1
|
28 |
+
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2
|
29 |
+
reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2
|
30 |
+
if noise.shape[2] <= 8:
|
31 |
+
break
|
32 |
+
noise = F.avg_pool2d(noise, kernel_size=2)
|
33 |
+
return reg_loss
|
34 |
+
|
35 |
+
def kl_divergence(self, x):
|
36 |
+
_mu = x.mean()
|
37 |
+
_var = x.var()
|
38 |
+
return _var + _mu**2 - 1 - torch.log(_var+1e-7)
|
39 |
+
|
40 |
+
|
41 |
+
def __call__(
|
42 |
+
self,
|
43 |
+
prompt: Union[str, List[str]] = None,
|
44 |
+
num_inversion_steps: int = 50,
|
45 |
+
guidance_scale: float = 7.5,
|
46 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
47 |
+
num_images_per_prompt: Optional[int] = 1,
|
48 |
+
eta: float = 0.0,
|
49 |
+
output_type: Optional[str] = "pil",
|
50 |
+
return_dict: bool = True,
|
51 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
52 |
+
img=None, # the input image as a PIL image
|
53 |
+
torch_dtype=torch.float32,
|
54 |
+
|
55 |
+
# inversion regularization parameters
|
56 |
+
lambda_ac: float = 20.0,
|
57 |
+
lambda_kl: float = 20.0,
|
58 |
+
num_reg_steps: int = 5,
|
59 |
+
num_ac_rolls: int = 5,
|
60 |
+
):
|
61 |
+
|
62 |
+
# 0. modify the unet to be useful :D
|
63 |
+
self.unet = prep_unet(self.unet)
|
64 |
+
|
65 |
+
# set the scheduler to be the Inverse DDIM scheduler
|
66 |
+
# self.scheduler = MyDDIMScheduler.from_config(self.scheduler.config)
|
67 |
+
|
68 |
+
device = self._execution_device
|
69 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
70 |
+
self.scheduler.set_timesteps(num_inversion_steps, device=device)
|
71 |
+
timesteps = self.scheduler.timesteps
|
72 |
+
|
73 |
+
# Encode the input image with the first stage model
|
74 |
+
x0 = np.array(img)/255
|
75 |
+
x0 = torch.from_numpy(x0).type(torch_dtype).permute(2, 0, 1).unsqueeze(dim=0).repeat(1, 1, 1, 1).cuda()
|
76 |
+
x0 = (x0 - 0.5) * 2.
|
77 |
+
with torch.no_grad():
|
78 |
+
x0_enc = self.vae.encode(x0).latent_dist.sample().to(device, torch_dtype)
|
79 |
+
latents = x0_enc = 0.18215 * x0_enc
|
80 |
+
|
81 |
+
# Decode and return the image
|
82 |
+
with torch.no_grad():
|
83 |
+
x0_dec = self.decode_latents(x0_enc.detach())
|
84 |
+
image_x0_dec = self.numpy_to_pil(x0_dec)
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt).to(device)
|
88 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(None, eta)
|
89 |
+
|
90 |
+
# Do the inversion
|
91 |
+
num_warmup_steps = len(timesteps) - num_inversion_steps * self.scheduler.order # should be 0?
|
92 |
+
with self.progress_bar(total=num_inversion_steps) as progress_bar:
|
93 |
+
for i, t in enumerate(timesteps.flip(0)[1:-1]):
|
94 |
+
# expand the latents if we are doing classifier free guidance
|
95 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
96 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
97 |
+
|
98 |
+
# predict the noise residual
|
99 |
+
with torch.no_grad():
|
100 |
+
noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
|
101 |
+
|
102 |
+
# perform guidance
|
103 |
+
if do_classifier_free_guidance:
|
104 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
105 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
106 |
+
|
107 |
+
# regularization of the noise prediction
|
108 |
+
e_t = noise_pred
|
109 |
+
for _outer in range(num_reg_steps):
|
110 |
+
if lambda_ac>0:
|
111 |
+
for _inner in range(num_ac_rolls):
|
112 |
+
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
|
113 |
+
l_ac = self.auto_corr_loss(_var)
|
114 |
+
l_ac.backward()
|
115 |
+
_grad = _var.grad.detach()/num_ac_rolls
|
116 |
+
e_t = e_t - lambda_ac*_grad
|
117 |
+
if lambda_kl>0:
|
118 |
+
_var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
|
119 |
+
l_kld = self.kl_divergence(_var)
|
120 |
+
l_kld.backward()
|
121 |
+
_grad = _var.grad.detach()
|
122 |
+
e_t = e_t - lambda_kl*_grad
|
123 |
+
e_t = e_t.detach()
|
124 |
+
noise_pred = e_t
|
125 |
+
|
126 |
+
# compute the previous noisy sample x_t -> x_t-1
|
127 |
+
latents = self.scheduler.step(noise_pred, t, latents, reverse=True, **extra_step_kwargs).prev_sample
|
128 |
+
|
129 |
+
# call the callback, if provided
|
130 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
131 |
+
progress_bar.update()
|
132 |
+
|
133 |
+
|
134 |
+
x_inv = latents.detach().clone()
|
135 |
+
# reconstruct the image
|
136 |
+
|
137 |
+
# 8. Post-processing
|
138 |
+
image = self.decode_latents(latents.detach())
|
139 |
+
image = self.numpy_to_pil(image)
|
140 |
+
return x_inv, image, image_x0_dec
|
src/utils/edit_directions.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
"""
|
6 |
+
This function takes in a task name and returns the direction in the embedding space that transforms class A to class B for the given task.
|
7 |
+
|
8 |
+
Parameters:
|
9 |
+
task_name (str): name of the task for which direction is to be constructed.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
torch.Tensor: A tensor representing the direction in the embedding space that transforms class A to class B.
|
13 |
+
|
14 |
+
Examples:
|
15 |
+
>>> construct_direction("cat2dog")
|
16 |
+
"""
|
17 |
+
def construct_direction(task_name):
|
18 |
+
emb_dir = f"assets/embeddings_sd_1.4"
|
19 |
+
if task_name=="cat2dog":
|
20 |
+
embs_a = torch.load(os.path.join(emb_dir, f"cat.pt"))
|
21 |
+
embs_b = torch.load(os.path.join(emb_dir, f"dog.pt"))
|
22 |
+
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
|
23 |
+
elif task_name=="dog2cat":
|
24 |
+
embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
|
25 |
+
embs_b = torch.load(os.path.join(emb_dir, f"cat.pt"))
|
26 |
+
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
|
27 |
+
elif task_name=="horse2zebra":
|
28 |
+
embs_a = torch.load(os.path.join(emb_dir, f"horse.pt"))
|
29 |
+
embs_b = torch.load(os.path.join(emb_dir, f"zebra.pt"))
|
30 |
+
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
|
31 |
+
elif task_name=="zebra2horse":
|
32 |
+
embs_a = torch.load(os.path.join(emb_dir, f"zebra.pt"))
|
33 |
+
embs_b = torch.load(os.path.join(emb_dir, f"horse.pt"))
|
34 |
+
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
|
35 |
+
elif task_name=="horse2llama":
|
36 |
+
embs_a = torch.load(os.path.join(emb_dir, f"horse.pt"))
|
37 |
+
embs_b = torch.load(os.path.join(emb_dir, f"llama.pt"))
|
38 |
+
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
|
39 |
+
elif task_name=="dog2capy":
|
40 |
+
embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
|
41 |
+
embs_b = torch.load(os.path.join(emb_dir, f"capy.pt"))
|
42 |
+
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
|
43 |
+
elif task_name=='dogglasses':
|
44 |
+
embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
|
45 |
+
embs_b = torch.load(os.path.join(emb_dir, f"dogs_with_glasses.pt"))
|
46 |
+
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
|
47 |
+
else:
|
48 |
+
raise NotImplementedError
|
src/utils/edit_pipeline.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb, sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
7 |
+
sys.path.insert(0, "src/utils")
|
8 |
+
from base_pipeline import BasePipeline
|
9 |
+
from cross_attention import prep_unet
|
10 |
+
|
11 |
+
|
12 |
+
class EditingPipeline(BasePipeline):
|
13 |
+
def __call__(
|
14 |
+
self,
|
15 |
+
prompt: Union[str, List[str]] = None,
|
16 |
+
height: Optional[int] = None,
|
17 |
+
width: Optional[int] = None,
|
18 |
+
num_inference_steps: int = 50,
|
19 |
+
guidance_scale: float = 7.5,
|
20 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
21 |
+
num_images_per_prompt: Optional[int] = 1,
|
22 |
+
eta: float = 0.0,
|
23 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
24 |
+
latents: Optional[torch.FloatTensor] = None,
|
25 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
26 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
27 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
28 |
+
|
29 |
+
# pix2pix parameters
|
30 |
+
guidance_amount=0.1,
|
31 |
+
edit_dir=None,
|
32 |
+
x_in=None,
|
33 |
+
|
34 |
+
):
|
35 |
+
|
36 |
+
x_in.to(dtype=self.unet.dtype, device=self._execution_device)
|
37 |
+
|
38 |
+
# 0. modify the unet to be useful :D
|
39 |
+
self.unet = prep_unet(self.unet)
|
40 |
+
|
41 |
+
# 1. setup all caching objects
|
42 |
+
d_ref_t2attn = {} # reference cross attention maps
|
43 |
+
|
44 |
+
# 2. Default height and width to unet
|
45 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
46 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
47 |
+
|
48 |
+
# TODO: add the input checker function
|
49 |
+
# self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds )
|
50 |
+
|
51 |
+
# 2. Define call parameters
|
52 |
+
if prompt is not None and isinstance(prompt, str):
|
53 |
+
batch_size = 1
|
54 |
+
elif prompt is not None and isinstance(prompt, list):
|
55 |
+
batch_size = len(prompt)
|
56 |
+
else:
|
57 |
+
batch_size = prompt_embeds.shape[0]
|
58 |
+
|
59 |
+
device = self._execution_device
|
60 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
61 |
+
x_in = x_in.to(dtype=self.unet.dtype, device=self._execution_device)
|
62 |
+
# 3. Encode input prompt = 2x77x1024
|
63 |
+
prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,)
|
64 |
+
|
65 |
+
# 4. Prepare timesteps
|
66 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
67 |
+
timesteps = self.scheduler.timesteps
|
68 |
+
|
69 |
+
# 5. Prepare latent variables
|
70 |
+
num_channels_latents = self.unet.in_channels
|
71 |
+
|
72 |
+
# randomly sample a latent code if not provided
|
73 |
+
latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,)
|
74 |
+
|
75 |
+
latents_init = latents.clone()
|
76 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
77 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
78 |
+
|
79 |
+
# 7. First Denoising loop for getting the reference cross attention maps
|
80 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
81 |
+
with torch.no_grad():
|
82 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
83 |
+
for i, t in enumerate(timesteps):
|
84 |
+
# expand the latents if we are doing classifier free guidance
|
85 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
86 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
87 |
+
|
88 |
+
# predict the noise residual
|
89 |
+
noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample
|
90 |
+
|
91 |
+
# add the cross attention map to the dictionary
|
92 |
+
d_ref_t2attn[t.item()] = {}
|
93 |
+
for name, module in self.unet.named_modules():
|
94 |
+
module_name = type(module).__name__
|
95 |
+
if module_name == "CrossAttention" and 'attn2' in name:
|
96 |
+
attn_mask = module.attn_probs # size is num_channel,s*s,77
|
97 |
+
d_ref_t2attn[t.item()][name] = attn_mask.detach().cpu()
|
98 |
+
|
99 |
+
# perform guidance
|
100 |
+
if do_classifier_free_guidance:
|
101 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
102 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
103 |
+
|
104 |
+
# compute the previous noisy sample x_t -> x_t-1
|
105 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
106 |
+
|
107 |
+
# call the callback, if provided
|
108 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
109 |
+
progress_bar.update()
|
110 |
+
|
111 |
+
# make the reference image (reconstruction)
|
112 |
+
image_rec = self.numpy_to_pil(self.decode_latents(latents.detach()))
|
113 |
+
|
114 |
+
prompt_embeds_edit = prompt_embeds.clone()
|
115 |
+
#add the edit only to the second prompt, idx 0 is the negative prompt
|
116 |
+
prompt_embeds_edit[1:2] += edit_dir
|
117 |
+
|
118 |
+
latents = latents_init
|
119 |
+
# Second denoising loop for editing the text prompt
|
120 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
121 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
122 |
+
for i, t in enumerate(timesteps):
|
123 |
+
# expand the latents if we are doing classifier free guidance
|
124 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
125 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
126 |
+
|
127 |
+
x_in = latent_model_input.detach().clone()
|
128 |
+
x_in.requires_grad = True
|
129 |
+
|
130 |
+
opt = torch.optim.SGD([x_in], lr=guidance_amount)
|
131 |
+
|
132 |
+
# predict the noise residual
|
133 |
+
noise_pred = self.unet(x_in,t,encoder_hidden_states=prompt_embeds_edit.detach(),cross_attention_kwargs=cross_attention_kwargs,).sample
|
134 |
+
|
135 |
+
loss = 0.0
|
136 |
+
for name, module in self.unet.named_modules():
|
137 |
+
module_name = type(module).__name__
|
138 |
+
if module_name == "CrossAttention" and 'attn2' in name:
|
139 |
+
curr = module.attn_probs # size is num_channel,s*s,77
|
140 |
+
ref = d_ref_t2attn[t.item()][name].detach().cuda()
|
141 |
+
loss += ((curr-ref)**2).sum((1,2)).mean(0)
|
142 |
+
loss.backward(retain_graph=False)
|
143 |
+
opt.step()
|
144 |
+
|
145 |
+
# recompute the noise
|
146 |
+
with torch.no_grad():
|
147 |
+
noise_pred = self.unet(x_in.detach(),t,encoder_hidden_states=prompt_embeds_edit,cross_attention_kwargs=cross_attention_kwargs,).sample
|
148 |
+
|
149 |
+
latents = x_in.detach().chunk(2)[0]
|
150 |
+
|
151 |
+
# perform guidance
|
152 |
+
if do_classifier_free_guidance:
|
153 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
154 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
155 |
+
|
156 |
+
# compute the previous noisy sample x_t -> x_t-1
|
157 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
158 |
+
|
159 |
+
# call the callback, if provided
|
160 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
161 |
+
progress_bar.update()
|
162 |
+
|
163 |
+
|
164 |
+
# 8. Post-processing
|
165 |
+
image = self.decode_latents(latents.detach())
|
166 |
+
|
167 |
+
# 9. Run safety checker
|
168 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
169 |
+
|
170 |
+
# 10. Convert to PIL
|
171 |
+
image_edit = self.numpy_to_pil(image)
|
172 |
+
|
173 |
+
|
174 |
+
return image_rec, image_edit
|
src/utils/scheduler.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16 |
+
# and https://github.com/hojonathanho/diffusion
|
17 |
+
import os, sys, pdb
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
+
from diffusers.utils import BaseOutput, randn_tensor
|
27 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
32 |
+
class DDIMSchedulerOutput(BaseOutput):
|
33 |
+
"""
|
34 |
+
Output class for the scheduler's step function output.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
38 |
+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
39 |
+
denoising loop.
|
40 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41 |
+
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
42 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
43 |
+
"""
|
44 |
+
|
45 |
+
prev_sample: torch.FloatTensor
|
46 |
+
pred_original_sample: Optional[torch.FloatTensor] = None
|
47 |
+
|
48 |
+
|
49 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
50 |
+
"""
|
51 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
52 |
+
(1-beta) over time from t = [0,1].
|
53 |
+
|
54 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
55 |
+
to that part of the diffusion process.
|
56 |
+
|
57 |
+
|
58 |
+
Args:
|
59 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
60 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
61 |
+
prevent singularities.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
65 |
+
"""
|
66 |
+
|
67 |
+
def alpha_bar(time_step):
|
68 |
+
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
69 |
+
|
70 |
+
betas = []
|
71 |
+
for i in range(num_diffusion_timesteps):
|
72 |
+
t1 = i / num_diffusion_timesteps
|
73 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
74 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
75 |
+
return torch.tensor(betas)
|
76 |
+
|
77 |
+
|
78 |
+
class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
79 |
+
"""
|
80 |
+
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
81 |
+
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
|
82 |
+
|
83 |
+
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
84 |
+
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
85 |
+
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
86 |
+
[`~SchedulerMixin.from_pretrained`] functions.
|
87 |
+
|
88 |
+
For more details, see the original paper: https://arxiv.org/abs/2010.02502
|
89 |
+
|
90 |
+
Args:
|
91 |
+
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
92 |
+
beta_start (`float`): the starting `beta` value of inference.
|
93 |
+
beta_end (`float`): the final `beta` value.
|
94 |
+
beta_schedule (`str`):
|
95 |
+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
96 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
97 |
+
trained_betas (`np.ndarray`, optional):
|
98 |
+
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
99 |
+
clip_sample (`bool`, default `True`):
|
100 |
+
option to clip predicted sample between -1 and 1 for numerical stability.
|
101 |
+
set_alpha_to_one (`bool`, default `True`):
|
102 |
+
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
103 |
+
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
104 |
+
otherwise it uses the value of alpha at step 0.
|
105 |
+
steps_offset (`int`, default `0`):
|
106 |
+
an offset added to the inference steps. You can use a combination of `offset=1` and
|
107 |
+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
108 |
+
stable diffusion.
|
109 |
+
prediction_type (`str`, default `epsilon`, optional):
|
110 |
+
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
111 |
+
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
112 |
+
https://imagen.research.google/video/paper.pdf)
|
113 |
+
"""
|
114 |
+
|
115 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
116 |
+
order = 1
|
117 |
+
|
118 |
+
@register_to_config
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
num_train_timesteps: int = 1000,
|
122 |
+
beta_start: float = 0.0001,
|
123 |
+
beta_end: float = 0.02,
|
124 |
+
beta_schedule: str = "linear",
|
125 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
126 |
+
clip_sample: bool = True,
|
127 |
+
set_alpha_to_one: bool = True,
|
128 |
+
steps_offset: int = 0,
|
129 |
+
prediction_type: str = "epsilon",
|
130 |
+
):
|
131 |
+
if trained_betas is not None:
|
132 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
133 |
+
elif beta_schedule == "linear":
|
134 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
135 |
+
elif beta_schedule == "scaled_linear":
|
136 |
+
# this schedule is very specific to the latent diffusion model.
|
137 |
+
self.betas = (
|
138 |
+
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
139 |
+
)
|
140 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
141 |
+
# Glide cosine schedule
|
142 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
143 |
+
else:
|
144 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
145 |
+
|
146 |
+
self.alphas = 1.0 - self.betas
|
147 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
148 |
+
|
149 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
150 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
151 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
152 |
+
# whether we use the final alpha of the "non-previous" one.
|
153 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
154 |
+
|
155 |
+
# standard deviation of the initial noise distribution
|
156 |
+
self.init_noise_sigma = 1.0
|
157 |
+
|
158 |
+
# setable values
|
159 |
+
self.num_inference_steps = None
|
160 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
161 |
+
|
162 |
+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
163 |
+
"""
|
164 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
165 |
+
current timestep.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
sample (`torch.FloatTensor`): input sample
|
169 |
+
timestep (`int`, optional): current timestep
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
`torch.FloatTensor`: scaled input sample
|
173 |
+
"""
|
174 |
+
return sample
|
175 |
+
|
176 |
+
def _get_variance(self, timestep, prev_timestep):
|
177 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
178 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
179 |
+
beta_prod_t = 1 - alpha_prod_t
|
180 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
181 |
+
|
182 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
183 |
+
|
184 |
+
return variance
|
185 |
+
|
186 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
187 |
+
"""
|
188 |
+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
num_inference_steps (`int`):
|
192 |
+
the number of diffusion steps used when generating samples with a pre-trained model.
|
193 |
+
"""
|
194 |
+
|
195 |
+
if num_inference_steps > self.config.num_train_timesteps:
|
196 |
+
raise ValueError(
|
197 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
198 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
199 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
200 |
+
)
|
201 |
+
|
202 |
+
self.num_inference_steps = num_inference_steps
|
203 |
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
204 |
+
# creates integer timesteps by multiplying by ratio
|
205 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
206 |
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
207 |
+
self.timesteps = torch.from_numpy(timesteps).to(device)
|
208 |
+
self.timesteps += self.config.steps_offset
|
209 |
+
|
210 |
+
def step(
|
211 |
+
self,
|
212 |
+
model_output: torch.FloatTensor,
|
213 |
+
timestep: int,
|
214 |
+
sample: torch.FloatTensor,
|
215 |
+
eta: float = 0.0,
|
216 |
+
use_clipped_model_output: bool = False,
|
217 |
+
generator=None,
|
218 |
+
variance_noise: Optional[torch.FloatTensor] = None,
|
219 |
+
return_dict: bool = True,
|
220 |
+
reverse=False
|
221 |
+
) -> Union[DDIMSchedulerOutput, Tuple]:
|
222 |
+
|
223 |
+
|
224 |
+
e_t = model_output
|
225 |
+
|
226 |
+
x = sample
|
227 |
+
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
|
228 |
+
# print(timestep, prev_timestep)
|
229 |
+
a_t = alpha_prod_t = self.alphas_cumprod[timestep-1]
|
230 |
+
a_prev = alpha_t_prev = self.alphas_cumprod[prev_timestep-1] if prev_timestep >= 0 else self.final_alpha_cumprod
|
231 |
+
beta_prod_t = 1 - alpha_prod_t
|
232 |
+
|
233 |
+
pred_x0 = (x - (1-a_t)**0.5 * e_t) / a_t.sqrt()
|
234 |
+
# direction pointing to x_t
|
235 |
+
dir_xt = (1. - a_prev).sqrt() * e_t
|
236 |
+
x = a_prev.sqrt()*pred_x0 + dir_xt
|
237 |
+
if not return_dict:
|
238 |
+
return (x,)
|
239 |
+
return DDIMSchedulerOutput(prev_sample=x, pred_original_sample=pred_x0)
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
def add_noise(
|
246 |
+
self,
|
247 |
+
original_samples: torch.FloatTensor,
|
248 |
+
noise: torch.FloatTensor,
|
249 |
+
timesteps: torch.IntTensor,
|
250 |
+
) -> torch.FloatTensor:
|
251 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
252 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
253 |
+
timesteps = timesteps.to(original_samples.device)
|
254 |
+
|
255 |
+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
256 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
257 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
258 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
259 |
+
|
260 |
+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
261 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
262 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
263 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
264 |
+
|
265 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
266 |
+
return noisy_samples
|
267 |
+
|
268 |
+
def get_velocity(
|
269 |
+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
270 |
+
) -> torch.FloatTensor:
|
271 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
272 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
273 |
+
timesteps = timesteps.to(sample.device)
|
274 |
+
|
275 |
+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
276 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
277 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
278 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
279 |
+
|
280 |
+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
281 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
282 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
283 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
284 |
+
|
285 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
286 |
+
return velocity
|
287 |
+
|
288 |
+
def __len__(self):
|
289 |
+
return self.config.num_train_timesteps
|