sky24h commited on
Commit
910b9ab
·
1 Parent(s): e510889

init commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +162 -0
  2. README.md +3 -3
  3. app.py +522 -4
  4. ckpt/put ckpt here.txt +0 -0
  5. configs/stable-diffusion/v2-inference-v.yaml +68 -0
  6. configs/stable-diffusion/v2-inference.yaml +67 -0
  7. configs/stable-diffusion/v2-inpainting-inference.yaml +158 -0
  8. configs/stable-diffusion/v2-midas-inference.yaml +74 -0
  9. configs/stable-diffusion/x4-upscaling.yaml +76 -0
  10. gradio/background/bg03.png +0 -0
  11. gradio/background/bg36.png +0 -0
  12. gradio/background/bg52.png +0 -0
  13. gradio/background/bg58.png +0 -0
  14. gradio/background/bg62.png +0 -0
  15. gradio/foreground/fg10_63d22a7f1f5b66e8e5ac28f7.jpg +0 -0
  16. gradio/foreground/fg50_63d22c871f5b66e8e5ac95e1.jpg +0 -0
  17. gradio/foreground/fg88_63d9d508b82cf5cb1db01976.jpg +0 -0
  18. gradio/foreground/fg90_63d9d4a0b82cf5cb1db00800.jpg +0 -0
  19. gradio/foreground/fg92_63d9d6c9b82cf5cb1db05fda.jpg +0 -0
  20. gradio/seg_foreground/fg10_mask.jpg +0 -0
  21. gradio/seg_foreground/fg50_mask.png +0 -0
  22. gradio/seg_foreground/fg88_mask.png +0 -0
  23. gradio/seg_foreground/fg90_mask.png +0 -0
  24. gradio/seg_foreground/fg92_mask.png +0 -0
  25. ldm/data/__init__.py +0 -0
  26. ldm/data/util.py +24 -0
  27. ldm/models/autoencoder.py +219 -0
  28. ldm/models/diffusion/__init__.py +0 -0
  29. ldm/models/diffusion/ddim.py +403 -0
  30. ldm/models/diffusion/ddpm.py +1796 -0
  31. ldm/models/diffusion/dpm_solver/__init__.py +1 -0
  32. ldm/models/diffusion/dpm_solver/dpm_solver.py +1194 -0
  33. ldm/models/diffusion/dpm_solver/sampler.py +252 -0
  34. ldm/models/diffusion/plms.py +244 -0
  35. ldm/models/diffusion/sampling_util.py +22 -0
  36. ldm/modules/attention.py +377 -0
  37. ldm/modules/diffusionmodules/__init__.py +0 -0
  38. ldm/modules/diffusionmodules/model.py +852 -0
  39. ldm/modules/diffusionmodules/openaimodel.py +803 -0
  40. ldm/modules/diffusionmodules/upscaling.py +81 -0
  41. ldm/modules/diffusionmodules/util.py +273 -0
  42. ldm/modules/distributions/__init__.py +0 -0
  43. ldm/modules/distributions/distributions.py +92 -0
  44. ldm/modules/ema.py +80 -0
  45. ldm/modules/encoders/__init__.py +0 -0
  46. ldm/modules/encoders/modules.py +221 -0
  47. ldm/modules/image_degradation/__init__.py +2 -0
  48. ldm/modules/image_degradation/bsrgan.py +730 -0
  49. ldm/modules/image_degradation/bsrgan_light.py +651 -0
  50. ldm/modules/image_degradation/utils/test.png +0 -0
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ *.ckpt
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/#use-with-ide
112
+ .pdm.toml
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: TF ICON Unofficial
3
- emoji: 📈
4
  colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
+ title: IF-ICON Unofficial
3
+ emoji: 🦄
4
  colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.39.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -1,7 +1,525 @@
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import PIL
3
+ import cv2
4
+ import time
5
+ import torch
6
+ import numpy as np
7
  import gradio as gr
8
+ from PIL import Image
9
+ from torch import autocast
10
+ from contextlib import nullcontext
11
+ from itertools import islice
12
+ from omegaconf import OmegaConf
13
+ from einops import rearrange, repeat
14
+ from pytorch_lightning import seed_everything
15
 
16
+ from ldm.util import instantiate_from_config
17
+ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
18
+ from gradio_image_annotation import image_annotator
19
 
20
+
21
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
22
+ CONFIG_PATH = "./configs/stable-diffusion/v2-inference.yaml"
23
+ CKPT_PATH = "./ckpt/v2-1_512-ema-pruned.ckpt"
24
+ if not os.path.exists(CKPT_PATH):
25
+ # automatically download the checkpoint if it doesn't exist
26
+ print(f"Checkpoint {CKPT_PATH} not found, downloading from huggingface")
27
+ os.system(f"wget -O {CKPT_PATH} https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt")
28
+ CONFIG = OmegaConf.load(CONFIG_PATH)
29
+
30
+
31
+ def load_img(image, SCALE, pad=False, seg_map=False, target_size=None):
32
+ if seg_map:
33
+ # Load the input image and segmentation map
34
+ # image = Image.open(path).convert("RGB")
35
+ # seg_map = Image.open(seg).convert("1")
36
+
37
+ seg_map = seg_map.convert("1")
38
+ # Get the width and height of the original image
39
+ w, h = image.size
40
+
41
+ # Calculate the aspect ratio of the original image
42
+ aspect_ratio = h / w
43
+
44
+ # Determine the new dimensions for resizing the image while maintaining aspect ratio
45
+ if aspect_ratio > 1:
46
+ new_w = int(SCALE * 256 / aspect_ratio)
47
+ new_h = int(SCALE * 256)
48
+ else:
49
+ new_w = int(SCALE * 256)
50
+ new_h = int(SCALE * 256 * aspect_ratio)
51
+
52
+ # Resize the image and the segmentation map to the new dimensions
53
+ image_resize = image.resize((new_w, new_h))
54
+ segmentation_map_resize = cv2.resize(np.array(seg_map).astype(np.uint8), (new_w, new_h), interpolation=cv2.INTER_NEAREST)
55
+
56
+ # Pad the segmentation map to match the target size
57
+ padded_segmentation_map = np.zeros((target_size[1], target_size[0]))
58
+ start_x = (target_size[1] - segmentation_map_resize.shape[0]) // 2
59
+ start_y = (target_size[0] - segmentation_map_resize.shape[1]) // 2
60
+ padded_segmentation_map[start_x : start_x + segmentation_map_resize.shape[0], start_y : start_y + segmentation_map_resize.shape[1]] = (
61
+ segmentation_map_resize
62
+ )
63
+
64
+ # Create a new RGB image with the target size and place the resized image in the center
65
+ padded_image = Image.new("RGB", target_size)
66
+ start_x = (target_size[0] - image_resize.width) // 2
67
+ start_y = (target_size[1] - image_resize.height) // 2
68
+ padded_image.paste(image_resize, (start_x, start_y))
69
+
70
+ # Update the variable "image" to contain the final padded image
71
+ image = padded_image
72
+ else:
73
+ # image = Image.open(path).convert("RGB")
74
+ w, h = image.size
75
+ # print(f"loaded input image of size ({w}, {h}) from {path}")
76
+ w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
77
+ w = h = 512
78
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
79
+
80
+ image = np.array(image).astype(np.float32) / 255.0
81
+ image = image[None].transpose(0, 3, 1, 2)
82
+ image = torch.from_numpy(image)
83
+
84
+ if pad or seg_map:
85
+ return 2.0 * image - 1.0, new_w, new_h, padded_segmentation_map
86
+
87
+ return 2.0 * image - 1.0, w, h
88
+
89
+
90
+ def load_model_and_get_prompt_embedding(model, scale, device, prompts, inv=False):
91
+ if inv:
92
+ inv_emb = model.get_learned_conditioning(prompts, inv)
93
+ c = uc = inv_emb
94
+ else:
95
+ inv_emb = None
96
+
97
+ if scale != 1.0:
98
+ uc = model.get_learned_conditioning([""])
99
+ else:
100
+ uc = None
101
+ c = model.get_learned_conditioning(prompts)
102
+
103
+ return c, uc, inv_emb
104
+
105
+
106
+ def chunk(it, size):
107
+ it = iter(it)
108
+ return iter(lambda: tuple(islice(it, size)), ())
109
+
110
+
111
+ def load_model_from_config(config, ckpt, gpu, verbose=False):
112
+ print(f"Loading model from {ckpt}")
113
+ pl_sd = torch.load(ckpt, map_location=gpu)
114
+ if "global_step" in pl_sd:
115
+ print(f"Global Step: {pl_sd['global_step']}")
116
+ sd = pl_sd["state_dict"]
117
+ model = instantiate_from_config(config.model)
118
+ m, u = model.load_state_dict(sd, strict=False)
119
+ if len(m) > 0 and verbose:
120
+ print("missing keys:")
121
+ print(m)
122
+ if len(u) > 0 and verbose:
123
+ print("unexpected keys:")
124
+ print(u)
125
+
126
+ model.eval()
127
+ return model
128
+
129
+
130
+ MODEL = load_model_from_config(CONFIG, CKPT_PATH, DEVICE)
131
+ MODEL.to(device=DEVICE)
132
+
133
+
134
+ # @spaces.GPU(duration=60)
135
+ def tficon(img_with_mask, ref_img, seg, prompt, dpm_order, dpm_steps, tau_a, tau_b, domain, seed, scale):
136
+ init_img = img_with_mask["image"]
137
+ n_samples = 1
138
+ precision = "autocast"
139
+ ddim_eta = 0.0
140
+ dpm_order = int(dpm_order[0])
141
+
142
+ scale = scale
143
+
144
+ device = DEVICE
145
+ model = MODEL
146
+ batch_size = n_samples
147
+ sampler = DPMSolverSampler(model)
148
+
149
+ seed_everything(seed)
150
+ # img = cv2.imread(mask, 0)
151
+ # # Threshold the image to create binary image
152
+ # _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
153
+ # # Find the contours of the white region in the image
154
+ # contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
155
+ # # Find the bounding rectangle of the largest contour
156
+ # x, y, new_w, new_h = cv2.boundingRect(contours[0])
157
+ # Calculate the center of the rectangle
158
+
159
+ bbox = img_with_mask["boxes"][0]
160
+ x = bbox["xmin"]
161
+ y = bbox["ymin"]
162
+ new_w = bbox["xmax"] - bbox["xmin"]
163
+ new_h = bbox["ymax"] - bbox["ymin"]
164
+
165
+ center_x = x + new_w / 2
166
+ center_y = y + new_h / 2
167
+ # Calculate the percentage from the top and left
168
+ center_row_from_top = round(center_y / 512, 2)
169
+ center_col_from_left = round(center_x / 512, 2)
170
+
171
+ aspect_ratio = new_h / new_w
172
+
173
+ if aspect_ratio > 1:
174
+ mask_scale = new_w * aspect_ratio / 256
175
+ mask_scale = new_h / 256
176
+ else:
177
+ mask_scale = new_w / 256
178
+ mask_scale = new_h / (aspect_ratio * 256)
179
+
180
+ # mask_scale = round(mask_scale, 2)
181
+
182
+ # =============================================================================================
183
+
184
+ data = [batch_size * [prompt]]
185
+ # read background image
186
+ init_image, target_width, target_height = load_img(init_img, mask_scale)
187
+ init_image = repeat(init_image.to(device), "1 ... -> b ...", b=batch_size)
188
+ save_image = init_image.clone()
189
+
190
+ # read foreground image and its segmentation map
191
+ ref_image, width, height, segmentation_map = load_img(ref_img, mask_scale, seg_map=seg, target_size=(target_width, target_height))
192
+ ref_image = repeat(ref_image.to(device), "1 ... -> b ...", b=batch_size)
193
+
194
+ segmentation_map_orig = repeat(torch.tensor(segmentation_map)[None, None, ...].to(device), "1 1 ... -> b 4 ...", b=batch_size)
195
+ segmentation_map_save = repeat(torch.tensor(segmentation_map)[None, None, ...].to(device), "1 1 ... -> b 3 ...", b=batch_size)
196
+ segmentation_map = segmentation_map_orig[:, :, ::8, ::8].to(device)
197
+
198
+ top_rr = int((0.5 * (target_height - height)) / target_height * init_image.shape[2]) # xx% from the top
199
+ bottom_rr = int((0.5 * (target_height + height)) / target_height * init_image.shape[2])
200
+ left_rr = int((0.5 * (target_width - width)) / target_width * init_image.shape[3]) # xx% from the left
201
+ right_rr = int((0.5 * (target_width + width)) / target_width * init_image.shape[3])
202
+
203
+ center_row_rm = int(center_row_from_top * target_height)
204
+ center_col_rm = int(center_col_from_left * target_width)
205
+
206
+ step_height2, remainder = divmod(height, 2)
207
+ step_height1 = step_height2 + remainder
208
+ step_width2, remainder = divmod(width, 2)
209
+ step_width1 = step_width2 + remainder
210
+
211
+ # compositing in pixel space for same-domain composition
212
+ save_image[:, :, center_row_rm - step_height1 : center_row_rm + step_height2, center_col_rm - step_width1 : center_col_rm + step_width2] = (
213
+ save_image[
214
+ :, :, center_row_rm - step_height1 : center_row_rm + step_height2, center_col_rm - step_width1 : center_col_rm + step_width2
215
+ ].clone()
216
+ * (1 - segmentation_map_save[:, :, top_rr:bottom_rr, left_rr:right_rr])
217
+ + ref_image[:, :, top_rr:bottom_rr, left_rr:right_rr].clone() * segmentation_map_save[:, :, top_rr:bottom_rr, left_rr:right_rr]
218
+ )
219
+
220
+ # save the mask and the pixel space composited image
221
+ save_mask = torch.zeros_like(init_image)
222
+ save_mask[:, :, center_row_rm - step_height1 : center_row_rm + step_height2, center_col_rm - step_width1 : center_col_rm + step_width2] = 1
223
+
224
+ # image = Image.fromarray(((save_image/torch.max(save_image.max(), abs(save_image.min())) + 1) * 127.5)[0].permute(1,2,0).to(dtype=torch.uint8).cpu().numpy())
225
+ precision_scope = autocast if precision == "autocast" else nullcontext
226
+
227
+ # image composition
228
+ with torch.no_grad():
229
+ with precision_scope("cuda"):
230
+ for prompts in data:
231
+ print(prompts)
232
+ c, uc, inv_emb = load_model_and_get_prompt_embedding(model, scale, device, prompts, inv=True)
233
+
234
+ if domain == "Real Domain": # same domain
235
+ init_image = save_image
236
+
237
+ T1 = time.time()
238
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))
239
+
240
+ # ref's location in ref image in the latent space
241
+ top_rr = int((0.5 * (target_height - height)) / target_height * init_latent.shape[2])
242
+ bottom_rr = int((0.5 * (target_height + height)) / target_height * init_latent.shape[2])
243
+ left_rr = int((0.5 * (target_width - width)) / target_width * init_latent.shape[3])
244
+ right_rr = int((0.5 * (target_width + width)) / target_width * init_latent.shape[3])
245
+
246
+ new_height = bottom_rr - top_rr
247
+ new_width = right_rr - left_rr
248
+
249
+ step_height2, remainder = divmod(new_height, 2)
250
+ step_height1 = step_height2 + remainder
251
+ step_width2, remainder = divmod(new_width, 2)
252
+ step_width1 = step_width2 + remainder
253
+
254
+ center_row_rm = int(center_row_from_top * init_latent.shape[2])
255
+ center_col_rm = int(center_col_from_left * init_latent.shape[3])
256
+
257
+ param = [
258
+ max(0, int(center_row_rm - step_height1)),
259
+ min(init_latent.shape[2] - 1, int(center_row_rm + step_height2)),
260
+ max(0, int(center_col_rm - step_width1)),
261
+ min(init_latent.shape[3] - 1, int(center_col_rm + step_width2)),
262
+ ]
263
+
264
+ ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref_image))
265
+
266
+ shape = [init_latent.shape[1], init_latent.shape[2], init_latent.shape[3]]
267
+ z_enc, _ = sampler.sample(
268
+ steps = dpm_steps,
269
+ inv_emb = inv_emb,
270
+ unconditional_conditioning = uc,
271
+ conditioning = c,
272
+ batch_size = n_samples,
273
+ shape = shape,
274
+ verbose = False,
275
+ unconditional_guidance_scale = scale,
276
+ eta = ddim_eta,
277
+ order = dpm_order,
278
+ x_T = init_latent,
279
+ width = width,
280
+ height = height,
281
+ DPMencode = True,
282
+ )
283
+
284
+ z_ref_enc, _ = sampler.sample(
285
+ steps = dpm_steps,
286
+ inv_emb = inv_emb,
287
+ unconditional_conditioning = uc,
288
+ conditioning = c,
289
+ batch_size = n_samples,
290
+ shape = shape,
291
+ verbose = False,
292
+ unconditional_guidance_scale = scale,
293
+ eta = ddim_eta,
294
+ order = dpm_order,
295
+ x_T = ref_latent,
296
+ DPMencode = True,
297
+ width = width,
298
+ height = height,
299
+ ref = True,
300
+ )
301
+
302
+ samples_orig = z_enc.clone()
303
+
304
+ # inpainting in XOR region of M_seg and M_mask
305
+ z_enc[:, :, param[0] : param[1], param[2] : param[3]] = z_enc[
306
+ :, :, param[0] : param[1], param[2] : param[3]
307
+ ] * segmentation_map[:, :, top_rr:bottom_rr, left_rr:right_rr] + torch.randn(
308
+ (1, 4, bottom_rr - top_rr, right_rr - left_rr), device=device
309
+ ) * (1 - segmentation_map[:, :, top_rr:bottom_rr, left_rr:right_rr])
310
+
311
+ samples_for_cross = samples_orig.clone()
312
+ samples_ref = z_ref_enc.clone()
313
+ samples = z_enc.clone()
314
+
315
+ # noise composition
316
+ if domain == "Cross Domain":
317
+ samples[:, :, param[0] : param[1], param[2] : param[3]] = torch.randn(
318
+ (1, 4, bottom_rr - top_rr, right_rr - left_rr), device=device
319
+ )
320
+ # apply the segmentation mask on the noise
321
+ samples[:, :, param[0] : param[1], param[2] : param[3]] = (
322
+ samples[:, :, param[0] : param[1], param[2] : param[3]].clone()
323
+ * (1 - segmentation_map[:, :, top_rr:bottom_rr, left_rr:right_rr])
324
+ + z_ref_enc[:, :, top_rr:bottom_rr, left_rr:right_rr].clone()
325
+ * segmentation_map[:, :, top_rr:bottom_rr, left_rr:right_rr]
326
+ )
327
+
328
+ mask = torch.zeros_like(z_enc, device=device)
329
+ mask[:, :, param[0] : param[1], param[2] : param[3]] = 1
330
+
331
+ samples, _ = sampler.sample(
332
+ steps = dpm_steps,
333
+ inv_emb = inv_emb,
334
+ conditioning = c,
335
+ batch_size = n_samples,
336
+ shape = shape,
337
+ verbose = False,
338
+ unconditional_guidance_scale = scale,
339
+ unconditional_conditioning = uc,
340
+ eta = ddim_eta,
341
+ order = dpm_order,
342
+ x_T = [samples_orig, samples.clone(), samples_for_cross, samples_ref, samples, init_latent],
343
+ width = width,
344
+ height = height,
345
+ segmentation_map = segmentation_map,
346
+ param = param,
347
+ mask = mask,
348
+ target_height = target_height,
349
+ target_width = target_width,
350
+ center_row_rm = center_row_from_top,
351
+ center_col_rm = center_col_from_left,
352
+ tau_a = tau_a,
353
+ tau_b = tau_b,
354
+ )
355
+
356
+ x_samples = model.decode_first_stage(samples)
357
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
358
+
359
+ T2 = time.time()
360
+ print("Running Time: %s s" % (T2 - T1))
361
+
362
+ for x_sample in x_samples:
363
+ x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
364
+ img = Image.fromarray(x_sample.astype(np.uint8))
365
+ # img.save(os.path.join(sample_path, f"{base_count:05}_{prompts[0]}.png"))
366
+ return img
367
+
368
+
369
+ def read_content(file_path: str) -> str:
370
+ """read the content of target file"""
371
+ with open(file_path, "r", encoding="utf-8") as f:
372
+ content = f.read()
373
+
374
+ return content
375
+
376
+
377
+ example = {}
378
+ ref_dir = "./gradio/foreground"
379
+ image_dir = "./gradio/background"
380
+ seg_dir = "./gradio/seg_foreground"
381
+ image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
382
+ image_list.sort()
383
+
384
+ ref_list = [os.path.join(ref_dir, file) for file in os.listdir(ref_dir)]
385
+ ref_list.sort()
386
+ seg_list = [os.path.join(seg_dir, file) for file in os.listdir(seg_dir)]
387
+ seg_list.sort()
388
+ reference_list = [[ref_img, ref_mask] for ref_img, ref_mask in zip(ref_list, seg_list)]
389
+
390
+ image_list = [
391
+ {
392
+ "image": image,
393
+ "boxes": [
394
+ {
395
+ "xmin" : 128,
396
+ "ymin" : 128,
397
+ "xmax" : 384,
398
+ "ymax" : 384,
399
+ "label": "Mask",
400
+ "color": (250, 0, 0),
401
+ }
402
+ ],
403
+ }
404
+ for image in image_list
405
+ ]
406
+
407
+
408
+ def update_mask(image):
409
+ print("update mask")
410
+ bbox = image["boxes"][0]
411
+ label = image["boxes"][0]["label"]
412
+ xmin = bbox["xmin"]
413
+ ymin = bbox["ymin"]
414
+ xmax = bbox["xmax"]
415
+ ymax = bbox["ymax"]
416
+ coords = [xmin, ymin, xmax, ymax]
417
+ return (image["image"], [(coords, label)])
418
+
419
+
420
+ if __name__ == "__main__":
421
+ with gr.Blocks() as demo:
422
+ gr.HTML(
423
+ """
424
+ <h1 style="text-align: center; font-size: 32px; font-family: 'Times New Roman', Times, serif;">
425
+ 🦄TF-ICON: Diffusion-Based Training-Free Cross-Domain Image Composition
426
+ </h1>
427
+ <p style="text-align: center; font-size: 20px; font-family: 'Times New Roman', Times, serif;">
428
+ <a style="text-align: center; display:inline-block"
429
+ href="https://shilin-lu.github.io/tf-icon.github.io/">
430
+ <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/paper-page-sm.svg#center"
431
+ alt="Paper Page">
432
+ </a>
433
+ <a style="text-align: center; display:inline-block" href="https://huggingface.co/spaces/sky24h/TF-ICON-unofficial?duplicate=true">
434
+ <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
435
+ </a>
436
+ </p>
437
+ This is an unofficial demo for the paper 'TF-ICON: Diffusion-Based Training-Free Cross-Domain Image Composition'.
438
+ </p>
439
+ """
440
+ )
441
+ with gr.Row():
442
+ with gr.Column():
443
+ # back_image_invisible = gr.Image(elem_id="image_upload", type="pil", label="Background Image", height=512, visible=False)
444
+ image = image_annotator(
445
+ None,
446
+ label_list=["Mask"],
447
+ label_colors=[(255, 0, 0)],
448
+ height=512,
449
+ image_type="pil",
450
+ )
451
+ # back_image_invisible.change(fn=set_image, inputs=[back_image_invisible, image])
452
+
453
+ mask_btn = gr.Button("Generate Mask")
454
+ reference = gr.Image(elem_id="image_upload", type="pil", label="Foreground Image", height=512)
455
+
456
+ with gr.Row():
457
+ # guidance = gr.Slider(label="Guidance scale", value=5, maximum=15,interactive=True)
458
+ steps = gr.Slider(label="Steps", value=50, minimum=2, maximum=75, step=1, interactive=True)
459
+ seed = gr.Slider(0, 10000, label="Seed (0 = random)", value=3407, step=1)
460
+
461
+ with gr.Row():
462
+ tau_a = gr.Slider(
463
+ label="tau_a",
464
+ value=0.4,
465
+ minimum=0.0,
466
+ maximum=1.0,
467
+ step=0.1,
468
+ interactive=True,
469
+ info="Foreground Attention Injection",
470
+ )
471
+ tau_b = gr.Slider(
472
+ label="tau_b", value=0.8, minimum=0.0, maximum=1.0, step=0.1, interactive=True, info="Background Preservation"
473
+ )
474
+
475
+ with gr.Row():
476
+ scale = gr.Slider(
477
+ label="CFG",
478
+ value=2.5,
479
+ minimum=0.0,
480
+ maximum=15.0,
481
+ step=0.5,
482
+ interactive=True,
483
+ info="CFG=2.5 for real domain CFG>=5.0 for cross domain",
484
+ )
485
+ dpm_order = gr.CheckboxGroup(["1", "2", "3"], value="2", label="DPM Solver Order")
486
+
487
+ domain = gr.Radio(
488
+ ["Cross Domain", "Real Domain"],
489
+ value="Real Domain",
490
+ label="Domain",
491
+ info="When background is real image, choose Real Domain; otherwise, choose Cross Domain",
492
+ )
493
+ prompt = gr.Textbox(label="Prompt", info="an oil painting (or a pencil drawing) of a panda") # .style(height=512)
494
+
495
+ btn = gr.Button("Run!") #
496
+
497
+ with gr.Column():
498
+ mask = gr.AnnotatedImage(
499
+ label="Composition Region",
500
+ # info="Setting mask for composition region: first click for the top left corner, second click for the bottom right corner",
501
+ color_map={"Region for Composing Object": "#9987FF", "Click Second Point for Mask": "#f44336"},
502
+ height=512,
503
+ )
504
+
505
+ mask_btn.click(fn=update_mask, inputs=[image], outputs=[mask])
506
+ # image.select(get_select_coordinates, image, mask)
507
+
508
+ seg = gr.Image(elem_id="image_upload", type="pil", label="Segmentation Mask for Foreground", height=512)
509
+
510
+ image_out = gr.Image(label="Output", elem_id="output-img", height=512)
511
+
512
+ # with gr.Group(elem_id="share-btn-container"):
513
+ # community_icon = gr.HTML(community_icon_html, visible=True)
514
+ # loading_icon = gr.HTML(loading_icon_html, visible=True)
515
+ # share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
516
+
517
+ with gr.Row():
518
+ with gr.Column():
519
+ gr.Examples(image_list, inputs=[image], label="Examples - Background Image", examples_per_page=12)
520
+ with gr.Column():
521
+ gr.Examples(reference_list, inputs=[reference, seg], label="Examples - Foreground Image", examples_per_page=12)
522
+
523
+ btn.click(fn=tficon, inputs=[image, reference, seg, prompt, dpm_order, steps, tau_a, tau_b, domain, seed, scale], outputs=[image_out])
524
+
525
+ demo.queue(max_size=10).launch()
ckpt/put ckpt here.txt ADDED
File without changes
configs/stable-diffusion/v2-inference-v.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ parameterization: "v"
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False # we set this to false because this is an inference only config
20
+
21
+ unet_config:
22
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23
+ params:
24
+ use_checkpoint: True
25
+ use_fp16: True
26
+ image_size: 32 # unused
27
+ in_channels: 4
28
+ out_channels: 4
29
+ model_channels: 320
30
+ attention_resolutions: [ 4, 2, 1 ]
31
+ num_res_blocks: 2
32
+ channel_mult: [ 1, 2, 4, 4 ]
33
+ num_head_channels: 64 # need to fix for flash-attn
34
+ use_spatial_transformer: True
35
+ use_linear_in_transformer: True
36
+ transformer_depth: 1
37
+ context_dim: 1024
38
+ legacy: False
39
+
40
+ first_stage_config:
41
+ target: ldm.models.autoencoder.AutoencoderKL
42
+ params:
43
+ embed_dim: 4
44
+ monitor: val/rec_loss
45
+ ddconfig:
46
+ #attn_type: "vanilla-xformers"
47
+ double_z: true
48
+ z_channels: 4
49
+ resolution: 256
50
+ in_channels: 3
51
+ out_ch: 3
52
+ ch: 128
53
+ ch_mult:
54
+ - 1
55
+ - 2
56
+ - 4
57
+ - 4
58
+ num_res_blocks: 2
59
+ attn_resolutions: []
60
+ dropout: 0.0
61
+ lossconfig:
62
+ target: torch.nn.Identity
63
+
64
+ cond_stage_config:
65
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
66
+ params:
67
+ freeze: True
68
+ layer: "penultimate"
configs/stable-diffusion/v2-inference.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False # we set this to false because this is an inference only config
19
+
20
+ unet_config:
21
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22
+ params:
23
+ use_checkpoint: True
24
+ use_fp16: True
25
+ image_size: 32 # unused
26
+ in_channels: 4
27
+ out_channels: 4
28
+ model_channels: 320
29
+ attention_resolutions: [ 4, 2, 1 ]
30
+ num_res_blocks: 2
31
+ channel_mult: [ 1, 2, 4, 4 ]
32
+ num_head_channels: 64 # need to fix for flash-attn
33
+ use_spatial_transformer: True
34
+ use_linear_in_transformer: True
35
+ transformer_depth: 1
36
+ context_dim: 1024
37
+ legacy: False
38
+
39
+ first_stage_config:
40
+ target: ldm.models.autoencoder.AutoencoderKL
41
+ params:
42
+ embed_dim: 4
43
+ monitor: val/rec_loss
44
+ ddconfig:
45
+ #attn_type: "vanilla-xformers"
46
+ double_z: true
47
+ z_channels: 4
48
+ resolution: 256
49
+ in_channels: 3
50
+ out_ch: 3
51
+ ch: 128
52
+ ch_mult:
53
+ - 1
54
+ - 2
55
+ - 4
56
+ - 4
57
+ num_res_blocks: 2
58
+ attn_resolutions: []
59
+ dropout: 0.0
60
+ lossconfig:
61
+ target: torch.nn.Identity
62
+
63
+ cond_stage_config:
64
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
65
+ params:
66
+ freeze: True
67
+ layer: "penultimate"
configs/stable-diffusion/v2-inpainting-inference.yaml ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-05
3
+ target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false
15
+ conditioning_key: hybrid
16
+ scale_factor: 0.18215
17
+ monitor: val/loss_simple_ema
18
+ finetune_keys: null
19
+ use_ema: False
20
+
21
+ unet_config:
22
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23
+ params:
24
+ use_checkpoint: True
25
+ image_size: 32 # unused
26
+ in_channels: 9
27
+ out_channels: 4
28
+ model_channels: 320
29
+ attention_resolutions: [ 4, 2, 1 ]
30
+ num_res_blocks: 2
31
+ channel_mult: [ 1, 2, 4, 4 ]
32
+ num_head_channels: 64 # need to fix for flash-attn
33
+ use_spatial_transformer: True
34
+ use_linear_in_transformer: True
35
+ transformer_depth: 1
36
+ context_dim: 1024
37
+ legacy: False
38
+
39
+ first_stage_config:
40
+ target: ldm.models.autoencoder.AutoencoderKL
41
+ params:
42
+ embed_dim: 4
43
+ monitor: val/rec_loss
44
+ ddconfig:
45
+ #attn_type: "vanilla-xformers"
46
+ double_z: true
47
+ z_channels: 4
48
+ resolution: 256
49
+ in_channels: 3
50
+ out_ch: 3
51
+ ch: 128
52
+ ch_mult:
53
+ - 1
54
+ - 2
55
+ - 4
56
+ - 4
57
+ num_res_blocks: 2
58
+ attn_resolutions: [ ]
59
+ dropout: 0.0
60
+ lossconfig:
61
+ target: torch.nn.Identity
62
+
63
+ cond_stage_config:
64
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
65
+ params:
66
+ freeze: True
67
+ layer: "penultimate"
68
+
69
+
70
+ data:
71
+ target: ldm.data.laion.WebDataModuleFromConfig
72
+ params:
73
+ tar_base: null # for concat as in LAION-A
74
+ p_unsafe_threshold: 0.1
75
+ filter_word_list: "data/filters.yaml"
76
+ max_pwatermark: 0.45
77
+ batch_size: 8
78
+ num_workers: 6
79
+ multinode: True
80
+ min_size: 512
81
+ train:
82
+ shards:
83
+ - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
84
+ - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
85
+ - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
86
+ - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
87
+ - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
88
+ shuffle: 10000
89
+ image_key: jpg
90
+ image_transforms:
91
+ - target: torchvision.transforms.Resize
92
+ params:
93
+ size: 512
94
+ interpolation: 3
95
+ - target: torchvision.transforms.RandomCrop
96
+ params:
97
+ size: 512
98
+ postprocess:
99
+ target: ldm.data.laion.AddMask
100
+ params:
101
+ mode: "512train-large"
102
+ p_drop: 0.25
103
+ # NOTE use enough shards to avoid empty validation loops in workers
104
+ validation:
105
+ shards:
106
+ - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
107
+ shuffle: 0
108
+ image_key: jpg
109
+ image_transforms:
110
+ - target: torchvision.transforms.Resize
111
+ params:
112
+ size: 512
113
+ interpolation: 3
114
+ - target: torchvision.transforms.CenterCrop
115
+ params:
116
+ size: 512
117
+ postprocess:
118
+ target: ldm.data.laion.AddMask
119
+ params:
120
+ mode: "512train-large"
121
+ p_drop: 0.25
122
+
123
+ lightning:
124
+ find_unused_parameters: True
125
+ modelcheckpoint:
126
+ params:
127
+ every_n_train_steps: 5000
128
+
129
+ callbacks:
130
+ metrics_over_trainsteps_checkpoint:
131
+ params:
132
+ every_n_train_steps: 10000
133
+
134
+ image_logger:
135
+ target: main.ImageLogger
136
+ params:
137
+ enable_autocast: False
138
+ disabled: False
139
+ batch_frequency: 1000
140
+ max_images: 4
141
+ increase_log_steps: False
142
+ log_first_step: False
143
+ log_images_kwargs:
144
+ use_ema_scope: False
145
+ inpaint: False
146
+ plot_progressive_rows: False
147
+ plot_diffusion_rows: False
148
+ N: 4
149
+ unconditional_guidance_scale: 5.0
150
+ unconditional_guidance_label: [""]
151
+ ddim_steps: 50 # todo check these out for depth2img,
152
+ ddim_eta: 0.0 # todo check these out for depth2img,
153
+
154
+ trainer:
155
+ benchmark: True
156
+ val_check_interval: 5000000
157
+ num_sanity_val_steps: 0
158
+ accumulate_grad_batches: 1
configs/stable-diffusion/v2-midas-inference.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-07
3
+ target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false
15
+ conditioning_key: hybrid
16
+ scale_factor: 0.18215
17
+ monitor: val/loss_simple_ema
18
+ finetune_keys: null
19
+ use_ema: False
20
+
21
+ depth_stage_config:
22
+ target: ldm.modules.midas.api.MiDaSInference
23
+ params:
24
+ model_type: "dpt_hybrid"
25
+
26
+ unet_config:
27
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
28
+ params:
29
+ use_checkpoint: True
30
+ image_size: 32 # unused
31
+ in_channels: 5
32
+ out_channels: 4
33
+ model_channels: 320
34
+ attention_resolutions: [ 4, 2, 1 ]
35
+ num_res_blocks: 2
36
+ channel_mult: [ 1, 2, 4, 4 ]
37
+ num_head_channels: 64 # need to fix for flash-attn
38
+ use_spatial_transformer: True
39
+ use_linear_in_transformer: True
40
+ transformer_depth: 1
41
+ context_dim: 1024
42
+ legacy: False
43
+
44
+ first_stage_config:
45
+ target: ldm.models.autoencoder.AutoencoderKL
46
+ params:
47
+ embed_dim: 4
48
+ monitor: val/rec_loss
49
+ ddconfig:
50
+ #attn_type: "vanilla-xformers"
51
+ double_z: true
52
+ z_channels: 4
53
+ resolution: 256
54
+ in_channels: 3
55
+ out_ch: 3
56
+ ch: 128
57
+ ch_mult:
58
+ - 1
59
+ - 2
60
+ - 4
61
+ - 4
62
+ num_res_blocks: 2
63
+ attn_resolutions: [ ]
64
+ dropout: 0.0
65
+ lossconfig:
66
+ target: torch.nn.Identity
67
+
68
+ cond_stage_config:
69
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
70
+ params:
71
+ freeze: True
72
+ layer: "penultimate"
73
+
74
+
configs/stable-diffusion/x4-upscaling.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
4
+ params:
5
+ parameterization: "v"
6
+ low_scale_key: "lr"
7
+ linear_start: 0.0001
8
+ linear_end: 0.02
9
+ num_timesteps_cond: 1
10
+ log_every_t: 200
11
+ timesteps: 1000
12
+ first_stage_key: "jpg"
13
+ cond_stage_key: "txt"
14
+ image_size: 128
15
+ channels: 4
16
+ cond_stage_trainable: false
17
+ conditioning_key: "hybrid-adm"
18
+ monitor: val/loss_simple_ema
19
+ scale_factor: 0.08333
20
+ use_ema: False
21
+
22
+ low_scale_config:
23
+ target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
24
+ params:
25
+ noise_schedule_config: # image space
26
+ linear_start: 0.0001
27
+ linear_end: 0.02
28
+ max_noise_level: 350
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ use_checkpoint: True
34
+ num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
35
+ image_size: 128
36
+ in_channels: 7
37
+ out_channels: 4
38
+ model_channels: 256
39
+ attention_resolutions: [ 2,4,8]
40
+ num_res_blocks: 2
41
+ channel_mult: [ 1, 2, 2, 4]
42
+ disable_self_attentions: [True, True, True, False]
43
+ disable_middle_self_attn: False
44
+ num_heads: 8
45
+ use_spatial_transformer: True
46
+ transformer_depth: 1
47
+ context_dim: 1024
48
+ legacy: False
49
+ use_linear_in_transformer: True
50
+
51
+ first_stage_config:
52
+ target: ldm.models.autoencoder.AutoencoderKL
53
+ params:
54
+ embed_dim: 4
55
+ ddconfig:
56
+ # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
57
+ double_z: True
58
+ z_channels: 4
59
+ resolution: 256
60
+ in_channels: 3
61
+ out_ch: 3
62
+ ch: 128
63
+ ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
64
+ num_res_blocks: 2
65
+ attn_resolutions: [ ]
66
+ dropout: 0.0
67
+
68
+ lossconfig:
69
+ target: torch.nn.Identity
70
+
71
+ cond_stage_config:
72
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
73
+ params:
74
+ freeze: True
75
+ layer: "penultimate"
76
+
gradio/background/bg03.png ADDED
gradio/background/bg36.png ADDED
gradio/background/bg52.png ADDED
gradio/background/bg58.png ADDED
gradio/background/bg62.png ADDED
gradio/foreground/fg10_63d22a7f1f5b66e8e5ac28f7.jpg ADDED
gradio/foreground/fg50_63d22c871f5b66e8e5ac95e1.jpg ADDED
gradio/foreground/fg88_63d9d508b82cf5cb1db01976.jpg ADDED
gradio/foreground/fg90_63d9d4a0b82cf5cb1db00800.jpg ADDED
gradio/foreground/fg92_63d9d6c9b82cf5cb1db05fda.jpg ADDED
gradio/seg_foreground/fg10_mask.jpg ADDED
gradio/seg_foreground/fg50_mask.png ADDED
gradio/seg_foreground/fg88_mask.png ADDED
gradio/seg_foreground/fg90_mask.png ADDED
gradio/seg_foreground/fg92_mask.png ADDED
ldm/data/__init__.py ADDED
File without changes
ldm/data/util.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ldm.modules.midas.api import load_midas_transform
4
+
5
+
6
+ class AddMiDaS(object):
7
+ def __init__(self, model_type):
8
+ super().__init__()
9
+ self.transform = load_midas_transform(model_type)
10
+
11
+ def pt2np(self, x):
12
+ x = ((x + 1.0) * .5).detach().cpu().numpy()
13
+ return x
14
+
15
+ def np2pt(self, x):
16
+ x = torch.from_numpy(x) * 2 - 1.
17
+ return x
18
+
19
+ def __call__(self, sample):
20
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
21
+ x = self.pt2np(sample['jpg'])
22
+ x = self.transform({"image": x})["image"]
23
+ sample['midas_in'] = x
24
+ return sample
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
7
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8
+
9
+ from ldm.util import instantiate_from_config
10
+ from ldm.modules.ema import LitEma
11
+
12
+
13
+ class AutoencoderKL(pl.LightningModule):
14
+ def __init__(self,
15
+ ddconfig,
16
+ lossconfig,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ ema_decay=None,
24
+ learn_logvar=False
25
+ ):
26
+ super().__init__()
27
+ self.learn_logvar = learn_logvar
28
+ self.image_key = image_key
29
+ self.encoder = Encoder(**ddconfig)
30
+ self.decoder = Decoder(**ddconfig)
31
+ self.loss = instantiate_from_config(lossconfig)
32
+ assert ddconfig["double_z"]
33
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
+ self.embed_dim = embed_dim
36
+ if colorize_nlabels is not None:
37
+ assert type(colorize_nlabels)==int
38
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ self.use_ema = ema_decay is not None
43
+ if self.use_ema:
44
+ self.ema_decay = ema_decay
45
+ assert 0. < ema_decay < 1.
46
+ self.model_ema = LitEma(self, decay=ema_decay)
47
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48
+
49
+ if ckpt_path is not None:
50
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51
+
52
+ def init_from_ckpt(self, path, ignore_keys=list()):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ keys = list(sd.keys())
55
+ for k in keys:
56
+ for ik in ignore_keys:
57
+ if k.startswith(ik):
58
+ print("Deleting key {} from state_dict.".format(k))
59
+ del sd[k]
60
+ self.load_state_dict(sd, strict=False)
61
+ print(f"Restored from {path}")
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ if self.use_ema:
80
+ self.model_ema(self)
81
+
82
+ def encode(self, x):
83
+ h = self.encoder(x)
84
+ moments = self.quant_conv(h)
85
+ posterior = DiagonalGaussianDistribution(moments)
86
+ return posterior
87
+
88
+ def decode(self, z):
89
+ z = self.post_quant_conv(z)
90
+ dec = self.decoder(z)
91
+ return dec
92
+
93
+ def forward(self, input, sample_posterior=True):
94
+ posterior = self.encode(input)
95
+ if sample_posterior:
96
+ z = posterior.sample()
97
+ else:
98
+ z = posterior.mode()
99
+ dec = self.decode(z)
100
+ return dec, posterior
101
+
102
+ def get_input(self, batch, k):
103
+ x = batch[k]
104
+ if len(x.shape) == 3:
105
+ x = x[..., None]
106
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107
+ return x
108
+
109
+ def training_step(self, batch, batch_idx, optimizer_idx):
110
+ inputs = self.get_input(batch, self.image_key)
111
+ reconstructions, posterior = self(inputs)
112
+
113
+ if optimizer_idx == 0:
114
+ # train encoder+decoder+logvar
115
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
116
+ last_layer=self.get_last_layer(), split="train")
117
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
118
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
119
+ return aeloss
120
+
121
+ if optimizer_idx == 1:
122
+ # train the discriminator
123
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
124
+ last_layer=self.get_last_layer(), split="train")
125
+
126
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128
+ return discloss
129
+
130
+ def validation_step(self, batch, batch_idx):
131
+ log_dict = self._validation_step(batch, batch_idx)
132
+ with self.ema_scope():
133
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
134
+ return log_dict
135
+
136
+ def _validation_step(self, batch, batch_idx, postfix=""):
137
+ inputs = self.get_input(batch, self.image_key)
138
+ reconstructions, posterior = self(inputs)
139
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
140
+ last_layer=self.get_last_layer(), split="val"+postfix)
141
+
142
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
143
+ last_layer=self.get_last_layer(), split="val"+postfix)
144
+
145
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
146
+ self.log_dict(log_dict_ae)
147
+ self.log_dict(log_dict_disc)
148
+ return self.log_dict
149
+
150
+ def configure_optimizers(self):
151
+ lr = self.learning_rate
152
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
153
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
154
+ if self.learn_logvar:
155
+ print(f"{self.__class__.__name__}: Learning logvar")
156
+ ae_params_list.append(self.loss.logvar)
157
+ opt_ae = torch.optim.Adam(ae_params_list,
158
+ lr=lr, betas=(0.5, 0.9))
159
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
160
+ lr=lr, betas=(0.5, 0.9))
161
+ return [opt_ae, opt_disc], []
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.conv_out.weight
165
+
166
+ @torch.no_grad()
167
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
168
+ log = dict()
169
+ x = self.get_input(batch, self.image_key)
170
+ x = x.to(self.device)
171
+ if not only_inputs:
172
+ xrec, posterior = self(x)
173
+ if x.shape[1] > 3:
174
+ # colorize with random projection
175
+ assert xrec.shape[1] > 3
176
+ x = self.to_rgb(x)
177
+ xrec = self.to_rgb(xrec)
178
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
179
+ log["reconstructions"] = xrec
180
+ if log_ema or self.use_ema:
181
+ with self.ema_scope():
182
+ xrec_ema, posterior_ema = self(x)
183
+ if x.shape[1] > 3:
184
+ # colorize with random projection
185
+ assert xrec_ema.shape[1] > 3
186
+ xrec_ema = self.to_rgb(xrec_ema)
187
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
188
+ log["reconstructions_ema"] = xrec_ema
189
+ log["inputs"] = x
190
+ return log
191
+
192
+ def to_rgb(self, x):
193
+ assert self.image_key == "segmentation"
194
+ if not hasattr(self, "colorize"):
195
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
196
+ x = F.conv2d(x, weight=self.colorize)
197
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
198
+ return x
199
+
200
+
201
+ class IdentityFirstStage(torch.nn.Module):
202
+ def __init__(self, *args, vq_interface=False, **kwargs):
203
+ self.vq_interface = vq_interface
204
+ super().__init__()
205
+
206
+ def encode(self, x, *args, **kwargs):
207
+ return x
208
+
209
+ def decode(self, x, *args, **kwargs):
210
+ return x
211
+
212
+ def quantize(self, x, *args, **kwargs):
213
+ if self.vq_interface:
214
+ return x, None, [None, None, None]
215
+ return x
216
+
217
+ def forward(self, x, *args, **kwargs):
218
+ return x
219
+
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda"):
20
+ attr = attr.to(torch.device("cuda"))
21
+ setattr(self, name, attr)
22
+
23
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
+ alphas_cumprod = self.model.alphas_cumprod
27
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
+
30
+ self.register_buffer('betas', to_torch(self.model.betas))
31
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
+
34
+ # calculations for diffusion q(x_t | x_{t-1}) and others
35
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
+
41
+ # ddim sampling parameters
42
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev, ddim_alphas_next = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
+ ddim_timesteps=self.ddim_timesteps,
44
+ eta=ddim_eta,verbose=verbose)
45
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
46
+ self.register_buffer('ddim_alphas', ddim_alphas)
47
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
+ self.register_buffer('ddim_alphas_next', ddim_alphas_next)
49
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
+ return self.ddim_timesteps
54
+
55
+ @torch.no_grad()
56
+ def sample(self,
57
+ S,
58
+ batch_size,
59
+ shape,
60
+ conditioning=None,
61
+ callback=None,
62
+ normals_sequence=None,
63
+ img_callback=None,
64
+ quantize_x0=False,
65
+ eta=0.,
66
+ mask=None,
67
+ x0=None,
68
+ temperature=1.,
69
+ noise_dropout=0.,
70
+ score_corrector=None,
71
+ corrector_kwargs=None,
72
+ verbose=True,
73
+ x_T=None,
74
+ log_every_t=100,
75
+ unconditional_guidance_scale=1.,
76
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
77
+ dynamic_threshold=None,
78
+ ucg_schedule=None,
79
+ encode=False,
80
+ **kwargs
81
+ ):
82
+ if conditioning is not None:
83
+ if isinstance(conditioning, dict):
84
+ ctmp = conditioning[list(conditioning.keys())[0]]
85
+ while isinstance(ctmp, list): ctmp = ctmp[0]
86
+ cbs = ctmp.shape[0]
87
+ if cbs != batch_size:
88
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89
+
90
+ elif isinstance(conditioning, list):
91
+ for ctmp in conditioning:
92
+ if ctmp.shape[0] != batch_size:
93
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
94
+
95
+ else:
96
+ if conditioning.shape[0] != batch_size:
97
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
98
+
99
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
100
+ # sampling
101
+ C, H, W = shape
102
+ size = (batch_size, C, H, W)
103
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
104
+
105
+ samples, intermediates = self.ddim_sampling(conditioning, size,
106
+ callback=callback,
107
+ img_callback=img_callback,
108
+ quantize_denoised=quantize_x0,
109
+ mask=mask, x0=x0,
110
+ ddim_use_original_steps=False,
111
+ noise_dropout=noise_dropout,
112
+ temperature=temperature,
113
+ score_corrector=score_corrector,
114
+ corrector_kwargs=corrector_kwargs,
115
+ x_T=x_T,
116
+ log_every_t=log_every_t,
117
+ unconditional_guidance_scale=unconditional_guidance_scale,
118
+ unconditional_conditioning=unconditional_conditioning,
119
+ dynamic_threshold=dynamic_threshold,
120
+ ucg_schedule=ucg_schedule,
121
+ encode=encode
122
+ )
123
+ return samples, intermediates
124
+
125
+ @torch.no_grad()
126
+ def ddim_sampling(self, cond, shape,
127
+ x_T=None, ddim_use_original_steps=False,
128
+ callback=None, timesteps=None, quantize_denoised=False,
129
+ mask=None, x0=None, img_callback=None, log_every_t=100,
130
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
131
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
132
+ ucg_schedule=None, encode=False):
133
+ device = self.model.betas.device
134
+ b = shape[0]
135
+ if x_T is None:
136
+ img = torch.randn(shape, device=device)
137
+ else:
138
+ img = x_T
139
+
140
+ if timesteps is None:
141
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
142
+ elif timesteps is not None and not ddim_use_original_steps:
143
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
144
+ timesteps = self.ddim_timesteps[:subset_end]
145
+
146
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
147
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
148
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
149
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
150
+
151
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
152
+
153
+ for i, step in enumerate(iterator):
154
+ index = total_steps - i - 1
155
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
156
+
157
+ if mask is not None:
158
+ assert x0 is not None
159
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
160
+ img = img_orig * mask + (1. - mask) * img
161
+
162
+ if ucg_schedule is not None:
163
+ assert len(ucg_schedule) == len(time_range)
164
+ unconditional_guidance_scale = ucg_schedule[i]
165
+
166
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
167
+ quantize_denoised=quantize_denoised, temperature=temperature,
168
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
169
+ corrector_kwargs=corrector_kwargs,
170
+ unconditional_guidance_scale=unconditional_guidance_scale,
171
+ unconditional_conditioning=unconditional_conditioning,
172
+ dynamic_threshold=dynamic_threshold,
173
+ encode=encode)
174
+ img, pred_x0 = outs
175
+ if callback: callback(i)
176
+ if img_callback: img_callback(pred_x0, i)
177
+
178
+ if index % log_every_t == 0 or index == total_steps - 1:
179
+ intermediates['x_inter'].append(img)
180
+ intermediates['pred_x0'].append(pred_x0)
181
+
182
+ return img, intermediates
183
+
184
+ @torch.no_grad()
185
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
186
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.,
187
+ unconditional_conditioning=None, dynamic_threshold=None, encode=False, encode_uncon=False, decode_uncon=False,
188
+ controller=None, inject=False, ref_init=None):
189
+ b, *_, device = *x.shape, x.device
190
+
191
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
192
+ model_output = self.model.apply_model(x, t, c, encode=encode, encode_uncon=encode_uncon, decode_uncon=decode_uncon, controller=controller, inject=inject)
193
+ else:
194
+ if ref_init == None:
195
+ x_in = torch.cat([x] * 2)
196
+ else:
197
+ x_in = torch.cat([x, x, ref_init, ref_init], dim=0)
198
+ unconditional_conditioning = torch.cat([unconditional_conditioning] * 2)
199
+ c = torch.cat([c] * 2)
200
+
201
+ t_in = torch.cat([t] * 2)
202
+ if type(c) == int:
203
+ c_in = None
204
+ else:
205
+ if isinstance(c, dict):
206
+ assert isinstance(unconditional_conditioning, dict)
207
+ c_in = dict()
208
+ for k in c:
209
+ if isinstance(c[k], list):
210
+ c_in[k] = [torch.cat([
211
+ unconditional_conditioning[k][i],
212
+ c[k][i]]) for i in range(len(c[k]))]
213
+ else:
214
+ c_in[k] = torch.cat([
215
+ unconditional_conditioning[k],
216
+ c[k]])
217
+ elif isinstance(c, list):
218
+ c_in = list()
219
+ assert isinstance(unconditional_conditioning, list)
220
+ for i in range(len(c)):
221
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
222
+ else:
223
+ c_in = torch.cat([unconditional_conditioning, c])
224
+
225
+ if ref_init == None:
226
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in, encode=encode, encode_uncon=encode_uncon,
227
+ decode_uncon=decode_uncon, controller=controller, inject=inject).chunk(2)
228
+ else:
229
+ model_uncond, model_t, _, _ = self.model.apply_model(x_in, t_in, c_in, encode=encode, encode_uncon=encode_uncon,
230
+ decode_uncon=decode_uncon, controller=controller, inject=inject).chunk(4)
231
+
232
+ # 如果只改decode的采样,好像会出来什么也没有,但是encode和decode都改就可以平衡。
233
+ if encode_uncon == True and decode_uncon == True:
234
+ model_output = model_uncond
235
+
236
+ elif encode_uncon == True and decode_uncon == False:
237
+ if encode:
238
+ model_output = model_uncond
239
+ else:
240
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
241
+
242
+ elif encode_uncon == False and decode_uncon == False:
243
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
244
+
245
+
246
+ if self.model.parameterization == "v":
247
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
248
+ else:
249
+ e_t = model_output
250
+
251
+ if score_corrector is not None:
252
+ assert self.model.parameterization == "eps", 'not implemented'
253
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
254
+
255
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
256
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
257
+ alphas_next = self.model.alphas_cumprod_next if use_original_steps else self.ddim_alphas_next
258
+
259
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
260
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
261
+ # select parameters corresponding to the currently considered timestep
262
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
263
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
264
+ a_next = torch.full((b, 1, 1, 1), alphas_next[index], device=device)
265
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
266
+ # sigma_t = torch.full((b, 1, 1, 1), 1, device=device)
267
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
268
+
269
+ # current prediction for x_0
270
+ if self.model.parameterization != "v":
271
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
272
+ else:
273
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
274
+
275
+ if quantize_denoised:
276
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
277
+
278
+ if dynamic_threshold is not None:
279
+ raise NotImplementedError()
280
+
281
+ # direction pointing to x_t
282
+ if encode == False:
283
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
284
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
285
+ if noise_dropout > 0.:
286
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
287
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt #+ noise
288
+ else:
289
+ # direction pointing to x_t for forward
290
+ dir_xt = (1. - a_next).sqrt() * e_t
291
+ # noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
292
+ # if noise_dropout > 0.:
293
+ # noise = torch.nn.functional.dropout(noise, p=noise_dropout)
294
+ x_prev = a_next.sqrt() * pred_x0 + dir_xt
295
+
296
+ return x_prev, pred_x0
297
+
298
+ @torch.no_grad()
299
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
300
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
301
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
302
+
303
+ assert t_enc <= num_reference_steps
304
+ num_steps = t_enc
305
+
306
+ if use_original_steps:
307
+ alphas_next = self.alphas_cumprod[:num_steps]
308
+ alphas = self.alphas_cumprod_prev[:num_steps]
309
+ else:
310
+ alphas_next = self.ddim_alphas[:num_steps]
311
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
312
+
313
+ x_next = x0
314
+ intermediates = []
315
+ inter_steps = []
316
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
317
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
318
+ if unconditional_guidance_scale == 1.:
319
+ noise_pred = self.model.apply_model(x_next, t, c)
320
+ else:
321
+ assert unconditional_conditioning is not None
322
+ e_t_uncond, noise_pred = torch.chunk(
323
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
324
+ torch.cat((unconditional_conditioning, c))), 2)
325
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
326
+
327
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
328
+ weighted_noise_pred = alphas_next[i].sqrt() * (
329
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
330
+ x_next = xt_weighted + weighted_noise_pred
331
+ if return_intermediates and i % (
332
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
333
+ intermediates.append(x_next)
334
+ inter_steps.append(i)
335
+ elif return_intermediates and i >= num_steps - 2:
336
+ intermediates.append(x_next)
337
+ inter_steps.append(i)
338
+ if callback: callback(i)
339
+
340
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
341
+ if return_intermediates:
342
+ out.update({'intermediates': intermediates})
343
+ return x_next, out
344
+
345
+ @torch.no_grad()
346
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
347
+ # fast, but does not allow for exact reconstruction
348
+ # t serves as an index to gather the correct alphas
349
+ if use_original_steps:
350
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
351
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
352
+ else:
353
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
354
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
355
+
356
+ if noise is None:
357
+ noise = torch.randn_like(x0)
358
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
359
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
360
+
361
+ @torch.no_grad()
362
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
363
+ use_original_steps=False, callback=None, encode=False, encode_uncon=True, decode_uncon=True, controller=None):
364
+
365
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
366
+ timesteps = timesteps[:t_start]
367
+
368
+ if encode:
369
+ time_range = timesteps
370
+ else:
371
+ time_range = np.flip(timesteps)
372
+
373
+ total_steps = timesteps.shape[0]
374
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
375
+
376
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
377
+ x_dec = x_latent
378
+ for i, step in enumerate(iterator):
379
+ if encode:
380
+ index = i
381
+ else:
382
+ index = total_steps - i - 1
383
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
384
+
385
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
386
+ unconditional_guidance_scale=unconditional_guidance_scale,
387
+ unconditional_conditioning=unconditional_conditioning, encode=encode,
388
+ encode_uncon=encode_uncon, decode_uncon=decode_uncon, controller=controller)
389
+ if callback: callback(i)
390
+ return x_dec
391
+
392
+ @torch.no_grad()
393
+ def decode_one_step(self, x_latent, cond, ts, index, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
394
+ use_original_steps=False, callback=None, encode=False, encode_uncon=True, decode_uncon=True,
395
+ controller=None, inject=False, ref_init=None):
396
+
397
+ x_dec, _ = self.p_sample_ddim(x_latent, cond, ts, index=index, use_original_steps=use_original_steps,
398
+ unconditional_guidance_scale=unconditional_guidance_scale,
399
+ unconditional_conditioning=unconditional_conditioning, encode=encode,
400
+ encode_uncon=encode_uncon, decode_uncon=decode_uncon, controller=controller,
401
+ inject=inject, ref_init=ref_init)
402
+ # if callback: callback(i)
403
+ return x_dec
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from einops import rearrange, repeat
15
+ from contextlib import contextmanager, nullcontext
16
+ from functools import partial
17
+ import itertools
18
+ from tqdm import tqdm
19
+ from torchvision.utils import make_grid
20
+ from pytorch_lightning.utilities.distributed import rank_zero_only
21
+ from omegaconf import ListConfig
22
+
23
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
24
+ from ldm.modules.ema import LitEma
25
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
26
+ from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
27
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
28
+ from ldm.models.diffusion.ddim import DDIMSampler
29
+
30
+
31
+ __conditioning_keys__ = {'concat': 'c_concat',
32
+ 'crossattn': 'c_crossattn',
33
+ 'adm': 'y'}
34
+
35
+
36
+ def disabled_train(self, mode=True):
37
+ """Overwrite model.train with this function to make sure train/eval mode
38
+ does not change anymore."""
39
+ return self
40
+
41
+
42
+ def uniform_on_device(r1, r2, shape, device):
43
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
44
+
45
+
46
+ class DDPM(pl.LightningModule):
47
+ # classic DDPM with Gaussian diffusion, in image space
48
+ def __init__(self,
49
+ unet_config,
50
+ timesteps=1000,
51
+ beta_schedule="linear",
52
+ loss_type="l2",
53
+ ckpt_path=None,
54
+ ignore_keys=[],
55
+ load_only_unet=False,
56
+ monitor="val/loss",
57
+ use_ema=True,
58
+ first_stage_key="image",
59
+ image_size=256,
60
+ channels=3,
61
+ log_every_t=100,
62
+ clip_denoised=True,
63
+ linear_start=1e-4,
64
+ linear_end=2e-2,
65
+ cosine_s=8e-3,
66
+ given_betas=None,
67
+ original_elbo_weight=0.,
68
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
69
+ l_simple_weight=1.,
70
+ conditioning_key=None,
71
+ parameterization="eps", # all assuming fixed variance schedules
72
+ scheduler_config=None,
73
+ use_positional_encodings=False,
74
+ learn_logvar=False,
75
+ logvar_init=0.,
76
+ make_it_fit=False,
77
+ ucg_training=None,
78
+ reset_ema=False,
79
+ reset_num_ema_updates=False,
80
+ ):
81
+ super().__init__()
82
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
83
+ self.parameterization = parameterization
84
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
85
+ self.cond_stage_model = None
86
+ self.clip_denoised = clip_denoised
87
+ self.log_every_t = log_every_t
88
+ self.first_stage_key = first_stage_key
89
+ self.image_size = image_size # try conv?
90
+ self.channels = channels
91
+ self.use_positional_encodings = use_positional_encodings
92
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
93
+ count_params(self.model, verbose=True)
94
+ self.use_ema = use_ema
95
+ if self.use_ema:
96
+ self.model_ema = LitEma(self.model)
97
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
98
+
99
+ self.use_scheduler = scheduler_config is not None
100
+ if self.use_scheduler:
101
+ self.scheduler_config = scheduler_config
102
+
103
+ self.v_posterior = v_posterior
104
+ self.original_elbo_weight = original_elbo_weight
105
+ self.l_simple_weight = l_simple_weight
106
+
107
+ if monitor is not None:
108
+ self.monitor = monitor
109
+ self.make_it_fit = make_it_fit
110
+ if reset_ema: assert exists(ckpt_path)
111
+ if ckpt_path is not None:
112
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
113
+ if reset_ema:
114
+ assert self.use_ema
115
+ print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
116
+ self.model_ema = LitEma(self.model)
117
+ if reset_num_ema_updates:
118
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
119
+ assert self.use_ema
120
+ self.model_ema.reset_num_updates()
121
+
122
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
123
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
124
+
125
+ self.loss_type = loss_type
126
+
127
+ self.learn_logvar = learn_logvar
128
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
129
+ if self.learn_logvar:
130
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
131
+
132
+ self.ucg_training = ucg_training or dict()
133
+ if self.ucg_training:
134
+ self.ucg_prng = np.random.RandomState()
135
+
136
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
137
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
138
+ if exists(given_betas):
139
+ betas = given_betas
140
+ else:
141
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
142
+ cosine_s=cosine_s)
143
+ alphas = 1. - betas
144
+ alphas_cumprod = np.cumprod(alphas, axis=0)
145
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
146
+
147
+ timesteps, = betas.shape
148
+ self.num_timesteps = int(timesteps)
149
+ self.linear_start = linear_start
150
+ self.linear_end = linear_end
151
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
152
+
153
+ to_torch = partial(torch.tensor, dtype=torch.float32)
154
+
155
+ self.register_buffer('betas', to_torch(betas))
156
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
157
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
158
+
159
+ # calculations for diffusion q(x_t | x_{t-1}) and others
160
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
161
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
162
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
163
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
164
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
165
+
166
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
167
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
168
+ 1. - alphas_cumprod) + self.v_posterior * betas
169
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
170
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
171
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
172
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
173
+ self.register_buffer('posterior_mean_coef1', to_torch(
174
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
175
+ self.register_buffer('posterior_mean_coef2', to_torch(
176
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
177
+
178
+ if self.parameterization == "eps":
179
+ lvlb_weights = self.betas ** 2 / (
180
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
181
+ elif self.parameterization == "x0":
182
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
183
+ elif self.parameterization == "v":
184
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
185
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
186
+ else:
187
+ raise NotImplementedError("mu not supported")
188
+ lvlb_weights[0] = lvlb_weights[1]
189
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
190
+ assert not torch.isnan(self.lvlb_weights).all()
191
+
192
+ @contextmanager
193
+ def ema_scope(self, context=None):
194
+ if self.use_ema:
195
+ self.model_ema.store(self.model.parameters())
196
+ self.model_ema.copy_to(self.model)
197
+ if context is not None:
198
+ print(f"{context}: Switched to EMA weights")
199
+ try:
200
+ yield None
201
+ finally:
202
+ if self.use_ema:
203
+ self.model_ema.restore(self.model.parameters())
204
+ if context is not None:
205
+ print(f"{context}: Restored training weights")
206
+
207
+ @torch.no_grad()
208
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
209
+ sd = torch.load(path, map_location="cpu")
210
+ if "state_dict" in list(sd.keys()):
211
+ sd = sd["state_dict"]
212
+ keys = list(sd.keys())
213
+ for k in keys:
214
+ for ik in ignore_keys:
215
+ if k.startswith(ik):
216
+ print("Deleting key {} from state_dict.".format(k))
217
+ del sd[k]
218
+ if self.make_it_fit:
219
+ n_params = len([name for name, _ in
220
+ itertools.chain(self.named_parameters(),
221
+ self.named_buffers())])
222
+ for name, param in tqdm(
223
+ itertools.chain(self.named_parameters(),
224
+ self.named_buffers()),
225
+ desc="Fitting old weights to new weights",
226
+ total=n_params
227
+ ):
228
+ if not name in sd:
229
+ continue
230
+ old_shape = sd[name].shape
231
+ new_shape = param.shape
232
+ assert len(old_shape) == len(new_shape)
233
+ if len(new_shape) > 2:
234
+ # we only modify first two axes
235
+ assert new_shape[2:] == old_shape[2:]
236
+ # assumes first axis corresponds to output dim
237
+ if not new_shape == old_shape:
238
+ new_param = param.clone()
239
+ old_param = sd[name]
240
+ if len(new_shape) == 1:
241
+ for i in range(new_param.shape[0]):
242
+ new_param[i] = old_param[i % old_shape[0]]
243
+ elif len(new_shape) >= 2:
244
+ for i in range(new_param.shape[0]):
245
+ for j in range(new_param.shape[1]):
246
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
247
+
248
+ n_used_old = torch.ones(old_shape[1])
249
+ for j in range(new_param.shape[1]):
250
+ n_used_old[j % old_shape[1]] += 1
251
+ n_used_new = torch.zeros(new_shape[1])
252
+ for j in range(new_param.shape[1]):
253
+ n_used_new[j] = n_used_old[j % old_shape[1]]
254
+
255
+ n_used_new = n_used_new[None, :]
256
+ while len(n_used_new.shape) < len(new_shape):
257
+ n_used_new = n_used_new.unsqueeze(-1)
258
+ new_param /= n_used_new
259
+
260
+ sd[name] = new_param
261
+
262
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
263
+ sd, strict=False)
264
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
265
+ if len(missing) > 0:
266
+ print(f"Missing Keys:\n {missing}")
267
+ if len(unexpected) > 0:
268
+ print(f"\nUnexpected Keys:\n {unexpected}")
269
+
270
+ def q_mean_variance(self, x_start, t):
271
+ """
272
+ Get the distribution q(x_t | x_0).
273
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
274
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
275
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
276
+ """
277
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
278
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
279
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
280
+ return mean, variance, log_variance
281
+
282
+ def predict_start_from_noise(self, x_t, t, noise):
283
+ return (
284
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
285
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
286
+ )
287
+
288
+ def predict_start_from_z_and_v(self, x_t, t, v):
289
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
290
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
291
+ return (
292
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
293
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
294
+ )
295
+
296
+ def predict_eps_from_z_and_v(self, x_t, t, v):
297
+ return (
298
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
299
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
300
+ )
301
+
302
+ def q_posterior(self, x_start, x_t, t):
303
+ posterior_mean = (
304
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
305
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
306
+ )
307
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
308
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
309
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
310
+
311
+ def p_mean_variance(self, x, t, clip_denoised: bool):
312
+ model_out = self.model(x, t)
313
+ if self.parameterization == "eps":
314
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
315
+ elif self.parameterization == "x0":
316
+ x_recon = model_out
317
+ if clip_denoised:
318
+ x_recon.clamp_(-1., 1.)
319
+
320
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
321
+ return model_mean, posterior_variance, posterior_log_variance
322
+
323
+ @torch.no_grad()
324
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
325
+ b, *_, device = *x.shape, x.device
326
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
327
+ noise = noise_like(x.shape, device, repeat_noise)
328
+ # no noise when t == 0
329
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
330
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
331
+
332
+ @torch.no_grad()
333
+ def p_sample_loop(self, shape, return_intermediates=False):
334
+ device = self.betas.device
335
+ b = shape[0]
336
+ img = torch.randn(shape, device=device)
337
+ intermediates = [img]
338
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
339
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
340
+ clip_denoised=self.clip_denoised)
341
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
342
+ intermediates.append(img)
343
+ if return_intermediates:
344
+ return img, intermediates
345
+ return img
346
+
347
+ @torch.no_grad()
348
+ def sample(self, batch_size=16, return_intermediates=False):
349
+ image_size = self.image_size
350
+ channels = self.channels
351
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
352
+ return_intermediates=return_intermediates)
353
+
354
+ def q_sample(self, x_start, t, noise=None):
355
+ noise = default(noise, lambda: torch.randn_like(x_start))
356
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
357
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
358
+
359
+ def get_v(self, x, noise, t):
360
+ return (
361
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
362
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
363
+ )
364
+
365
+ def get_loss(self, pred, target, mean=True):
366
+ if self.loss_type == 'l1':
367
+ loss = (target - pred).abs()
368
+ if mean:
369
+ loss = loss.mean()
370
+ elif self.loss_type == 'l2':
371
+ if mean:
372
+ loss = torch.nn.functional.mse_loss(target, pred)
373
+ else:
374
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
375
+ else:
376
+ raise NotImplementedError("unknown loss type '{loss_type}'")
377
+
378
+ return loss
379
+
380
+ def p_losses(self, x_start, t, noise=None):
381
+ noise = default(noise, lambda: torch.randn_like(x_start))
382
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
383
+ model_out = self.model(x_noisy, t)
384
+
385
+ loss_dict = {}
386
+ if self.parameterization == "eps":
387
+ target = noise
388
+ elif self.parameterization == "x0":
389
+ target = x_start
390
+ elif self.parameterization == "v":
391
+ target = self.get_v(x_start, noise, t)
392
+ else:
393
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
394
+
395
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
396
+
397
+ log_prefix = 'train' if self.training else 'val'
398
+
399
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
400
+ loss_simple = loss.mean() * self.l_simple_weight
401
+
402
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
403
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
404
+
405
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
406
+
407
+ loss_dict.update({f'{log_prefix}/loss': loss})
408
+
409
+ return loss, loss_dict
410
+
411
+ def forward(self, x, *args, **kwargs):
412
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
413
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
414
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
415
+ return self.p_losses(x, t, *args, **kwargs)
416
+
417
+ def get_input(self, batch, k):
418
+ x = batch[k]
419
+ if len(x.shape) == 3:
420
+ x = x[..., None]
421
+ x = rearrange(x, 'b h w c -> b c h w')
422
+ x = x.to(memory_format=torch.contiguous_format).float()
423
+ return x
424
+
425
+ def shared_step(self, batch):
426
+ x = self.get_input(batch, self.first_stage_key)
427
+ loss, loss_dict = self(x)
428
+ return loss, loss_dict
429
+
430
+ def training_step(self, batch, batch_idx):
431
+ for k in self.ucg_training:
432
+ p = self.ucg_training[k]["p"]
433
+ val = self.ucg_training[k]["val"]
434
+ if val is None:
435
+ val = ""
436
+ for i in range(len(batch[k])):
437
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
438
+ batch[k][i] = val
439
+
440
+ loss, loss_dict = self.shared_step(batch)
441
+
442
+ self.log_dict(loss_dict, prog_bar=True,
443
+ logger=True, on_step=True, on_epoch=True)
444
+
445
+ self.log("global_step", self.global_step,
446
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
447
+
448
+ if self.use_scheduler:
449
+ lr = self.optimizers().param_groups[0]['lr']
450
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
451
+
452
+ return loss
453
+
454
+ @torch.no_grad()
455
+ def validation_step(self, batch, batch_idx):
456
+ _, loss_dict_no_ema = self.shared_step(batch)
457
+ with self.ema_scope():
458
+ _, loss_dict_ema = self.shared_step(batch)
459
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
460
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
461
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
462
+
463
+ def on_train_batch_end(self, *args, **kwargs):
464
+ if self.use_ema:
465
+ self.model_ema(self.model)
466
+
467
+ def _get_rows_from_list(self, samples):
468
+ n_imgs_per_row = len(samples)
469
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
470
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
471
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
472
+ return denoise_grid
473
+
474
+ @torch.no_grad()
475
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
476
+ log = dict()
477
+ x = self.get_input(batch, self.first_stage_key)
478
+ N = min(x.shape[0], N)
479
+ n_row = min(x.shape[0], n_row)
480
+ x = x.to(self.device)[:N]
481
+ log["inputs"] = x
482
+
483
+ # get diffusion row
484
+ diffusion_row = list()
485
+ x_start = x[:n_row]
486
+
487
+ for t in range(self.num_timesteps):
488
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
489
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
490
+ t = t.to(self.device).long()
491
+ noise = torch.randn_like(x_start)
492
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
493
+ diffusion_row.append(x_noisy)
494
+
495
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
496
+
497
+ if sample:
498
+ # get denoise row
499
+ with self.ema_scope("Plotting"):
500
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
501
+
502
+ log["samples"] = samples
503
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
504
+
505
+ if return_keys:
506
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
507
+ return log
508
+ else:
509
+ return {key: log[key] for key in return_keys}
510
+ return log
511
+
512
+ def configure_optimizers(self):
513
+ lr = self.learning_rate
514
+ params = list(self.model.parameters())
515
+ if self.learn_logvar:
516
+ params = params + [self.logvar]
517
+ opt = torch.optim.AdamW(params, lr=lr)
518
+ return opt
519
+
520
+
521
+ class LatentDiffusion(DDPM):
522
+ """main class"""
523
+
524
+ def __init__(self,
525
+ first_stage_config,
526
+ cond_stage_config,
527
+ num_timesteps_cond=None,
528
+ cond_stage_key="image",
529
+ cond_stage_trainable=False,
530
+ concat_mode=True,
531
+ cond_stage_forward=None,
532
+ conditioning_key=None,
533
+ scale_factor=1.0,
534
+ scale_by_std=False,
535
+ force_null_conditioning=False,
536
+ *args, **kwargs):
537
+ self.force_null_conditioning = force_null_conditioning
538
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
539
+ self.scale_by_std = scale_by_std
540
+ assert self.num_timesteps_cond <= kwargs['timesteps']
541
+ # for backwards compatibility after implementation of DiffusionWrapper
542
+ if conditioning_key is None:
543
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
544
+ if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
545
+ conditioning_key = None
546
+ ckpt_path = kwargs.pop("ckpt_path", None)
547
+ reset_ema = kwargs.pop("reset_ema", False)
548
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
549
+ ignore_keys = kwargs.pop("ignore_keys", [])
550
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
551
+ self.concat_mode = concat_mode
552
+ self.cond_stage_trainable = cond_stage_trainable
553
+ self.cond_stage_key = cond_stage_key
554
+ try:
555
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
556
+ except:
557
+ self.num_downs = 0
558
+ if not scale_by_std:
559
+ self.scale_factor = scale_factor
560
+ else:
561
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
562
+ self.instantiate_first_stage(first_stage_config)
563
+ self.instantiate_cond_stage(cond_stage_config)
564
+ self.cond_stage_forward = cond_stage_forward
565
+ self.clip_denoised = False
566
+ self.bbox_tokenizer = None
567
+
568
+ self.restarted_from_ckpt = False
569
+ if ckpt_path is not None:
570
+ self.init_from_ckpt(ckpt_path, ignore_keys)
571
+ self.restarted_from_ckpt = True
572
+ if reset_ema:
573
+ assert self.use_ema
574
+ print(
575
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
576
+ self.model_ema = LitEma(self.model)
577
+ if reset_num_ema_updates:
578
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
579
+ assert self.use_ema
580
+ self.model_ema.reset_num_updates()
581
+
582
+ def make_cond_schedule(self, ):
583
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
584
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
585
+ self.cond_ids[:self.num_timesteps_cond] = ids
586
+
587
+ @rank_zero_only
588
+ @torch.no_grad()
589
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
590
+ # only for very first batch
591
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
592
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
593
+ # set rescale weight to 1./std of encodings
594
+ print("### USING STD-RESCALING ###")
595
+ x = super().get_input(batch, self.first_stage_key)
596
+ x = x.to(self.device)
597
+ encoder_posterior = self.encode_first_stage(x)
598
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
599
+ del self.scale_factor
600
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
601
+ print(f"setting self.scale_factor to {self.scale_factor}")
602
+ print("### USING STD-RESCALING ###")
603
+
604
+ def register_schedule(self,
605
+ given_betas=None, beta_schedule="linear", timesteps=1000,
606
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
607
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
608
+
609
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
610
+ if self.shorten_cond_schedule:
611
+ self.make_cond_schedule()
612
+
613
+ def instantiate_first_stage(self, config):
614
+ model = instantiate_from_config(config)
615
+ self.first_stage_model = model.eval()
616
+ self.first_stage_model.train = disabled_train
617
+ for param in self.first_stage_model.parameters():
618
+ param.requires_grad = False
619
+
620
+ def instantiate_cond_stage(self, config):
621
+ if not self.cond_stage_trainable:
622
+ if config == "__is_first_stage__":
623
+ print("Using first stage also as cond stage.")
624
+ self.cond_stage_model = self.first_stage_model
625
+ elif config == "__is_unconditional__":
626
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
627
+ self.cond_stage_model = None
628
+ # self.be_unconditional = True
629
+ else:
630
+ model = instantiate_from_config(config)
631
+ self.cond_stage_model = model.eval()
632
+ self.cond_stage_model.train = disabled_train
633
+ for param in self.cond_stage_model.parameters():
634
+ param.requires_grad = False
635
+ else:
636
+ assert config != '__is_first_stage__'
637
+ assert config != '__is_unconditional__'
638
+ model = instantiate_from_config(config)
639
+ self.cond_stage_model = model
640
+
641
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
642
+ denoise_row = []
643
+ for zd in tqdm(samples, desc=desc):
644
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
645
+ force_not_quantize=force_no_decoder_quantization))
646
+ n_imgs_per_row = len(denoise_row)
647
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
648
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
649
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
650
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
651
+ return denoise_grid
652
+
653
+ def get_first_stage_encoding(self, encoder_posterior):
654
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
655
+ z = encoder_posterior.sample()
656
+ elif isinstance(encoder_posterior, torch.Tensor):
657
+ z = encoder_posterior
658
+ else:
659
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
660
+ return self.scale_factor * z
661
+
662
+ def get_learned_conditioning(self, c, inv=False):
663
+
664
+ if self.cond_stage_forward is None:
665
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
666
+ c = self.cond_stage_model.encode(c, inv, device=self.device)
667
+ if isinstance(c, DiagonalGaussianDistribution):
668
+ c = c.mode()
669
+ else:
670
+ c = self.cond_stage_model(c)
671
+ else:
672
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
673
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
674
+ return c
675
+
676
+ def meshgrid(self, h, w):
677
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
678
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
679
+
680
+ arr = torch.cat([y, x], dim=-1)
681
+ return arr
682
+
683
+ def delta_border(self, h, w):
684
+ """
685
+ :param h: height
686
+ :param w: width
687
+ :return: normalized distance to image border,
688
+ wtith min distance = 0 at border and max dist = 0.5 at image center
689
+ """
690
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
691
+ arr = self.meshgrid(h, w) / lower_right_corner
692
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
693
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
694
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
695
+ return edge_dist
696
+
697
+ def get_weighting(self, h, w, Ly, Lx, device):
698
+ weighting = self.delta_border(h, w)
699
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
700
+ self.split_input_params["clip_max_weight"], )
701
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
702
+
703
+ if self.split_input_params["tie_braker"]:
704
+ L_weighting = self.delta_border(Ly, Lx)
705
+ L_weighting = torch.clip(L_weighting,
706
+ self.split_input_params["clip_min_tie_weight"],
707
+ self.split_input_params["clip_max_tie_weight"])
708
+
709
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
710
+ weighting = weighting * L_weighting
711
+ return weighting
712
+
713
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
714
+ """
715
+ :param x: img of size (bs, c, h, w)
716
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
717
+ """
718
+ bs, nc, h, w = x.shape
719
+
720
+ # number of crops in image
721
+ Ly = (h - kernel_size[0]) // stride[0] + 1
722
+ Lx = (w - kernel_size[1]) // stride[1] + 1
723
+
724
+ if uf == 1 and df == 1:
725
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
726
+ unfold = torch.nn.Unfold(**fold_params)
727
+
728
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
729
+
730
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
731
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
732
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
733
+
734
+ elif uf > 1 and df == 1:
735
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
736
+ unfold = torch.nn.Unfold(**fold_params)
737
+
738
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
739
+ dilation=1, padding=0,
740
+ stride=(stride[0] * uf, stride[1] * uf))
741
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
742
+
743
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
744
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
745
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
746
+
747
+ elif df > 1 and uf == 1:
748
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
749
+ unfold = torch.nn.Unfold(**fold_params)
750
+
751
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
752
+ dilation=1, padding=0,
753
+ stride=(stride[0] // df, stride[1] // df))
754
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
755
+
756
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
757
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
758
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
759
+
760
+ else:
761
+ raise NotImplementedError
762
+
763
+ return fold, unfold, normalization, weighting
764
+
765
+ @torch.no_grad()
766
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
767
+ cond_key=None, return_original_cond=False, bs=None, return_x=False):
768
+ x = super().get_input(batch, k)
769
+ if bs is not None:
770
+ x = x[:bs]
771
+ x = x.to(self.device)
772
+ encoder_posterior = self.encode_first_stage(x)
773
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
774
+
775
+ if self.model.conditioning_key is not None and not self.force_null_conditioning:
776
+ if cond_key is None:
777
+ cond_key = self.cond_stage_key
778
+ if cond_key != self.first_stage_key:
779
+ if cond_key in ['caption', 'coordinates_bbox', "txt"]:
780
+ xc = batch[cond_key]
781
+ elif cond_key in ['class_label', 'cls']:
782
+ xc = batch
783
+ else:
784
+ xc = super().get_input(batch, cond_key).to(self.device)
785
+ else:
786
+ xc = x
787
+ if not self.cond_stage_trainable or force_c_encode:
788
+ if isinstance(xc, dict) or isinstance(xc, list):
789
+ c = self.get_learned_conditioning(xc)
790
+ else:
791
+ c = self.get_learned_conditioning(xc.to(self.device))
792
+ else:
793
+ c = xc
794
+ if bs is not None:
795
+ c = c[:bs]
796
+
797
+ if self.use_positional_encodings:
798
+ pos_x, pos_y = self.compute_latent_shifts(batch)
799
+ ckey = __conditioning_keys__[self.model.conditioning_key]
800
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
801
+
802
+ else:
803
+ c = None
804
+ xc = None
805
+ if self.use_positional_encodings:
806
+ pos_x, pos_y = self.compute_latent_shifts(batch)
807
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
808
+ out = [z, c]
809
+ if return_first_stage_outputs:
810
+ xrec = self.decode_first_stage(z)
811
+ out.extend([x, xrec])
812
+ if return_x:
813
+ out.extend([x])
814
+ if return_original_cond:
815
+ out.append(xc)
816
+ return out
817
+
818
+ @torch.no_grad()
819
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
820
+ if predict_cids:
821
+ if z.dim() == 4:
822
+ z = torch.argmax(z.exp(), dim=1).long()
823
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
824
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
825
+
826
+ z = 1. / self.scale_factor * z
827
+ return self.first_stage_model.decode(z)
828
+
829
+ @torch.no_grad()
830
+ def encode_first_stage(self, x):
831
+ return self.first_stage_model.encode(x)
832
+
833
+ def shared_step(self, batch, **kwargs):
834
+ x, c = self.get_input(batch, self.first_stage_key)
835
+ loss = self(x, c)
836
+ return loss
837
+
838
+ def forward(self, x, c, *args, **kwargs):
839
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
840
+ if self.model.conditioning_key is not None:
841
+ assert c is not None
842
+ if self.cond_stage_trainable:
843
+ c = self.get_learned_conditioning(c)
844
+ if self.shorten_cond_schedule: # TODO: drop this option
845
+ tc = self.cond_ids[t].to(self.device)
846
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
847
+ return self.p_losses(x, c, t, *args, **kwargs)
848
+
849
+ def apply_model(self, x_noisy, t, cond, return_ids=False, encode=False, encode_uncon=False, decode_uncon=False, controller=None, inject=False):
850
+ if isinstance(cond, dict):
851
+ # hybrid case, cond is expected to be a dict
852
+ pass
853
+ else:
854
+ if not isinstance(cond, list):
855
+ cond = [cond]
856
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
857
+ cond = {key: cond}
858
+
859
+ x_recon = self.model(x_noisy, t, **cond, encode=encode, encode_uncon=encode_uncon, decode_uncon=decode_uncon, controller=controller, inject=inject)
860
+
861
+ if isinstance(x_recon, tuple) and not return_ids:
862
+ return x_recon[0]
863
+ else:
864
+ return x_recon
865
+
866
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
867
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
868
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
869
+
870
+ def _prior_bpd(self, x_start):
871
+ """
872
+ Get the prior KL term for the variational lower-bound, measured in
873
+ bits-per-dim.
874
+ This term can't be optimized, as it only depends on the encoder.
875
+ :param x_start: the [N x C x ...] tensor of inputs.
876
+ :return: a batch of [N] KL values (in bits), one per batch element.
877
+ """
878
+ batch_size = x_start.shape[0]
879
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
880
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
881
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
882
+ return mean_flat(kl_prior) / np.log(2.0)
883
+
884
+ def p_losses(self, x_start, cond, t, noise=None):
885
+ noise = default(noise, lambda: torch.randn_like(x_start))
886
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
887
+ model_output = self.apply_model(x_noisy, t, cond)
888
+
889
+ loss_dict = {}
890
+ prefix = 'train' if self.training else 'val'
891
+
892
+ if self.parameterization == "x0":
893
+ target = x_start
894
+ elif self.parameterization == "eps":
895
+ target = noise
896
+ elif self.parameterization == "v":
897
+ target = self.get_v(x_start, noise, t)
898
+ else:
899
+ raise NotImplementedError()
900
+
901
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
902
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
903
+
904
+ logvar_t = self.logvar[t].to(self.device)
905
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
906
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
907
+ if self.learn_logvar:
908
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
909
+ loss_dict.update({'logvar': self.logvar.data.mean()})
910
+
911
+ loss = self.l_simple_weight * loss.mean()
912
+
913
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
914
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
915
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
916
+ loss += (self.original_elbo_weight * loss_vlb)
917
+ loss_dict.update({f'{prefix}/loss': loss})
918
+
919
+ return loss, loss_dict
920
+
921
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
922
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
923
+ t_in = t
924
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
925
+
926
+ if score_corrector is not None:
927
+ assert self.parameterization == "eps"
928
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
929
+
930
+ if return_codebook_ids:
931
+ model_out, logits = model_out
932
+
933
+ if self.parameterization == "eps":
934
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
935
+ elif self.parameterization == "x0":
936
+ x_recon = model_out
937
+ else:
938
+ raise NotImplementedError()
939
+
940
+ if clip_denoised:
941
+ x_recon.clamp_(-1., 1.)
942
+ if quantize_denoised:
943
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
944
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
945
+ if return_codebook_ids:
946
+ return model_mean, posterior_variance, posterior_log_variance, logits
947
+ elif return_x0:
948
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
949
+ else:
950
+ return model_mean, posterior_variance, posterior_log_variance
951
+
952
+ @torch.no_grad()
953
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
954
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
955
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
956
+ b, *_, device = *x.shape, x.device
957
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
958
+ return_codebook_ids=return_codebook_ids,
959
+ quantize_denoised=quantize_denoised,
960
+ return_x0=return_x0,
961
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
962
+ if return_codebook_ids:
963
+ raise DeprecationWarning("Support dropped.")
964
+ model_mean, _, model_log_variance, logits = outputs
965
+ elif return_x0:
966
+ model_mean, _, model_log_variance, x0 = outputs
967
+ else:
968
+ model_mean, _, model_log_variance = outputs
969
+
970
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
971
+ if noise_dropout > 0.:
972
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
973
+ # no noise when t == 0
974
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
975
+
976
+ if return_codebook_ids:
977
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
978
+ if return_x0:
979
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
980
+ else:
981
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
982
+
983
+ @torch.no_grad()
984
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
985
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
986
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
987
+ log_every_t=None):
988
+ if not log_every_t:
989
+ log_every_t = self.log_every_t
990
+ timesteps = self.num_timesteps
991
+ if batch_size is not None:
992
+ b = batch_size if batch_size is not None else shape[0]
993
+ shape = [batch_size] + list(shape)
994
+ else:
995
+ b = batch_size = shape[0]
996
+ if x_T is None:
997
+ img = torch.randn(shape, device=self.device)
998
+ else:
999
+ img = x_T
1000
+ intermediates = []
1001
+ if cond is not None:
1002
+ if isinstance(cond, dict):
1003
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1004
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1005
+ else:
1006
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1007
+
1008
+ if start_T is not None:
1009
+ timesteps = min(timesteps, start_T)
1010
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1011
+ total=timesteps) if verbose else reversed(
1012
+ range(0, timesteps))
1013
+ if type(temperature) == float:
1014
+ temperature = [temperature] * timesteps
1015
+
1016
+ for i in iterator:
1017
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1018
+ if self.shorten_cond_schedule:
1019
+ assert self.model.conditioning_key != 'hybrid'
1020
+ tc = self.cond_ids[ts].to(cond.device)
1021
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1022
+
1023
+ img, x0_partial = self.p_sample(img, cond, ts,
1024
+ clip_denoised=self.clip_denoised,
1025
+ quantize_denoised=quantize_denoised, return_x0=True,
1026
+ temperature=temperature[i], noise_dropout=noise_dropout,
1027
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1028
+ if mask is not None:
1029
+ assert x0 is not None
1030
+ img_orig = self.q_sample(x0, ts)
1031
+ img = img_orig * mask + (1. - mask) * img
1032
+
1033
+ if i % log_every_t == 0 or i == timesteps - 1:
1034
+ intermediates.append(x0_partial)
1035
+ if callback: callback(i)
1036
+ if img_callback: img_callback(img, i)
1037
+ return img, intermediates
1038
+
1039
+ @torch.no_grad()
1040
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1041
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1042
+ mask=None, x0=None, img_callback=None, start_T=None,
1043
+ log_every_t=None):
1044
+
1045
+ if not log_every_t:
1046
+ log_every_t = self.log_every_t
1047
+ device = self.betas.device
1048
+ b = shape[0]
1049
+ if x_T is None:
1050
+ img = torch.randn(shape, device=device)
1051
+ else:
1052
+ img = x_T
1053
+
1054
+ intermediates = [img]
1055
+ if timesteps is None:
1056
+ timesteps = self.num_timesteps
1057
+
1058
+ if start_T is not None:
1059
+ timesteps = min(timesteps, start_T)
1060
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1061
+ range(0, timesteps))
1062
+
1063
+ if mask is not None:
1064
+ assert x0 is not None
1065
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1066
+
1067
+ for i in iterator:
1068
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1069
+ if self.shorten_cond_schedule:
1070
+ assert self.model.conditioning_key != 'hybrid'
1071
+ tc = self.cond_ids[ts].to(cond.device)
1072
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1073
+
1074
+ img = self.p_sample(img, cond, ts,
1075
+ clip_denoised=self.clip_denoised,
1076
+ quantize_denoised=quantize_denoised)
1077
+ if mask is not None:
1078
+ img_orig = self.q_sample(x0, ts)
1079
+ img = img_orig * mask + (1. - mask) * img
1080
+
1081
+ if i % log_every_t == 0 or i == timesteps - 1:
1082
+ intermediates.append(img)
1083
+ if callback: callback(i)
1084
+ if img_callback: img_callback(img, i)
1085
+
1086
+ if return_intermediates:
1087
+ return img, intermediates
1088
+ return img
1089
+
1090
+ @torch.no_grad()
1091
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1092
+ verbose=True, timesteps=None, quantize_denoised=False,
1093
+ mask=None, x0=None, shape=None, **kwargs):
1094
+ if shape is None:
1095
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1096
+ if cond is not None:
1097
+ if isinstance(cond, dict):
1098
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1099
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1100
+ else:
1101
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1102
+ return self.p_sample_loop(cond,
1103
+ shape,
1104
+ return_intermediates=return_intermediates, x_T=x_T,
1105
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1106
+ mask=mask, x0=x0)
1107
+
1108
+ @torch.no_grad()
1109
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
1110
+ if ddim:
1111
+ ddim_sampler = DDIMSampler(self)
1112
+ shape = (self.channels, self.image_size, self.image_size)
1113
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
1114
+ shape, cond, verbose=False, **kwargs)
1115
+
1116
+ else:
1117
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1118
+ return_intermediates=True, **kwargs)
1119
+
1120
+ return samples, intermediates
1121
+
1122
+ @torch.no_grad()
1123
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
1124
+ if null_label is not None:
1125
+ xc = null_label
1126
+ if isinstance(xc, ListConfig):
1127
+ xc = list(xc)
1128
+ if isinstance(xc, dict) or isinstance(xc, list):
1129
+ c = self.get_learned_conditioning(xc)
1130
+ else:
1131
+ if hasattr(xc, "to"):
1132
+ xc = xc.to(self.device)
1133
+ c = self.get_learned_conditioning(xc)
1134
+ else:
1135
+ if self.cond_stage_key in ["class_label", "cls"]:
1136
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
1137
+ return self.get_learned_conditioning(xc)
1138
+ else:
1139
+ raise NotImplementedError("todo")
1140
+ if isinstance(c, list): # in case the encoder gives us a list
1141
+ for i in range(len(c)):
1142
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
1143
+ else:
1144
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
1145
+ return c
1146
+
1147
+ @torch.no_grad()
1148
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
1149
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1150
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1151
+ use_ema_scope=True,
1152
+ **kwargs):
1153
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1154
+ use_ddim = ddim_steps is not None
1155
+
1156
+ log = dict()
1157
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1158
+ return_first_stage_outputs=True,
1159
+ force_c_encode=True,
1160
+ return_original_cond=True,
1161
+ bs=N)
1162
+ N = min(x.shape[0], N)
1163
+ n_row = min(x.shape[0], n_row)
1164
+ log["inputs"] = x
1165
+ log["reconstruction"] = xrec
1166
+ if self.model.conditioning_key is not None:
1167
+ if hasattr(self.cond_stage_model, "decode"):
1168
+ xc = self.cond_stage_model.decode(c)
1169
+ log["conditioning"] = xc
1170
+ elif self.cond_stage_key in ["caption", "txt"]:
1171
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1172
+ log["conditioning"] = xc
1173
+ elif self.cond_stage_key in ['class_label', "cls"]:
1174
+ try:
1175
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1176
+ log['conditioning'] = xc
1177
+ except KeyError:
1178
+ # probably no "human_label" in batch
1179
+ pass
1180
+ elif isimage(xc):
1181
+ log["conditioning"] = xc
1182
+ if ismap(xc):
1183
+ log["original_conditioning"] = self.to_rgb(xc)
1184
+
1185
+ if plot_diffusion_rows:
1186
+ # get diffusion row
1187
+ diffusion_row = list()
1188
+ z_start = z[:n_row]
1189
+ for t in range(self.num_timesteps):
1190
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1191
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1192
+ t = t.to(self.device).long()
1193
+ noise = torch.randn_like(z_start)
1194
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1195
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1196
+
1197
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1198
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1199
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1200
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1201
+ log["diffusion_row"] = diffusion_grid
1202
+
1203
+ if sample:
1204
+ # get denoise row
1205
+ with ema_scope("Sampling"):
1206
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1207
+ ddim_steps=ddim_steps, eta=ddim_eta)
1208
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1209
+ x_samples = self.decode_first_stage(samples)
1210
+ log["samples"] = x_samples
1211
+ if plot_denoise_rows:
1212
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1213
+ log["denoise_row"] = denoise_grid
1214
+
1215
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1216
+ self.first_stage_model, IdentityFirstStage):
1217
+ # also display when quantizing x0 while sampling
1218
+ with ema_scope("Plotting Quantized Denoised"):
1219
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1220
+ ddim_steps=ddim_steps, eta=ddim_eta,
1221
+ quantize_denoised=True)
1222
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1223
+ # quantize_denoised=True)
1224
+ x_samples = self.decode_first_stage(samples.to(self.device))
1225
+ log["samples_x0_quantized"] = x_samples
1226
+
1227
+ if unconditional_guidance_scale > 1.0:
1228
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1229
+ if self.model.conditioning_key == "crossattn-adm":
1230
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
1231
+ with ema_scope("Sampling with classifier-free guidance"):
1232
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1233
+ ddim_steps=ddim_steps, eta=ddim_eta,
1234
+ unconditional_guidance_scale=unconditional_guidance_scale,
1235
+ unconditional_conditioning=uc,
1236
+ )
1237
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1238
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1239
+
1240
+ if inpaint:
1241
+ # make a simple center square
1242
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1243
+ mask = torch.ones(N, h, w).to(self.device)
1244
+ # zeros will be filled in
1245
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1246
+ mask = mask[:, None, ...]
1247
+ with ema_scope("Plotting Inpaint"):
1248
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1249
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1250
+ x_samples = self.decode_first_stage(samples.to(self.device))
1251
+ log["samples_inpainting"] = x_samples
1252
+ log["mask"] = mask
1253
+
1254
+ # outpaint
1255
+ mask = 1. - mask
1256
+ with ema_scope("Plotting Outpaint"):
1257
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1258
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1259
+ x_samples = self.decode_first_stage(samples.to(self.device))
1260
+ log["samples_outpainting"] = x_samples
1261
+
1262
+ if plot_progressive_rows:
1263
+ with ema_scope("Plotting Progressives"):
1264
+ img, progressives = self.progressive_denoising(c,
1265
+ shape=(self.channels, self.image_size, self.image_size),
1266
+ batch_size=N)
1267
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1268
+ log["progressive_row"] = prog_row
1269
+
1270
+ if return_keys:
1271
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1272
+ return log
1273
+ else:
1274
+ return {key: log[key] for key in return_keys}
1275
+ return log
1276
+
1277
+ def configure_optimizers(self):
1278
+ lr = self.learning_rate
1279
+ params = list(self.model.parameters())
1280
+ if self.cond_stage_trainable:
1281
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1282
+ params = params + list(self.cond_stage_model.parameters())
1283
+ if self.learn_logvar:
1284
+ print('Diffusion model optimizing logvar')
1285
+ params.append(self.logvar)
1286
+ opt = torch.optim.AdamW(params, lr=lr)
1287
+ if self.use_scheduler:
1288
+ assert 'target' in self.scheduler_config
1289
+ scheduler = instantiate_from_config(self.scheduler_config)
1290
+
1291
+ print("Setting up LambdaLR scheduler...")
1292
+ scheduler = [
1293
+ {
1294
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1295
+ 'interval': 'step',
1296
+ 'frequency': 1
1297
+ }]
1298
+ return [opt], scheduler
1299
+ return opt
1300
+
1301
+ @torch.no_grad()
1302
+ def to_rgb(self, x):
1303
+ x = x.float()
1304
+ if not hasattr(self, "colorize"):
1305
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1306
+ x = nn.functional.conv2d(x, weight=self.colorize)
1307
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1308
+ return x
1309
+
1310
+
1311
+ class DiffusionWrapper(pl.LightningModule):
1312
+ def __init__(self, diff_model_config, conditioning_key):
1313
+ super().__init__()
1314
+ self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
1315
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1316
+ self.conditioning_key = conditioning_key
1317
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
1318
+
1319
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, encode=False, encode_uncon=False, decode_uncon=False, controller=None, inject=False):
1320
+ if self.conditioning_key is None:
1321
+ out = self.diffusion_model(x, t, encode=encode, encode_uncon=encode_uncon, decode_uncon=decode_uncon, controller=controller, inject=inject)
1322
+ elif self.conditioning_key == 'concat':
1323
+ xc = torch.cat([x] + c_concat, dim=1)
1324
+ out = self.diffusion_model(xc, t)
1325
+ elif self.conditioning_key == 'crossattn':
1326
+ if not self.sequential_cross_attn:
1327
+ cc = torch.cat(c_crossattn, 1)
1328
+ else:
1329
+ cc = c_crossattn
1330
+ out = self.diffusion_model(x, t, context=cc, encode=encode, encode_uncon=encode_uncon, decode_uncon=decode_uncon, controller=controller, inject=inject)
1331
+ elif self.conditioning_key == 'hybrid':
1332
+ xc = torch.cat([x] + c_concat, dim=1)
1333
+ cc = torch.cat(c_crossattn, 1)
1334
+ out = self.diffusion_model(xc, t, context=cc)
1335
+ elif self.conditioning_key == 'hybrid-adm':
1336
+ assert c_adm is not None
1337
+ xc = torch.cat([x] + c_concat, dim=1)
1338
+ cc = torch.cat(c_crossattn, 1)
1339
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
1340
+ elif self.conditioning_key == 'crossattn-adm':
1341
+ assert c_adm is not None
1342
+ cc = torch.cat(c_crossattn, 1)
1343
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
1344
+ elif self.conditioning_key == 'adm':
1345
+ cc = c_crossattn[0]
1346
+ out = self.diffusion_model(x, t, y=cc)
1347
+ else:
1348
+ raise NotImplementedError()
1349
+
1350
+ return out
1351
+
1352
+
1353
+ class LatentUpscaleDiffusion(LatentDiffusion):
1354
+ def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
1355
+ super().__init__(*args, **kwargs)
1356
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
1357
+ assert not self.cond_stage_trainable
1358
+ self.instantiate_low_stage(low_scale_config)
1359
+ self.low_scale_key = low_scale_key
1360
+ self.noise_level_key = noise_level_key
1361
+
1362
+ def instantiate_low_stage(self, config):
1363
+ model = instantiate_from_config(config)
1364
+ self.low_scale_model = model.eval()
1365
+ self.low_scale_model.train = disabled_train
1366
+ for param in self.low_scale_model.parameters():
1367
+ param.requires_grad = False
1368
+
1369
+ @torch.no_grad()
1370
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
1371
+ if not log_mode:
1372
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
1373
+ else:
1374
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1375
+ force_c_encode=True, return_original_cond=True, bs=bs)
1376
+ x_low = batch[self.low_scale_key][:bs]
1377
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
1378
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
1379
+ zx, noise_level = self.low_scale_model(x_low)
1380
+ if self.noise_level_key is not None:
1381
+ # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
1382
+ raise NotImplementedError('TODO')
1383
+
1384
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
1385
+ if log_mode:
1386
+ # TODO: maybe disable if too expensive
1387
+ x_low_rec = self.low_scale_model.decode(zx)
1388
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
1389
+ return z, all_conds
1390
+
1391
+ @torch.no_grad()
1392
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1393
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
1394
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
1395
+ **kwargs):
1396
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1397
+ use_ddim = ddim_steps is not None
1398
+
1399
+ log = dict()
1400
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
1401
+ log_mode=True)
1402
+ N = min(x.shape[0], N)
1403
+ n_row = min(x.shape[0], n_row)
1404
+ log["inputs"] = x
1405
+ log["reconstruction"] = xrec
1406
+ log["x_lr"] = x_low
1407
+ log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
1408
+ if self.model.conditioning_key is not None:
1409
+ if hasattr(self.cond_stage_model, "decode"):
1410
+ xc = self.cond_stage_model.decode(c)
1411
+ log["conditioning"] = xc
1412
+ elif self.cond_stage_key in ["caption", "txt"]:
1413
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1414
+ log["conditioning"] = xc
1415
+ elif self.cond_stage_key in ['class_label', 'cls']:
1416
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1417
+ log['conditioning'] = xc
1418
+ elif isimage(xc):
1419
+ log["conditioning"] = xc
1420
+ if ismap(xc):
1421
+ log["original_conditioning"] = self.to_rgb(xc)
1422
+
1423
+ if plot_diffusion_rows:
1424
+ # get diffusion row
1425
+ diffusion_row = list()
1426
+ z_start = z[:n_row]
1427
+ for t in range(self.num_timesteps):
1428
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1429
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1430
+ t = t.to(self.device).long()
1431
+ noise = torch.randn_like(z_start)
1432
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1433
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1434
+
1435
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1436
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1437
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1438
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1439
+ log["diffusion_row"] = diffusion_grid
1440
+
1441
+ if sample:
1442
+ # get denoise row
1443
+ with ema_scope("Sampling"):
1444
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1445
+ ddim_steps=ddim_steps, eta=ddim_eta)
1446
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1447
+ x_samples = self.decode_first_stage(samples)
1448
+ log["samples"] = x_samples
1449
+ if plot_denoise_rows:
1450
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1451
+ log["denoise_row"] = denoise_grid
1452
+
1453
+ if unconditional_guidance_scale > 1.0:
1454
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1455
+ # TODO explore better "unconditional" choices for the other keys
1456
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
1457
+ uc = dict()
1458
+ for k in c:
1459
+ if k == "c_crossattn":
1460
+ assert isinstance(c[k], list) and len(c[k]) == 1
1461
+ uc[k] = [uc_tmp]
1462
+ elif k == "c_adm": # todo: only run with text-based guidance?
1463
+ assert isinstance(c[k], torch.Tensor)
1464
+ #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
1465
+ uc[k] = c[k]
1466
+ elif isinstance(c[k], list):
1467
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
1468
+ else:
1469
+ uc[k] = c[k]
1470
+
1471
+ with ema_scope("Sampling with classifier-free guidance"):
1472
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1473
+ ddim_steps=ddim_steps, eta=ddim_eta,
1474
+ unconditional_guidance_scale=unconditional_guidance_scale,
1475
+ unconditional_conditioning=uc,
1476
+ )
1477
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1478
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1479
+
1480
+ if plot_progressive_rows:
1481
+ with ema_scope("Plotting Progressives"):
1482
+ img, progressives = self.progressive_denoising(c,
1483
+ shape=(self.channels, self.image_size, self.image_size),
1484
+ batch_size=N)
1485
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1486
+ log["progressive_row"] = prog_row
1487
+
1488
+ return log
1489
+
1490
+
1491
+ class LatentFinetuneDiffusion(LatentDiffusion):
1492
+ """
1493
+ Basis for different finetunas, such as inpainting or depth2image
1494
+ To disable finetuning mode, set finetune_keys to None
1495
+ """
1496
+
1497
+ def __init__(self,
1498
+ concat_keys: tuple,
1499
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
1500
+ "model_ema.diffusion_modelinput_blocks00weight"
1501
+ ),
1502
+ keep_finetune_dims=4,
1503
+ # if model was trained without concat mode before and we would like to keep these channels
1504
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
1505
+ c_concat_log_end=None,
1506
+ *args, **kwargs
1507
+ ):
1508
+ ckpt_path = kwargs.pop("ckpt_path", None)
1509
+ ignore_keys = kwargs.pop("ignore_keys", list())
1510
+ super().__init__(*args, **kwargs)
1511
+ self.finetune_keys = finetune_keys
1512
+ self.concat_keys = concat_keys
1513
+ self.keep_dims = keep_finetune_dims
1514
+ self.c_concat_log_start = c_concat_log_start
1515
+ self.c_concat_log_end = c_concat_log_end
1516
+ if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
1517
+ if exists(ckpt_path):
1518
+ self.init_from_ckpt(ckpt_path, ignore_keys)
1519
+
1520
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
1521
+ sd = torch.load(path, map_location="cpu")
1522
+ if "state_dict" in list(sd.keys()):
1523
+ sd = sd["state_dict"]
1524
+ keys = list(sd.keys())
1525
+ for k in keys:
1526
+ for ik in ignore_keys:
1527
+ if k.startswith(ik):
1528
+ print("Deleting key {} from state_dict.".format(k))
1529
+ del sd[k]
1530
+
1531
+ # make it explicit, finetune by including extra input channels
1532
+ if exists(self.finetune_keys) and k in self.finetune_keys:
1533
+ new_entry = None
1534
+ for name, param in self.named_parameters():
1535
+ if name in self.finetune_keys:
1536
+ print(
1537
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
1538
+ new_entry = torch.zeros_like(param) # zero init
1539
+ assert exists(new_entry), 'did not find matching parameter to modify'
1540
+ new_entry[:, :self.keep_dims, ...] = sd[k]
1541
+ sd[k] = new_entry
1542
+
1543
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
1544
+ sd, strict=False)
1545
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
1546
+ if len(missing) > 0:
1547
+ print(f"Missing Keys: {missing}")
1548
+ if len(unexpected) > 0:
1549
+ print(f"Unexpected Keys: {unexpected}")
1550
+
1551
+ @torch.no_grad()
1552
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1553
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1554
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1555
+ use_ema_scope=True,
1556
+ **kwargs):
1557
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1558
+ use_ddim = ddim_steps is not None
1559
+
1560
+ log = dict()
1561
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
1562
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
1563
+ N = min(x.shape[0], N)
1564
+ n_row = min(x.shape[0], n_row)
1565
+ log["inputs"] = x
1566
+ log["reconstruction"] = xrec
1567
+ if self.model.conditioning_key is not None:
1568
+ if hasattr(self.cond_stage_model, "decode"):
1569
+ xc = self.cond_stage_model.decode(c)
1570
+ log["conditioning"] = xc
1571
+ elif self.cond_stage_key in ["caption", "txt"]:
1572
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1573
+ log["conditioning"] = xc
1574
+ elif self.cond_stage_key in ['class_label', 'cls']:
1575
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1576
+ log['conditioning'] = xc
1577
+ elif isimage(xc):
1578
+ log["conditioning"] = xc
1579
+ if ismap(xc):
1580
+ log["original_conditioning"] = self.to_rgb(xc)
1581
+
1582
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
1583
+ log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
1584
+
1585
+ if plot_diffusion_rows:
1586
+ # get diffusion row
1587
+ diffusion_row = list()
1588
+ z_start = z[:n_row]
1589
+ for t in range(self.num_timesteps):
1590
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1591
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1592
+ t = t.to(self.device).long()
1593
+ noise = torch.randn_like(z_start)
1594
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1595
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1596
+
1597
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1598
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1599
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1600
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1601
+ log["diffusion_row"] = diffusion_grid
1602
+
1603
+ if sample:
1604
+ # get denoise row
1605
+ with ema_scope("Sampling"):
1606
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1607
+ batch_size=N, ddim=use_ddim,
1608
+ ddim_steps=ddim_steps, eta=ddim_eta)
1609
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1610
+ x_samples = self.decode_first_stage(samples)
1611
+ log["samples"] = x_samples
1612
+ if plot_denoise_rows:
1613
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1614
+ log["denoise_row"] = denoise_grid
1615
+
1616
+ if unconditional_guidance_scale > 1.0:
1617
+ uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1618
+ uc_cat = c_cat
1619
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
1620
+ with ema_scope("Sampling with classifier-free guidance"):
1621
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1622
+ batch_size=N, ddim=use_ddim,
1623
+ ddim_steps=ddim_steps, eta=ddim_eta,
1624
+ unconditional_guidance_scale=unconditional_guidance_scale,
1625
+ unconditional_conditioning=uc_full,
1626
+ )
1627
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1628
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1629
+
1630
+ return log
1631
+
1632
+
1633
+ class LatentInpaintDiffusion(LatentFinetuneDiffusion):
1634
+ """
1635
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
1636
+ e.g. mask as concat and text via cross-attn.
1637
+ To disable finetuning mode, set finetune_keys to None
1638
+ """
1639
+
1640
+ def __init__(self,
1641
+ concat_keys=("mask", "masked_image"),
1642
+ masked_image_key="masked_image",
1643
+ *args, **kwargs
1644
+ ):
1645
+ super().__init__(concat_keys, *args, **kwargs)
1646
+ self.masked_image_key = masked_image_key
1647
+ assert self.masked_image_key in concat_keys
1648
+
1649
+ @torch.no_grad()
1650
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1651
+ # note: restricted to non-trainable encoders currently
1652
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
1653
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1654
+ force_c_encode=True, return_original_cond=True, bs=bs)
1655
+
1656
+ assert exists(self.concat_keys)
1657
+ c_cat = list()
1658
+ for ck in self.concat_keys:
1659
+ cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1660
+ if bs is not None:
1661
+ cc = cc[:bs]
1662
+ cc = cc.to(self.device)
1663
+ bchw = z.shape
1664
+ if ck != self.masked_image_key:
1665
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
1666
+ else:
1667
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
1668
+ c_cat.append(cc)
1669
+ c_cat = torch.cat(c_cat, dim=1)
1670
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1671
+ if return_first_stage_outputs:
1672
+ return z, all_conds, x, xrec, xc
1673
+ return z, all_conds
1674
+
1675
+ @torch.no_grad()
1676
+ def log_images(self, *args, **kwargs):
1677
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
1678
+ log["masked_image"] = rearrange(args[0]["masked_image"],
1679
+ 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1680
+ return log
1681
+
1682
+
1683
+ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
1684
+ """
1685
+ condition on monocular depth estimation
1686
+ """
1687
+
1688
+ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
1689
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
1690
+ self.depth_model = instantiate_from_config(depth_stage_config)
1691
+ self.depth_stage_key = concat_keys[0]
1692
+
1693
+ @torch.no_grad()
1694
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1695
+ # note: restricted to non-trainable encoders currently
1696
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
1697
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1698
+ force_c_encode=True, return_original_cond=True, bs=bs)
1699
+
1700
+ assert exists(self.concat_keys)
1701
+ assert len(self.concat_keys) == 1
1702
+ c_cat = list()
1703
+ for ck in self.concat_keys:
1704
+ cc = batch[ck]
1705
+ if bs is not None:
1706
+ cc = cc[:bs]
1707
+ cc = cc.to(self.device)
1708
+ cc = self.depth_model(cc)
1709
+ cc = torch.nn.functional.interpolate(
1710
+ cc,
1711
+ size=z.shape[2:],
1712
+ mode="bicubic",
1713
+ align_corners=False,
1714
+ )
1715
+
1716
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
1717
+ keepdim=True)
1718
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
1719
+ c_cat.append(cc)
1720
+ c_cat = torch.cat(c_cat, dim=1)
1721
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1722
+ if return_first_stage_outputs:
1723
+ return z, all_conds, x, xrec, xc
1724
+ return z, all_conds
1725
+
1726
+ @torch.no_grad()
1727
+ def log_images(self, *args, **kwargs):
1728
+ log = super().log_images(*args, **kwargs)
1729
+ depth = self.depth_model(args[0][self.depth_stage_key])
1730
+ depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
1731
+ torch.amax(depth, dim=[1, 2, 3], keepdim=True)
1732
+ log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
1733
+ return log
1734
+
1735
+
1736
+ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
1737
+ """
1738
+ condition on low-res image (and optionally on some spatial noise augmentation)
1739
+ """
1740
+ def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
1741
+ low_scale_config=None, low_scale_key=None, *args, **kwargs):
1742
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
1743
+ self.reshuffle_patch_size = reshuffle_patch_size
1744
+ self.low_scale_model = None
1745
+ if low_scale_config is not None:
1746
+ print("Initializing a low-scale model")
1747
+ assert exists(low_scale_key)
1748
+ self.instantiate_low_stage(low_scale_config)
1749
+ self.low_scale_key = low_scale_key
1750
+
1751
+ def instantiate_low_stage(self, config):
1752
+ model = instantiate_from_config(config)
1753
+ self.low_scale_model = model.eval()
1754
+ self.low_scale_model.train = disabled_train
1755
+ for param in self.low_scale_model.parameters():
1756
+ param.requires_grad = False
1757
+
1758
+ @torch.no_grad()
1759
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1760
+ # note: restricted to non-trainable encoders currently
1761
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
1762
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1763
+ force_c_encode=True, return_original_cond=True, bs=bs)
1764
+
1765
+ assert exists(self.concat_keys)
1766
+ assert len(self.concat_keys) == 1
1767
+ # optionally make spatial noise_level here
1768
+ c_cat = list()
1769
+ noise_level = None
1770
+ for ck in self.concat_keys:
1771
+ cc = batch[ck]
1772
+ cc = rearrange(cc, 'b h w c -> b c h w')
1773
+ if exists(self.reshuffle_patch_size):
1774
+ assert isinstance(self.reshuffle_patch_size, int)
1775
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
1776
+ p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
1777
+ if bs is not None:
1778
+ cc = cc[:bs]
1779
+ cc = cc.to(self.device)
1780
+ if exists(self.low_scale_model) and ck == self.low_scale_key:
1781
+ cc, noise_level = self.low_scale_model(cc)
1782
+ c_cat.append(cc)
1783
+ c_cat = torch.cat(c_cat, dim=1)
1784
+ if exists(noise_level):
1785
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
1786
+ else:
1787
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1788
+ if return_first_stage_outputs:
1789
+ return z, all_conds, x, xrec, xc
1790
+ return z, all_conds
1791
+
1792
+ @torch.no_grad()
1793
+ def log_images(self, *args, **kwargs):
1794
+ log = super().log_images(*args, **kwargs)
1795
+ log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
1796
+ return log
ldm/models/diffusion/dpm_solver/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import DPMSolverSampler
ldm/models/diffusion/dpm_solver/dpm_solver.py ADDED
@@ -0,0 +1,1194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ import ptp_scripts.ptp_utils as ptp_utils
5
+ from tqdm import tqdm
6
+
7
+
8
+ class NoiseScheduleVP:
9
+ def __init__(
10
+ self,
11
+ schedule='discrete',
12
+ betas=None,
13
+ alphas_cumprod=None,
14
+ continuous_beta_0=0.1,
15
+ continuous_beta_1=20.,
16
+ ):
17
+ """Create a wrapper class for the forward SDE (VP type).
18
+ ***
19
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
20
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
21
+ ***
22
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
23
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
24
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
25
+ log_alpha_t = self.marginal_log_mean_coeff(t)
26
+ sigma_t = self.marginal_std(t)
27
+ lambda_t = self.marginal_lambda(t)
28
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
29
+ t = self.inverse_lambda(lambda_t)
30
+ ===============================================================
31
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
32
+ 1. For discrete-time DPMs:
33
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
34
+ t_i = (i + 1) / N
35
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
36
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
37
+ Args:
38
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
39
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
40
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
41
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
42
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
43
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
44
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
45
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
46
+ and
47
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
48
+ 2. For continuous-time DPMs:
49
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
50
+ schedule are the default settings in DDPM and improved-DDPM:
51
+ Args:
52
+ beta_min: A `float` number. The smallest beta for the linear schedule.
53
+ beta_max: A `float` number. The largest beta for the linear schedule.
54
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
55
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
56
+ T: A `float` number. The ending time of the forward process.
57
+ ===============================================================
58
+ Args:
59
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
60
+ 'linear' or 'cosine' for continuous-time DPMs.
61
+ Returns:
62
+ A wrapper object of the forward SDE (VP type).
63
+
64
+ ===============================================================
65
+ Example:
66
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
67
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
68
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
69
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
70
+ # For continuous-time DPMs (VPSDE), linear schedule:
71
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
72
+ """
73
+
74
+ if schedule not in ['discrete', 'linear', 'cosine']:
75
+ raise ValueError(
76
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
77
+ schedule))
78
+
79
+ self.schedule = schedule
80
+ if schedule == 'discrete':
81
+ if betas is not None:
82
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
83
+ else:
84
+ assert alphas_cumprod is not None
85
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
86
+ self.total_N = len(log_alphas)
87
+ self.T = 1.
88
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
89
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
90
+ else:
91
+ self.total_N = 1000
92
+ self.beta_0 = continuous_beta_0
93
+ self.beta_1 = continuous_beta_1
94
+ self.cosine_s = 0.008
95
+ self.cosine_beta_max = 999.
96
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
97
+ 1. + self.cosine_s) / math.pi - self.cosine_s
98
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
99
+ self.schedule = schedule
100
+ if schedule == 'cosine':
101
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
102
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
103
+ self.T = 0.9946
104
+ else:
105
+ self.T = 1.
106
+
107
+ def marginal_log_mean_coeff(self, t):
108
+ """
109
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
110
+ """
111
+ if self.schedule == 'discrete':
112
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
113
+ self.log_alpha_array.to(t.device)).reshape((-1))
114
+ elif self.schedule == 'linear':
115
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
116
+ elif self.schedule == 'cosine':
117
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
118
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
119
+ return log_alpha_t
120
+
121
+ def marginal_alpha(self, t):
122
+ """
123
+ Compute alpha_t of a given continuous-time label t in [0, T].
124
+ """
125
+ return torch.exp(self.marginal_log_mean_coeff(t))
126
+
127
+ def marginal_std(self, t):
128
+ """
129
+ Compute sigma_t of a given continuous-time label t in [0, T].
130
+ """
131
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
132
+
133
+ def marginal_lambda(self, t):
134
+ """
135
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
136
+ """
137
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
138
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
139
+ return log_mean_coeff - log_std
140
+
141
+ def inverse_lambda(self, lamb):
142
+ """
143
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
144
+ """
145
+ if self.schedule == 'linear':
146
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
147
+ Delta = self.beta_0 ** 2 + tmp
148
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
149
+ elif self.schedule == 'discrete':
150
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
151
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
152
+ torch.flip(self.t_array.to(lamb.device), [1]))
153
+ return t.reshape((-1,))
154
+ else:
155
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
156
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
157
+ 1. + self.cosine_s) / math.pi - self.cosine_s
158
+ t = t_fn(log_alpha)
159
+ return t
160
+
161
+
162
+ def model_wrapper(
163
+ model,
164
+ noise_schedule,
165
+ model_type="noise",
166
+ model_kwargs={},
167
+ guidance_type="uncond",
168
+ condition=None,
169
+ unconditional_condition=None,
170
+ guidance_scale=1.,
171
+ classifier_fn=None,
172
+ classifier_kwargs={},
173
+ ):
174
+ """Create a wrapper function for the noise prediction model.
175
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
176
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
177
+ We support four types of the diffusion model by setting `model_type`:
178
+ 1. "noise": noise prediction model. (Trained by predicting noise).
179
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
180
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
181
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
182
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
183
+ arXiv preprint arXiv:2202.00512 (2022).
184
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
185
+ arXiv preprint arXiv:2210.02303 (2022).
186
+
187
+ 4. "score": marginal score function. (Trained by denoising score matching).
188
+ Note that the score function and the noise prediction model follows a simple relationship:
189
+ ```
190
+ noise(x_t, t) = -sigma_t * score(x_t, t)
191
+ ```
192
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
193
+ 1. "uncond": unconditional sampling by DPMs.
194
+ The input `model` has the following format:
195
+ ``
196
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
197
+ ``
198
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
199
+ The input `model` has the following format:
200
+ ``
201
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
202
+ ``
203
+ The input `classifier_fn` has the following format:
204
+ ``
205
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
206
+ ``
207
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
208
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
209
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
210
+ The input `model` has the following format:
211
+ ``
212
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
213
+ ``
214
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
215
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
216
+ arXiv preprint arXiv:2207.12598 (2022).
217
+
218
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
219
+ or continuous-time labels (i.e. epsilon to T).
220
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
221
+ ``
222
+ def model_fn(x, t_continuous) -> noise:
223
+ t_input = get_model_input_time(t_continuous)
224
+ return noise_pred(model, x, t_input, **model_kwargs)
225
+ ``
226
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
227
+ ===============================================================
228
+ Args:
229
+ model: A diffusion model with the corresponding format described above.
230
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
231
+ model_type: A `str`. The parameterization type of the diffusion model.
232
+ "noise" or "x_start" or "v" or "score".
233
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
234
+ guidance_type: A `str`. The type of the guidance for sampling.
235
+ "uncond" or "classifier" or "classifier-free".
236
+ condition: A pytorch tensor. The condition for the guided sampling.
237
+ Only used for "classifier" or "classifier-free" guidance type.
238
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
239
+ Only used for "classifier-free" guidance type.
240
+ guidance_scale: A `float`. The scale for the guided sampling.
241
+ classifier_fn: A classifier function. Only used for the classifier guidance.
242
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
243
+ Returns:
244
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
245
+ """
246
+
247
+ def get_model_input_time(t_continuous):
248
+ """
249
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
250
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
251
+ For continuous-time DPMs, we just use `t_continuous`.
252
+ """
253
+ if noise_schedule.schedule == 'discrete':
254
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
255
+ else:
256
+ return t_continuous
257
+
258
+ def noise_pred_fn(x, t_continuous, cond=None, DPMencode=False, controller=None, inject=False):
259
+ if t_continuous.reshape((-1,)).shape[0] == 1:
260
+ t_continuous = t_continuous.expand((x.shape[0]))
261
+ t_input = get_model_input_time(t_continuous)
262
+ if cond is None:
263
+ output = model(x, t_input, **model_kwargs)
264
+ else:
265
+ output = model(x, t_input, cond, DPMencode, controller=controller, inject=inject, **model_kwargs)
266
+ if model_type == "noise":
267
+ return output
268
+ elif model_type == "x_start":
269
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
270
+ dims = x.dim()
271
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
272
+ elif model_type == "v":
273
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
274
+ dims = x.dim()
275
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
276
+ elif model_type == "score":
277
+ sigma_t = noise_schedule.marginal_std(t_continuous)
278
+ dims = x.dim()
279
+ return -expand_dims(sigma_t, dims) * output
280
+
281
+ def cond_grad_fn(x, t_input):
282
+ """
283
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
284
+ """
285
+ with torch.enable_grad():
286
+ x_in = x.detach().requires_grad_(True)
287
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
288
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
289
+
290
+ def model_fn(x, t_continuous, DPMencode=False, controller=None, inject=False, ref_init=None):
291
+ """
292
+ The noise predicition model function that is used for DPM-Solver.
293
+ """
294
+ if t_continuous.reshape((-1,)).shape[0] == 1:
295
+ t_continuous = t_continuous.expand((x.shape[0]))
296
+ if guidance_type == "uncond":
297
+ return noise_pred_fn(x, t_continuous)
298
+ elif guidance_type == "classifier":
299
+ assert classifier_fn is not None
300
+ t_input = get_model_input_time(t_continuous)
301
+ cond_grad = cond_grad_fn(x, t_input)
302
+ sigma_t = noise_schedule.marginal_std(t_continuous)
303
+ noise = noise_pred_fn(x, t_continuous)
304
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
305
+ elif guidance_type == "classifier-free":
306
+ if guidance_scale == 1. or unconditional_condition is None:
307
+ return noise_pred_fn(x, t_continuous, cond=condition, DPMencode=DPMencode, controller=controller, inject=inject)
308
+ else:
309
+ if ref_init == None:
310
+ x_in = torch.cat([x] * 2)
311
+ c_in = torch.cat([unconditional_condition, condition])
312
+ else:
313
+ x_in = torch.cat([x, x, ref_init, ref_init], dim=0)
314
+ uc = torch.cat([unconditional_condition] * 2)
315
+ c = torch.cat([condition] * 2)
316
+ c_in = torch.cat([uc, c])
317
+
318
+ # x_in = torch.cat([x] * 2)
319
+ t_in = torch.cat([t_continuous] * 2)
320
+
321
+ if ref_init == None:
322
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in, DPMencode=DPMencode, controller=controller, inject=inject).chunk(2)
323
+ else:
324
+ noise_uncond, noise, _, _ = noise_pred_fn(x_in, t_in, cond=c_in, DPMencode=DPMencode, controller=controller, inject=inject).chunk(4)
325
+
326
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
327
+
328
+ assert model_type in ["noise", "x_start", "v"]
329
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
330
+ return model_fn
331
+
332
+
333
+ class DPM_Solver:
334
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
335
+ """Construct a DPM-Solver.
336
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
337
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
338
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
339
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
340
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
341
+ Args:
342
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
343
+ ``
344
+ def model_fn(x, t_continuous):
345
+ return noise
346
+ ``
347
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
348
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
349
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
350
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
351
+
352
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
353
+ """
354
+ self.model = model_fn
355
+ self.noise_schedule = noise_schedule
356
+ self.predict_x0 = predict_x0
357
+ self.thresholding = thresholding
358
+ self.max_val = max_val
359
+
360
+ def noise_prediction_fn(self, x, t, DPMencode=False, controller=None, inject=False, ref_init=None):
361
+ """
362
+ Return the noise prediction model.
363
+ """
364
+ # ptp_utils.register_attention_control(self.model, controller)
365
+ return self.model(x, t, DPMencode=DPMencode, controller=controller, inject=inject, ref_init=ref_init)
366
+
367
+ def data_prediction_fn(self, x, t, DPMencode=False, controller=None, inject=False, ref_init=None):
368
+ """
369
+ Return the data prediction model (with thresholding).
370
+ """
371
+ noise = self.noise_prediction_fn(x, t, DPMencode=DPMencode, controller=controller, inject=inject, ref_init=ref_init)
372
+ dims = x.dim()
373
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
374
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
375
+ if self.thresholding:
376
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
377
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
378
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
379
+ x0 = torch.clamp(x0, -s, s) / s
380
+ return x0
381
+
382
+ def model_fn(self, x, t, DPMencode=False, controller=None, inject=False, ref_init=None):
383
+ """
384
+ Convert the model to the noise prediction model or the data prediction model.
385
+ """
386
+ if self.predict_x0:
387
+ return self.data_prediction_fn(x, t, DPMencode=DPMencode, controller=controller, inject=inject, ref_init=ref_init)
388
+ else:
389
+ return self.noise_prediction_fn(x, t, DPMencode=DPMencode)
390
+
391
+ def get_time_steps(self, skip_type, t_T, t_0, N, device, DPMencode=False):
392
+ """Compute the intermediate time steps for sampling.
393
+ Args:
394
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
395
+ - 'logSNR': uniform logSNR for the time steps.
396
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
397
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
398
+ t_T: A `float`. The starting time of the sampling (default is T).
399
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
400
+ N: A `int`. The total number of the spacing of the time steps.
401
+ device: A torch device.
402
+ Returns:
403
+ A pytorch tensor of the time steps, with the shape (N + 1,).
404
+ """
405
+ if skip_type == 'logSNR':
406
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
407
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
408
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
409
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
410
+ elif skip_type == 'time_uniform':
411
+ if DPMencode:
412
+ return torch.linspace(t_0, t_T, N + 1).to(device)
413
+ else:
414
+ return torch.linspace(t_T, t_0, N + 1).to(device)
415
+
416
+ elif skip_type == 'time_quadratic':
417
+ t_order = 2
418
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
419
+ return t
420
+ else:
421
+ raise ValueError(
422
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
423
+
424
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device, DPMencode=False):
425
+ """
426
+ Get the order of each step for sampling by the singlestep DPM-Solver.
427
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
428
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
429
+ - If order == 1:
430
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
431
+ - If order == 2:
432
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
433
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
434
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
435
+ - If order == 3:
436
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
437
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
438
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
439
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
440
+ ============================================
441
+ Args:
442
+ order: A `int`. The max order for the solver (2 or 3).
443
+ steps: A `int`. The total number of function evaluations (NFE).
444
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
445
+ - 'logSNR': uniform logSNR for the time steps.
446
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
447
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
448
+ t_T: A `float`. The starting time of the sampling (default is T).
449
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
450
+ device: A torch device.
451
+ Returns:
452
+ orders: A list of the solver order of each step.
453
+ """
454
+ if order == 3:
455
+ K = steps // 3 + 1
456
+ if steps % 3 == 0:
457
+ orders = [3, ] * (K - 2) + [2, 1]
458
+ elif steps % 3 == 1:
459
+ orders = [3, ] * (K - 1) + [1]
460
+ else:
461
+ orders = [3, ] * (K - 1) + [2]
462
+ elif order == 2:
463
+ if steps % 2 == 0:
464
+ K = steps // 2
465
+ orders = [2, ] * K
466
+ else:
467
+ K = steps // 2 + 1
468
+ orders = [2, ] * (K - 1) + [1]
469
+ elif order == 1:
470
+ K = 1
471
+ orders = [1, ] * steps
472
+ else:
473
+ raise ValueError("'order' must be '1' or '2' or '3'.")
474
+ if skip_type == 'logSNR':
475
+ # To reproduce the results in DPM-Solver paper
476
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device, DPMencode=DPMencode)
477
+ else:
478
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device, DPMencode=DPMencode)[
479
+ torch.cumsum(torch.tensor([0, ] + orders), dim=0).to(device)]
480
+ return timesteps_outer, orders
481
+
482
+ def denoise_to_zero_fn(self, x, s):
483
+ """
484
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
485
+ """
486
+ return self.data_prediction_fn(x, s)
487
+
488
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False, DPMencode=False):
489
+ """
490
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
491
+ Args:
492
+ x: A pytorch tensor. The initial value at time `s`.
493
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
494
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
495
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
496
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
497
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
498
+ Returns:
499
+ x_t: A pytorch tensor. The approximated solution at time `t`.
500
+ """
501
+ ns = self.noise_schedule
502
+ dims = x.dim()
503
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
504
+ h = lambda_t - lambda_s
505
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
506
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
507
+ alpha_t = torch.exp(log_alpha_t)
508
+
509
+ if self.predict_x0:
510
+ phi_1 = torch.expm1(-h)
511
+ if model_s is None:
512
+ model_s = self.model_fn(x, s, DPMencode=DPMencode)
513
+ x_t = (
514
+ expand_dims(sigma_t / sigma_s, dims) * x
515
+ - expand_dims(alpha_t * phi_1, dims) * model_s
516
+ )
517
+ if return_intermediate:
518
+ return x_t, {'model_s': model_s}
519
+ else:
520
+ return x_t
521
+ else:
522
+ phi_1 = torch.expm1(h)
523
+ if model_s is None:
524
+ model_s = self.model_fn(x, s, DPMencode=DPMencode)
525
+ x_t = (
526
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
527
+ - expand_dims(sigma_t * phi_1, dims) * model_s
528
+ )
529
+ if return_intermediate:
530
+ return x_t, {'model_s': model_s}
531
+ else:
532
+ return x_t
533
+
534
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
535
+ solver_type='dpm_solver'):
536
+ """
537
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
538
+ Args:
539
+ x: A pytorch tensor. The initial value at time `s`.
540
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
541
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
542
+ r1: A `float`. The hyperparameter of the second-order solver.
543
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
544
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
545
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
546
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
547
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
548
+ Returns:
549
+ x_t: A pytorch tensor. The approximated solution at time `t`.
550
+ """
551
+ if solver_type not in ['dpm_solver', 'taylor']:
552
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
553
+ if r1 is None:
554
+ r1 = 0.5
555
+ ns = self.noise_schedule
556
+ dims = x.dim()
557
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
558
+ h = lambda_t - lambda_s
559
+ lambda_s1 = lambda_s + r1 * h
560
+ s1 = ns.inverse_lambda(lambda_s1)
561
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
562
+ s1), ns.marginal_log_mean_coeff(t)
563
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
564
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
565
+
566
+ if self.predict_x0:
567
+ phi_11 = torch.expm1(-r1 * h)
568
+ phi_1 = torch.expm1(-h)
569
+
570
+ if model_s is None:
571
+ model_s = self.model_fn(x, s)
572
+ x_s1 = (
573
+ expand_dims(sigma_s1 / sigma_s, dims) * x
574
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
575
+ )
576
+ model_s1 = self.model_fn(x_s1, s1)
577
+ if solver_type == 'dpm_solver':
578
+ x_t = (
579
+ expand_dims(sigma_t / sigma_s, dims) * x
580
+ - expand_dims(alpha_t * phi_1, dims) * model_s
581
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
582
+ )
583
+ elif solver_type == 'taylor':
584
+ x_t = (
585
+ expand_dims(sigma_t / sigma_s, dims) * x
586
+ - expand_dims(alpha_t * phi_1, dims) * model_s
587
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
588
+ model_s1 - model_s)
589
+ )
590
+ else:
591
+ phi_11 = torch.expm1(r1 * h)
592
+ phi_1 = torch.expm1(h)
593
+
594
+ if model_s is None:
595
+ model_s = self.model_fn(x, s)
596
+ x_s1 = (
597
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
598
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
599
+ )
600
+ model_s1 = self.model_fn(x_s1, s1)
601
+ if solver_type == 'dpm_solver':
602
+ x_t = (
603
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
604
+ - expand_dims(sigma_t * phi_1, dims) * model_s
605
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
606
+ )
607
+ elif solver_type == 'taylor':
608
+ x_t = (
609
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
610
+ - expand_dims(sigma_t * phi_1, dims) * model_s
611
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
612
+ )
613
+ if return_intermediate:
614
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
615
+ else:
616
+ return x_t
617
+
618
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
619
+ return_intermediate=False, solver_type='dpm_solver'):
620
+ """
621
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
622
+ Args:
623
+ x: A pytorch tensor. The initial value at time `s`.
624
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
625
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
626
+ r1: A `float`. The hyperparameter of the third-order solver.
627
+ r2: A `float`. The hyperparameter of the third-order solver.
628
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
629
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
630
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
631
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
632
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
633
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
634
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
635
+ Returns:
636
+ x_t: A pytorch tensor. The approximated solution at time `t`.
637
+ """
638
+ if solver_type not in ['dpm_solver', 'taylor']:
639
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
640
+ if r1 is None:
641
+ r1 = 1. / 3.
642
+ if r2 is None:
643
+ r2 = 2. / 3.
644
+ ns = self.noise_schedule
645
+ dims = x.dim()
646
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
647
+ h = lambda_t - lambda_s
648
+ lambda_s1 = lambda_s + r1 * h
649
+ lambda_s2 = lambda_s + r2 * h
650
+ s1 = ns.inverse_lambda(lambda_s1)
651
+ s2 = ns.inverse_lambda(lambda_s2)
652
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
653
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
654
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
655
+ s2), ns.marginal_std(t)
656
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
657
+
658
+ if self.predict_x0:
659
+ phi_11 = torch.expm1(-r1 * h)
660
+ phi_12 = torch.expm1(-r2 * h)
661
+ phi_1 = torch.expm1(-h)
662
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
663
+ phi_2 = phi_1 / h + 1.
664
+ phi_3 = phi_2 / h - 0.5
665
+
666
+ if model_s is None:
667
+ model_s = self.model_fn(x, s)
668
+ if model_s1 is None:
669
+ x_s1 = (
670
+ expand_dims(sigma_s1 / sigma_s, dims) * x
671
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
672
+ )
673
+ model_s1 = self.model_fn(x_s1, s1)
674
+ x_s2 = (
675
+ expand_dims(sigma_s2 / sigma_s, dims) * x
676
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
677
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
678
+ )
679
+ model_s2 = self.model_fn(x_s2, s2)
680
+ if solver_type == 'dpm_solver':
681
+ x_t = (
682
+ expand_dims(sigma_t / sigma_s, dims) * x
683
+ - expand_dims(alpha_t * phi_1, dims) * model_s
684
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
685
+ )
686
+ elif solver_type == 'taylor':
687
+ D1_0 = (1. / r1) * (model_s1 - model_s)
688
+ D1_1 = (1. / r2) * (model_s2 - model_s)
689
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
690
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
691
+ x_t = (
692
+ expand_dims(sigma_t / sigma_s, dims) * x
693
+ - expand_dims(alpha_t * phi_1, dims) * model_s
694
+ + expand_dims(alpha_t * phi_2, dims) * D1
695
+ - expand_dims(alpha_t * phi_3, dims) * D2
696
+ )
697
+ else:
698
+ phi_11 = torch.expm1(r1 * h)
699
+ phi_12 = torch.expm1(r2 * h)
700
+ phi_1 = torch.expm1(h)
701
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
702
+ phi_2 = phi_1 / h - 1.
703
+ phi_3 = phi_2 / h - 0.5
704
+
705
+ if model_s is None:
706
+ model_s = self.model_fn(x, s)
707
+ if model_s1 is None:
708
+ x_s1 = (
709
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
710
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
711
+ )
712
+ model_s1 = self.model_fn(x_s1, s1)
713
+ x_s2 = (
714
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
715
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
716
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
717
+ )
718
+ model_s2 = self.model_fn(x_s2, s2)
719
+ if solver_type == 'dpm_solver':
720
+ x_t = (
721
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
722
+ - expand_dims(sigma_t * phi_1, dims) * model_s
723
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
724
+ )
725
+ elif solver_type == 'taylor':
726
+ D1_0 = (1. / r1) * (model_s1 - model_s)
727
+ D1_1 = (1. / r2) * (model_s2 - model_s)
728
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
729
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
730
+ x_t = (
731
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
732
+ - expand_dims(sigma_t * phi_1, dims) * model_s
733
+ - expand_dims(sigma_t * phi_2, dims) * D1
734
+ - expand_dims(sigma_t * phi_3, dims) * D2
735
+ )
736
+
737
+ if return_intermediate:
738
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
739
+ else:
740
+ return x_t
741
+
742
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver", DPMencode=False):
743
+ """
744
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
745
+ Args:
746
+ x: A pytorch tensor. The initial value at time `s`.
747
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
748
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
749
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
750
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
751
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
752
+ Returns:
753
+ x_t: A pytorch tensor. The approximated solution at time `t`.
754
+ """
755
+ if solver_type not in ['dpm_solver', 'taylor']:
756
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
757
+ ns = self.noise_schedule
758
+ dims = x.dim()
759
+ model_prev_1, model_prev_0 = model_prev_list
760
+ t_prev_1, t_prev_0 = t_prev_list
761
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
762
+ t_prev_0), ns.marginal_lambda(t)
763
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
764
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
765
+ alpha_t = torch.exp(log_alpha_t)
766
+
767
+ h_0 = lambda_prev_0 - lambda_prev_1
768
+ h = lambda_t - lambda_prev_0
769
+ r0 = h_0 / h
770
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
771
+ if self.predict_x0:
772
+ if solver_type == 'dpm_solver':
773
+ x_t = (
774
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
775
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
776
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
777
+ )
778
+ elif solver_type == 'taylor':
779
+ x_t = (
780
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
781
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
782
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
783
+ )
784
+ else:
785
+ if solver_type == 'dpm_solver':
786
+ x_t = (
787
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
788
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
789
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
790
+ )
791
+ elif solver_type == 'taylor':
792
+ x_t = (
793
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
794
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
795
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
796
+ )
797
+ return x_t
798
+
799
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver', DPMencode=False):
800
+ """
801
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
802
+ Args:
803
+ x: A pytorch tensor. The initial value at time `s`.
804
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
805
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
806
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
807
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
808
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
809
+ Returns:
810
+ x_t: A pytorch tensor. The approximated solution at time `t`.
811
+ """
812
+ ns = self.noise_schedule
813
+ dims = x.dim()
814
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
815
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
816
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
817
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
818
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
819
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
820
+ alpha_t = torch.exp(log_alpha_t)
821
+
822
+ h_1 = lambda_prev_1 - lambda_prev_2
823
+ h_0 = lambda_prev_0 - lambda_prev_1
824
+ h = lambda_t - lambda_prev_0
825
+ r0, r1 = h_0 / h, h_1 / h
826
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
827
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
828
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
829
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
830
+ if self.predict_x0:
831
+ x_t = (
832
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
833
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
834
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
835
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
836
+ )
837
+ else:
838
+ x_t = (
839
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
840
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
841
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
842
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
843
+ )
844
+ return x_t
845
+
846
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
847
+ r2=None):
848
+ """
849
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
850
+ Args:
851
+ x: A pytorch tensor. The initial value at time `s`.
852
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
853
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
854
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
855
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
856
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
857
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
858
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
859
+ r2: A `float`. The hyperparameter of the third-order solver.
860
+ Returns:
861
+ x_t: A pytorch tensor. The approximated solution at time `t`.
862
+ """
863
+ if order == 1:
864
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
865
+ elif order == 2:
866
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
867
+ solver_type=solver_type, r1=r1)
868
+ elif order == 3:
869
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
870
+ solver_type=solver_type, r1=r1, r2=r2)
871
+ else:
872
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
873
+
874
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver', DPMencode=False):
875
+ """
876
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
877
+ Args:
878
+ x: A pytorch tensor. The initial value at time `s`.
879
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
880
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
881
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
882
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
883
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
884
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
885
+ Returns:
886
+ x_t: A pytorch tensor. The approximated solution at time `t`.
887
+ """
888
+ if order == 1:
889
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1], DPMencode=DPMencode)
890
+ elif order == 2:
891
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type, DPMencode=DPMencode)
892
+ elif order == 3:
893
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type, DPMencode=DPMencode)
894
+ else:
895
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
896
+
897
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
898
+ solver_type='dpm_solver'):
899
+ """
900
+ The adaptive step size solver based on singlestep DPM-Solver.
901
+ Args:
902
+ x: A pytorch tensor. The initial value at time `t_T`.
903
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
904
+ t_T: A `float`. The starting time of the sampling (default is T).
905
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
906
+ h_init: A `float`. The initial step size (for logSNR).
907
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
908
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
909
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
910
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
911
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
912
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
913
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
914
+ Returns:
915
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
916
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
917
+ """
918
+ ns = self.noise_schedule
919
+ s = t_T * torch.ones((x.shape[0],)).to(x)
920
+ lambda_s = ns.marginal_lambda(s)
921
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
922
+ h = h_init * torch.ones_like(s).to(x)
923
+ x_prev = x
924
+ nfe = 0
925
+ if order == 2:
926
+ r1 = 0.5
927
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
928
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
929
+ solver_type=solver_type,
930
+ **kwargs)
931
+ elif order == 3:
932
+ r1, r2 = 1. / 3., 2. / 3.
933
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
934
+ return_intermediate=True,
935
+ solver_type=solver_type)
936
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
937
+ solver_type=solver_type,
938
+ **kwargs)
939
+ else:
940
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
941
+ while torch.abs((s - t_0)).mean() > t_err:
942
+ t = ns.inverse_lambda(lambda_s + h)
943
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
944
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
945
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
946
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
947
+ E = norm_fn((x_higher - x_lower) / delta).max()
948
+ if torch.all(E <= 1.):
949
+ x = x_higher
950
+ s = t
951
+ x_prev = x_lower
952
+ lambda_s = ns.marginal_lambda(s)
953
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
954
+ nfe += order
955
+ print('adaptive solver nfe', nfe)
956
+ return x
957
+
958
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
959
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
960
+ atol=0.0078, rtol=0.05, DPMencode=False
961
+ ):
962
+ """
963
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
964
+ =====================================================
965
+ We support the following algorithms for both noise prediction model and data prediction model:
966
+ - 'singlestep':
967
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
968
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
969
+ The total number of function evaluations (NFE) == `steps`.
970
+ Given a fixed NFE == `steps`, the sampling procedure is:
971
+ - If `order` == 1:
972
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
973
+ - If `order` == 2:
974
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
975
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
976
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
977
+ - If `order` == 3:
978
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
979
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
980
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
981
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
982
+ - 'multistep':
983
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
984
+ We initialize the first `order` values by lower order multistep solvers.
985
+ Given a fixed NFE == `steps`, the sampling procedure is:
986
+ Denote K = steps.
987
+ - If `order` == 1:
988
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
989
+ - If `order` == 2:
990
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
991
+ - If `order` == 3:
992
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
993
+ - 'singlestep_fixed':
994
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
995
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
996
+ - 'adaptive':
997
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
998
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
999
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1000
+ (NFE) and the sample quality.
1001
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1002
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1003
+ =====================================================
1004
+ Some advices for choosing the algorithm:
1005
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1006
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
1007
+ e.g.
1008
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
1009
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1010
+ skip_type='time_uniform', method='singlestep')
1011
+ - For **guided sampling with large guidance scale** by DPMs:
1012
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
1013
+ e.g.
1014
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
1015
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1016
+ skip_type='time_uniform', method='multistep')
1017
+ We support three types of `skip_type`:
1018
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1019
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1020
+ - 'time_quadratic': quadratic time for the time steps.
1021
+ =====================================================
1022
+ Args:
1023
+ x: A pytorch tensor. The initial value at time `t_start`
1024
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1025
+ steps: A `int`. The total number of function evaluations (NFE).
1026
+ t_start: A `float`. The starting time of the sampling.
1027
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1028
+ t_end: A `float`. The ending time of the sampling.
1029
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1030
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1031
+ For discrete-time DPMs:
1032
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1033
+ For continuous-time DPMs:
1034
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1035
+ order: A `int`. The order of DPM-Solver.
1036
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1037
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1038
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1039
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1040
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1041
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1042
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1043
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1044
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1045
+ it for high-resolutional images.
1046
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1047
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1048
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1049
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1050
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1051
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1052
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1053
+ Returns:
1054
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1055
+ """
1056
+
1057
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1058
+ t_T = self.noise_schedule.T if t_start is None else t_start
1059
+
1060
+ device = x.device
1061
+ if method == 'adaptive':
1062
+ with torch.no_grad():
1063
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1064
+ solver_type=solver_type)
1065
+ elif method == 'multistep':
1066
+ assert steps >= order
1067
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device, DPMencode=DPMencode)
1068
+ assert timesteps.shape[0] - 1 == steps
1069
+ with torch.no_grad():
1070
+ vec_t = timesteps[0].expand((x.shape[0]))
1071
+ model_prev_list = [self.model_fn(x, vec_t, DPMencode=DPMencode)]
1072
+ t_prev_list = [vec_t]
1073
+ # Init the first `order` values by lower order multistep DPM-Solver.
1074
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
1075
+ vec_t = timesteps[init_order].expand(x.shape[0])
1076
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1077
+ solver_type=solver_type, DPMencode=DPMencode)
1078
+ model_prev_list.append(self.model_fn(x, vec_t, DPMencode=DPMencode))
1079
+ t_prev_list.append(vec_t)
1080
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1081
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1082
+ vec_t = timesteps[step].expand(x.shape[0])
1083
+ if lower_order_final and steps < 15:
1084
+ step_order = min(order, steps + 1 - step)
1085
+ else:
1086
+ step_order = order
1087
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
1088
+ solver_type=solver_type, DPMencode=DPMencode)
1089
+ for i in range(order - 1):
1090
+ t_prev_list[i] = t_prev_list[i + 1]
1091
+ model_prev_list[i] = model_prev_list[i + 1]
1092
+ t_prev_list[-1] = vec_t
1093
+ # We do not need to evaluate the final model value.
1094
+ if step < steps:
1095
+ model_prev_list[-1] = self.model_fn(x, vec_t, DPMencode=DPMencode)
1096
+ elif method in ['singlestep', 'singlestep_fixed']:
1097
+ if method == 'singlestep':
1098
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
1099
+ skip_type=skip_type,
1100
+ t_T=t_T, t_0=t_0,
1101
+ device=device)
1102
+ elif method == 'singlestep_fixed':
1103
+ K = steps // order
1104
+ orders = [order, ] * K
1105
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1106
+ for i, order in enumerate(orders):
1107
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1108
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1109
+ N=order, device=device)
1110
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1111
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1112
+ h = lambda_inner[-1] - lambda_inner[0]
1113
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1114
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1115
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1116
+ if denoise_to_zero:
1117
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1118
+ return x
1119
+
1120
+ def sample_one_step(self, data, step, steps, order=3, lower_order_final=True, solver_type='dpm_solver',
1121
+ atol=0.0078, rtol=0.05, DPMencode=False, controller=None, inject=False, ref_init=None):
1122
+
1123
+ vec_t = data['timesteps'][step].expand(data['x'].shape[0])
1124
+ if lower_order_final and steps < 15:
1125
+ step_order = min(order, steps + 1 - step)
1126
+ else:
1127
+ step_order = order
1128
+ data['x'] = self.multistep_dpm_solver_update(data['x'], data['model_prev_list'], data['t_prev_list'], vec_t, step_order,
1129
+ solver_type=solver_type, DPMencode=DPMencode)
1130
+ for i in range(order - 1):
1131
+ data['t_prev_list'][i] = data['t_prev_list'][i + 1]
1132
+ data['model_prev_list'][i] = data['model_prev_list'][i + 1]
1133
+ data['t_prev_list'][-1] = vec_t
1134
+ # We do not need to evaluate the final model value.
1135
+ if step < steps:
1136
+ data['model_prev_list'][-1] = self.model_fn(data['x'], vec_t, DPMencode=DPMencode, controller=controller, inject=inject, ref_init=ref_init)
1137
+
1138
+ return {'x': data['x'], 'model_prev_list': data['model_prev_list'], 't_prev_list': data['t_prev_list'], 'timesteps': data['timesteps']}
1139
+
1140
+ #############################################################
1141
+ # other utility functions
1142
+ #############################################################
1143
+
1144
+ def interpolate_fn(x, xp, yp):
1145
+ """
1146
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1147
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1148
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1149
+ Args:
1150
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1151
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1152
+ yp: PyTorch tensor with shape [C, K].
1153
+ Returns:
1154
+ The function values f(x), with shape [N, C].
1155
+ """
1156
+ N, K = x.shape[0], xp.shape[1]
1157
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1158
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1159
+ x_idx = torch.argmin(x_indices, dim=2)
1160
+ cand_start_idx = x_idx - 1
1161
+ start_idx = torch.where(
1162
+ torch.eq(x_idx, 0),
1163
+ torch.tensor(1, device=x.device),
1164
+ torch.where(
1165
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1166
+ ),
1167
+ )
1168
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1169
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1170
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1171
+ start_idx2 = torch.where(
1172
+ torch.eq(x_idx, 0),
1173
+ torch.tensor(0, device=x.device),
1174
+ torch.where(
1175
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1176
+ ),
1177
+ )
1178
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1179
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1180
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1181
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1182
+ return cand
1183
+
1184
+
1185
+ def expand_dims(v, dims):
1186
+ """
1187
+ Expand the tensor `v` to the dim `dims`.
1188
+ Args:
1189
+ `v`: a PyTorch tensor with shape [N].
1190
+ `dim`: a `int`.
1191
+ Returns:
1192
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1193
+ """
1194
+ return v[(...,) + (None,) * (dims - 1)]
ldm/models/diffusion/dpm_solver/sampler.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+ import torch
3
+ import ptp_scripts.ptp_scripts as ptp
4
+ import ptp_scripts.ptp_utils as ptp_utils
5
+ # from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
6
+ from scripts.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
7
+ from tqdm import tqdm
8
+
9
+ MODEL_TYPES = {
10
+ "eps": "noise",
11
+ "v": "v"
12
+ }
13
+
14
+
15
+ class DPMSolverSampler(object):
16
+ def __init__(self, model, **kwargs):
17
+ super().__init__()
18
+ self.model = model
19
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
20
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
21
+
22
+ def register_buffer(self, name, attr):
23
+ if type(attr) == torch.Tensor:
24
+ if attr.device != self.model.device:
25
+ attr = attr.to(self.model.device)
26
+ setattr(self, name, attr)
27
+
28
+ @torch.no_grad()
29
+ def sample(self,
30
+ steps,
31
+ batch_size,
32
+ shape,
33
+ conditioning=None,
34
+ inv_emb=None,
35
+ callback=None,
36
+ normals_sequence=None,
37
+ img_callback=None,
38
+ quantize_x0=False,
39
+ eta=0.,
40
+ mask=None,
41
+ x0=None,
42
+ temperature=1.,
43
+ noise_dropout=0.,
44
+ score_corrector=None,
45
+ corrector_kwargs=None,
46
+ verbose=True,
47
+ x_T=None,
48
+ log_every_t=100,
49
+ unconditional_guidance_scale=1.,
50
+ unconditional_conditioning=None,
51
+ t_start=None,
52
+ t_end=None,
53
+ DPMencode=False,
54
+ order=3,
55
+ width=None,
56
+ height=None,
57
+ ref=False,
58
+ top=None,
59
+ left=None,
60
+ bottom=None,
61
+ right=None,
62
+ segmentation_map=None,
63
+ param=None,
64
+ target_height=None,
65
+ target_width=None,
66
+ center_row_rm=None,
67
+ center_col_rm=None,
68
+ tau_a=0.4,
69
+ tau_b=0.8,
70
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
71
+ **kwargs
72
+ ):
73
+ if conditioning is not None:
74
+ if isinstance(conditioning, dict):
75
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
76
+ if cbs != batch_size:
77
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
78
+ else:
79
+ if conditioning.shape[0] != batch_size:
80
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
81
+
82
+ # sampling
83
+ C, H, W = shape
84
+ size = (batch_size, C, H, W)
85
+
86
+ # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {steps}')
87
+
88
+ device = self.model.betas.device
89
+ if x_T is None:
90
+ x = torch.randn(size, device=device)
91
+ else:
92
+ x = x_T
93
+
94
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
95
+
96
+ if DPMencode:
97
+ # x_T is not a list
98
+ model_fn = model_wrapper(
99
+ lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=None, inject=inject),
100
+ ns,
101
+ model_type=MODEL_TYPES[self.model.parameterization],
102
+ guidance_type="classifier-free",
103
+ condition=inv_emb,
104
+ unconditional_condition=inv_emb,
105
+ guidance_scale=unconditional_guidance_scale,
106
+ )
107
+
108
+ dpm_solver = DPM_Solver(model_fn, ns)
109
+ data, _ = self.low_order_sample(x, dpm_solver, steps, order, t_start, t_end, device, DPMencode=DPMencode)
110
+
111
+ for step in range(order, steps + 1):
112
+ data = dpm_solver.sample_one_step(data, step, steps, order=order, DPMencode=DPMencode)
113
+
114
+ return data['x'].to(device), None
115
+ else:
116
+ # x_T is a list
117
+ model_fn_decode = model_wrapper(
118
+ lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject),
119
+ ns,
120
+ model_type=MODEL_TYPES[self.model.parameterization],
121
+ guidance_type="classifier-free",
122
+ condition=inv_emb,
123
+ unconditional_condition=inv_emb,
124
+ guidance_scale=unconditional_guidance_scale,
125
+ )
126
+ model_fn_gen = model_wrapper(
127
+ lambda x, t, c, DPMencode, controller, inject: self.model.apply_model(x, t, c, encode=DPMencode, controller=controller, inject=inject),
128
+ ns,
129
+ model_type=MODEL_TYPES[self.model.parameterization],
130
+ guidance_type="classifier-free",
131
+ condition=conditioning,
132
+ unconditional_condition=unconditional_conditioning,
133
+ guidance_scale=unconditional_guidance_scale,
134
+ )
135
+
136
+ orig_controller = ptp.AttentionStore()
137
+ ref_controller = ptp.AttentionStore()
138
+ cross_controller = ptp.AttentionStore()
139
+ gen_controller = ptp.AttentionStore()
140
+ Inject_controller = ptp.AttentionStore()
141
+
142
+ dpm_solver_decode = DPM_Solver(model_fn_decode, ns)
143
+ dpm_solver_gen = DPM_Solver(model_fn_gen, ns)
144
+
145
+ # decoded background
146
+ ptp_utils.register_attention_control(self.model, orig_controller, center_row_rm, center_col_rm, target_height, target_width,
147
+ width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone())
148
+ orig, orig_controller = self.low_order_sample(x[0], dpm_solver_decode, steps, order, t_start, t_end, device, DPMencode=DPMencode, controller=orig_controller)
149
+ # decoded reference
150
+ ptp_utils.register_attention_control(self.model, ref_controller, center_row_rm, center_col_rm, target_height, target_width,
151
+ width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone())
152
+ ref, ref_controller = self.low_order_sample(x[3], dpm_solver_decode, steps, order, t_start, t_end, device, DPMencode=DPMencode, controller=ref_controller)
153
+
154
+ # decode for cross-attention
155
+ ptp_utils.register_attention_control(self.model, cross_controller, center_row_rm, center_col_rm, target_height, target_width,
156
+ width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), pseudo_cross=True)
157
+ cross, cross_controller = self.low_order_sample(x[2], dpm_solver_decode, steps, order, t_start, t_end, device, DPMencode=DPMencode,
158
+ controller=cross_controller, ref_init=ref['x'].clone())
159
+
160
+ # generation
161
+ Inject_controller = [orig_controller, ref_controller, cross_controller]
162
+ ptp_utils.register_attention_control(self.model, gen_controller, center_row_rm, center_col_rm, target_height, target_width,
163
+ width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), inject_bg=True)
164
+ gen, _ = self.low_order_sample(x[4], dpm_solver_gen, steps, order, t_start, t_end, device,
165
+ DPMencode=DPMencode, controller=Inject_controller, inject=True)
166
+
167
+ for i in range(len(orig['model_prev_list'])):
168
+ blended = orig['model_prev_list'][i].clone()
169
+ blended[:, :, param[0] : param[1], param[2] : param[3]] \
170
+ = gen['model_prev_list'][i][:, :, param[0] : param[1], param[2] : param[3]].clone()
171
+ gen['model_prev_list'][i] = blended.clone()
172
+
173
+ del orig_controller, ref_controller, cross_controller, gen_controller, Inject_controller
174
+
175
+ orig_controller = ptp.AttentionStore()
176
+ ref_controller = ptp.AttentionStore()
177
+ cross_controller = ptp.AttentionStore()
178
+ gen_controller = ptp.AttentionStore()
179
+
180
+ for step in range(order, steps + 1):
181
+ # decoded background
182
+ ptp_utils.register_attention_control(self.model, orig_controller, center_row_rm, center_col_rm, target_height, target_width,
183
+ width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone())
184
+ orig = dpm_solver_decode.sample_one_step(orig, step, steps, order=order, DPMencode=DPMencode)
185
+
186
+ # decode for cross-attention
187
+ ptp_utils.register_attention_control(self.model, cross_controller, center_row_rm, center_col_rm, target_height, target_width,
188
+ width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), pseudo_cross=True)
189
+ cross['x'] = orig['x']
190
+ cross = dpm_solver_decode.sample_one_step(cross, step, steps, order=order, DPMencode=DPMencode, ref_init=ref['x'].clone())
191
+
192
+ if step < int(tau_a*(steps) + 1 - order):
193
+ inject = True
194
+ # decoded reference
195
+ ptp_utils.register_attention_control(self.model, ref_controller, center_row_rm, center_col_rm, target_height, target_width,
196
+ width, height, top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone())
197
+ ref = dpm_solver_decode.sample_one_step(ref, step, steps, order=order, DPMencode=DPMencode)
198
+ controller = [orig_controller, ref_controller, cross_controller]
199
+ else:
200
+ inject = False
201
+ controller = [orig_controller, None, cross_controller]
202
+
203
+ if step < int(0.4*(steps) + 1 - order):
204
+ inject_bg = True
205
+ else:
206
+ inject_bg = False
207
+
208
+ # generation
209
+ ptp_utils.register_attention_control(self.model, gen_controller, center_row_rm, center_col_rm, target_height, target_width, width, height,
210
+ top, left, bottom, right, segmentation_map=segmentation_map[0, 0].clone(), inject_bg=inject_bg)
211
+ gen = dpm_solver_gen.sample_one_step(gen, step, steps, order=order, DPMencode=DPMencode, controller=controller, inject=inject)
212
+
213
+ if step < int(tau_b*(steps) + 1 - order):
214
+ blended = orig['x'].clone()
215
+ blended[:, :, param[0] : param[1], param[2] : param[3]] \
216
+ = gen['x'][:, :, param[0] : param[1], param[2] : param[3]].clone()
217
+ gen['x'] = blended.clone()
218
+
219
+ del orig_controller, ref_controller, cross_controller, gen_controller, controller
220
+ return gen['x'].to(device), None
221
+
222
+
223
+ def low_order_sample(self, x, dpm_solver, steps, order, t_start, t_end, device, DPMencode=False, controller=None, inject=False, ref_init=None):
224
+
225
+ t_0 = 1. / dpm_solver.noise_schedule.total_N if t_end is None else t_end
226
+ t_T = dpm_solver.noise_schedule.T if t_start is None else t_start
227
+
228
+ total_controller = []
229
+ assert steps >= order
230
+ timesteps = dpm_solver.get_time_steps(skip_type="time_uniform", t_T=t_T, t_0=t_0, N=steps, device=device, DPMencode=DPMencode)
231
+ assert timesteps.shape[0] - 1 == steps
232
+ with torch.no_grad():
233
+ vec_t = timesteps[0].expand((x.shape[0]))
234
+ model_prev_list = [dpm_solver.model_fn(x, vec_t, DPMencode=DPMencode,
235
+ controller=[controller[0][0], controller[1][0], controller[2][0]] if isinstance(controller, list) else controller,
236
+ inject=inject, ref_init=ref_init)]
237
+
238
+ total_controller.append(controller)
239
+ t_prev_list = [vec_t]
240
+ # Init the first `order` values by lower order multistep DPM-Solver.
241
+ for init_order in range(1, order):
242
+ vec_t = timesteps[init_order].expand(x.shape[0])
243
+ x = dpm_solver.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
244
+ solver_type='dpmsolver', DPMencode=DPMencode)
245
+ model_prev_list.append(dpm_solver.model_fn(x, vec_t, DPMencode=DPMencode,
246
+ controller=[controller[0][init_order], controller[1][init_order], controller[2][init_order]] if isinstance(controller, list) else controller,
247
+ inject=inject, ref_init=ref_init))
248
+ total_controller.append(controller)
249
+ t_prev_list.append(vec_t)
250
+
251
+ return {'x': x, 'model_prev_list': model_prev_list, 't_prev_list': t_prev_list, 'timesteps':timesteps}, total_controller
252
+
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+ from ldm.models.diffusion.sampling_util import norm_thresholding
10
+
11
+
12
+ class PLMSSampler(object):
13
+ def __init__(self, model, schedule="linear", **kwargs):
14
+ super().__init__()
15
+ self.model = model
16
+ self.ddpm_num_timesteps = model.num_timesteps
17
+ self.schedule = schedule
18
+
19
+ def register_buffer(self, name, attr):
20
+ if type(attr) == torch.Tensor:
21
+ if attr.device != torch.device("cuda"):
22
+ attr = attr.to(torch.device("cuda"))
23
+ setattr(self, name, attr)
24
+
25
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26
+ if ddim_eta != 0:
27
+ raise ValueError('ddim_eta must be 0 for PLMS')
28
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30
+ alphas_cumprod = self.model.alphas_cumprod
31
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33
+
34
+ self.register_buffer('betas', to_torch(self.model.betas))
35
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37
+
38
+ # calculations for diffusion q(x_t | x_{t-1}) and others
39
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44
+
45
+ # ddim sampling parameters
46
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47
+ ddim_timesteps=self.ddim_timesteps,
48
+ eta=ddim_eta,verbose=verbose)
49
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
50
+ self.register_buffer('ddim_alphas', ddim_alphas)
51
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57
+
58
+ @torch.no_grad()
59
+ def sample(self,
60
+ S,
61
+ batch_size,
62
+ shape,
63
+ conditioning=None,
64
+ callback=None,
65
+ normals_sequence=None,
66
+ img_callback=None,
67
+ quantize_x0=False,
68
+ eta=0.,
69
+ mask=None,
70
+ x0=None,
71
+ temperature=1.,
72
+ noise_dropout=0.,
73
+ score_corrector=None,
74
+ corrector_kwargs=None,
75
+ verbose=True,
76
+ x_T=None,
77
+ log_every_t=100,
78
+ unconditional_guidance_scale=1.,
79
+ unconditional_conditioning=None,
80
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81
+ dynamic_threshold=None,
82
+ **kwargs
83
+ ):
84
+ if conditioning is not None:
85
+ if isinstance(conditioning, dict):
86
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
87
+ if cbs != batch_size:
88
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89
+ else:
90
+ if conditioning.shape[0] != batch_size:
91
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
92
+
93
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
94
+ # sampling
95
+ C, H, W = shape
96
+ size = (batch_size, C, H, W)
97
+ print(f'Data shape for PLMS sampling is {size}')
98
+
99
+ samples, intermediates = self.plms_sampling(conditioning, size,
100
+ callback=callback,
101
+ img_callback=img_callback,
102
+ quantize_denoised=quantize_x0,
103
+ mask=mask, x0=x0,
104
+ ddim_use_original_steps=False,
105
+ noise_dropout=noise_dropout,
106
+ temperature=temperature,
107
+ score_corrector=score_corrector,
108
+ corrector_kwargs=corrector_kwargs,
109
+ x_T=x_T,
110
+ log_every_t=log_every_t,
111
+ unconditional_guidance_scale=unconditional_guidance_scale,
112
+ unconditional_conditioning=unconditional_conditioning,
113
+ dynamic_threshold=dynamic_threshold,
114
+ )
115
+ return samples, intermediates
116
+
117
+ @torch.no_grad()
118
+ def plms_sampling(self, cond, shape,
119
+ x_T=None, ddim_use_original_steps=False,
120
+ callback=None, timesteps=None, quantize_denoised=False,
121
+ mask=None, x0=None, img_callback=None, log_every_t=100,
122
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
123
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
124
+ dynamic_threshold=None):
125
+ device = self.model.betas.device
126
+ b = shape[0]
127
+ if x_T is None:
128
+ img = torch.randn(shape, device=device)
129
+ else:
130
+ img = x_T
131
+
132
+ if timesteps is None:
133
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
134
+ elif timesteps is not None and not ddim_use_original_steps:
135
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
136
+ timesteps = self.ddim_timesteps[:subset_end]
137
+
138
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
139
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
140
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
141
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
142
+
143
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
144
+ old_eps = []
145
+
146
+ for i, step in enumerate(iterator):
147
+ index = total_steps - i - 1
148
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
149
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
150
+
151
+ if mask is not None:
152
+ assert x0 is not None
153
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
154
+ img = img_orig * mask + (1. - mask) * img
155
+
156
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
157
+ quantize_denoised=quantize_denoised, temperature=temperature,
158
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
159
+ corrector_kwargs=corrector_kwargs,
160
+ unconditional_guidance_scale=unconditional_guidance_scale,
161
+ unconditional_conditioning=unconditional_conditioning,
162
+ old_eps=old_eps, t_next=ts_next,
163
+ dynamic_threshold=dynamic_threshold)
164
+ img, pred_x0, e_t = outs
165
+ old_eps.append(e_t)
166
+ if len(old_eps) >= 4:
167
+ old_eps.pop(0)
168
+ if callback: callback(i)
169
+ if img_callback: img_callback(pred_x0, i)
170
+
171
+ if index % log_every_t == 0 or index == total_steps - 1:
172
+ intermediates['x_inter'].append(img)
173
+ intermediates['pred_x0'].append(pred_x0)
174
+
175
+ return img, intermediates
176
+
177
+ @torch.no_grad()
178
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
179
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
180
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
181
+ dynamic_threshold=None):
182
+ b, *_, device = *x.shape, x.device
183
+
184
+ def get_model_output(x, t):
185
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
186
+ e_t = self.model.apply_model(x, t, c)
187
+ else:
188
+ x_in = torch.cat([x] * 2)
189
+ t_in = torch.cat([t] * 2)
190
+ c_in = torch.cat([unconditional_conditioning, c])
191
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
192
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
193
+
194
+ if score_corrector is not None:
195
+ assert self.model.parameterization == "eps"
196
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
197
+
198
+ return e_t
199
+
200
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
201
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
202
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
203
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
204
+
205
+ def get_x_prev_and_pred_x0(e_t, index):
206
+ # select parameters corresponding to the currently considered timestep
207
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
208
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
209
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
210
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
211
+
212
+ # current prediction for x_0
213
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
214
+ if quantize_denoised:
215
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
216
+ if dynamic_threshold is not None:
217
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
218
+ # direction pointing to x_t
219
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
220
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
221
+ if noise_dropout > 0.:
222
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
223
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
224
+ return x_prev, pred_x0
225
+
226
+ e_t = get_model_output(x, t)
227
+ if len(old_eps) == 0:
228
+ # Pseudo Improved Euler (2nd order)
229
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
230
+ e_t_next = get_model_output(x_prev, t_next)
231
+ e_t_prime = (e_t + e_t_next) / 2
232
+ elif len(old_eps) == 1:
233
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
234
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
235
+ elif len(old_eps) == 2:
236
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
237
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
238
+ elif len(old_eps) >= 3:
239
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
240
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
241
+
242
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
243
+
244
+ return x_prev, pred_x0, e_t
ldm/models/diffusion/sampling_util.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def append_dims(x, target_dims):
6
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8
+ dims_to_append = target_dims - x.ndim
9
+ if dims_to_append < 0:
10
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11
+ return x[(...,) + (None,) * dims_to_append]
12
+
13
+
14
+ def norm_thresholding(x0, value):
15
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16
+ return x0 * (value / s)
17
+
18
+
19
+ def spatial_norm_thresholding(x0, value):
20
+ # b c h w
21
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22
+ return x0 * (value / s)
ldm/modules/attention.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from ldm.modules.diffusionmodules.util import checkpoint
11
+ from PIL import Image
12
+
13
+ try:
14
+ import xformers
15
+ import xformers.ops
16
+ XFORMERS_IS_AVAILBLE = False
17
+ except:
18
+ XFORMERS_IS_AVAILBLE = False
19
+
20
+
21
+ def exists(val):
22
+ return val is not None
23
+
24
+
25
+ def uniq(arr):
26
+ return{el: True for el in arr}.keys()
27
+
28
+
29
+ def default(val, d):
30
+ if exists(val):
31
+ return val
32
+ return d() if isfunction(d) else d
33
+
34
+
35
+ def max_neg_value(t):
36
+ return -torch.finfo(t.dtype).max
37
+
38
+
39
+ def init_(tensor):
40
+ dim = tensor.shape[-1]
41
+ std = 1 / math.sqrt(dim)
42
+ tensor.uniform_(-std, std)
43
+ return tensor
44
+
45
+
46
+ # feedforward
47
+ class GEGLU(nn.Module):
48
+ def __init__(self, dim_in, dim_out):
49
+ super().__init__()
50
+ self.proj = nn.Linear(dim_in, dim_out * 2)
51
+
52
+ def forward(self, x):
53
+ x, gate = self.proj(x).chunk(2, dim=-1)
54
+ return x * F.gelu(gate)
55
+
56
+
57
+ class FeedForward(nn.Module):
58
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
59
+ super().__init__()
60
+ inner_dim = int(dim * mult)
61
+ dim_out = default(dim_out, dim)
62
+ project_in = nn.Sequential(
63
+ nn.Linear(dim, inner_dim),
64
+ nn.GELU()
65
+ ) if not glu else GEGLU(dim, inner_dim)
66
+
67
+ self.net = nn.Sequential(
68
+ project_in,
69
+ nn.Dropout(dropout),
70
+ nn.Linear(inner_dim, dim_out)
71
+ )
72
+
73
+ def forward(self, x):
74
+ return self.net(x)
75
+
76
+
77
+ def zero_module(module):
78
+ """
79
+ Zero out the parameters of a module and return it.
80
+ """
81
+ for p in module.parameters():
82
+ p.detach().zero_()
83
+ return module
84
+
85
+
86
+ def Normalize(in_channels):
87
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
88
+
89
+
90
+ class SpatialSelfAttention(nn.Module):
91
+ def __init__(self, in_channels):
92
+ super().__init__()
93
+ self.in_channels = in_channels
94
+
95
+ self.norm = Normalize(in_channels)
96
+ self.q = torch.nn.Conv2d(in_channels,
97
+ in_channels,
98
+ kernel_size=1,
99
+ stride=1,
100
+ padding=0)
101
+ self.k = torch.nn.Conv2d(in_channels,
102
+ in_channels,
103
+ kernel_size=1,
104
+ stride=1,
105
+ padding=0)
106
+ self.v = torch.nn.Conv2d(in_channels,
107
+ in_channels,
108
+ kernel_size=1,
109
+ stride=1,
110
+ padding=0)
111
+ self.proj_out = torch.nn.Conv2d(in_channels,
112
+ in_channels,
113
+ kernel_size=1,
114
+ stride=1,
115
+ padding=0)
116
+
117
+ def forward(self, x):
118
+ h_ = x
119
+ h_ = self.norm(h_)
120
+ q = self.q(h_)
121
+ k = self.k(h_)
122
+ v = self.v(h_)
123
+
124
+ # compute attention
125
+ b,c,h,w = q.shape
126
+ q = rearrange(q, 'b c h w -> b (h w) c')
127
+ k = rearrange(k, 'b c h w -> b c (h w)')
128
+ w_ = torch.einsum('bij,bjk->bik', q, k)
129
+
130
+ w_ = w_ * (int(c)**(-0.5))
131
+ w_ = torch.nn.functional.softmax(w_, dim=2)
132
+
133
+ # attend to values
134
+ v = rearrange(v, 'b c h w -> b c (h w)')
135
+ w_ = rearrange(w_, 'b i j -> b j i')
136
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
137
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
138
+ h_ = self.proj_out(h_)
139
+
140
+ return x+h_
141
+
142
+
143
+ class CrossAttention(nn.Module):
144
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
145
+ super().__init__()
146
+ inner_dim = dim_head * heads
147
+ context_dim = default(context_dim, query_dim)
148
+
149
+ self.scale = dim_head ** -0.5
150
+ self.heads = heads
151
+
152
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
153
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
154
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
155
+
156
+ self.to_out = nn.Sequential(
157
+ nn.Linear(inner_dim, query_dim),
158
+ nn.Dropout(dropout)
159
+ )
160
+
161
+ def forward(self, x, context=None, mask=None, encode=False, controller_for_inject=None, inject=False, layernum=None, main_height=None, main_width=None):
162
+ h = self.heads
163
+
164
+ q = self.to_q(x)
165
+ context = default(context, x)
166
+ k = self.to_k(context)
167
+ v = self.to_v(context)
168
+
169
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
170
+
171
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
172
+ del q, k
173
+
174
+ if exists(mask):
175
+ mask = rearrange(mask, 'b ... -> b (...)')
176
+ max_neg_value = -torch.finfo(sim.dtype).max
177
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
178
+ sim.masked_fill_(~mask, max_neg_value)
179
+
180
+ # a = ((sim.mean(0).mean(1).resize(64,64)/torch.max(sim.max(), abs(sim.min())) + 1)*127.5).cpu().numpy().astype(np.uint8)
181
+ # image = Image.fromarray(a)
182
+ # image.resize((512,512)).save('2.jpg')
183
+
184
+ # u, s, vh = np.linalg.svd(sim.mean(0).cpu().numpy().astype(np.float32) - np.mean(sim.mean(0).cpu().numpy().astype(np.float32), axis=1, keepdims=True))
185
+ # images = []
186
+ # for i in range(3):
187
+ # image = u[:,i].reshape(64, 64)
188
+ # image = image - image.min()
189
+ # image = 255 * image / image.max()
190
+ # image = np.expand_dims(image, axis=2).astype(np.uint8)
191
+ # images.append(image)
192
+
193
+ # final = np.dstack(images)
194
+ # final = Image.fromarray(final).resize((256, 256))
195
+ # final = np.array(final)
196
+ # import ptp_scripts.ptp_utils as ptp_utils
197
+ # ptp_utils.view_images(final)
198
+
199
+ # attention, what we cannot get enough of
200
+ sim = sim.softmax(dim=-1)
201
+ # sim = sim.sigmoid()
202
+
203
+ out = einsum('b i j, b j d -> b i d', sim, v)
204
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
205
+ return self.to_out(out)
206
+
207
+
208
+ class MemoryEfficientCrossAttention(nn.Module):
209
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
210
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
211
+ super().__init__()
212
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
213
+ f"{heads} heads.")
214
+ inner_dim = dim_head * heads
215
+ context_dim = default(context_dim, query_dim)
216
+
217
+ self.heads = heads
218
+ self.dim_head = dim_head
219
+
220
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
221
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
222
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
223
+
224
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
225
+ self.attention_op: Optional[Any] = None
226
+
227
+ def forward(self, x, context=None, mask=None):
228
+ q = self.to_q(x)
229
+ context = default(context, x)
230
+ k = self.to_k(context)
231
+ v = self.to_v(context)
232
+
233
+ b, _, _ = q.shape
234
+ q, k, v = map(
235
+ lambda t: t.unsqueeze(3)
236
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
237
+ .permute(0, 2, 1, 3)
238
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
239
+ .contiguous(),
240
+ (q, k, v),
241
+ )
242
+
243
+ # actually compute the attention, what we cannot get enough of
244
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
245
+
246
+ if exists(mask):
247
+ raise NotImplementedError
248
+ out = (
249
+ out.unsqueeze(0)
250
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
251
+ .permute(0, 2, 1, 3)
252
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
253
+ )
254
+ return self.to_out(out)
255
+
256
+
257
+ class BasicTransformerBlock(nn.Module):
258
+ ATTENTION_MODES = {
259
+ "softmax": CrossAttention, # vanilla attention
260
+ "softmax-xformers": MemoryEfficientCrossAttention
261
+ }
262
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
263
+ disable_self_attn=False):
264
+ super().__init__()
265
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
266
+ assert attn_mode in self.ATTENTION_MODES
267
+ attn_cls = self.ATTENTION_MODES[attn_mode]
268
+ self.disable_self_attn = disable_self_attn
269
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
270
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
271
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
272
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
273
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
274
+ self.norm1 = nn.LayerNorm(dim)
275
+ self.norm2 = nn.LayerNorm(dim)
276
+ self.norm3 = nn.LayerNorm(dim)
277
+ self.checkpoint = checkpoint
278
+
279
+ def forward(self, x, context=None, encode=False, encode_uncon=False, decode_uncon=False, controller=None, inject=False, layernum=0, h=None, w=None):
280
+ return checkpoint(self._forward, (x, context, encode, encode_uncon, decode_uncon, controller, inject, layernum, h, w), self.parameters(), self.checkpoint)
281
+
282
+ def _forward(self, x, context=None, encode=False, encode_uncon=False, decode_uncon=False, controller=None, inject=False, layernum=0, h=None, w=None):
283
+
284
+ if encode_uncon == True and decode_uncon == True:
285
+ # pass
286
+ x = self.attn1(self.norm1(x), context=None, encode=encode) + x
287
+ x = self.attn1(self.norm1(x), context=None, encode=encode) + x # 如果要加层数,要记得改 register_attention_control
288
+
289
+ elif encode_uncon == True and decode_uncon == False:
290
+ if encode:
291
+ x = self.attn1(self.norm1(x), context=None, encode=encode) + x
292
+ x = self.attn1(self.norm1(x), context=None, encode=encode) + x # 如果要加层数,要记得改 register_attention_control
293
+ else:
294
+ x = self.attn1(self.norm1(x), context=context
295
+ if self.disable_self_attn else None, controller_for_inject=controller, inject=inject, layernum=layernum) + x
296
+ x = self.attn1(self.norm1(x), context=context
297
+ if self.disable_self_attn else None, controller_for_inject=controller, inject=inject, layernum=layernum+1) + x
298
+ x = self.attn2(self.norm2(x), context=context) + x
299
+
300
+ elif encode_uncon == False and decode_uncon == False:
301
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, encode=encode,
302
+ controller_for_inject=controller, inject=inject, layernum=layernum, main_height=h, main_width=w) + x
303
+ x = self.attn2(self.norm2(x), context=context, encode=encode) + x
304
+ # pass
305
+
306
+ x = self.ff(self.norm3(x)) + x
307
+ return x
308
+
309
+
310
+ class SpatialTransformer(nn.Module):
311
+ """
312
+ Transformer block for image-like data.
313
+ First, project the input (aka embedding)
314
+ and reshape to b, t, d.
315
+ Then apply standard transformer action.
316
+ Finally, reshape to image
317
+ NEW: use_linear for more efficiency instead of the 1x1 convs
318
+ """
319
+ def __init__(self, in_channels, n_heads, d_head,
320
+ depth=1, dropout=0., context_dim=None,
321
+ disable_self_attn=False, use_linear=False,
322
+ use_checkpoint=True):
323
+ super().__init__()
324
+ if exists(context_dim) and not isinstance(context_dim, list):
325
+ context_dim = [context_dim]
326
+ self.in_channels = in_channels
327
+ inner_dim = n_heads * d_head
328
+ self.norm = Normalize(in_channels)
329
+ if not use_linear:
330
+ self.proj_in = nn.Conv2d(in_channels,
331
+ inner_dim,
332
+ kernel_size=1,
333
+ stride=1,
334
+ padding=0)
335
+ else:
336
+ self.proj_in = nn.Linear(in_channels, inner_dim)
337
+
338
+ self.transformer_blocks = nn.ModuleList(
339
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
340
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
341
+ for d in range(depth)]
342
+ )
343
+ if not use_linear:
344
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
345
+ in_channels,
346
+ kernel_size=1,
347
+ stride=1,
348
+ padding=0))
349
+ else:
350
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
351
+ self.use_linear = use_linear
352
+
353
+ def forward(self, x, context=None, encode=False, encode_uncon=True, decode_uncon=True, controller=None, inject=False, layernum=0):
354
+ # note: if no context is given, cross-attention defaults to self-attention
355
+ if not isinstance(context, list):
356
+ context = [context]
357
+ b, c, h, w = x.shape
358
+ x_in = x
359
+ x = self.norm(x)
360
+ if not self.use_linear:
361
+ x = self.proj_in(x)
362
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
363
+ if self.use_linear:
364
+ x = self.proj_in(x)
365
+ for i, block in enumerate(self.transformer_blocks):
366
+ x = block(x, context=context[i], encode=encode, encode_uncon=encode_uncon, decode_uncon=decode_uncon,
367
+ controller=controller, inject=inject, layernum=layernum, h=h, w=w)
368
+ if self.use_linear:
369
+ x = self.proj_out(x)
370
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
371
+ if not self.use_linear:
372
+ x = self.proj_out(x)
373
+
374
+ layernum = layernum + 1 # 和register_recr对应起来
375
+
376
+ return x + x_in, layernum
377
+
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,852 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from typing import Optional, Any
8
+
9
+ from ldm.modules.attention import MemoryEfficientCrossAttention
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+ XFORMERS_IS_AVAILBLE = True
15
+ except:
16
+ XFORMERS_IS_AVAILBLE = False
17
+ print("No module 'xformers'. Proceeding without it.")
18
+
19
+
20
+ def get_timestep_embedding(timesteps, embedding_dim):
21
+ """
22
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
23
+ From Fairseq.
24
+ Build sinusoidal embeddings.
25
+ This matches the implementation in tensor2tensor, but differs slightly
26
+ from the description in Section 3.5 of "Attention Is All You Need".
27
+ """
28
+ assert len(timesteps.shape) == 1
29
+
30
+ half_dim = embedding_dim // 2
31
+ emb = math.log(10000) / (half_dim - 1)
32
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
33
+ emb = emb.to(device=timesteps.device)
34
+ emb = timesteps.float()[:, None] * emb[None, :]
35
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
36
+ if embedding_dim % 2 == 1: # zero pad
37
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
38
+ return emb
39
+
40
+
41
+ def nonlinearity(x):
42
+ # swish
43
+ return x*torch.sigmoid(x)
44
+
45
+
46
+ def Normalize(in_channels, num_groups=32):
47
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
48
+
49
+
50
+ class Upsample(nn.Module):
51
+ def __init__(self, in_channels, with_conv):
52
+ super().__init__()
53
+ self.with_conv = with_conv
54
+ if self.with_conv:
55
+ self.conv = torch.nn.Conv2d(in_channels,
56
+ in_channels,
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=1)
60
+
61
+ def forward(self, x):
62
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
63
+ if self.with_conv:
64
+ x = self.conv(x)
65
+ return x
66
+
67
+
68
+ class Downsample(nn.Module):
69
+ def __init__(self, in_channels, with_conv):
70
+ super().__init__()
71
+ self.with_conv = with_conv
72
+ if self.with_conv:
73
+ # no asymmetric padding in torch conv, must do it ourselves
74
+ self.conv = torch.nn.Conv2d(in_channels,
75
+ in_channels,
76
+ kernel_size=3,
77
+ stride=2,
78
+ padding=0)
79
+
80
+ def forward(self, x):
81
+ if self.with_conv:
82
+ pad = (0,1,0,1)
83
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
84
+ x = self.conv(x)
85
+ else:
86
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
87
+ return x
88
+
89
+
90
+ class ResnetBlock(nn.Module):
91
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
92
+ dropout, temb_channels=512):
93
+ super().__init__()
94
+ self.in_channels = in_channels
95
+ out_channels = in_channels if out_channels is None else out_channels
96
+ self.out_channels = out_channels
97
+ self.use_conv_shortcut = conv_shortcut
98
+
99
+ self.norm1 = Normalize(in_channels)
100
+ self.conv1 = torch.nn.Conv2d(in_channels,
101
+ out_channels,
102
+ kernel_size=3,
103
+ stride=1,
104
+ padding=1)
105
+ if temb_channels > 0:
106
+ self.temb_proj = torch.nn.Linear(temb_channels,
107
+ out_channels)
108
+ self.norm2 = Normalize(out_channels)
109
+ self.dropout = torch.nn.Dropout(dropout)
110
+ self.conv2 = torch.nn.Conv2d(out_channels,
111
+ out_channels,
112
+ kernel_size=3,
113
+ stride=1,
114
+ padding=1)
115
+ if self.in_channels != self.out_channels:
116
+ if self.use_conv_shortcut:
117
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
118
+ out_channels,
119
+ kernel_size=3,
120
+ stride=1,
121
+ padding=1)
122
+ else:
123
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
124
+ out_channels,
125
+ kernel_size=1,
126
+ stride=1,
127
+ padding=0)
128
+
129
+ def forward(self, x, temb):
130
+ h = x
131
+ h = self.norm1(h)
132
+ h = nonlinearity(h)
133
+ h = self.conv1(h)
134
+
135
+ if temb is not None:
136
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
137
+
138
+ h = self.norm2(h)
139
+ h = nonlinearity(h)
140
+ h = self.dropout(h)
141
+ h = self.conv2(h)
142
+
143
+ if self.in_channels != self.out_channels:
144
+ if self.use_conv_shortcut:
145
+ x = self.conv_shortcut(x)
146
+ else:
147
+ x = self.nin_shortcut(x)
148
+
149
+ return x+h
150
+
151
+
152
+ class AttnBlock(nn.Module):
153
+ def __init__(self, in_channels):
154
+ super().__init__()
155
+ self.in_channels = in_channels
156
+
157
+ self.norm = Normalize(in_channels)
158
+ self.q = torch.nn.Conv2d(in_channels,
159
+ in_channels,
160
+ kernel_size=1,
161
+ stride=1,
162
+ padding=0)
163
+ self.k = torch.nn.Conv2d(in_channels,
164
+ in_channels,
165
+ kernel_size=1,
166
+ stride=1,
167
+ padding=0)
168
+ self.v = torch.nn.Conv2d(in_channels,
169
+ in_channels,
170
+ kernel_size=1,
171
+ stride=1,
172
+ padding=0)
173
+ self.proj_out = torch.nn.Conv2d(in_channels,
174
+ in_channels,
175
+ kernel_size=1,
176
+ stride=1,
177
+ padding=0)
178
+
179
+ def forward(self, x):
180
+ h_ = x
181
+ h_ = self.norm(h_)
182
+ q = self.q(h_)
183
+ k = self.k(h_)
184
+ v = self.v(h_)
185
+
186
+ # compute attention
187
+ b,c,h,w = q.shape
188
+ q = q.reshape(b,c,h*w)
189
+ q = q.permute(0,2,1) # b,hw,c
190
+ k = k.reshape(b,c,h*w) # b,c,hw
191
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
192
+ w_ = w_ * (int(c)**(-0.5))
193
+ w_ = torch.nn.functional.softmax(w_, dim=2)
194
+
195
+ # attend to values
196
+ v = v.reshape(b,c,h*w)
197
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
198
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
199
+ h_ = h_.reshape(b,c,h,w)
200
+
201
+ h_ = self.proj_out(h_)
202
+
203
+ return x+h_
204
+
205
+ class MemoryEfficientAttnBlock(nn.Module):
206
+ """
207
+ Uses xformers efficient implementation,
208
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
209
+ Note: this is a single-head self-attention operation
210
+ """
211
+ #
212
+ def __init__(self, in_channels):
213
+ super().__init__()
214
+ self.in_channels = in_channels
215
+
216
+ self.norm = Normalize(in_channels)
217
+ self.q = torch.nn.Conv2d(in_channels,
218
+ in_channels,
219
+ kernel_size=1,
220
+ stride=1,
221
+ padding=0)
222
+ self.k = torch.nn.Conv2d(in_channels,
223
+ in_channels,
224
+ kernel_size=1,
225
+ stride=1,
226
+ padding=0)
227
+ self.v = torch.nn.Conv2d(in_channels,
228
+ in_channels,
229
+ kernel_size=1,
230
+ stride=1,
231
+ padding=0)
232
+ self.proj_out = torch.nn.Conv2d(in_channels,
233
+ in_channels,
234
+ kernel_size=1,
235
+ stride=1,
236
+ padding=0)
237
+ self.attention_op: Optional[Any] = None
238
+
239
+ def forward(self, x):
240
+ h_ = x
241
+ h_ = self.norm(h_)
242
+ q = self.q(h_)
243
+ k = self.k(h_)
244
+ v = self.v(h_)
245
+
246
+ # compute attention
247
+ B, C, H, W = q.shape
248
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
249
+
250
+ q, k, v = map(
251
+ lambda t: t.unsqueeze(3)
252
+ .reshape(B, t.shape[1], 1, C)
253
+ .permute(0, 2, 1, 3)
254
+ .reshape(B * 1, t.shape[1], C)
255
+ .contiguous(),
256
+ (q, k, v),
257
+ )
258
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
259
+
260
+ out = (
261
+ out.unsqueeze(0)
262
+ .reshape(B, 1, out.shape[1], C)
263
+ .permute(0, 2, 1, 3)
264
+ .reshape(B, out.shape[1], C)
265
+ )
266
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
267
+ out = self.proj_out(out)
268
+ return x+out
269
+
270
+
271
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
272
+ def forward(self, x, context=None, mask=None):
273
+ b, c, h, w = x.shape
274
+ x = rearrange(x, 'b c h w -> b (h w) c')
275
+ out = super().forward(x, context=context, mask=mask)
276
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
277
+ return x + out
278
+
279
+
280
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
281
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
282
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
283
+ attn_type = "vanilla-xformers"
284
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
285
+ if attn_type == "vanilla":
286
+ assert attn_kwargs is None
287
+ return AttnBlock(in_channels)
288
+ elif attn_type == "vanilla-xformers":
289
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
290
+ return MemoryEfficientAttnBlock(in_channels)
291
+ elif type == "memory-efficient-cross-attn":
292
+ attn_kwargs["query_dim"] = in_channels
293
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
294
+ elif attn_type == "none":
295
+ return nn.Identity(in_channels)
296
+ else:
297
+ raise NotImplementedError()
298
+
299
+
300
+ class Model(nn.Module):
301
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
302
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
303
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
304
+ super().__init__()
305
+ if use_linear_attn: attn_type = "linear"
306
+ self.ch = ch
307
+ self.temb_ch = self.ch*4
308
+ self.num_resolutions = len(ch_mult)
309
+ self.num_res_blocks = num_res_blocks
310
+ self.resolution = resolution
311
+ self.in_channels = in_channels
312
+
313
+ self.use_timestep = use_timestep
314
+ if self.use_timestep:
315
+ # timestep embedding
316
+ self.temb = nn.Module()
317
+ self.temb.dense = nn.ModuleList([
318
+ torch.nn.Linear(self.ch,
319
+ self.temb_ch),
320
+ torch.nn.Linear(self.temb_ch,
321
+ self.temb_ch),
322
+ ])
323
+
324
+ # downsampling
325
+ self.conv_in = torch.nn.Conv2d(in_channels,
326
+ self.ch,
327
+ kernel_size=3,
328
+ stride=1,
329
+ padding=1)
330
+
331
+ curr_res = resolution
332
+ in_ch_mult = (1,)+tuple(ch_mult)
333
+ self.down = nn.ModuleList()
334
+ for i_level in range(self.num_resolutions):
335
+ block = nn.ModuleList()
336
+ attn = nn.ModuleList()
337
+ block_in = ch*in_ch_mult[i_level]
338
+ block_out = ch*ch_mult[i_level]
339
+ for i_block in range(self.num_res_blocks):
340
+ block.append(ResnetBlock(in_channels=block_in,
341
+ out_channels=block_out,
342
+ temb_channels=self.temb_ch,
343
+ dropout=dropout))
344
+ block_in = block_out
345
+ if curr_res in attn_resolutions:
346
+ attn.append(make_attn(block_in, attn_type=attn_type))
347
+ down = nn.Module()
348
+ down.block = block
349
+ down.attn = attn
350
+ if i_level != self.num_resolutions-1:
351
+ down.downsample = Downsample(block_in, resamp_with_conv)
352
+ curr_res = curr_res // 2
353
+ self.down.append(down)
354
+
355
+ # middle
356
+ self.mid = nn.Module()
357
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
358
+ out_channels=block_in,
359
+ temb_channels=self.temb_ch,
360
+ dropout=dropout)
361
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
362
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
363
+ out_channels=block_in,
364
+ temb_channels=self.temb_ch,
365
+ dropout=dropout)
366
+
367
+ # upsampling
368
+ self.up = nn.ModuleList()
369
+ for i_level in reversed(range(self.num_resolutions)):
370
+ block = nn.ModuleList()
371
+ attn = nn.ModuleList()
372
+ block_out = ch*ch_mult[i_level]
373
+ skip_in = ch*ch_mult[i_level]
374
+ for i_block in range(self.num_res_blocks+1):
375
+ if i_block == self.num_res_blocks:
376
+ skip_in = ch*in_ch_mult[i_level]
377
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
378
+ out_channels=block_out,
379
+ temb_channels=self.temb_ch,
380
+ dropout=dropout))
381
+ block_in = block_out
382
+ if curr_res in attn_resolutions:
383
+ attn.append(make_attn(block_in, attn_type=attn_type))
384
+ up = nn.Module()
385
+ up.block = block
386
+ up.attn = attn
387
+ if i_level != 0:
388
+ up.upsample = Upsample(block_in, resamp_with_conv)
389
+ curr_res = curr_res * 2
390
+ self.up.insert(0, up) # prepend to get consistent order
391
+
392
+ # end
393
+ self.norm_out = Normalize(block_in)
394
+ self.conv_out = torch.nn.Conv2d(block_in,
395
+ out_ch,
396
+ kernel_size=3,
397
+ stride=1,
398
+ padding=1)
399
+
400
+ def forward(self, x, t=None, context=None):
401
+ #assert x.shape[2] == x.shape[3] == self.resolution
402
+ if context is not None:
403
+ # assume aligned context, cat along channel axis
404
+ x = torch.cat((x, context), dim=1)
405
+ if self.use_timestep:
406
+ # timestep embedding
407
+ assert t is not None
408
+ temb = get_timestep_embedding(t, self.ch)
409
+ temb = self.temb.dense[0](temb)
410
+ temb = nonlinearity(temb)
411
+ temb = self.temb.dense[1](temb)
412
+ else:
413
+ temb = None
414
+
415
+ # downsampling
416
+ hs = [self.conv_in(x)]
417
+ for i_level in range(self.num_resolutions):
418
+ for i_block in range(self.num_res_blocks):
419
+ h = self.down[i_level].block[i_block](hs[-1], temb)
420
+ if len(self.down[i_level].attn) > 0:
421
+ h = self.down[i_level].attn[i_block](h)
422
+ hs.append(h)
423
+ if i_level != self.num_resolutions-1:
424
+ hs.append(self.down[i_level].downsample(hs[-1]))
425
+
426
+ # middle
427
+ h = hs[-1]
428
+ h = self.mid.block_1(h, temb)
429
+ h = self.mid.attn_1(h)
430
+ h = self.mid.block_2(h, temb)
431
+
432
+ # upsampling
433
+ for i_level in reversed(range(self.num_resolutions)):
434
+ for i_block in range(self.num_res_blocks+1):
435
+ h = self.up[i_level].block[i_block](
436
+ torch.cat([h, hs.pop()], dim=1), temb)
437
+ if len(self.up[i_level].attn) > 0:
438
+ h = self.up[i_level].attn[i_block](h)
439
+ if i_level != 0:
440
+ h = self.up[i_level].upsample(h)
441
+
442
+ # end
443
+ h = self.norm_out(h)
444
+ h = nonlinearity(h)
445
+ h = self.conv_out(h)
446
+ return h
447
+
448
+ def get_last_layer(self):
449
+ return self.conv_out.weight
450
+
451
+
452
+ class Encoder(nn.Module):
453
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
454
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
455
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
456
+ **ignore_kwargs):
457
+ super().__init__()
458
+ if use_linear_attn: attn_type = "linear"
459
+ self.ch = ch
460
+ self.temb_ch = 0
461
+ self.num_resolutions = len(ch_mult)
462
+ self.num_res_blocks = num_res_blocks
463
+ self.resolution = resolution
464
+ self.in_channels = in_channels
465
+
466
+ # downsampling
467
+ self.conv_in = torch.nn.Conv2d(in_channels,
468
+ self.ch,
469
+ kernel_size=3,
470
+ stride=1,
471
+ padding=1)
472
+
473
+ curr_res = resolution
474
+ in_ch_mult = (1,)+tuple(ch_mult)
475
+ self.in_ch_mult = in_ch_mult
476
+ self.down = nn.ModuleList()
477
+ for i_level in range(self.num_resolutions):
478
+ block = nn.ModuleList()
479
+ attn = nn.ModuleList()
480
+ block_in = ch*in_ch_mult[i_level]
481
+ block_out = ch*ch_mult[i_level]
482
+ for i_block in range(self.num_res_blocks):
483
+ block.append(ResnetBlock(in_channels=block_in,
484
+ out_channels=block_out,
485
+ temb_channels=self.temb_ch,
486
+ dropout=dropout))
487
+ block_in = block_out
488
+ if curr_res in attn_resolutions:
489
+ attn.append(make_attn(block_in, attn_type=attn_type))
490
+ down = nn.Module()
491
+ down.block = block
492
+ down.attn = attn
493
+ if i_level != self.num_resolutions-1:
494
+ down.downsample = Downsample(block_in, resamp_with_conv)
495
+ curr_res = curr_res // 2
496
+ self.down.append(down)
497
+
498
+ # middle
499
+ self.mid = nn.Module()
500
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout)
504
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
505
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
506
+ out_channels=block_in,
507
+ temb_channels=self.temb_ch,
508
+ dropout=dropout)
509
+
510
+ # end
511
+ self.norm_out = Normalize(block_in)
512
+ self.conv_out = torch.nn.Conv2d(block_in,
513
+ 2*z_channels if double_z else z_channels,
514
+ kernel_size=3,
515
+ stride=1,
516
+ padding=1)
517
+
518
+ def forward(self, x):
519
+ # timestep embedding
520
+ temb = None
521
+
522
+ # downsampling
523
+ hs = [self.conv_in(x)]
524
+ for i_level in range(self.num_resolutions):
525
+ for i_block in range(self.num_res_blocks):
526
+ h = self.down[i_level].block[i_block](hs[-1], temb)
527
+ if len(self.down[i_level].attn) > 0:
528
+ h = self.down[i_level].attn[i_block](h)
529
+ hs.append(h)
530
+ if i_level != self.num_resolutions-1:
531
+ hs.append(self.down[i_level].downsample(hs[-1]))
532
+
533
+ # middle
534
+ h = hs[-1]
535
+ h = self.mid.block_1(h, temb)
536
+ h = self.mid.attn_1(h)
537
+ h = self.mid.block_2(h, temb)
538
+
539
+ # end
540
+ h = self.norm_out(h)
541
+ h = nonlinearity(h)
542
+ h = self.conv_out(h)
543
+ return h
544
+
545
+
546
+ class Decoder(nn.Module):
547
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
548
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
549
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
550
+ attn_type="vanilla", **ignorekwargs):
551
+ super().__init__()
552
+ if use_linear_attn: attn_type = "linear"
553
+ self.ch = ch
554
+ self.temb_ch = 0
555
+ self.num_resolutions = len(ch_mult)
556
+ self.num_res_blocks = num_res_blocks
557
+ self.resolution = resolution
558
+ self.in_channels = in_channels
559
+ self.give_pre_end = give_pre_end
560
+ self.tanh_out = tanh_out
561
+
562
+ # compute in_ch_mult, block_in and curr_res at lowest res
563
+ in_ch_mult = (1,)+tuple(ch_mult)
564
+ block_in = ch*ch_mult[self.num_resolutions-1]
565
+ curr_res = resolution // 2**(self.num_resolutions-1)
566
+ self.z_shape = (1,z_channels,curr_res,curr_res)
567
+ print("Working with z of shape {} = {} dimensions.".format(
568
+ self.z_shape, np.prod(self.z_shape)))
569
+
570
+ # z to block_in
571
+ self.conv_in = torch.nn.Conv2d(z_channels,
572
+ block_in,
573
+ kernel_size=3,
574
+ stride=1,
575
+ padding=1)
576
+
577
+ # middle
578
+ self.mid = nn.Module()
579
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
580
+ out_channels=block_in,
581
+ temb_channels=self.temb_ch,
582
+ dropout=dropout)
583
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
584
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
585
+ out_channels=block_in,
586
+ temb_channels=self.temb_ch,
587
+ dropout=dropout)
588
+
589
+ # upsampling
590
+ self.up = nn.ModuleList()
591
+ for i_level in reversed(range(self.num_resolutions)):
592
+ block = nn.ModuleList()
593
+ attn = nn.ModuleList()
594
+ block_out = ch*ch_mult[i_level]
595
+ for i_block in range(self.num_res_blocks+1):
596
+ block.append(ResnetBlock(in_channels=block_in,
597
+ out_channels=block_out,
598
+ temb_channels=self.temb_ch,
599
+ dropout=dropout))
600
+ block_in = block_out
601
+ if curr_res in attn_resolutions:
602
+ attn.append(make_attn(block_in, attn_type=attn_type))
603
+ up = nn.Module()
604
+ up.block = block
605
+ up.attn = attn
606
+ if i_level != 0:
607
+ up.upsample = Upsample(block_in, resamp_with_conv)
608
+ curr_res = curr_res * 2
609
+ self.up.insert(0, up) # prepend to get consistent order
610
+
611
+ # end
612
+ self.norm_out = Normalize(block_in)
613
+ self.conv_out = torch.nn.Conv2d(block_in,
614
+ out_ch,
615
+ kernel_size=3,
616
+ stride=1,
617
+ padding=1)
618
+
619
+ def forward(self, z):
620
+ #assert z.shape[1:] == self.z_shape[1:]
621
+ self.last_z_shape = z.shape
622
+
623
+ # timestep embedding
624
+ temb = None
625
+
626
+ # z to block_in
627
+ h = self.conv_in(z)
628
+
629
+ # middle
630
+ h = self.mid.block_1(h, temb)
631
+ h = self.mid.attn_1(h)
632
+ h = self.mid.block_2(h, temb)
633
+
634
+ # upsampling
635
+ for i_level in reversed(range(self.num_resolutions)):
636
+ for i_block in range(self.num_res_blocks+1):
637
+ h = self.up[i_level].block[i_block](h, temb)
638
+ if len(self.up[i_level].attn) > 0:
639
+ h = self.up[i_level].attn[i_block](h)
640
+ if i_level != 0:
641
+ h = self.up[i_level].upsample(h)
642
+
643
+ # end
644
+ if self.give_pre_end:
645
+ return h
646
+
647
+ h = self.norm_out(h)
648
+ h = nonlinearity(h)
649
+ h = self.conv_out(h)
650
+ if self.tanh_out:
651
+ h = torch.tanh(h)
652
+ return h
653
+
654
+
655
+ class SimpleDecoder(nn.Module):
656
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
657
+ super().__init__()
658
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
659
+ ResnetBlock(in_channels=in_channels,
660
+ out_channels=2 * in_channels,
661
+ temb_channels=0, dropout=0.0),
662
+ ResnetBlock(in_channels=2 * in_channels,
663
+ out_channels=4 * in_channels,
664
+ temb_channels=0, dropout=0.0),
665
+ ResnetBlock(in_channels=4 * in_channels,
666
+ out_channels=2 * in_channels,
667
+ temb_channels=0, dropout=0.0),
668
+ nn.Conv2d(2*in_channels, in_channels, 1),
669
+ Upsample(in_channels, with_conv=True)])
670
+ # end
671
+ self.norm_out = Normalize(in_channels)
672
+ self.conv_out = torch.nn.Conv2d(in_channels,
673
+ out_channels,
674
+ kernel_size=3,
675
+ stride=1,
676
+ padding=1)
677
+
678
+ def forward(self, x):
679
+ for i, layer in enumerate(self.model):
680
+ if i in [1,2,3]:
681
+ x = layer(x, None)
682
+ else:
683
+ x = layer(x)
684
+
685
+ h = self.norm_out(x)
686
+ h = nonlinearity(h)
687
+ x = self.conv_out(h)
688
+ return x
689
+
690
+
691
+ class UpsampleDecoder(nn.Module):
692
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
693
+ ch_mult=(2,2), dropout=0.0):
694
+ super().__init__()
695
+ # upsampling
696
+ self.temb_ch = 0
697
+ self.num_resolutions = len(ch_mult)
698
+ self.num_res_blocks = num_res_blocks
699
+ block_in = in_channels
700
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
701
+ self.res_blocks = nn.ModuleList()
702
+ self.upsample_blocks = nn.ModuleList()
703
+ for i_level in range(self.num_resolutions):
704
+ res_block = []
705
+ block_out = ch * ch_mult[i_level]
706
+ for i_block in range(self.num_res_blocks + 1):
707
+ res_block.append(ResnetBlock(in_channels=block_in,
708
+ out_channels=block_out,
709
+ temb_channels=self.temb_ch,
710
+ dropout=dropout))
711
+ block_in = block_out
712
+ self.res_blocks.append(nn.ModuleList(res_block))
713
+ if i_level != self.num_resolutions - 1:
714
+ self.upsample_blocks.append(Upsample(block_in, True))
715
+ curr_res = curr_res * 2
716
+
717
+ # end
718
+ self.norm_out = Normalize(block_in)
719
+ self.conv_out = torch.nn.Conv2d(block_in,
720
+ out_channels,
721
+ kernel_size=3,
722
+ stride=1,
723
+ padding=1)
724
+
725
+ def forward(self, x):
726
+ # upsampling
727
+ h = x
728
+ for k, i_level in enumerate(range(self.num_resolutions)):
729
+ for i_block in range(self.num_res_blocks + 1):
730
+ h = self.res_blocks[i_level][i_block](h, None)
731
+ if i_level != self.num_resolutions - 1:
732
+ h = self.upsample_blocks[k](h)
733
+ h = self.norm_out(h)
734
+ h = nonlinearity(h)
735
+ h = self.conv_out(h)
736
+ return h
737
+
738
+
739
+ class LatentRescaler(nn.Module):
740
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
741
+ super().__init__()
742
+ # residual block, interpolate, residual block
743
+ self.factor = factor
744
+ self.conv_in = nn.Conv2d(in_channels,
745
+ mid_channels,
746
+ kernel_size=3,
747
+ stride=1,
748
+ padding=1)
749
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
750
+ out_channels=mid_channels,
751
+ temb_channels=0,
752
+ dropout=0.0) for _ in range(depth)])
753
+ self.attn = AttnBlock(mid_channels)
754
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
755
+ out_channels=mid_channels,
756
+ temb_channels=0,
757
+ dropout=0.0) for _ in range(depth)])
758
+
759
+ self.conv_out = nn.Conv2d(mid_channels,
760
+ out_channels,
761
+ kernel_size=1,
762
+ )
763
+
764
+ def forward(self, x):
765
+ x = self.conv_in(x)
766
+ for block in self.res_block1:
767
+ x = block(x, None)
768
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
769
+ x = self.attn(x)
770
+ for block in self.res_block2:
771
+ x = block(x, None)
772
+ x = self.conv_out(x)
773
+ return x
774
+
775
+
776
+ class MergedRescaleEncoder(nn.Module):
777
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
778
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
779
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
780
+ super().__init__()
781
+ intermediate_chn = ch * ch_mult[-1]
782
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
783
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
784
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
785
+ out_ch=None)
786
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
787
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
788
+
789
+ def forward(self, x):
790
+ x = self.encoder(x)
791
+ x = self.rescaler(x)
792
+ return x
793
+
794
+
795
+ class MergedRescaleDecoder(nn.Module):
796
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
797
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
798
+ super().__init__()
799
+ tmp_chn = z_channels*ch_mult[-1]
800
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
801
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
802
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
803
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
804
+ out_channels=tmp_chn, depth=rescale_module_depth)
805
+
806
+ def forward(self, x):
807
+ x = self.rescaler(x)
808
+ x = self.decoder(x)
809
+ return x
810
+
811
+
812
+ class Upsampler(nn.Module):
813
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
814
+ super().__init__()
815
+ assert out_size >= in_size
816
+ num_blocks = int(np.log2(out_size//in_size))+1
817
+ factor_up = 1.+ (out_size % in_size)
818
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
819
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
820
+ out_channels=in_channels)
821
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
822
+ attn_resolutions=[], in_channels=None, ch=in_channels,
823
+ ch_mult=[ch_mult for _ in range(num_blocks)])
824
+
825
+ def forward(self, x):
826
+ x = self.rescaler(x)
827
+ x = self.decoder(x)
828
+ return x
829
+
830
+
831
+ class Resize(nn.Module):
832
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
833
+ super().__init__()
834
+ self.with_conv = learned
835
+ self.mode = mode
836
+ if self.with_conv:
837
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
838
+ raise NotImplementedError()
839
+ assert in_channels is not None
840
+ # no asymmetric padding in torch conv, must do it ourselves
841
+ self.conv = torch.nn.Conv2d(in_channels,
842
+ in_channels,
843
+ kernel_size=4,
844
+ stride=2,
845
+ padding=1)
846
+
847
+ def forward(self, x, scale_factor=1.0):
848
+ if scale_factor==1.0:
849
+ return x
850
+ else:
851
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
852
+ return x
ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from ldm.modules.diffusionmodules.util import (
10
+ checkpoint,
11
+ conv_nd,
12
+ linear,
13
+ avg_pool_nd,
14
+ zero_module,
15
+ normalization,
16
+ timestep_embedding,
17
+ )
18
+ from ldm.modules.attention import SpatialTransformer
19
+ from ldm.util import exists
20
+
21
+
22
+ # dummy replace
23
+ def convert_module_to_f16(x):
24
+ pass
25
+
26
+ def convert_module_to_f32(x):
27
+ pass
28
+
29
+
30
+ ## go
31
+ class AttentionPool2d(nn.Module):
32
+ """
33
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ spacial_dim: int,
39
+ embed_dim: int,
40
+ num_heads_channels: int,
41
+ output_dim: int = None,
42
+ ):
43
+ super().__init__()
44
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
45
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
46
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
47
+ self.num_heads = embed_dim // num_heads_channels
48
+ self.attention = QKVAttention(self.num_heads)
49
+
50
+ def forward(self, x):
51
+ b, c, *_spatial = x.shape
52
+ x = x.reshape(b, c, -1) # NC(HW)
53
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
54
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
55
+ x = self.qkv_proj(x)
56
+ x = self.attention(x)
57
+ x = self.c_proj(x)
58
+ return x[:, :, 0]
59
+
60
+
61
+ class TimestepBlock(nn.Module):
62
+ """
63
+ Any module where forward() takes timestep embeddings as a second argument.
64
+ """
65
+
66
+ @abstractmethod
67
+ def forward(self, x, emb):
68
+ """
69
+ Apply the module to `x` given `emb` timestep embeddings.
70
+ """
71
+
72
+
73
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
74
+ """
75
+ A sequential module that passes timestep embeddings to the children that
76
+ support it as an extra input.
77
+ """
78
+
79
+ def forward(self, x, emb, context=None, encode=False, encode_uncon=False, decode_uncon=False, controller=None, inject=False, layernum=0):
80
+ for layer in self:
81
+ if isinstance(layer, TimestepBlock):
82
+ if x.shape[0] == 4:
83
+ x1 = layer(x[:2], emb)
84
+ x2 = layer(x[2:], emb)
85
+ x = th.cat([x1, x2], dim=0)
86
+ else:
87
+ x = layer(x, emb)
88
+ elif isinstance(layer, SpatialTransformer):
89
+ x, layernum = layer(x, context, encode=encode, encode_uncon=encode_uncon, decode_uncon=decode_uncon,
90
+ controller=controller, inject=inject, layernum=layernum)
91
+ else:
92
+ x = layer(x)
93
+ return x, layernum
94
+
95
+
96
+ class Upsample(nn.Module):
97
+ """
98
+ An upsampling layer with an optional convolution.
99
+ :param channels: channels in the inputs and outputs.
100
+ :param use_conv: a bool determining if a convolution is applied.
101
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
102
+ upsampling occurs in the inner-two dimensions.
103
+ """
104
+
105
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
106
+ super().__init__()
107
+ self.channels = channels
108
+ self.out_channels = out_channels or channels
109
+ self.use_conv = use_conv
110
+ self.dims = dims
111
+ if use_conv:
112
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
113
+
114
+ def forward(self, x):
115
+ assert x.shape[1] == self.channels
116
+ if self.dims == 3:
117
+ x = F.interpolate(
118
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
119
+ )
120
+ else:
121
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
122
+ if self.use_conv:
123
+ x = self.conv(x)
124
+ return x
125
+
126
+ class TransposedUpsample(nn.Module):
127
+ 'Learned 2x upsampling without padding'
128
+ def __init__(self, channels, out_channels=None, ks=5):
129
+ super().__init__()
130
+ self.channels = channels
131
+ self.out_channels = out_channels or channels
132
+
133
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
134
+
135
+ def forward(self,x):
136
+ return self.up(x)
137
+
138
+
139
+ class Downsample(nn.Module):
140
+ """
141
+ A downsampling layer with an optional convolution.
142
+ :param channels: channels in the inputs and outputs.
143
+ :param use_conv: a bool determining if a convolution is applied.
144
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
145
+ downsampling occurs in the inner-two dimensions.
146
+ """
147
+
148
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
149
+ super().__init__()
150
+ self.channels = channels
151
+ self.out_channels = out_channels or channels
152
+ self.use_conv = use_conv
153
+ self.dims = dims
154
+ stride = 2 if dims != 3 else (1, 2, 2)
155
+ if use_conv:
156
+ self.op = conv_nd(
157
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
158
+ )
159
+ else:
160
+ assert self.channels == self.out_channels
161
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
162
+
163
+ def forward(self, x):
164
+ assert x.shape[1] == self.channels
165
+ return self.op(x)
166
+
167
+
168
+ class ResBlock(TimestepBlock):
169
+ """
170
+ A residual block that can optionally change the number of channels.
171
+ :param channels: the number of input channels.
172
+ :param emb_channels: the number of timestep embedding channels.
173
+ :param dropout: the rate of dropout.
174
+ :param out_channels: if specified, the number of out channels.
175
+ :param use_conv: if True and out_channels is specified, use a spatial
176
+ convolution instead of a smaller 1x1 convolution to change the
177
+ channels in the skip connection.
178
+ :param dims: determines if the signal is 1D, 2D, or 3D.
179
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
180
+ :param up: if True, use this block for upsampling.
181
+ :param down: if True, use this block for downsampling.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ channels,
187
+ emb_channels,
188
+ dropout,
189
+ out_channels=None,
190
+ use_conv=False,
191
+ use_scale_shift_norm=False,
192
+ dims=2,
193
+ use_checkpoint=False,
194
+ up=False,
195
+ down=False,
196
+ ):
197
+ super().__init__()
198
+ self.channels = channels
199
+ self.emb_channels = emb_channels
200
+ self.dropout = dropout
201
+ self.out_channels = out_channels or channels
202
+ self.use_conv = use_conv
203
+ self.use_checkpoint = use_checkpoint
204
+ self.use_scale_shift_norm = use_scale_shift_norm
205
+
206
+ self.in_layers = nn.Sequential(
207
+ normalization(channels),
208
+ nn.SiLU(),
209
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
210
+ )
211
+
212
+ self.updown = up or down
213
+
214
+ if up:
215
+ self.h_upd = Upsample(channels, False, dims)
216
+ self.x_upd = Upsample(channels, False, dims)
217
+ elif down:
218
+ self.h_upd = Downsample(channels, False, dims)
219
+ self.x_upd = Downsample(channels, False, dims)
220
+ else:
221
+ self.h_upd = self.x_upd = nn.Identity()
222
+
223
+ self.emb_layers = nn.Sequential(
224
+ nn.SiLU(),
225
+ linear(
226
+ emb_channels,
227
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
228
+ ),
229
+ )
230
+ self.out_layers = nn.Sequential(
231
+ normalization(self.out_channels),
232
+ nn.SiLU(),
233
+ nn.Dropout(p=dropout),
234
+ zero_module(
235
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
236
+ ),
237
+ )
238
+
239
+ if self.out_channels == channels:
240
+ self.skip_connection = nn.Identity()
241
+ elif use_conv:
242
+ self.skip_connection = conv_nd(
243
+ dims, channels, self.out_channels, 3, padding=1
244
+ )
245
+ else:
246
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
247
+
248
+ def forward(self, x, emb):
249
+ """
250
+ Apply the block to a Tensor, conditioned on a timestep embedding.
251
+ :param x: an [N x C x ...] Tensor of features.
252
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
253
+ :return: an [N x C x ...] Tensor of outputs.
254
+ """
255
+ return checkpoint(
256
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
257
+ )
258
+
259
+
260
+ def _forward(self, x, emb):
261
+ if self.updown:
262
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
263
+ h = in_rest(x)
264
+ h = self.h_upd(h)
265
+ x = self.x_upd(x)
266
+ h = in_conv(h)
267
+ else:
268
+ h = self.in_layers(x)
269
+ emb_out = self.emb_layers(emb).type(h.dtype)
270
+ while len(emb_out.shape) < len(h.shape):
271
+ emb_out = emb_out[..., None]
272
+ if self.use_scale_shift_norm:
273
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
274
+ scale, shift = th.chunk(emb_out, 2, dim=1)
275
+ h = out_norm(h) * (1 + scale) + shift
276
+ h = out_rest(h)
277
+ else:
278
+ h = h + emb_out
279
+ h = self.out_layers(h)
280
+ return self.skip_connection(x) + h
281
+
282
+
283
+ class AttentionBlock(nn.Module):
284
+ """
285
+ An attention block that allows spatial positions to attend to each other.
286
+ Originally ported from here, but adapted to the N-d case.
287
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
288
+ """
289
+
290
+ def __init__(
291
+ self,
292
+ channels,
293
+ num_heads=1,
294
+ num_head_channels=-1,
295
+ use_checkpoint=False,
296
+ use_new_attention_order=False,
297
+ ):
298
+ super().__init__()
299
+ self.channels = channels
300
+ if num_head_channels == -1:
301
+ self.num_heads = num_heads
302
+ else:
303
+ assert (
304
+ channels % num_head_channels == 0
305
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
306
+ self.num_heads = channels // num_head_channels
307
+ self.use_checkpoint = use_checkpoint
308
+ self.norm = normalization(channels)
309
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
310
+ if use_new_attention_order:
311
+ # split qkv before split heads
312
+ self.attention = QKVAttention(self.num_heads)
313
+ else:
314
+ # split heads before split qkv
315
+ self.attention = QKVAttentionLegacy(self.num_heads)
316
+
317
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
318
+
319
+ def forward(self, x):
320
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
321
+ #return pt_checkpoint(self._forward, x) # pytorch
322
+
323
+ def _forward(self, x):
324
+ b, c, *spatial = x.shape
325
+ x = x.reshape(b, c, -1)
326
+ qkv = self.qkv(self.norm(x))
327
+ h = self.attention(qkv)
328
+ h = self.proj_out(h)
329
+ return (x + h).reshape(b, c, *spatial)
330
+
331
+
332
+ def count_flops_attn(model, _x, y):
333
+ """
334
+ A counter for the `thop` package to count the operations in an
335
+ attention operation.
336
+ Meant to be used like:
337
+ macs, params = thop.profile(
338
+ model,
339
+ inputs=(inputs, timestamps),
340
+ custom_ops={QKVAttention: QKVAttention.count_flops},
341
+ )
342
+ """
343
+ b, c, *spatial = y[0].shape
344
+ num_spatial = int(np.prod(spatial))
345
+ # We perform two matmuls with the same number of ops.
346
+ # The first computes the weight matrix, the second computes
347
+ # the combination of the value vectors.
348
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
349
+ model.total_ops += th.DoubleTensor([matmul_ops])
350
+
351
+
352
+ class QKVAttentionLegacy(nn.Module):
353
+ """
354
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
355
+ """
356
+
357
+ def __init__(self, n_heads):
358
+ super().__init__()
359
+ self.n_heads = n_heads
360
+
361
+ def forward(self, qkv):
362
+ """
363
+ Apply QKV attention.
364
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
365
+ :return: an [N x (H * C) x T] tensor after attention.
366
+ """
367
+ bs, width, length = qkv.shape
368
+ assert width % (3 * self.n_heads) == 0
369
+ ch = width // (3 * self.n_heads)
370
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
371
+ scale = 1 / math.sqrt(math.sqrt(ch))
372
+ weight = th.einsum(
373
+ "bct,bcs->bts", q * scale, k * scale
374
+ ) # More stable with f16 than dividing afterwards
375
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
376
+ a = th.einsum("bts,bcs->bct", weight, v)
377
+ return a.reshape(bs, -1, length)
378
+
379
+ @staticmethod
380
+ def count_flops(model, _x, y):
381
+ return count_flops_attn(model, _x, y)
382
+
383
+
384
+ class QKVAttention(nn.Module):
385
+ """
386
+ A module which performs QKV attention and splits in a different order.
387
+ """
388
+
389
+ def __init__(self, n_heads):
390
+ super().__init__()
391
+ self.n_heads = n_heads
392
+
393
+ def forward(self, qkv):
394
+ """
395
+ Apply QKV attention.
396
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
397
+ :return: an [N x (H * C) x T] tensor after attention.
398
+ """
399
+ bs, width, length = qkv.shape
400
+ assert width % (3 * self.n_heads) == 0
401
+ ch = width // (3 * self.n_heads)
402
+ q, k, v = qkv.chunk(3, dim=1)
403
+ scale = 1 / math.sqrt(math.sqrt(ch))
404
+ weight = th.einsum(
405
+ "bct,bcs->bts",
406
+ (q * scale).view(bs * self.n_heads, ch, length),
407
+ (k * scale).view(bs * self.n_heads, ch, length),
408
+ ) # More stable with f16 than dividing afterwards
409
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
410
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
411
+ return a.reshape(bs, -1, length)
412
+
413
+ @staticmethod
414
+ def count_flops(model, _x, y):
415
+ return count_flops_attn(model, _x, y)
416
+
417
+
418
+ class UNetModel(nn.Module):
419
+ """
420
+ The full UNet model with attention and timestep embedding.
421
+ :param in_channels: channels in the input Tensor.
422
+ :param model_channels: base channel count for the model.
423
+ :param out_channels: channels in the output Tensor.
424
+ :param num_res_blocks: number of residual blocks per downsample.
425
+ :param attention_resolutions: a collection of downsample rates at which
426
+ attention will take place. May be a set, list, or tuple.
427
+ For example, if this contains 4, then at 4x downsampling, attention
428
+ will be used.
429
+ :param dropout: the dropout probability.
430
+ :param channel_mult: channel multiplier for each level of the UNet.
431
+ :param conv_resample: if True, use learned convolutions for upsampling and
432
+ downsampling.
433
+ :param dims: determines if the signal is 1D, 2D, or 3D.
434
+ :param num_classes: if specified (as an int), then this model will be
435
+ class-conditional with `num_classes` classes.
436
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
437
+ :param num_heads: the number of attention heads in each attention layer.
438
+ :param num_heads_channels: if specified, ignore num_heads and instead use
439
+ a fixed channel width per attention head.
440
+ :param num_heads_upsample: works with num_heads to set a different number
441
+ of heads for upsampling. Deprecated.
442
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
443
+ :param resblock_updown: use residual blocks for up/downsampling.
444
+ :param use_new_attention_order: use a different attention pattern for potentially
445
+ increased efficiency.
446
+ """
447
+
448
+ def __init__(
449
+ self,
450
+ image_size,
451
+ in_channels,
452
+ model_channels,
453
+ out_channels,
454
+ num_res_blocks,
455
+ attention_resolutions,
456
+ dropout=0,
457
+ channel_mult=(1, 2, 4, 8),
458
+ conv_resample=True,
459
+ dims=2,
460
+ num_classes=None,
461
+ use_checkpoint=False,
462
+ use_fp16=False,
463
+ num_heads=-1,
464
+ num_head_channels=-1,
465
+ num_heads_upsample=-1,
466
+ use_scale_shift_norm=False,
467
+ resblock_updown=False,
468
+ use_new_attention_order=False,
469
+ use_spatial_transformer=False, # custom transformer support
470
+ transformer_depth=1, # custom transformer support
471
+ context_dim=None, # custom transformer support
472
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
473
+ legacy=True,
474
+ disable_self_attentions=None,
475
+ num_attention_blocks=None,
476
+ disable_middle_self_attn=False,
477
+ use_linear_in_transformer=False,
478
+ ):
479
+ super().__init__()
480
+ if use_spatial_transformer:
481
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
482
+
483
+ if context_dim is not None:
484
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
485
+ from omegaconf.listconfig import ListConfig
486
+ if type(context_dim) == ListConfig:
487
+ context_dim = list(context_dim)
488
+
489
+ if num_heads_upsample == -1:
490
+ num_heads_upsample = num_heads
491
+
492
+ if num_heads == -1:
493
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
494
+
495
+ if num_head_channels == -1:
496
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
497
+
498
+ self.image_size = image_size
499
+ self.in_channels = in_channels
500
+ self.model_channels = model_channels
501
+ self.out_channels = out_channels
502
+ if isinstance(num_res_blocks, int):
503
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
504
+ else:
505
+ if len(num_res_blocks) != len(channel_mult):
506
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
507
+ "as a list/tuple (per-level) with the same length as channel_mult")
508
+ self.num_res_blocks = num_res_blocks
509
+ if disable_self_attentions is not None:
510
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
511
+ assert len(disable_self_attentions) == len(channel_mult)
512
+ if num_attention_blocks is not None:
513
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
514
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
515
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
516
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
517
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
518
+ f"attention will still not be set.")
519
+
520
+ self.attention_resolutions = attention_resolutions
521
+ self.dropout = dropout
522
+ self.channel_mult = channel_mult
523
+ self.conv_resample = conv_resample
524
+ self.num_classes = num_classes
525
+ self.use_checkpoint = use_checkpoint
526
+ self.dtype = th.float16 if use_fp16 else th.float32
527
+ self.num_heads = num_heads
528
+ self.num_head_channels = num_head_channels
529
+ self.num_heads_upsample = num_heads_upsample
530
+ self.predict_codebook_ids = n_embed is not None
531
+
532
+ time_embed_dim = model_channels * 4
533
+ self.time_embed = nn.Sequential(
534
+ linear(model_channels, time_embed_dim),
535
+ nn.SiLU(),
536
+ linear(time_embed_dim, time_embed_dim),
537
+ )
538
+
539
+ if self.num_classes is not None:
540
+ if isinstance(self.num_classes, int):
541
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
542
+ elif self.num_classes == "continuous":
543
+ print("setting up linear c_adm embedding layer")
544
+ self.label_emb = nn.Linear(1, time_embed_dim)
545
+ else:
546
+ raise ValueError()
547
+
548
+ self.input_blocks = nn.ModuleList(
549
+ [
550
+ TimestepEmbedSequential(
551
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
552
+ )
553
+ ]
554
+ )
555
+ self._feature_size = model_channels
556
+ input_block_chans = [model_channels]
557
+ ch = model_channels
558
+ ds = 1
559
+ for level, mult in enumerate(channel_mult):
560
+ for nr in range(self.num_res_blocks[level]):
561
+ layers = [
562
+ ResBlock(
563
+ ch,
564
+ time_embed_dim,
565
+ dropout,
566
+ out_channels=mult * model_channels,
567
+ dims=dims,
568
+ use_checkpoint=use_checkpoint,
569
+ use_scale_shift_norm=use_scale_shift_norm,
570
+ )
571
+ ]
572
+ ch = mult * model_channels
573
+ if ds in attention_resolutions:
574
+ if num_head_channels == -1:
575
+ dim_head = ch // num_heads
576
+ else:
577
+ num_heads = ch // num_head_channels
578
+ dim_head = num_head_channels
579
+ if legacy:
580
+ #num_heads = 1
581
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
582
+ if exists(disable_self_attentions):
583
+ disabled_sa = disable_self_attentions[level]
584
+ else:
585
+ disabled_sa = False
586
+ # disabled_sa = True
587
+
588
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
589
+ layers.append(
590
+ AttentionBlock(
591
+ ch,
592
+ use_checkpoint=use_checkpoint,
593
+ num_heads=num_heads,
594
+ num_head_channels=dim_head,
595
+ use_new_attention_order=use_new_attention_order,
596
+ ) if not use_spatial_transformer else SpatialTransformer(
597
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
598
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
599
+ use_checkpoint=use_checkpoint
600
+ )
601
+ )
602
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
603
+ self._feature_size += ch
604
+ input_block_chans.append(ch)
605
+ if level != len(channel_mult) - 1:
606
+ out_ch = ch
607
+ self.input_blocks.append(
608
+ TimestepEmbedSequential(
609
+ ResBlock(
610
+ ch,
611
+ time_embed_dim,
612
+ dropout,
613
+ out_channels=out_ch,
614
+ dims=dims,
615
+ use_checkpoint=use_checkpoint,
616
+ use_scale_shift_norm=use_scale_shift_norm,
617
+ down=True,
618
+ )
619
+ if resblock_updown
620
+ else Downsample(
621
+ ch, conv_resample, dims=dims, out_channels=out_ch
622
+ )
623
+ )
624
+ )
625
+ ch = out_ch
626
+ input_block_chans.append(ch)
627
+ ds *= 2
628
+ self._feature_size += ch
629
+
630
+ if num_head_channels == -1:
631
+ dim_head = ch // num_heads
632
+ else:
633
+ num_heads = ch // num_head_channels
634
+ dim_head = num_head_channels
635
+ if legacy:
636
+ #num_heads = 1
637
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
638
+ self.middle_block = TimestepEmbedSequential(
639
+ ResBlock(
640
+ ch,
641
+ time_embed_dim,
642
+ dropout,
643
+ dims=dims,
644
+ use_checkpoint=use_checkpoint,
645
+ use_scale_shift_norm=use_scale_shift_norm,
646
+ ),
647
+ AttentionBlock(
648
+ ch,
649
+ use_checkpoint=use_checkpoint,
650
+ num_heads=num_heads,
651
+ num_head_channels=dim_head,
652
+ use_new_attention_order=use_new_attention_order,
653
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
654
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
655
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
656
+ use_checkpoint=use_checkpoint
657
+ ),
658
+ ResBlock(
659
+ ch,
660
+ time_embed_dim,
661
+ dropout,
662
+ dims=dims,
663
+ use_checkpoint=use_checkpoint,
664
+ use_scale_shift_norm=use_scale_shift_norm,
665
+ ),
666
+ )
667
+ self._feature_size += ch
668
+
669
+ self.output_blocks = nn.ModuleList([])
670
+ for level, mult in list(enumerate(channel_mult))[::-1]:
671
+ for i in range(self.num_res_blocks[level] + 1):
672
+ ich = input_block_chans.pop()
673
+ layers = [
674
+ ResBlock(
675
+ ch + ich,
676
+ time_embed_dim,
677
+ dropout,
678
+ out_channels=model_channels * mult,
679
+ dims=dims,
680
+ use_checkpoint=use_checkpoint,
681
+ use_scale_shift_norm=use_scale_shift_norm,
682
+ )
683
+ ]
684
+ ch = model_channels * mult
685
+ if ds in attention_resolutions:
686
+ if num_head_channels == -1:
687
+ dim_head = ch // num_heads
688
+ else:
689
+ num_heads = ch // num_head_channels
690
+ dim_head = num_head_channels
691
+ if legacy:
692
+ #num_heads = 1
693
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
694
+ if exists(disable_self_attentions):
695
+ disabled_sa = disable_self_attentions[level]
696
+ else:
697
+ disabled_sa = False
698
+
699
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
700
+ layers.append(
701
+ AttentionBlock(
702
+ ch,
703
+ use_checkpoint=use_checkpoint,
704
+ num_heads=num_heads_upsample,
705
+ num_head_channels=dim_head,
706
+ use_new_attention_order=use_new_attention_order,
707
+ ) if not use_spatial_transformer else SpatialTransformer(
708
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
709
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
710
+ use_checkpoint=use_checkpoint
711
+ )
712
+ )
713
+ if level and i == self.num_res_blocks[level]:
714
+ out_ch = ch
715
+ layers.append(
716
+ ResBlock(
717
+ ch,
718
+ time_embed_dim,
719
+ dropout,
720
+ out_channels=out_ch,
721
+ dims=dims,
722
+ use_checkpoint=use_checkpoint,
723
+ use_scale_shift_norm=use_scale_shift_norm,
724
+ up=True,
725
+ )
726
+ if resblock_updown
727
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
728
+ )
729
+ ds //= 2
730
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
731
+ self._feature_size += ch
732
+
733
+ self.out = nn.Sequential(
734
+ normalization(ch),
735
+ nn.SiLU(),
736
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
737
+ )
738
+ if self.predict_codebook_ids:
739
+ self.id_predictor = nn.Sequential(
740
+ normalization(ch),
741
+ conv_nd(dims, model_channels, n_embed, 1),
742
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
743
+ )
744
+
745
+ def convert_to_fp16(self):
746
+ """
747
+ Convert the torso of the model to float16.
748
+ """
749
+ self.input_blocks.apply(convert_module_to_f16)
750
+ self.middle_block.apply(convert_module_to_f16)
751
+ self.output_blocks.apply(convert_module_to_f16)
752
+
753
+ def convert_to_fp32(self):
754
+ """
755
+ Convert the torso of the model to float32.
756
+ """
757
+ self.input_blocks.apply(convert_module_to_f32)
758
+ self.middle_block.apply(convert_module_to_f32)
759
+ self.output_blocks.apply(convert_module_to_f32)
760
+
761
+ def forward(self, x, timesteps=None, context=None, y=None, encode=False, encode_uncon=False, decode_uncon=False, controller=None, inject=False, **kwargs):
762
+ """
763
+ Apply the model to an input batch.
764
+ :param x: an [N x C x ...] Tensor of inputs.
765
+ :param timesteps: a 1-D batch of timesteps.
766
+ :param context: conditioning plugged in via crossattn
767
+ :param y: an [N] Tensor of labels, if class-conditional.
768
+ :return: an [N x C x ...] Tensor of outputs.
769
+ """
770
+ assert (y is not None) == (
771
+ self.num_classes is not None
772
+ ), "must specify y if and only if the model is class-conditional"
773
+ hs = []
774
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
775
+ emb = self.time_embed(t_emb)
776
+
777
+ if self.num_classes is not None:
778
+ assert y.shape[0] == x.shape[0]
779
+ emb = emb + self.label_emb(y)
780
+
781
+ layernum = 0
782
+
783
+ h = x.type(self.dtype)
784
+ for module in self.input_blocks:
785
+ h, layernum = module(h, emb, context, encode=encode, encode_uncon=encode_uncon, decode_uncon=decode_uncon, controller=controller, inject=inject, layernum=layernum)
786
+ hs.append(h)
787
+ # print(layernum)
788
+
789
+ layernum = 0
790
+ h, layernum = self.middle_block(h, emb, context, encode=encode, encode_uncon=encode_uncon, decode_uncon=decode_uncon, controller=controller, inject=inject, layernum=layernum)
791
+ # print(layernum)
792
+
793
+ layernum = 0
794
+ for module in self.output_blocks:
795
+ h = th.cat([h, hs.pop()], dim=1)
796
+ h, layernum = module(h, emb, context, encode=encode, encode_uncon=encode_uncon, decode_uncon=decode_uncon, controller=controller, inject=inject, layernum=layernum)
797
+ # print(layernum)
798
+
799
+ h = h.type(x.dtype)
800
+ if self.predict_codebook_ids:
801
+ return self.id_predictor(h)
802
+ else:
803
+ return self.out(h)
ldm/modules/diffusionmodules/upscaling.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from functools import partial
5
+
6
+ from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7
+ from ldm.util import default
8
+
9
+
10
+ class AbstractLowScaleModel(nn.Module):
11
+ # for concatenating a downsampled image to the latent representation
12
+ def __init__(self, noise_schedule_config=None):
13
+ super(AbstractLowScaleModel, self).__init__()
14
+ if noise_schedule_config is not None:
15
+ self.register_schedule(**noise_schedule_config)
16
+
17
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
18
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20
+ cosine_s=cosine_s)
21
+ alphas = 1. - betas
22
+ alphas_cumprod = np.cumprod(alphas, axis=0)
23
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24
+
25
+ timesteps, = betas.shape
26
+ self.num_timesteps = int(timesteps)
27
+ self.linear_start = linear_start
28
+ self.linear_end = linear_end
29
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30
+
31
+ to_torch = partial(torch.tensor, dtype=torch.float32)
32
+
33
+ self.register_buffer('betas', to_torch(betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43
+
44
+ def q_sample(self, x_start, t, noise=None):
45
+ noise = default(noise, lambda: torch.randn_like(x_start))
46
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48
+
49
+ def forward(self, x):
50
+ return x, None
51
+
52
+ def decode(self, x):
53
+ return x
54
+
55
+
56
+ class SimpleImageConcat(AbstractLowScaleModel):
57
+ # no noise level conditioning
58
+ def __init__(self):
59
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60
+ self.max_noise_level = 0
61
+
62
+ def forward(self, x):
63
+ # fix to constant noise level
64
+ return x, torch.zeros(x.shape[0], device=x.device).long()
65
+
66
+
67
+ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69
+ super().__init__(noise_schedule_config=noise_schedule_config)
70
+ self.max_noise_level = max_noise_level
71
+
72
+ def forward(self, x, noise_level=None):
73
+ if noise_level is None:
74
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75
+ else:
76
+ assert isinstance(noise_level, torch.Tensor)
77
+ z = self.q_sample(x, noise_level)
78
+ return z, noise_level
79
+
80
+
81
+
ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ steps_out = ddim_timesteps + 1
58
+ if verbose:
59
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
60
+ return steps_out
61
+
62
+
63
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64
+ # select alphas for computing the variance schedule
65
+ if ddim_timesteps[-1] == 1000:
66
+ ddim_timesteps = ddim_timesteps - 1
67
+ alphas = alphacums[ddim_timesteps]
68
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
69
+ alphas_next = np.asarray(alphacums[ddim_timesteps[1:]].tolist() + [alphacums[-1].tolist()])
70
+
71
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
72
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
73
+ if verbose:
74
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
75
+ print(f'For the chosen value of eta, which is {eta}, '
76
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
77
+ return sigmas, alphas, alphas_prev, alphas_next
78
+
79
+
80
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
81
+ """
82
+ Create a beta schedule that discretizes the given alpha_t_bar function,
83
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
84
+ :param num_diffusion_timesteps: the number of betas to produce.
85
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
86
+ produces the cumulative product of (1-beta) up to that
87
+ part of the diffusion process.
88
+ :param max_beta: the maximum beta to use; use values lower than 1 to
89
+ prevent singularities.
90
+ """
91
+ betas = []
92
+ for i in range(num_diffusion_timesteps):
93
+ t1 = i / num_diffusion_timesteps
94
+ t2 = (i + 1) / num_diffusion_timesteps
95
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
96
+ return np.array(betas)
97
+
98
+
99
+ def extract_into_tensor(a, t, x_shape):
100
+ b, *_ = t.shape
101
+ out = a.gather(-1, t)
102
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
103
+
104
+
105
+ def checkpoint(func, inputs, params, flag):
106
+ """
107
+ Evaluate a function without caching intermediate activations, allowing for
108
+ reduced memory at the expense of extra compute in the backward pass.
109
+ :param func: the function to evaluate.
110
+ :param inputs: the argument sequence to pass to `func`.
111
+ :param params: a sequence of parameters `func` depends on but does not
112
+ explicitly take as arguments.
113
+ :param flag: if False, disable gradient checkpointing.
114
+ """
115
+ if flag:
116
+ args = tuple(inputs) + tuple(params)
117
+ return CheckpointFunction.apply(func, len(inputs), *args)
118
+ else:
119
+ return func(*inputs)
120
+
121
+
122
+ class CheckpointFunction(torch.autograd.Function):
123
+ @staticmethod
124
+ def forward(ctx, run_function, length, *args):
125
+ ctx.run_function = run_function
126
+ ctx.input_tensors = list(args[:length])
127
+ ctx.input_params = list(args[length:])
128
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
129
+ "dtype": torch.get_autocast_gpu_dtype(),
130
+ "cache_enabled": torch.is_autocast_cache_enabled()}
131
+ with torch.no_grad():
132
+ output_tensors = ctx.run_function(*ctx.input_tensors)
133
+ return output_tensors
134
+
135
+ @staticmethod
136
+ def backward(ctx, *output_grads):
137
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
138
+ with torch.enable_grad(), \
139
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
140
+ # Fixes a bug where the first op in run_function modifies the
141
+ # Tensor storage in place, which is not allowed for detach()'d
142
+ # Tensors.
143
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
144
+ output_tensors = ctx.run_function(*shallow_copies)
145
+ input_grads = torch.autograd.grad(
146
+ output_tensors,
147
+ ctx.input_tensors + ctx.input_params,
148
+ output_grads,
149
+ allow_unused=True,
150
+ )
151
+ del ctx.input_tensors
152
+ del ctx.input_params
153
+ del output_tensors
154
+ return (None, None) + input_grads
155
+
156
+
157
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
158
+ """
159
+ Create sinusoidal timestep embeddings.
160
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
161
+ These may be fractional.
162
+ :param dim: the dimension of the output.
163
+ :param max_period: controls the minimum frequency of the embeddings.
164
+ :return: an [N x dim] Tensor of positional embeddings.
165
+ """
166
+ if not repeat_only:
167
+ half = dim // 2
168
+ freqs = torch.exp(
169
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
170
+ ).to(device=timesteps.device)
171
+ args = timesteps[:, None].float() * freqs[None]
172
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
173
+ if dim % 2:
174
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
175
+ else:
176
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
177
+ return embedding
178
+
179
+
180
+ def zero_module(module):
181
+ """
182
+ Zero out the parameters of a module and return it.
183
+ """
184
+ for p in module.parameters():
185
+ p.detach().zero_()
186
+ return module
187
+
188
+
189
+ def scale_module(module, scale):
190
+ """
191
+ Scale the parameters of a module and return it.
192
+ """
193
+ for p in module.parameters():
194
+ p.detach().mul_(scale)
195
+ return module
196
+
197
+
198
+ def mean_flat(tensor):
199
+ """
200
+ Take the mean over all non-batch dimensions.
201
+ """
202
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
203
+
204
+
205
+ def normalization(channels):
206
+ """
207
+ Make a standard normalization layer.
208
+ :param channels: number of input channels.
209
+ :return: an nn.Module for normalization.
210
+ """
211
+ return GroupNorm32(32, channels)
212
+
213
+
214
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
215
+ class SiLU(nn.Module):
216
+ def forward(self, x):
217
+ return x * torch.sigmoid(x)
218
+
219
+
220
+ class GroupNorm32(nn.GroupNorm):
221
+ def forward(self, x):
222
+ return super().forward(x.float()).type(x.dtype)
223
+
224
+ def conv_nd(dims, *args, **kwargs):
225
+ """
226
+ Create a 1D, 2D, or 3D convolution module.
227
+ """
228
+ if dims == 1:
229
+ return nn.Conv1d(*args, **kwargs)
230
+ elif dims == 2:
231
+ return nn.Conv2d(*args, **kwargs)
232
+ elif dims == 3:
233
+ return nn.Conv3d(*args, **kwargs)
234
+ raise ValueError(f"unsupported dimensions: {dims}")
235
+
236
+
237
+ def linear(*args, **kwargs):
238
+ """
239
+ Create a linear module.
240
+ """
241
+ return nn.Linear(*args, **kwargs)
242
+
243
+
244
+ def avg_pool_nd(dims, *args, **kwargs):
245
+ """
246
+ Create a 1D, 2D, or 3D average pooling module.
247
+ """
248
+ if dims == 1:
249
+ return nn.AvgPool1d(*args, **kwargs)
250
+ elif dims == 2:
251
+ return nn.AvgPool2d(*args, **kwargs)
252
+ elif dims == 3:
253
+ return nn.AvgPool3d(*args, **kwargs)
254
+ raise ValueError(f"unsupported dimensions: {dims}")
255
+
256
+
257
+ class HybridConditioner(nn.Module):
258
+
259
+ def __init__(self, c_concat_config, c_crossattn_config):
260
+ super().__init__()
261
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
262
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
263
+
264
+ def forward(self, c_concat, c_crossattn):
265
+ c_concat = self.concat_conditioner(c_concat)
266
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
267
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
268
+
269
+
270
+ def noise_like(shape, device, repeat=False):
271
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
272
+ noise = lambda: torch.randn(shape, device=device)
273
+ return repeat_noise() if repeat else noise()
ldm/modules/distributions/__init__.py ADDED
File without changes
ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
ldm/modules/ema.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1, dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ # remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.', '')
20
+ self.m_name2s_name.update({name: s_name})
21
+ self.register_buffer(s_name, p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def reset_num_updates(self):
26
+ del self.num_updates
27
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28
+
29
+ def forward(self, model):
30
+ decay = self.decay
31
+
32
+ if self.num_updates >= 0:
33
+ self.num_updates += 1
34
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35
+
36
+ one_minus_decay = 1.0 - decay
37
+
38
+ with torch.no_grad():
39
+ m_param = dict(model.named_parameters())
40
+ shadow_params = dict(self.named_buffers())
41
+
42
+ for key in m_param:
43
+ if m_param[key].requires_grad:
44
+ sname = self.m_name2s_name[key]
45
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47
+ else:
48
+ assert not key in self.m_name2s_name
49
+
50
+ def copy_to(self, model):
51
+ m_param = dict(model.named_parameters())
52
+ shadow_params = dict(self.named_buffers())
53
+ for key in m_param:
54
+ if m_param[key].requires_grad:
55
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56
+ else:
57
+ assert not key in self.m_name2s_name
58
+
59
+ def store(self, parameters):
60
+ """
61
+ Save the current parameters for restoring later.
62
+ Args:
63
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64
+ temporarily stored.
65
+ """
66
+ self.collected_params = [param.clone() for param in parameters]
67
+
68
+ def restore(self, parameters):
69
+ """
70
+ Restore the parameters stored with the `store` method.
71
+ Useful to validate the model with EMA parameters without affecting the
72
+ original optimization process. Store the parameters before the
73
+ `copy_to` method. After validation (or model saving), use this to
74
+ restore the former parameters.
75
+ Args:
76
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77
+ updated with the stored parameters.
78
+ """
79
+ for c_param, param in zip(self.collected_params, parameters):
80
+ param.data.copy_(c_param.data)
ldm/modules/encoders/__init__.py ADDED
File without changes
ldm/modules/encoders/modules.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6
+
7
+ import open_clip
8
+ from ldm.util import default, count_params
9
+ import einops
10
+
11
+ class AbstractEncoder(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def encode(self, *args, **kwargs):
16
+ raise NotImplementedError
17
+
18
+
19
+ class IdentityEncoder(AbstractEncoder):
20
+
21
+ def encode(self, x):
22
+ return x
23
+
24
+
25
+ class ClassEmbedder(nn.Module):
26
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
27
+ super().__init__()
28
+ self.key = key
29
+ self.embedding = nn.Embedding(n_classes, embed_dim)
30
+ self.n_classes = n_classes
31
+ self.ucg_rate = ucg_rate
32
+
33
+ def forward(self, batch, key=None, disable_dropout=False):
34
+ if key is None:
35
+ key = self.key
36
+ # this is for use in crossattn
37
+ c = batch[key][:, None]
38
+ if self.ucg_rate > 0. and not disable_dropout:
39
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
40
+ c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
41
+ c = c.long()
42
+ c = self.embedding(c)
43
+ return c
44
+
45
+ def get_unconditional_conditioning(self, bs, device="cuda"):
46
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
47
+ uc = torch.ones((bs,), device=device) * uc_class
48
+ uc = {self.key: uc}
49
+ return uc
50
+
51
+
52
+ def disabled_train(self, mode=True):
53
+ """Overwrite model.train with this function to make sure train/eval mode
54
+ does not change anymore."""
55
+ return self
56
+
57
+
58
+ class FrozenT5Embedder(AbstractEncoder):
59
+ """Uses the T5 transformer encoder for text"""
60
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
61
+ super().__init__()
62
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
63
+ self.transformer = T5EncoderModel.from_pretrained(version)
64
+ self.device = device
65
+ self.max_length = max_length # TODO: typical value?
66
+ if freeze:
67
+ self.freeze()
68
+
69
+ def freeze(self):
70
+ self.transformer = self.transformer.eval()
71
+ #self.train = disabled_train
72
+ for param in self.parameters():
73
+ param.requires_grad = False
74
+
75
+ def forward(self, text):
76
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
77
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
78
+ tokens = batch_encoding["input_ids"].to(self.device)
79
+ outputs = self.transformer(input_ids=tokens)
80
+
81
+ z = outputs.last_hidden_state
82
+ return z
83
+
84
+ def encode(self, text):
85
+ return self(text)
86
+
87
+
88
+ class FrozenCLIPEmbedder(AbstractEncoder):
89
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
90
+ LAYERS = [
91
+ "last",
92
+ "pooled",
93
+ "hidden"
94
+ ]
95
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
96
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
97
+ super().__init__()
98
+ assert layer in self.LAYERS
99
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
100
+ self.transformer = CLIPTextModel.from_pretrained(version)
101
+ self.device = device
102
+ self.max_length = max_length
103
+ if freeze:
104
+ self.freeze()
105
+ self.layer = layer
106
+ self.layer_idx = layer_idx
107
+ if layer == "hidden":
108
+ assert layer_idx is not None
109
+ assert 0 <= abs(layer_idx) <= 12
110
+
111
+ def freeze(self):
112
+ self.transformer = self.transformer.eval()
113
+ #self.train = disabled_train
114
+ for param in self.parameters():
115
+ param.requires_grad = False
116
+
117
+ def forward(self, text):
118
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
119
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
120
+ tokens = batch_encoding["input_ids"].to(self.device)
121
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
122
+ if self.layer == "last":
123
+ z = outputs.last_hidden_state
124
+ elif self.layer == "pooled":
125
+ z = outputs.pooler_output[:, None, :]
126
+ else:
127
+ z = outputs.hidden_states[self.layer_idx]
128
+ return z
129
+
130
+ def encode(self, text):
131
+ return self(text)
132
+
133
+
134
+ class FrozenOpenCLIPEmbedder(AbstractEncoder):
135
+ """
136
+ Uses the OpenCLIP transformer encoder for text
137
+ """
138
+ LAYERS = [
139
+ #"pooled",
140
+ "last",
141
+ "penultimate"
142
+ ]
143
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
144
+ freeze=True, layer="last"):
145
+ super().__init__()
146
+ assert layer in self.LAYERS
147
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
148
+ del model.visual
149
+ self.model = model
150
+
151
+ self.device = device
152
+ self.max_length = max_length
153
+ if freeze:
154
+ self.freeze()
155
+ self.layer = layer
156
+ if self.layer == "last":
157
+ self.layer_idx = 0
158
+ elif self.layer == "penultimate":
159
+ self.layer_idx = 1
160
+ else:
161
+ raise NotImplementedError()
162
+
163
+ def freeze(self):
164
+ self.model = self.model.eval()
165
+ for param in self.parameters():
166
+ param.requires_grad = False
167
+
168
+ def forward(self, text, inv=False):
169
+ tokens = open_clip.tokenize(text)
170
+ if inv:
171
+ tokens[0] = torch.zeros(77) + 7788
172
+ z = self.encode_with_transformer(tokens.to(self.device), inv)
173
+ else:
174
+ z = self.encode_with_transformer(tokens.to(self.device))
175
+ return z
176
+
177
+ def encode_with_transformer(self, text, inv=False):
178
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
179
+ if inv == False:
180
+ # x = einops.repeat(x[:,0], 'i j -> i c j', c=77)
181
+ x = x + self.model.positional_embedding
182
+ x = x.permute(1, 0, 2) # NLD -> LND
183
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
184
+ x = x.permute(1, 0, 2) # LND -> NLD
185
+ x = self.model.ln_final(x)
186
+ return x
187
+
188
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
189
+ for i, r in enumerate(self.model.transformer.resblocks):
190
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
191
+ break
192
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
193
+ x = checkpoint(r, x, attn_mask)
194
+ else:
195
+ x = r(x, attn_mask=attn_mask)
196
+ return x
197
+
198
+ def encode(self, text, inv=False, device=None):
199
+ if device is not None:
200
+ self.device = device
201
+ return self(text, inv)
202
+
203
+
204
+ class FrozenCLIPT5Encoder(AbstractEncoder):
205
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
206
+ clip_max_length=77, t5_max_length=77):
207
+ super().__init__()
208
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
209
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
210
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
211
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
212
+
213
+ def encode(self, text):
214
+ return self(text)
215
+
216
+ def forward(self, text):
217
+ clip_z = self.clip_encoder.encode(text)
218
+ t5_z = self.t5_encoder.encode(text)
219
+ return [clip_z, t5_z]
220
+
221
+
ldm/modules/image_degradation/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2
+ from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
ldm/modules/image_degradation/bsrgan.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ # --------------------------------------------
4
+ # Super-Resolution
5
+ # --------------------------------------------
6
+ #
7
+ # Kai Zhang ([email protected])
8
+ # https://github.com/cszn
9
+ # From 2019/03--2021/08
10
+ # --------------------------------------------
11
+ """
12
+
13
+ import numpy as np
14
+ import cv2
15
+ import torch
16
+
17
+ from functools import partial
18
+ import random
19
+ from scipy import ndimage
20
+ import scipy
21
+ import scipy.stats as ss
22
+ from scipy.interpolate import interp2d
23
+ from scipy.linalg import orth
24
+ import albumentations
25
+
26
+ import ldm.modules.image_degradation.utils_image as util
27
+
28
+
29
+ def modcrop_np(img, sf):
30
+ '''
31
+ Args:
32
+ img: numpy image, WxH or WxHxC
33
+ sf: scale factor
34
+ Return:
35
+ cropped image
36
+ '''
37
+ w, h = img.shape[:2]
38
+ im = np.copy(img)
39
+ return im[:w - w % sf, :h - h % sf, ...]
40
+
41
+
42
+ """
43
+ # --------------------------------------------
44
+ # anisotropic Gaussian kernels
45
+ # --------------------------------------------
46
+ """
47
+
48
+
49
+ def analytic_kernel(k):
50
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
51
+ k_size = k.shape[0]
52
+ # Calculate the big kernels size
53
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
54
+ # Loop over the small kernel to fill the big one
55
+ for r in range(k_size):
56
+ for c in range(k_size):
57
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
58
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
59
+ crop = k_size // 2
60
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
61
+ # Normalize to 1
62
+ return cropped_big_k / cropped_big_k.sum()
63
+
64
+
65
+ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
66
+ """ generate an anisotropic Gaussian kernel
67
+ Args:
68
+ ksize : e.g., 15, kernel size
69
+ theta : [0, pi], rotation angle range
70
+ l1 : [0.1,50], scaling of eigenvalues
71
+ l2 : [0.1,l1], scaling of eigenvalues
72
+ If l1 = l2, will get an isotropic Gaussian kernel.
73
+ Returns:
74
+ k : kernel
75
+ """
76
+
77
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
78
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
79
+ D = np.array([[l1, 0], [0, l2]])
80
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
81
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
82
+
83
+ return k
84
+
85
+
86
+ def gm_blur_kernel(mean, cov, size=15):
87
+ center = size / 2.0 + 0.5
88
+ k = np.zeros([size, size])
89
+ for y in range(size):
90
+ for x in range(size):
91
+ cy = y - center + 1
92
+ cx = x - center + 1
93
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
94
+
95
+ k = k / np.sum(k)
96
+ return k
97
+
98
+
99
+ def shift_pixel(x, sf, upper_left=True):
100
+ """shift pixel for super-resolution with different scale factors
101
+ Args:
102
+ x: WxHxC or WxH
103
+ sf: scale factor
104
+ upper_left: shift direction
105
+ """
106
+ h, w = x.shape[:2]
107
+ shift = (sf - 1) * 0.5
108
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
109
+ if upper_left:
110
+ x1 = xv + shift
111
+ y1 = yv + shift
112
+ else:
113
+ x1 = xv - shift
114
+ y1 = yv - shift
115
+
116
+ x1 = np.clip(x1, 0, w - 1)
117
+ y1 = np.clip(y1, 0, h - 1)
118
+
119
+ if x.ndim == 2:
120
+ x = interp2d(xv, yv, x)(x1, y1)
121
+ if x.ndim == 3:
122
+ for i in range(x.shape[-1]):
123
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
124
+
125
+ return x
126
+
127
+
128
+ def blur(x, k):
129
+ '''
130
+ x: image, NxcxHxW
131
+ k: kernel, Nx1xhxw
132
+ '''
133
+ n, c = x.shape[:2]
134
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
135
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
136
+ k = k.repeat(1, c, 1, 1)
137
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
138
+ x = x.view(1, -1, x.shape[2], x.shape[3])
139
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
140
+ x = x.view(n, c, x.shape[2], x.shape[3])
141
+
142
+ return x
143
+
144
+
145
+ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
146
+ """"
147
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
148
+ # Kai Zhang
149
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
150
+ # max_var = 2.5 * sf
151
+ """
152
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
153
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
154
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
155
+ theta = np.random.rand() * np.pi # random theta
156
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
157
+
158
+ # Set COV matrix using Lambdas and Theta
159
+ LAMBDA = np.diag([lambda_1, lambda_2])
160
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
161
+ [np.sin(theta), np.cos(theta)]])
162
+ SIGMA = Q @ LAMBDA @ Q.T
163
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
164
+
165
+ # Set expectation position (shifting kernel for aligned image)
166
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
167
+ MU = MU[None, None, :, None]
168
+
169
+ # Create meshgrid for Gaussian
170
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
171
+ Z = np.stack([X, Y], 2)[:, :, :, None]
172
+
173
+ # Calcualte Gaussian for every pixel of the kernel
174
+ ZZ = Z - MU
175
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
176
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
177
+
178
+ # shift the kernel so it will be centered
179
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
180
+
181
+ # Normalize the kernel and return
182
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
183
+ kernel = raw_kernel / np.sum(raw_kernel)
184
+ return kernel
185
+
186
+
187
+ def fspecial_gaussian(hsize, sigma):
188
+ hsize = [hsize, hsize]
189
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
190
+ std = sigma
191
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
192
+ arg = -(x * x + y * y) / (2 * std * std)
193
+ h = np.exp(arg)
194
+ h[h < scipy.finfo(float).eps * h.max()] = 0
195
+ sumh = h.sum()
196
+ if sumh != 0:
197
+ h = h / sumh
198
+ return h
199
+
200
+
201
+ def fspecial_laplacian(alpha):
202
+ alpha = max([0, min([alpha, 1])])
203
+ h1 = alpha / (alpha + 1)
204
+ h2 = (1 - alpha) / (alpha + 1)
205
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
206
+ h = np.array(h)
207
+ return h
208
+
209
+
210
+ def fspecial(filter_type, *args, **kwargs):
211
+ '''
212
+ python code from:
213
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
214
+ '''
215
+ if filter_type == 'gaussian':
216
+ return fspecial_gaussian(*args, **kwargs)
217
+ if filter_type == 'laplacian':
218
+ return fspecial_laplacian(*args, **kwargs)
219
+
220
+
221
+ """
222
+ # --------------------------------------------
223
+ # degradation models
224
+ # --------------------------------------------
225
+ """
226
+
227
+
228
+ def bicubic_degradation(x, sf=3):
229
+ '''
230
+ Args:
231
+ x: HxWxC image, [0, 1]
232
+ sf: down-scale factor
233
+ Return:
234
+ bicubicly downsampled LR image
235
+ '''
236
+ x = util.imresize_np(x, scale=1 / sf)
237
+ return x
238
+
239
+
240
+ def srmd_degradation(x, k, sf=3):
241
+ ''' blur + bicubic downsampling
242
+ Args:
243
+ x: HxWxC image, [0, 1]
244
+ k: hxw, double
245
+ sf: down-scale factor
246
+ Return:
247
+ downsampled LR image
248
+ Reference:
249
+ @inproceedings{zhang2018learning,
250
+ title={Learning a single convolutional super-resolution network for multiple degradations},
251
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
252
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
253
+ pages={3262--3271},
254
+ year={2018}
255
+ }
256
+ '''
257
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
258
+ x = bicubic_degradation(x, sf=sf)
259
+ return x
260
+
261
+
262
+ def dpsr_degradation(x, k, sf=3):
263
+ ''' bicubic downsampling + blur
264
+ Args:
265
+ x: HxWxC image, [0, 1]
266
+ k: hxw, double
267
+ sf: down-scale factor
268
+ Return:
269
+ downsampled LR image
270
+ Reference:
271
+ @inproceedings{zhang2019deep,
272
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
273
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
274
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
275
+ pages={1671--1681},
276
+ year={2019}
277
+ }
278
+ '''
279
+ x = bicubic_degradation(x, sf=sf)
280
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
281
+ return x
282
+
283
+
284
+ def classical_degradation(x, k, sf=3):
285
+ ''' blur + downsampling
286
+ Args:
287
+ x: HxWxC image, [0, 1]/[0, 255]
288
+ k: hxw, double
289
+ sf: down-scale factor
290
+ Return:
291
+ downsampled LR image
292
+ '''
293
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
294
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
295
+ st = 0
296
+ return x[st::sf, st::sf, ...]
297
+
298
+
299
+ def add_sharpening(img, weight=0.5, radius=50, threshold=10):
300
+ """USM sharpening. borrowed from real-ESRGAN
301
+ Input image: I; Blurry image: B.
302
+ 1. K = I + weight * (I - B)
303
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
304
+ 3. Blur mask:
305
+ 4. Out = Mask * K + (1 - Mask) * I
306
+ Args:
307
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
308
+ weight (float): Sharp weight. Default: 1.
309
+ radius (float): Kernel size of Gaussian blur. Default: 50.
310
+ threshold (int):
311
+ """
312
+ if radius % 2 == 0:
313
+ radius += 1
314
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
315
+ residual = img - blur
316
+ mask = np.abs(residual) * 255 > threshold
317
+ mask = mask.astype('float32')
318
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
319
+
320
+ K = img + weight * residual
321
+ K = np.clip(K, 0, 1)
322
+ return soft_mask * K + (1 - soft_mask) * img
323
+
324
+
325
+ def add_blur(img, sf=4):
326
+ wd2 = 4.0 + sf
327
+ wd = 2.0 + 0.2 * sf
328
+ if random.random() < 0.5:
329
+ l1 = wd2 * random.random()
330
+ l2 = wd2 * random.random()
331
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
332
+ else:
333
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
334
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
335
+
336
+ return img
337
+
338
+
339
+ def add_resize(img, sf=4):
340
+ rnum = np.random.rand()
341
+ if rnum > 0.8: # up
342
+ sf1 = random.uniform(1, 2)
343
+ elif rnum < 0.7: # down
344
+ sf1 = random.uniform(0.5 / sf, 1)
345
+ else:
346
+ sf1 = 1.0
347
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
348
+ img = np.clip(img, 0.0, 1.0)
349
+
350
+ return img
351
+
352
+
353
+ # def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
354
+ # noise_level = random.randint(noise_level1, noise_level2)
355
+ # rnum = np.random.rand()
356
+ # if rnum > 0.6: # add color Gaussian noise
357
+ # img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
358
+ # elif rnum < 0.4: # add grayscale Gaussian noise
359
+ # img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
360
+ # else: # add noise
361
+ # L = noise_level2 / 255.
362
+ # D = np.diag(np.random.rand(3))
363
+ # U = orth(np.random.rand(3, 3))
364
+ # conv = np.dot(np.dot(np.transpose(U), D), U)
365
+ # img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
366
+ # img = np.clip(img, 0.0, 1.0)
367
+ # return img
368
+
369
+ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
370
+ noise_level = random.randint(noise_level1, noise_level2)
371
+ rnum = np.random.rand()
372
+ if rnum > 0.6: # add color Gaussian noise
373
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
374
+ elif rnum < 0.4: # add grayscale Gaussian noise
375
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
376
+ else: # add noise
377
+ L = noise_level2 / 255.
378
+ D = np.diag(np.random.rand(3))
379
+ U = orth(np.random.rand(3, 3))
380
+ conv = np.dot(np.dot(np.transpose(U), D), U)
381
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
382
+ img = np.clip(img, 0.0, 1.0)
383
+ return img
384
+
385
+
386
+ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
387
+ noise_level = random.randint(noise_level1, noise_level2)
388
+ img = np.clip(img, 0.0, 1.0)
389
+ rnum = random.random()
390
+ if rnum > 0.6:
391
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
392
+ elif rnum < 0.4:
393
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
394
+ else:
395
+ L = noise_level2 / 255.
396
+ D = np.diag(np.random.rand(3))
397
+ U = orth(np.random.rand(3, 3))
398
+ conv = np.dot(np.dot(np.transpose(U), D), U)
399
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
400
+ img = np.clip(img, 0.0, 1.0)
401
+ return img
402
+
403
+
404
+ def add_Poisson_noise(img):
405
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
406
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
407
+ if random.random() < 0.5:
408
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
409
+ else:
410
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
411
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
412
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
413
+ img += noise_gray[:, :, np.newaxis]
414
+ img = np.clip(img, 0.0, 1.0)
415
+ return img
416
+
417
+
418
+ def add_JPEG_noise(img):
419
+ quality_factor = random.randint(30, 95)
420
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
421
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
422
+ img = cv2.imdecode(encimg, 1)
423
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
424
+ return img
425
+
426
+
427
+ def random_crop(lq, hq, sf=4, lq_patchsize=64):
428
+ h, w = lq.shape[:2]
429
+ rnd_h = random.randint(0, h - lq_patchsize)
430
+ rnd_w = random.randint(0, w - lq_patchsize)
431
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
432
+
433
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
434
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
435
+ return lq, hq
436
+
437
+
438
+ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
439
+ """
440
+ This is the degradation model of BSRGAN from the paper
441
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
442
+ ----------
443
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
444
+ sf: scale factor
445
+ isp_model: camera ISP model
446
+ Returns
447
+ -------
448
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
449
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
450
+ """
451
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
452
+ sf_ori = sf
453
+
454
+ h1, w1 = img.shape[:2]
455
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
456
+ h, w = img.shape[:2]
457
+
458
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
459
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
460
+
461
+ hq = img.copy()
462
+
463
+ if sf == 4 and random.random() < scale2_prob: # downsample1
464
+ if np.random.rand() < 0.5:
465
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
466
+ interpolation=random.choice([1, 2, 3]))
467
+ else:
468
+ img = util.imresize_np(img, 1 / 2, True)
469
+ img = np.clip(img, 0.0, 1.0)
470
+ sf = 2
471
+
472
+ shuffle_order = random.sample(range(7), 7)
473
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
474
+ if idx1 > idx2: # keep downsample3 last
475
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
476
+
477
+ for i in shuffle_order:
478
+
479
+ if i == 0:
480
+ img = add_blur(img, sf=sf)
481
+
482
+ elif i == 1:
483
+ img = add_blur(img, sf=sf)
484
+
485
+ elif i == 2:
486
+ a, b = img.shape[1], img.shape[0]
487
+ # downsample2
488
+ if random.random() < 0.75:
489
+ sf1 = random.uniform(1, 2 * sf)
490
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
491
+ interpolation=random.choice([1, 2, 3]))
492
+ else:
493
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
494
+ k_shifted = shift_pixel(k, sf)
495
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
496
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
497
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
498
+ img = np.clip(img, 0.0, 1.0)
499
+
500
+ elif i == 3:
501
+ # downsample3
502
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
503
+ img = np.clip(img, 0.0, 1.0)
504
+
505
+ elif i == 4:
506
+ # add Gaussian noise
507
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
508
+
509
+ elif i == 5:
510
+ # add JPEG noise
511
+ if random.random() < jpeg_prob:
512
+ img = add_JPEG_noise(img)
513
+
514
+ elif i == 6:
515
+ # add processed camera sensor noise
516
+ if random.random() < isp_prob and isp_model is not None:
517
+ with torch.no_grad():
518
+ img, hq = isp_model.forward(img.copy(), hq)
519
+
520
+ # add final JPEG compression noise
521
+ img = add_JPEG_noise(img)
522
+
523
+ # random crop
524
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
525
+
526
+ return img, hq
527
+
528
+
529
+ # todo no isp_model?
530
+ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
531
+ """
532
+ This is the degradation model of BSRGAN from the paper
533
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
534
+ ----------
535
+ sf: scale factor
536
+ isp_model: camera ISP model
537
+ Returns
538
+ -------
539
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
540
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
541
+ """
542
+ image = util.uint2single(image)
543
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
544
+ sf_ori = sf
545
+
546
+ h1, w1 = image.shape[:2]
547
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
548
+ h, w = image.shape[:2]
549
+
550
+ hq = image.copy()
551
+
552
+ if sf == 4 and random.random() < scale2_prob: # downsample1
553
+ if np.random.rand() < 0.5:
554
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
555
+ interpolation=random.choice([1, 2, 3]))
556
+ else:
557
+ image = util.imresize_np(image, 1 / 2, True)
558
+ image = np.clip(image, 0.0, 1.0)
559
+ sf = 2
560
+
561
+ shuffle_order = random.sample(range(7), 7)
562
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
563
+ if idx1 > idx2: # keep downsample3 last
564
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
565
+
566
+ for i in shuffle_order:
567
+
568
+ if i == 0:
569
+ image = add_blur(image, sf=sf)
570
+
571
+ elif i == 1:
572
+ image = add_blur(image, sf=sf)
573
+
574
+ elif i == 2:
575
+ a, b = image.shape[1], image.shape[0]
576
+ # downsample2
577
+ if random.random() < 0.75:
578
+ sf1 = random.uniform(1, 2 * sf)
579
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
580
+ interpolation=random.choice([1, 2, 3]))
581
+ else:
582
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
583
+ k_shifted = shift_pixel(k, sf)
584
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
585
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
586
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
587
+ image = np.clip(image, 0.0, 1.0)
588
+
589
+ elif i == 3:
590
+ # downsample3
591
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
592
+ image = np.clip(image, 0.0, 1.0)
593
+
594
+ elif i == 4:
595
+ # add Gaussian noise
596
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
597
+
598
+ elif i == 5:
599
+ # add JPEG noise
600
+ if random.random() < jpeg_prob:
601
+ image = add_JPEG_noise(image)
602
+
603
+ # elif i == 6:
604
+ # # add processed camera sensor noise
605
+ # if random.random() < isp_prob and isp_model is not None:
606
+ # with torch.no_grad():
607
+ # img, hq = isp_model.forward(img.copy(), hq)
608
+
609
+ # add final JPEG compression noise
610
+ image = add_JPEG_noise(image)
611
+ image = util.single2uint(image)
612
+ example = {"image":image}
613
+ return example
614
+
615
+
616
+ # TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
617
+ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
618
+ """
619
+ This is an extended degradation model by combining
620
+ the degradation models of BSRGAN and Real-ESRGAN
621
+ ----------
622
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
623
+ sf: scale factor
624
+ use_shuffle: the degradation shuffle
625
+ use_sharp: sharpening the img
626
+ Returns
627
+ -------
628
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
629
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
630
+ """
631
+
632
+ h1, w1 = img.shape[:2]
633
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
634
+ h, w = img.shape[:2]
635
+
636
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
637
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
638
+
639
+ if use_sharp:
640
+ img = add_sharpening(img)
641
+ hq = img.copy()
642
+
643
+ if random.random() < shuffle_prob:
644
+ shuffle_order = random.sample(range(13), 13)
645
+ else:
646
+ shuffle_order = list(range(13))
647
+ # local shuffle for noise, JPEG is always the last one
648
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
649
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
650
+
651
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
652
+
653
+ for i in shuffle_order:
654
+ if i == 0:
655
+ img = add_blur(img, sf=sf)
656
+ elif i == 1:
657
+ img = add_resize(img, sf=sf)
658
+ elif i == 2:
659
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
660
+ elif i == 3:
661
+ if random.random() < poisson_prob:
662
+ img = add_Poisson_noise(img)
663
+ elif i == 4:
664
+ if random.random() < speckle_prob:
665
+ img = add_speckle_noise(img)
666
+ elif i == 5:
667
+ if random.random() < isp_prob and isp_model is not None:
668
+ with torch.no_grad():
669
+ img, hq = isp_model.forward(img.copy(), hq)
670
+ elif i == 6:
671
+ img = add_JPEG_noise(img)
672
+ elif i == 7:
673
+ img = add_blur(img, sf=sf)
674
+ elif i == 8:
675
+ img = add_resize(img, sf=sf)
676
+ elif i == 9:
677
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
678
+ elif i == 10:
679
+ if random.random() < poisson_prob:
680
+ img = add_Poisson_noise(img)
681
+ elif i == 11:
682
+ if random.random() < speckle_prob:
683
+ img = add_speckle_noise(img)
684
+ elif i == 12:
685
+ if random.random() < isp_prob and isp_model is not None:
686
+ with torch.no_grad():
687
+ img, hq = isp_model.forward(img.copy(), hq)
688
+ else:
689
+ print('check the shuffle!')
690
+
691
+ # resize to desired size
692
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
693
+ interpolation=random.choice([1, 2, 3]))
694
+
695
+ # add final JPEG compression noise
696
+ img = add_JPEG_noise(img)
697
+
698
+ # random crop
699
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
700
+
701
+ return img, hq
702
+
703
+
704
+ if __name__ == '__main__':
705
+ print("hey")
706
+ img = util.imread_uint('utils/test.png', 3)
707
+ print(img)
708
+ img = util.uint2single(img)
709
+ print(img)
710
+ img = img[:448, :448]
711
+ h = img.shape[0] // 4
712
+ print("resizing to", h)
713
+ sf = 4
714
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
715
+ for i in range(20):
716
+ print(i)
717
+ img_lq = deg_fn(img)
718
+ print(img_lq)
719
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
720
+ print(img_lq.shape)
721
+ print("bicubic", img_lq_bicubic.shape)
722
+ print(img_hq.shape)
723
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
724
+ interpolation=0)
725
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
726
+ interpolation=0)
727
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
728
+ util.imsave(img_concat, str(i) + '.png')
729
+
730
+
ldm/modules/image_degradation/bsrgan_light.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+
6
+ from functools import partial
7
+ import random
8
+ from scipy import ndimage
9
+ import scipy
10
+ import scipy.stats as ss
11
+ from scipy.interpolate import interp2d
12
+ from scipy.linalg import orth
13
+ import albumentations
14
+
15
+ import ldm.modules.image_degradation.utils_image as util
16
+
17
+ """
18
+ # --------------------------------------------
19
+ # Super-Resolution
20
+ # --------------------------------------------
21
+ #
22
+ # Kai Zhang ([email protected])
23
+ # https://github.com/cszn
24
+ # From 2019/03--2021/08
25
+ # --------------------------------------------
26
+ """
27
+
28
+ def modcrop_np(img, sf):
29
+ '''
30
+ Args:
31
+ img: numpy image, WxH or WxHxC
32
+ sf: scale factor
33
+ Return:
34
+ cropped image
35
+ '''
36
+ w, h = img.shape[:2]
37
+ im = np.copy(img)
38
+ return im[:w - w % sf, :h - h % sf, ...]
39
+
40
+
41
+ """
42
+ # --------------------------------------------
43
+ # anisotropic Gaussian kernels
44
+ # --------------------------------------------
45
+ """
46
+
47
+
48
+ def analytic_kernel(k):
49
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
50
+ k_size = k.shape[0]
51
+ # Calculate the big kernels size
52
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
53
+ # Loop over the small kernel to fill the big one
54
+ for r in range(k_size):
55
+ for c in range(k_size):
56
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
57
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
58
+ crop = k_size // 2
59
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
60
+ # Normalize to 1
61
+ return cropped_big_k / cropped_big_k.sum()
62
+
63
+
64
+ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
65
+ """ generate an anisotropic Gaussian kernel
66
+ Args:
67
+ ksize : e.g., 15, kernel size
68
+ theta : [0, pi], rotation angle range
69
+ l1 : [0.1,50], scaling of eigenvalues
70
+ l2 : [0.1,l1], scaling of eigenvalues
71
+ If l1 = l2, will get an isotropic Gaussian kernel.
72
+ Returns:
73
+ k : kernel
74
+ """
75
+
76
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
77
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
78
+ D = np.array([[l1, 0], [0, l2]])
79
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
80
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
81
+
82
+ return k
83
+
84
+
85
+ def gm_blur_kernel(mean, cov, size=15):
86
+ center = size / 2.0 + 0.5
87
+ k = np.zeros([size, size])
88
+ for y in range(size):
89
+ for x in range(size):
90
+ cy = y - center + 1
91
+ cx = x - center + 1
92
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
93
+
94
+ k = k / np.sum(k)
95
+ return k
96
+
97
+
98
+ def shift_pixel(x, sf, upper_left=True):
99
+ """shift pixel for super-resolution with different scale factors
100
+ Args:
101
+ x: WxHxC or WxH
102
+ sf: scale factor
103
+ upper_left: shift direction
104
+ """
105
+ h, w = x.shape[:2]
106
+ shift = (sf - 1) * 0.5
107
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
108
+ if upper_left:
109
+ x1 = xv + shift
110
+ y1 = yv + shift
111
+ else:
112
+ x1 = xv - shift
113
+ y1 = yv - shift
114
+
115
+ x1 = np.clip(x1, 0, w - 1)
116
+ y1 = np.clip(y1, 0, h - 1)
117
+
118
+ if x.ndim == 2:
119
+ x = interp2d(xv, yv, x)(x1, y1)
120
+ if x.ndim == 3:
121
+ for i in range(x.shape[-1]):
122
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
123
+
124
+ return x
125
+
126
+
127
+ def blur(x, k):
128
+ '''
129
+ x: image, NxcxHxW
130
+ k: kernel, Nx1xhxw
131
+ '''
132
+ n, c = x.shape[:2]
133
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
134
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
135
+ k = k.repeat(1, c, 1, 1)
136
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
137
+ x = x.view(1, -1, x.shape[2], x.shape[3])
138
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
139
+ x = x.view(n, c, x.shape[2], x.shape[3])
140
+
141
+ return x
142
+
143
+
144
+ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
145
+ """"
146
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
147
+ # Kai Zhang
148
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
149
+ # max_var = 2.5 * sf
150
+ """
151
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
152
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
153
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
154
+ theta = np.random.rand() * np.pi # random theta
155
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
156
+
157
+ # Set COV matrix using Lambdas and Theta
158
+ LAMBDA = np.diag([lambda_1, lambda_2])
159
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
160
+ [np.sin(theta), np.cos(theta)]])
161
+ SIGMA = Q @ LAMBDA @ Q.T
162
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
163
+
164
+ # Set expectation position (shifting kernel for aligned image)
165
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
166
+ MU = MU[None, None, :, None]
167
+
168
+ # Create meshgrid for Gaussian
169
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
170
+ Z = np.stack([X, Y], 2)[:, :, :, None]
171
+
172
+ # Calcualte Gaussian for every pixel of the kernel
173
+ ZZ = Z - MU
174
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
175
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
176
+
177
+ # shift the kernel so it will be centered
178
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
179
+
180
+ # Normalize the kernel and return
181
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
182
+ kernel = raw_kernel / np.sum(raw_kernel)
183
+ return kernel
184
+
185
+
186
+ def fspecial_gaussian(hsize, sigma):
187
+ hsize = [hsize, hsize]
188
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
189
+ std = sigma
190
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
191
+ arg = -(x * x + y * y) / (2 * std * std)
192
+ h = np.exp(arg)
193
+ h[h < scipy.finfo(float).eps * h.max()] = 0
194
+ sumh = h.sum()
195
+ if sumh != 0:
196
+ h = h / sumh
197
+ return h
198
+
199
+
200
+ def fspecial_laplacian(alpha):
201
+ alpha = max([0, min([alpha, 1])])
202
+ h1 = alpha / (alpha + 1)
203
+ h2 = (1 - alpha) / (alpha + 1)
204
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
205
+ h = np.array(h)
206
+ return h
207
+
208
+
209
+ def fspecial(filter_type, *args, **kwargs):
210
+ '''
211
+ python code from:
212
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
213
+ '''
214
+ if filter_type == 'gaussian':
215
+ return fspecial_gaussian(*args, **kwargs)
216
+ if filter_type == 'laplacian':
217
+ return fspecial_laplacian(*args, **kwargs)
218
+
219
+
220
+ """
221
+ # --------------------------------------------
222
+ # degradation models
223
+ # --------------------------------------------
224
+ """
225
+
226
+
227
+ def bicubic_degradation(x, sf=3):
228
+ '''
229
+ Args:
230
+ x: HxWxC image, [0, 1]
231
+ sf: down-scale factor
232
+ Return:
233
+ bicubicly downsampled LR image
234
+ '''
235
+ x = util.imresize_np(x, scale=1 / sf)
236
+ return x
237
+
238
+
239
+ def srmd_degradation(x, k, sf=3):
240
+ ''' blur + bicubic downsampling
241
+ Args:
242
+ x: HxWxC image, [0, 1]
243
+ k: hxw, double
244
+ sf: down-scale factor
245
+ Return:
246
+ downsampled LR image
247
+ Reference:
248
+ @inproceedings{zhang2018learning,
249
+ title={Learning a single convolutional super-resolution network for multiple degradations},
250
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
251
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
252
+ pages={3262--3271},
253
+ year={2018}
254
+ }
255
+ '''
256
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
257
+ x = bicubic_degradation(x, sf=sf)
258
+ return x
259
+
260
+
261
+ def dpsr_degradation(x, k, sf=3):
262
+ ''' bicubic downsampling + blur
263
+ Args:
264
+ x: HxWxC image, [0, 1]
265
+ k: hxw, double
266
+ sf: down-scale factor
267
+ Return:
268
+ downsampled LR image
269
+ Reference:
270
+ @inproceedings{zhang2019deep,
271
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
272
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
273
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
274
+ pages={1671--1681},
275
+ year={2019}
276
+ }
277
+ '''
278
+ x = bicubic_degradation(x, sf=sf)
279
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
280
+ return x
281
+
282
+
283
+ def classical_degradation(x, k, sf=3):
284
+ ''' blur + downsampling
285
+ Args:
286
+ x: HxWxC image, [0, 1]/[0, 255]
287
+ k: hxw, double
288
+ sf: down-scale factor
289
+ Return:
290
+ downsampled LR image
291
+ '''
292
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
293
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
294
+ st = 0
295
+ return x[st::sf, st::sf, ...]
296
+
297
+
298
+ def add_sharpening(img, weight=0.5, radius=50, threshold=10):
299
+ """USM sharpening. borrowed from real-ESRGAN
300
+ Input image: I; Blurry image: B.
301
+ 1. K = I + weight * (I - B)
302
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
303
+ 3. Blur mask:
304
+ 4. Out = Mask * K + (1 - Mask) * I
305
+ Args:
306
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
307
+ weight (float): Sharp weight. Default: 1.
308
+ radius (float): Kernel size of Gaussian blur. Default: 50.
309
+ threshold (int):
310
+ """
311
+ if radius % 2 == 0:
312
+ radius += 1
313
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
314
+ residual = img - blur
315
+ mask = np.abs(residual) * 255 > threshold
316
+ mask = mask.astype('float32')
317
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
318
+
319
+ K = img + weight * residual
320
+ K = np.clip(K, 0, 1)
321
+ return soft_mask * K + (1 - soft_mask) * img
322
+
323
+
324
+ def add_blur(img, sf=4):
325
+ wd2 = 4.0 + sf
326
+ wd = 2.0 + 0.2 * sf
327
+
328
+ wd2 = wd2/4
329
+ wd = wd/4
330
+
331
+ if random.random() < 0.5:
332
+ l1 = wd2 * random.random()
333
+ l2 = wd2 * random.random()
334
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
335
+ else:
336
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
337
+ img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
338
+
339
+ return img
340
+
341
+
342
+ def add_resize(img, sf=4):
343
+ rnum = np.random.rand()
344
+ if rnum > 0.8: # up
345
+ sf1 = random.uniform(1, 2)
346
+ elif rnum < 0.7: # down
347
+ sf1 = random.uniform(0.5 / sf, 1)
348
+ else:
349
+ sf1 = 1.0
350
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
351
+ img = np.clip(img, 0.0, 1.0)
352
+
353
+ return img
354
+
355
+
356
+ # def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
357
+ # noise_level = random.randint(noise_level1, noise_level2)
358
+ # rnum = np.random.rand()
359
+ # if rnum > 0.6: # add color Gaussian noise
360
+ # img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
361
+ # elif rnum < 0.4: # add grayscale Gaussian noise
362
+ # img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
363
+ # else: # add noise
364
+ # L = noise_level2 / 255.
365
+ # D = np.diag(np.random.rand(3))
366
+ # U = orth(np.random.rand(3, 3))
367
+ # conv = np.dot(np.dot(np.transpose(U), D), U)
368
+ # img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
369
+ # img = np.clip(img, 0.0, 1.0)
370
+ # return img
371
+
372
+ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
373
+ noise_level = random.randint(noise_level1, noise_level2)
374
+ rnum = np.random.rand()
375
+ if rnum > 0.6: # add color Gaussian noise
376
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
377
+ elif rnum < 0.4: # add grayscale Gaussian noise
378
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
379
+ else: # add noise
380
+ L = noise_level2 / 255.
381
+ D = np.diag(np.random.rand(3))
382
+ U = orth(np.random.rand(3, 3))
383
+ conv = np.dot(np.dot(np.transpose(U), D), U)
384
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
385
+ img = np.clip(img, 0.0, 1.0)
386
+ return img
387
+
388
+
389
+ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
390
+ noise_level = random.randint(noise_level1, noise_level2)
391
+ img = np.clip(img, 0.0, 1.0)
392
+ rnum = random.random()
393
+ if rnum > 0.6:
394
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
395
+ elif rnum < 0.4:
396
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
397
+ else:
398
+ L = noise_level2 / 255.
399
+ D = np.diag(np.random.rand(3))
400
+ U = orth(np.random.rand(3, 3))
401
+ conv = np.dot(np.dot(np.transpose(U), D), U)
402
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
403
+ img = np.clip(img, 0.0, 1.0)
404
+ return img
405
+
406
+
407
+ def add_Poisson_noise(img):
408
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
409
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
410
+ if random.random() < 0.5:
411
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
412
+ else:
413
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
414
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
415
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
416
+ img += noise_gray[:, :, np.newaxis]
417
+ img = np.clip(img, 0.0, 1.0)
418
+ return img
419
+
420
+
421
+ def add_JPEG_noise(img):
422
+ quality_factor = random.randint(80, 95)
423
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
424
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
425
+ img = cv2.imdecode(encimg, 1)
426
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
427
+ return img
428
+
429
+
430
+ def random_crop(lq, hq, sf=4, lq_patchsize=64):
431
+ h, w = lq.shape[:2]
432
+ rnd_h = random.randint(0, h - lq_patchsize)
433
+ rnd_w = random.randint(0, w - lq_patchsize)
434
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
435
+
436
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
437
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
438
+ return lq, hq
439
+
440
+
441
+ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
442
+ """
443
+ This is the degradation model of BSRGAN from the paper
444
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
445
+ ----------
446
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
447
+ sf: scale factor
448
+ isp_model: camera ISP model
449
+ Returns
450
+ -------
451
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
452
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
453
+ """
454
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
455
+ sf_ori = sf
456
+
457
+ h1, w1 = img.shape[:2]
458
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
459
+ h, w = img.shape[:2]
460
+
461
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
462
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
463
+
464
+ hq = img.copy()
465
+
466
+ if sf == 4 and random.random() < scale2_prob: # downsample1
467
+ if np.random.rand() < 0.5:
468
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
469
+ interpolation=random.choice([1, 2, 3]))
470
+ else:
471
+ img = util.imresize_np(img, 1 / 2, True)
472
+ img = np.clip(img, 0.0, 1.0)
473
+ sf = 2
474
+
475
+ shuffle_order = random.sample(range(7), 7)
476
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
477
+ if idx1 > idx2: # keep downsample3 last
478
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
479
+
480
+ for i in shuffle_order:
481
+
482
+ if i == 0:
483
+ img = add_blur(img, sf=sf)
484
+
485
+ elif i == 1:
486
+ img = add_blur(img, sf=sf)
487
+
488
+ elif i == 2:
489
+ a, b = img.shape[1], img.shape[0]
490
+ # downsample2
491
+ if random.random() < 0.75:
492
+ sf1 = random.uniform(1, 2 * sf)
493
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
494
+ interpolation=random.choice([1, 2, 3]))
495
+ else:
496
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
497
+ k_shifted = shift_pixel(k, sf)
498
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
499
+ img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
500
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
501
+ img = np.clip(img, 0.0, 1.0)
502
+
503
+ elif i == 3:
504
+ # downsample3
505
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
506
+ img = np.clip(img, 0.0, 1.0)
507
+
508
+ elif i == 4:
509
+ # add Gaussian noise
510
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
511
+
512
+ elif i == 5:
513
+ # add JPEG noise
514
+ if random.random() < jpeg_prob:
515
+ img = add_JPEG_noise(img)
516
+
517
+ elif i == 6:
518
+ # add processed camera sensor noise
519
+ if random.random() < isp_prob and isp_model is not None:
520
+ with torch.no_grad():
521
+ img, hq = isp_model.forward(img.copy(), hq)
522
+
523
+ # add final JPEG compression noise
524
+ img = add_JPEG_noise(img)
525
+
526
+ # random crop
527
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
528
+
529
+ return img, hq
530
+
531
+
532
+ # todo no isp_model?
533
+ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
534
+ """
535
+ This is the degradation model of BSRGAN from the paper
536
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
537
+ ----------
538
+ sf: scale factor
539
+ isp_model: camera ISP model
540
+ Returns
541
+ -------
542
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
543
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
544
+ """
545
+ image = util.uint2single(image)
546
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
547
+ sf_ori = sf
548
+
549
+ h1, w1 = image.shape[:2]
550
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
551
+ h, w = image.shape[:2]
552
+
553
+ hq = image.copy()
554
+
555
+ if sf == 4 and random.random() < scale2_prob: # downsample1
556
+ if np.random.rand() < 0.5:
557
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
558
+ interpolation=random.choice([1, 2, 3]))
559
+ else:
560
+ image = util.imresize_np(image, 1 / 2, True)
561
+ image = np.clip(image, 0.0, 1.0)
562
+ sf = 2
563
+
564
+ shuffle_order = random.sample(range(7), 7)
565
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
566
+ if idx1 > idx2: # keep downsample3 last
567
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
568
+
569
+ for i in shuffle_order:
570
+
571
+ if i == 0:
572
+ image = add_blur(image, sf=sf)
573
+
574
+ # elif i == 1:
575
+ # image = add_blur(image, sf=sf)
576
+
577
+ if i == 0:
578
+ pass
579
+
580
+ elif i == 2:
581
+ a, b = image.shape[1], image.shape[0]
582
+ # downsample2
583
+ if random.random() < 0.8:
584
+ sf1 = random.uniform(1, 2 * sf)
585
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
586
+ interpolation=random.choice([1, 2, 3]))
587
+ else:
588
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
589
+ k_shifted = shift_pixel(k, sf)
590
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
591
+ image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
592
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
593
+
594
+ image = np.clip(image, 0.0, 1.0)
595
+
596
+ elif i == 3:
597
+ # downsample3
598
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
599
+ image = np.clip(image, 0.0, 1.0)
600
+
601
+ elif i == 4:
602
+ # add Gaussian noise
603
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
604
+
605
+ elif i == 5:
606
+ # add JPEG noise
607
+ if random.random() < jpeg_prob:
608
+ image = add_JPEG_noise(image)
609
+ #
610
+ # elif i == 6:
611
+ # # add processed camera sensor noise
612
+ # if random.random() < isp_prob and isp_model is not None:
613
+ # with torch.no_grad():
614
+ # img, hq = isp_model.forward(img.copy(), hq)
615
+
616
+ # add final JPEG compression noise
617
+ image = add_JPEG_noise(image)
618
+ image = util.single2uint(image)
619
+ if up:
620
+ image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
621
+ example = {"image": image}
622
+ return example
623
+
624
+
625
+
626
+
627
+ if __name__ == '__main__':
628
+ print("hey")
629
+ img = util.imread_uint('utils/test.png', 3)
630
+ img = img[:448, :448]
631
+ h = img.shape[0] // 4
632
+ print("resizing to", h)
633
+ sf = 4
634
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
635
+ for i in range(20):
636
+ print(i)
637
+ img_hq = img
638
+ img_lq = deg_fn(img)["image"]
639
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
640
+ print(img_lq)
641
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
642
+ print(img_lq.shape)
643
+ print("bicubic", img_lq_bicubic.shape)
644
+ print(img_hq.shape)
645
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
646
+ interpolation=0)
647
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
648
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
649
+ interpolation=0)
650
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
651
+ util.imsave(img_concat, str(i) + '.png')
ldm/modules/image_degradation/utils/test.png ADDED