basso4 commited on
Commit
b6e2095
·
verified ·
1 Parent(s): ee24757

Upload 57 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ldm/__pycache__/util.cpython-38.pyc +0 -0
  2. ldm/data/.DS_Store +0 -0
  3. ldm/data/__init__.py +0 -0
  4. ldm/data/__pycache__/__init__.cpython-38.pyc +0 -0
  5. ldm/data/__pycache__/image_dresscode.cpython-38.pyc +0 -0
  6. ldm/data/__pycache__/image_vitonhd.cpython-38.pyc +0 -0
  7. ldm/data/__pycache__/viton-images.cpython-38.pyc +0 -0
  8. ldm/data/base.py +23 -0
  9. ldm/data/image_dresscode.py +86 -0
  10. ldm/data/image_vitonhd.py +198 -0
  11. ldm/data/imagenet.py +394 -0
  12. ldm/data/lsun.py +92 -0
  13. ldm/lr_scheduler.py +81 -0
  14. ldm/models/.DS_Store +0 -0
  15. ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  16. ldm/models/autoencoder.py +408 -0
  17. ldm/models/diffusion/.DS_Store +0 -0
  18. ldm/models/diffusion/__init__.py +0 -0
  19. ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  20. ldm/models/diffusion/__pycache__/control.cpython-38.pyc +0 -0
  21. ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc +0 -0
  22. ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc +0 -0
  23. ldm/models/diffusion/classifier.py +267 -0
  24. ldm/models/diffusion/control.py +460 -0
  25. ldm/models/diffusion/ddim.py +265 -0
  26. ldm/models/diffusion/ddpm.py +144 -0
  27. ldm/models/diffusion/plms.py +239 -0
  28. ldm/modules/.DS_Store +0 -0
  29. ldm/modules/__pycache__/attention.cpython-38.pyc +0 -0
  30. ldm/modules/__pycache__/x_transformer.cpython-38.pyc +0 -0
  31. ldm/modules/attention.py +345 -0
  32. ldm/modules/diffusionmodules/.DS_Store +0 -0
  33. ldm/modules/diffusionmodules/__init__.py +0 -0
  34. ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc +0 -0
  35. ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc +0 -0
  36. ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc +0 -0
  37. ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc +0 -0
  38. ldm/modules/diffusionmodules/model.py +835 -0
  39. ldm/modules/diffusionmodules/openaimodel.py +707 -0
  40. ldm/modules/diffusionmodules/util.py +255 -0
  41. ldm/modules/distributions/.DS_Store +0 -0
  42. ldm/modules/distributions/__init__.py +0 -0
  43. ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc +0 -0
  44. ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc +0 -0
  45. ldm/modules/distributions/distributions.py +92 -0
  46. ldm/modules/ema.py +76 -0
  47. ldm/modules/encoders/.DS_Store +0 -0
  48. ldm/modules/encoders/__init__.py +0 -0
  49. ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc +0 -0
  50. ldm/modules/encoders/__pycache__/modules.cpython-38.pyc +0 -0
ldm/__pycache__/util.cpython-38.pyc ADDED
Binary file (5.87 kB). View file
 
ldm/data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ldm/data/__init__.py ADDED
File without changes
ldm/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (138 Bytes). View file
 
ldm/data/__pycache__/image_dresscode.cpython-38.pyc ADDED
Binary file (2.52 kB). View file
 
ldm/data/__pycache__/image_vitonhd.cpython-38.pyc ADDED
Binary file (3.13 kB). View file
 
ldm/data/__pycache__/viton-images.cpython-38.pyc ADDED
Binary file (2.68 kB). View file
 
ldm/data/base.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3
+
4
+
5
+ class Txt2ImgIterableBaseDataset(IterableDataset):
6
+ '''
7
+ Define an interface to make the IterableDatasets for text2img data chainable
8
+ '''
9
+ def __init__(self, num_records=0, valid_ids=None, size=256):
10
+ super().__init__()
11
+ self.num_records = num_records
12
+ self.valid_ids = valid_ids
13
+ self.sample_ids = valid_ids
14
+ self.size = size
15
+
16
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17
+
18
+ def __len__(self):
19
+ return self.num_records
20
+
21
+ @abstractmethod
22
+ def __iter__(self):
23
+ pass
ldm/data/image_dresscode.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ import torch.utils.data as data
5
+ import torchvision.transforms.functional as F
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ class OpenImageDataset(data.Dataset):
10
+ def __init__(self, state, dataset_dir, type="paired"):
11
+ self.state = state # train or test
12
+ self.dataset_dir = dataset_dir # /home/sd/zjh/Dataset/DressCode
13
+
14
+ # 确定状态
15
+ if state == "train":
16
+ self.dataset_file = os.path.join(dataset_dir, "train_pairs.txt")
17
+ if state == "test":
18
+ if type == "unpaired":
19
+ self.dataset_file = os.path.join(dataset_dir, "test_pairs_unpaired.txt")
20
+ if type == "paired":
21
+ self.dataset_file = os.path.join(dataset_dir, "test_pairs_paired.txt")
22
+
23
+ # 加载数据集
24
+ self.people_list = []
25
+ self.clothes_list = []
26
+ with open(self.dataset_file, 'r') as f:
27
+ for line in f.readlines():
28
+ people, clothes, category = line.strip().split()
29
+ if category == "0":
30
+ category = "upper_body"
31
+ elif category == "1":
32
+ category = "lower_body"
33
+ elif category == "2":
34
+ category = "dresses"
35
+ people_path = os.path.join(self.dataset_dir, category, "images", people)
36
+ clothes_path = os.path.join(self.dataset_dir, category, "images", clothes)
37
+ self.people_list.append(people_path)
38
+ self.clothes_list.append(clothes_path)
39
+
40
+
41
+ def __len__(self):
42
+ return len(self.people_list)
43
+
44
+ def __getitem__(self, index):
45
+ people_path = self.people_list[index]
46
+ # /home/sd/zjh/Dataset/DressCode/upper_body/images/000000_0.jpg
47
+ clothes_path = self.clothes_list[index]
48
+ # /home/sd/zjh/Dataset/DressCode/upper_body/images/000000_1.jpg
49
+ dense_path = people_path.replace("images", "dense")[:-5] + "5_uv.npz"
50
+ # /home/sd/zjh/Dataset/DressCode/upper_body/dense/000000_5_uv.npz
51
+ mask_path = people_path.replace("images", "mask")[:-3] + "png"
52
+ # /home/sd/Harddisk/zjh/DressCode/upper_body/mask/000000_0.png
53
+
54
+ # 加载图像
55
+ img = Image.open(people_path).convert("RGB").resize((512, 512))
56
+ img = torchvision.transforms.ToTensor()(img)
57
+ refernce = Image.open(clothes_path).convert("RGB").resize((224, 224))
58
+ refernce = torchvision.transforms.ToTensor()(refernce)
59
+ mask = Image.open(mask_path).convert("L").resize((512, 512))
60
+ mask = torchvision.transforms.ToTensor()(mask)
61
+ mask = 1-mask
62
+ densepose = np.load(dense_path)
63
+ densepose = torch.from_numpy(densepose['uv'])
64
+ densepose = torch.nn.functional.interpolate(densepose.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=True).squeeze(0)
65
+
66
+ # 正则化
67
+ img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
68
+ refernce = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
69
+ (0.26862954, 0.26130258, 0.27577711))(refernce)
70
+ # densepose = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(densepose)
71
+
72
+ # 生成 inpaint 和 hint
73
+ inpaint = img * mask
74
+ hint = torchvision.transforms.Resize((512, 512))(refernce)
75
+ hint = torch.cat((hint,densepose),dim = 0)
76
+
77
+
78
+ return {"GT": img, # [3, 512, 512]
79
+ "inpaint_image": inpaint, # [3, 512, 512]
80
+ "inpaint_mask": mask, # [1, 512, 512]
81
+ "ref_imgs": refernce, # [3, 224, 224]
82
+ "hint": hint # [5, 512, 512]
83
+ }
84
+
85
+
86
+
ldm/data/image_vitonhd.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import torch
5
+ import torchvision
6
+ from torchvision import transforms
7
+ import torch.utils.data as data
8
+ import torchvision.transforms.functional as TF
9
+ from PIL import Image
10
+ from typing import Literal, Tuple,List
11
+
12
+
13
+ class OpenImageDataset(data.Dataset):
14
+ def __init__(
15
+ self,
16
+ state: Literal["train", "test"],
17
+ dataset_dir: str,
18
+ type: Literal["paired", "unpaired"] = "paired",
19
+ ):
20
+ self.state=state
21
+ self.dataset_dir = dataset_dir
22
+ self.flip_transform = transforms.RandomHorizontalFlip(p=1)
23
+
24
+ with open(
25
+ os.path.join(dataset_dir, state, "vitonhd_" + state + "_tagged.json"), "r"
26
+ ) as file1:
27
+ data1 = json.load(file1)
28
+
29
+ annotation_list = [
30
+ # "colors",
31
+ # "textures",
32
+ "sleeveLength",
33
+ "neckLine",
34
+ "item",
35
+ ]
36
+
37
+ self.annotations_pair = {}
38
+ for k, v in data1.items():
39
+ for elem in v:
40
+ annotation_str = ""
41
+ for template in annotation_list:
42
+ for tag in elem["tag_info"]:
43
+ if (
44
+ tag["tag_name"] == template
45
+ and tag["tag_category"] is not None
46
+ ):
47
+ annotation_str += tag["tag_category"]
48
+ annotation_str += " "
49
+ self.annotations_pair[elem["file_name"]] = annotation_str
50
+
51
+
52
+ im_names = []
53
+ c_names = []
54
+
55
+ if state == "train":
56
+ filename = os.path.join(dataset_dir, f"{state}_pairs.txt")
57
+ else:
58
+ filename = os.path.join(dataset_dir, f"{state}_pairs.txt")
59
+
60
+ with open(filename, "r") as f:
61
+ for line in f.readlines():
62
+ if state == "train":
63
+ im_name, _ = line.strip().split()
64
+ c_name = im_name
65
+ else:
66
+ if type == "paired":
67
+ im_name, _ = line.strip().split()
68
+ c_name = im_name
69
+ else:
70
+ im_name, c_name = line.strip().split()
71
+
72
+ im_names.append(im_name)
73
+ c_names.append(c_name)
74
+
75
+ self.im_names = im_names
76
+ self.c_names = c_names
77
+
78
+ def __len__(self):
79
+ return len(self.im_names)
80
+
81
+ def __getitem__(self, index):
82
+ c_name = self.c_names[index]
83
+ im_name = self.im_names[index]
84
+
85
+ if c_name in self.annotations_pair:
86
+ cloth_annotation = self.annotations_pair[c_name]
87
+ else:
88
+ cloth_annotation = "shirts"
89
+
90
+ # 确定路径
91
+ img_path = os.path.join(self.dataset_dir, self.state, "image", im_name)
92
+ reference_path = os.path.join(self.dataset_dir, self.state, "cloth", c_name)
93
+ mask_path = os.path.join(self.dataset_dir, self.state, "agnostic-mask", im_name[:-4]+"_mask.png")
94
+ densepose_path = os.path.join(self.dataset_dir, self.state, "image-densepose", im_name)
95
+
96
+ # 加载图像
97
+ img = Image.open(img_path).convert("RGB").resize((512, 512))
98
+ img = torchvision.transforms.ToTensor()(img)
99
+ reference = Image.open(reference_path).convert("RGB").resize((224, 224))
100
+ reference = torchvision.transforms.ToTensor()(reference)
101
+ mask = Image.open(mask_path).convert("L").resize((512, 512))
102
+ mask = torchvision.transforms.ToTensor()(mask)
103
+ mask = 1-mask
104
+ densepose = Image.open(densepose_path).convert("RGB").resize((512, 512))
105
+ densepose = torchvision.transforms.ToTensor()(densepose)
106
+
107
+
108
+ #Data augmentation for training phase
109
+ if self.state == "train":
110
+ #Random horizontal flip
111
+ if random.random() > 0.5:
112
+ img = self.flip_transform(img)
113
+ mask = self.flip_transform(mask)
114
+ densepose = self.flip_transform(densepose)
115
+ reference = self.flip_transform(reference)
116
+
117
+ #Color jittering
118
+ if random.random() > 0.5:
119
+ color_jitter = transforms.ColorJitter(brightness=0.5, contrast=0.3, saturation=0.5, hue=0.5)
120
+ fn_idx, b, c, s, h = transforms.ColorJitter.get_params(color_jitter.brightness, color_jitter.contrast, color_jitter.saturation, color_jitter.hue)
121
+
122
+ img = TF.adjust_contrast(img, c)
123
+ img = TF.adjust_brightness(img, b)
124
+ img = TF.adjust_hue(img, h)
125
+ img = TF.adjust_saturation(img, s)
126
+
127
+ reference = TF.adjust_contrast(reference, c)
128
+ reference = TF.adjust_brightness(reference, b)
129
+ reference = TF.adjust_hue(reference, h)
130
+ reference = TF.adjust_saturation(reference, s)
131
+
132
+ #Scaling and shifting
133
+ if random.random() > 0.5:
134
+ scale_val = random.uniform(0.8, 1.2)
135
+ img = transforms.functional.affine(
136
+ img, angle=0, translate=[0, 0], scale=scale_val, shear=0
137
+ )
138
+ mask = transforms.functional.affine(
139
+ mask, angle=0, translate=[0, 0], scale=scale_val, shear=0
140
+ )
141
+ densepose = transforms.functional.affine(
142
+ densepose, angle=0, translate=[0, 0], scale=scale_val, shear=0
143
+ )
144
+
145
+ if random.random() > 0.5:
146
+ shift_valx = random.uniform(-0.2, 0.2)
147
+ shift_valy = random.uniform(-0.2, 0.2)
148
+ img = transforms.functional.affine(
149
+ img,
150
+ angle=0,
151
+ translate=[shift_valx * img.shape[-1], shift_valy * img.shape[-2]],
152
+ scale=1,
153
+ shear=0
154
+ )
155
+ mask = transforms.functional.affine(
156
+ mask,
157
+ angle=0,
158
+ translate=[shift_valx * mask.shape[-1], shift_valy * mask.shape[-2]],
159
+ scale=1,
160
+ shear=0
161
+ )
162
+ densepose = transforms.functional.affine(
163
+ densepose,
164
+ angle=0,
165
+ translate=[
166
+ shift_valx * densepose.shape[-1],
167
+ shift_valy * densepose.shape[-2]
168
+ ],
169
+ scale=1,
170
+ shear=0
171
+ )
172
+
173
+
174
+
175
+ # 正则化
176
+ img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
177
+ reference = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
178
+ (0.26862954, 0.26130258, 0.27577711))(reference)
179
+ densepose = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(densepose)
180
+
181
+ # 生成 inpaint 和 hint
182
+ inpaint = img * mask
183
+ hint = torchvision.transforms.Resize((512, 512))(reference)
184
+ hint = torch.cat((hint,densepose),dim = 0)
185
+
186
+ cloth_annotation = "a photo of " + cloth_annotation
187
+
188
+ return {"GT": img, # [3, 512, 512]
189
+ "inpaint_image": inpaint, # [3, 512, 512]
190
+ "inpaint_mask": mask, # [1, 512, 512]
191
+ "ref_imgs": reference, # [3, 224, 224]
192
+ "hint": hint, # [6, 512, 512]
193
+ "caption_cloth": cloth_annotation,
194
+ # "caption": "model is wearing " + cloth_annotation,
195
+ }
196
+
197
+
198
+
ldm/data/imagenet.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml, pickle, shutil, tarfile, glob
2
+ import cv2
3
+ import albumentations
4
+ import PIL
5
+ import numpy as np
6
+ import torchvision.transforms.functional as TF
7
+ from omegaconf import OmegaConf
8
+ from functools import partial
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, Subset
12
+
13
+ import taming.data.utils as tdu
14
+ from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
+ from taming.data.imagenet import ImagePaths
16
+
17
+ from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
+
19
+
20
+ def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
+ with open(path_to_yaml) as f:
22
+ di2s = yaml.load(f)
23
+ return dict((v,k) for k,v in di2s.items())
24
+
25
+
26
+ class ImageNetBase(Dataset):
27
+ def __init__(self, config=None):
28
+ self.config = config or OmegaConf.create()
29
+ if not type(self.config)==dict:
30
+ self.config = OmegaConf.to_container(self.config)
31
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
+ self._prepare()
34
+ self._prepare_synset_to_human()
35
+ self._prepare_idx_to_synset()
36
+ self._prepare_human_to_integer_label()
37
+ self._load()
38
+
39
+ def __len__(self):
40
+ return len(self.data)
41
+
42
+ def __getitem__(self, i):
43
+ return self.data[i]
44
+
45
+ def _prepare(self):
46
+ raise NotImplementedError()
47
+
48
+ def _filter_relpaths(self, relpaths):
49
+ ignore = set([
50
+ "n06596364_9591.JPEG",
51
+ ])
52
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
+ if "sub_indices" in self.config:
54
+ indices = str_to_indices(self.config["sub_indices"])
55
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
+ files = []
58
+ for rpath in relpaths:
59
+ syn = rpath.split("/")[0]
60
+ if syn in synsets:
61
+ files.append(rpath)
62
+ return files
63
+ else:
64
+ return relpaths
65
+
66
+ def _prepare_synset_to_human(self):
67
+ SIZE = 2655750
68
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
70
+ if (not os.path.exists(self.human_dict) or
71
+ not os.path.getsize(self.human_dict)==SIZE):
72
+ download(URL, self.human_dict)
73
+
74
+ def _prepare_idx_to_synset(self):
75
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
+ if (not os.path.exists(self.idx2syn)):
78
+ download(URL, self.idx2syn)
79
+
80
+ def _prepare_human_to_integer_label(self):
81
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
+ if (not os.path.exists(self.human2integer)):
84
+ download(URL, self.human2integer)
85
+ with open(self.human2integer, "r") as f:
86
+ lines = f.read().splitlines()
87
+ assert len(lines) == 1000
88
+ self.human2integer_dict = dict()
89
+ for line in lines:
90
+ value, key = line.split(":")
91
+ self.human2integer_dict[key] = int(value)
92
+
93
+ def _load(self):
94
+ with open(self.txt_filelist, "r") as f:
95
+ self.relpaths = f.read().splitlines()
96
+ l1 = len(self.relpaths)
97
+ self.relpaths = self._filter_relpaths(self.relpaths)
98
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
+
100
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
101
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
+
103
+ unique_synsets = np.unique(self.synsets)
104
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
+ if not self.keep_orig_class_label:
106
+ self.class_labels = [class_dict[s] for s in self.synsets]
107
+ else:
108
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
+
110
+ with open(self.human_dict, "r") as f:
111
+ human_dict = f.read().splitlines()
112
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
+
114
+ self.human_labels = [human_dict[s] for s in self.synsets]
115
+
116
+ labels = {
117
+ "relpath": np.array(self.relpaths),
118
+ "synsets": np.array(self.synsets),
119
+ "class_label": np.array(self.class_labels),
120
+ "human_label": np.array(self.human_labels),
121
+ }
122
+
123
+ if self.process_images:
124
+ self.size = retrieve(self.config, "size", default=256)
125
+ self.data = ImagePaths(self.abspaths,
126
+ labels=labels,
127
+ size=self.size,
128
+ random_crop=self.random_crop,
129
+ )
130
+ else:
131
+ self.data = self.abspaths
132
+
133
+
134
+ class ImageNetTrain(ImageNetBase):
135
+ NAME = "ILSVRC2012_train"
136
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
+ FILES = [
139
+ "ILSVRC2012_img_train.tar",
140
+ ]
141
+ SIZES = [
142
+ 147897477120,
143
+ ]
144
+
145
+ def __init__(self, process_images=True, data_root=None, **kwargs):
146
+ self.process_images = process_images
147
+ self.data_root = data_root
148
+ super().__init__(**kwargs)
149
+
150
+ def _prepare(self):
151
+ if self.data_root:
152
+ self.root = os.path.join(self.data_root, self.NAME)
153
+ else:
154
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
+
157
+ self.datadir = os.path.join(self.root, "data")
158
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
+ self.expected_length = 1281167
160
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
+ default=True)
162
+ if not tdu.is_prepared(self.root):
163
+ # prep
164
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
+
166
+ datadir = self.datadir
167
+ if not os.path.exists(datadir):
168
+ path = os.path.join(self.root, self.FILES[0])
169
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
+ import academictorrents as at
171
+ atpath = at.get(self.AT_HASH, datastore=self.root)
172
+ assert atpath == path
173
+
174
+ print("Extracting {} to {}".format(path, datadir))
175
+ os.makedirs(datadir, exist_ok=True)
176
+ with tarfile.open(path, "r:") as tar:
177
+ tar.extractall(path=datadir)
178
+
179
+ print("Extracting sub-tars.")
180
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
+ for subpath in tqdm(subpaths):
182
+ subdir = subpath[:-len(".tar")]
183
+ os.makedirs(subdir, exist_ok=True)
184
+ with tarfile.open(subpath, "r:") as tar:
185
+ tar.extractall(path=subdir)
186
+
187
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
+ filelist = sorted(filelist)
190
+ filelist = "\n".join(filelist)+"\n"
191
+ with open(self.txt_filelist, "w") as f:
192
+ f.write(filelist)
193
+
194
+ tdu.mark_prepared(self.root)
195
+
196
+
197
+ class ImageNetValidation(ImageNetBase):
198
+ NAME = "ILSVRC2012_validation"
199
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
+ FILES = [
203
+ "ILSVRC2012_img_val.tar",
204
+ "validation_synset.txt",
205
+ ]
206
+ SIZES = [
207
+ 6744924160,
208
+ 1950000,
209
+ ]
210
+
211
+ def __init__(self, process_images=True, data_root=None, **kwargs):
212
+ self.data_root = data_root
213
+ self.process_images = process_images
214
+ super().__init__(**kwargs)
215
+
216
+ def _prepare(self):
217
+ if self.data_root:
218
+ self.root = os.path.join(self.data_root, self.NAME)
219
+ else:
220
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
+ self.datadir = os.path.join(self.root, "data")
223
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
+ self.expected_length = 50000
225
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
+ default=False)
227
+ if not tdu.is_prepared(self.root):
228
+ # prep
229
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
+
231
+ datadir = self.datadir
232
+ if not os.path.exists(datadir):
233
+ path = os.path.join(self.root, self.FILES[0])
234
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
+ import academictorrents as at
236
+ atpath = at.get(self.AT_HASH, datastore=self.root)
237
+ assert atpath == path
238
+
239
+ print("Extracting {} to {}".format(path, datadir))
240
+ os.makedirs(datadir, exist_ok=True)
241
+ with tarfile.open(path, "r:") as tar:
242
+ tar.extractall(path=datadir)
243
+
244
+ vspath = os.path.join(self.root, self.FILES[1])
245
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
+ download(self.VS_URL, vspath)
247
+
248
+ with open(vspath, "r") as f:
249
+ synset_dict = f.read().splitlines()
250
+ synset_dict = dict(line.split() for line in synset_dict)
251
+
252
+ print("Reorganizing into synset folders")
253
+ synsets = np.unique(list(synset_dict.values()))
254
+ for s in synsets:
255
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
+ for k, v in synset_dict.items():
257
+ src = os.path.join(datadir, k)
258
+ dst = os.path.join(datadir, v)
259
+ shutil.move(src, dst)
260
+
261
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
+ filelist = sorted(filelist)
264
+ filelist = "\n".join(filelist)+"\n"
265
+ with open(self.txt_filelist, "w") as f:
266
+ f.write(filelist)
267
+
268
+ tdu.mark_prepared(self.root)
269
+
270
+
271
+
272
+ class ImageNetSR(Dataset):
273
+ def __init__(self, size=None,
274
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
+ random_crop=True):
276
+ """
277
+ Imagenet Superresolution Dataloader
278
+ Performs following ops in order:
279
+ 1. crops a crop of size s from image either as random or center crop
280
+ 2. resizes crop to size with cv2.area_interpolation
281
+ 3. degrades resized crop with degradation_fn
282
+
283
+ :param size: resizing to size after cropping
284
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
+ :param downscale_f: Low Resolution Downsample factor
286
+ :param min_crop_f: determines crop size s,
287
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
+ :param max_crop_f: ""
289
+ :param data_root:
290
+ :param random_crop:
291
+ """
292
+ self.base = self.get_base()
293
+ assert size
294
+ assert (size / downscale_f).is_integer()
295
+ self.size = size
296
+ self.LR_size = int(size / downscale_f)
297
+ self.min_crop_f = min_crop_f
298
+ self.max_crop_f = max_crop_f
299
+ assert(max_crop_f <= 1.)
300
+ self.center_crop = not random_crop
301
+
302
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
+
304
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
+
306
+ if degradation == "bsrgan":
307
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
+
309
+ elif degradation == "bsrgan_light":
310
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
+
312
+ else:
313
+ interpolation_fn = {
314
+ "cv_nearest": cv2.INTER_NEAREST,
315
+ "cv_bilinear": cv2.INTER_LINEAR,
316
+ "cv_bicubic": cv2.INTER_CUBIC,
317
+ "cv_area": cv2.INTER_AREA,
318
+ "cv_lanczos": cv2.INTER_LANCZOS4,
319
+ "pil_nearest": PIL.Image.NEAREST,
320
+ "pil_bilinear": PIL.Image.BILINEAR,
321
+ "pil_bicubic": PIL.Image.BICUBIC,
322
+ "pil_box": PIL.Image.BOX,
323
+ "pil_hamming": PIL.Image.HAMMING,
324
+ "pil_lanczos": PIL.Image.LANCZOS,
325
+ }[degradation]
326
+
327
+ self.pil_interpolation = degradation.startswith("pil_")
328
+
329
+ if self.pil_interpolation:
330
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
+
332
+ else:
333
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
+ interpolation=interpolation_fn)
335
+
336
+ def __len__(self):
337
+ return len(self.base)
338
+
339
+ def __getitem__(self, i):
340
+ example = self.base[i]
341
+ image = Image.open(example["file_path_"])
342
+
343
+ if not image.mode == "RGB":
344
+ image = image.convert("RGB")
345
+
346
+ image = np.array(image).astype(np.uint8)
347
+
348
+ min_side_len = min(image.shape[:2])
349
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
+ crop_side_len = int(crop_side_len)
351
+
352
+ if self.center_crop:
353
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
+
355
+ else:
356
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
+
358
+ image = self.cropper(image=image)["image"]
359
+ image = self.image_rescaler(image=image)["image"]
360
+
361
+ if self.pil_interpolation:
362
+ image_pil = PIL.Image.fromarray(image)
363
+ LR_image = self.degradation_process(image_pil)
364
+ LR_image = np.array(LR_image).astype(np.uint8)
365
+
366
+ else:
367
+ LR_image = self.degradation_process(image=image)["image"]
368
+
369
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
+
372
+ return example
373
+
374
+
375
+ class ImageNetSRTrain(ImageNetSR):
376
+ def __init__(self, **kwargs):
377
+ super().__init__(**kwargs)
378
+
379
+ def get_base(self):
380
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
+ indices = pickle.load(f)
382
+ dset = ImageNetTrain(process_images=False,)
383
+ return Subset(dset, indices)
384
+
385
+
386
+ class ImageNetSRValidation(ImageNetSR):
387
+ def __init__(self, **kwargs):
388
+ super().__init__(**kwargs)
389
+
390
+ def get_base(self):
391
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
+ indices = pickle.load(f)
393
+ dset = ImageNetValidation(process_images=False,)
394
+ return Subset(dset, indices)
ldm/data/lsun.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+
8
+
9
+ class LSUNBase(Dataset):
10
+ def __init__(self,
11
+ txt_file,
12
+ data_root,
13
+ size=None,
14
+ interpolation="bicubic",
15
+ flip_p=0.5
16
+ ):
17
+ self.data_paths = txt_file
18
+ self.data_root = data_root
19
+ with open(self.data_paths, "r") as f:
20
+ self.image_paths = f.read().splitlines()
21
+ self._length = len(self.image_paths)
22
+ self.labels = {
23
+ "relative_file_path_": [l for l in self.image_paths],
24
+ "file_path_": [os.path.join(self.data_root, l)
25
+ for l in self.image_paths],
26
+ }
27
+
28
+ self.size = size
29
+ self.interpolation = {"linear": PIL.Image.LINEAR,
30
+ "bilinear": PIL.Image.BILINEAR,
31
+ "bicubic": PIL.Image.BICUBIC,
32
+ "lanczos": PIL.Image.LANCZOS,
33
+ }[interpolation]
34
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35
+
36
+ def __len__(self):
37
+ return self._length
38
+
39
+ def __getitem__(self, i):
40
+ example = dict((k, self.labels[k][i]) for k in self.labels)
41
+ image = Image.open(example["file_path_"])
42
+ if not image.mode == "RGB":
43
+ image = image.convert("RGB")
44
+
45
+ # default to score-sde preprocessing
46
+ img = np.array(image).astype(np.uint8)
47
+ crop = min(img.shape[0], img.shape[1])
48
+ h, w, = img.shape[0], img.shape[1]
49
+ img = img[(h - crop) // 2:(h + crop) // 2,
50
+ (w - crop) // 2:(w + crop) // 2]
51
+
52
+ image = Image.fromarray(img)
53
+ if self.size is not None:
54
+ image = image.resize((self.size, self.size), resample=self.interpolation)
55
+
56
+ image = self.flip(image)
57
+ image = np.array(image).astype(np.uint8)
58
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59
+ return example
60
+
61
+
62
+ class LSUNChurchesTrain(LSUNBase):
63
+ def __init__(self, **kwargs):
64
+ super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65
+
66
+
67
+ class LSUNChurchesValidation(LSUNBase):
68
+ def __init__(self, flip_p=0., **kwargs):
69
+ super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70
+ flip_p=flip_p, **kwargs)
71
+
72
+
73
+ class LSUNBedroomsTrain(LSUNBase):
74
+ def __init__(self, **kwargs):
75
+ super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76
+
77
+
78
+ class LSUNBedroomsValidation(LSUNBase):
79
+ def __init__(self, flip_p=0.0, **kwargs):
80
+ super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81
+ flip_p=flip_p, **kwargs)
82
+
83
+
84
+ class LSUNCatsTrain(LSUNBase):
85
+ def __init__(self, **kwargs):
86
+ super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87
+
88
+
89
+ class LSUNCatsValidation(LSUNBase):
90
+ def __init__(self, flip_p=0., **kwargs):
91
+ super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92
+ flip_p=flip_p, **kwargs)
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps # [10000]
44
+ self.f_start = f_start # 1.e-6
45
+ self.f_min = f_min # 1.
46
+ self.f_max = f_max # 1.
47
+ self.cycle_lengths = cycle_lengths # [10000000000000]
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval # 0
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+
60
+ def __call__(self, n, **kwargs):
61
+ return self.schedule(n, **kwargs)
62
+
63
+
64
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
65
+
66
+ def schedule(self, n, **kwargs):
67
+ cycle = self.find_in_interval(n)
68
+ n = n - self.cum_cycles[cycle]
69
+ if self.verbosity_interval > 0:
70
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
71
+ f"current cycle {cycle}")
72
+
73
+ if n < self.lr_warm_up_steps[cycle]:
74
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
75
+ self.last_f = f
76
+ return f
77
+ else:
78
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
79
+ self.last_f = f
80
+ return f
81
+
ldm/models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ldm/models/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
+
8
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
+
11
+ from ldm.util import instantiate_from_config
12
+
13
+
14
+ class VQModel(pl.LightningModule):
15
+ def __init__(self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ batch_resize_range=None,
26
+ scheduler_config=None,
27
+ lr_g_factor=1.0,
28
+ remap=None,
29
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
+ use_ema=False
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.n_embed = n_embed
35
+ self.image_key = image_key
36
+ self.encoder = Encoder(**ddconfig)
37
+ self.decoder = Decoder(**ddconfig)
38
+ self.loss = instantiate_from_config(lossconfig)
39
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
+ remap=remap,
41
+ sane_index_shape=sane_index_shape)
42
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
+ if colorize_nlabels is not None:
45
+ assert type(colorize_nlabels)==int
46
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
+ if monitor is not None:
48
+ self.monitor = monitor
49
+ self.batch_resize_range = batch_resize_range
50
+ if self.batch_resize_range is not None:
51
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self)
56
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+
58
+ if ckpt_path is not None:
59
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
+ self.scheduler_config = scheduler_config
61
+ self.lr_g_factor = lr_g_factor
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def init_from_ckpt(self, path, ignore_keys=list()):
79
+ sd = torch.load(path, map_location="cpu")["state_dict"]
80
+ keys = list(sd.keys())
81
+ for k in keys:
82
+ for ik in ignore_keys:
83
+ if k.startswith(ik):
84
+ print("Deleting key {} from state_dict.".format(k))
85
+ del sd[k]
86
+ missing, unexpected = self.load_state_dict(sd, strict=False)
87
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
+ if len(missing) > 0:
89
+ print(f"Missing Keys: {missing}")
90
+ print(f"Unexpected Keys: {unexpected}")
91
+
92
+ def on_train_batch_end(self, *args, **kwargs):
93
+ if self.use_ema:
94
+ self.model_ema(self)
95
+
96
+ def encode(self, x):
97
+ h = self.encoder(x)
98
+ h = self.quant_conv(h)
99
+ quant, emb_loss, info = self.quantize(h)
100
+ return quant, emb_loss, info
101
+
102
+ def encode_to_prequant(self, x):
103
+ h = self.encoder(x)
104
+ h = self.quant_conv(h)
105
+ return h
106
+
107
+ def decode(self, quant):
108
+ quant = self.post_quant_conv(quant)
109
+ dec = self.decoder(quant)
110
+ return dec
111
+
112
+ def decode_code(self, code_b):
113
+ quant_b = self.quantize.embed_code(code_b)
114
+ dec = self.decode(quant_b)
115
+ return dec
116
+
117
+ def forward(self, input, return_pred_indices=False):
118
+ quant, diff, (_,_,ind) = self.encode(input)
119
+ dec = self.decode(quant)
120
+ if return_pred_indices:
121
+ return dec, diff, ind
122
+ return dec, diff
123
+
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
+ if self.batch_resize_range is not None:
130
+ lower_size = self.batch_resize_range[0]
131
+ upper_size = self.batch_resize_range[1]
132
+ if self.global_step <= 4:
133
+ # do the first few batches with max size to avoid later oom
134
+ new_resize = upper_size
135
+ else:
136
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
+ if new_resize != x.shape[2]:
138
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
139
+ x = x.detach()
140
+ return x
141
+
142
+ def training_step(self, batch, batch_idx, optimizer_idx):
143
+ # https://github.com/pytorch/pytorch/issues/37142
144
+ # try not to fool the heuristics
145
+ x = self.get_input(batch, self.image_key)
146
+ xrec, qloss, ind = self(x, return_pred_indices=True)
147
+
148
+ if optimizer_idx == 0:
149
+ # autoencode
150
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
+ last_layer=self.get_last_layer(), split="train",
152
+ predicted_indices=ind)
153
+
154
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
+ return aeloss
156
+
157
+ if optimizer_idx == 1:
158
+ # discriminator
159
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
+ last_layer=self.get_last_layer(), split="train")
161
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
+ return discloss
163
+
164
+ def validation_step(self, batch, batch_idx):
165
+ log_dict = self._validation_step(batch, batch_idx)
166
+ with self.ema_scope():
167
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
+ return log_dict
169
+
170
+ def _validation_step(self, batch, batch_idx, suffix=""):
171
+ x = self.get_input(batch, self.image_key)
172
+ xrec, qloss, ind = self(x, return_pred_indices=True)
173
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
+ self.global_step,
175
+ last_layer=self.get_last_layer(),
176
+ split="val"+suffix,
177
+ predicted_indices=ind
178
+ )
179
+
180
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
+ self.global_step,
182
+ last_layer=self.get_last_layer(),
183
+ split="val"+suffix,
184
+ predicted_indices=ind
185
+ )
186
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
+ self.log(f"val{suffix}/rec_loss", rec_loss,
188
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
+ self.log(f"val{suffix}/aeloss", aeloss,
190
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
+ del log_dict_ae[f"val{suffix}/rec_loss"]
193
+ self.log_dict(log_dict_ae)
194
+ self.log_dict(log_dict_disc)
195
+ return self.log_dict
196
+
197
+ def configure_optimizers(self):
198
+ lr_d = self.learning_rate
199
+ lr_g = self.lr_g_factor*self.learning_rate
200
+ print("lr_d", lr_d)
201
+ print("lr_g", lr_g)
202
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203
+ list(self.decoder.parameters())+
204
+ list(self.quantize.parameters())+
205
+ list(self.quant_conv.parameters())+
206
+ list(self.post_quant_conv.parameters()),
207
+ lr=lr_g, betas=(0.5, 0.9))
208
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209
+ lr=lr_d, betas=(0.5, 0.9))
210
+
211
+ if self.scheduler_config is not None:
212
+ scheduler = instantiate_from_config(self.scheduler_config)
213
+
214
+ print("Setting up LambdaLR scheduler...")
215
+ scheduler = [
216
+ {
217
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218
+ 'interval': 'step',
219
+ 'frequency': 1
220
+ },
221
+ {
222
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223
+ 'interval': 'step',
224
+ 'frequency': 1
225
+ },
226
+ ]
227
+ return [opt_ae, opt_disc], scheduler
228
+ return [opt_ae, opt_disc], []
229
+
230
+ def get_last_layer(self):
231
+ return self.decoder.conv_out.weight
232
+
233
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234
+ log = dict()
235
+ x = self.get_input(batch, self.image_key)
236
+ x = x.to(self.device)
237
+ if only_inputs:
238
+ log["inputs"] = x
239
+ return log
240
+ xrec, _ = self(x)
241
+ if x.shape[1] > 3:
242
+ # colorize with random projection
243
+ assert xrec.shape[1] > 3
244
+ x = self.to_rgb(x)
245
+ xrec = self.to_rgb(xrec)
246
+ log["inputs"] = x
247
+ log["reconstructions"] = xrec
248
+ if plot_ema:
249
+ with self.ema_scope():
250
+ xrec_ema, _ = self(x)
251
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252
+ log["reconstructions_ema"] = xrec_ema
253
+ return log
254
+
255
+ def to_rgb(self, x):
256
+ assert self.image_key == "segmentation"
257
+ if not hasattr(self, "colorize"):
258
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259
+ x = F.conv2d(x, weight=self.colorize)
260
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261
+ return x
262
+
263
+
264
+ class VQModelInterface(VQModel):
265
+ def __init__(self, embed_dim, *args, **kwargs):
266
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
267
+ self.embed_dim = embed_dim
268
+
269
+ def encode(self, x):
270
+ h = self.encoder(x)
271
+ h = self.quant_conv(h)
272
+ return h
273
+
274
+ def decode(self, h, force_not_quantize=False):
275
+ # also go through quantization layer
276
+ if not force_not_quantize:
277
+ quant, emb_loss, info = self.quantize(h)
278
+ else:
279
+ quant = h
280
+ quant = self.post_quant_conv(quant)
281
+ dec = self.decoder(quant)
282
+ return dec
283
+
284
+
285
+ class AutoencoderKL(pl.LightningModule):
286
+ def __init__(self,
287
+ embed_dim, # 4
288
+ monitor, # "val/rec_loss"
289
+ ddconfig, # {...}
290
+ lossconfig, # {target: torch.nn.Identity}
291
+ ckpt_path=None,
292
+ ignore_keys=[],
293
+ image_key="image",
294
+ colorize_nlabels=None):
295
+ super().__init__()
296
+ self.image_key = image_key # "image"
297
+ self.encoder = Encoder(**ddconfig)
298
+ self.decoder = Decoder(**ddconfig)
299
+ self.loss = instantiate_from_config(lossconfig)
300
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) # 8 -> 8
301
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) # 4 -> 4
302
+ self.embed_dim = embed_dim
303
+ self.monitor = monitor
304
+
305
+ def encode(self, x):
306
+ h = self.encoder(x)
307
+ moments = self.quant_conv(h)
308
+ posterior = DiagonalGaussianDistribution(moments)
309
+ return posterior
310
+
311
+ def decode(self, z):
312
+ z = self.post_quant_conv(z)
313
+ dec = self.decoder(z)
314
+ return dec
315
+
316
+ def forward(self, input, sample_posterior=True):
317
+ posterior = self.encode(input)
318
+ if sample_posterior:
319
+ z = posterior.sample()
320
+ else:
321
+ z = posterior.mode()
322
+ dec = self.decode(z)
323
+ return dec, posterior
324
+
325
+ def get_input(self, batch, k):
326
+ x = batch[k]
327
+ if len(x.shape) == 3:
328
+ x = x[..., None]
329
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
330
+ return x
331
+
332
+ def training_step(self, batch, batch_idx, optimizer_idx):
333
+ inputs = self.get_input(batch, self.image_key)
334
+ reconstructions, posterior = self(inputs)
335
+
336
+ if optimizer_idx == 0:
337
+ # train encoder+decoder+logvar
338
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
339
+ last_layer=self.get_last_layer(), split="train")
340
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
341
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
342
+ return aeloss
343
+
344
+ if optimizer_idx == 1:
345
+ # train the discriminator
346
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
347
+ last_layer=self.get_last_layer(), split="train")
348
+
349
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
350
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
351
+ return discloss
352
+
353
+ def validation_step(self, batch, batch_idx):
354
+ inputs = self.get_input(batch, self.image_key)
355
+ reconstructions, posterior = self(inputs)
356
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
357
+ last_layer=self.get_last_layer(), split="val")
358
+
359
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
360
+ last_layer=self.get_last_layer(), split="val")
361
+
362
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
363
+ self.log_dict(log_dict_ae)
364
+ self.log_dict(log_dict_disc)
365
+ return self.log_dict
366
+
367
+ def configure_optimizers(self):
368
+ lr = self.learning_rate
369
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
370
+ list(self.decoder.parameters())+
371
+ list(self.quant_conv.parameters())+
372
+ list(self.post_quant_conv.parameters()),
373
+ lr=lr, betas=(0.5, 0.9))
374
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
375
+ lr=lr, betas=(0.5, 0.9))
376
+ return [opt_ae, opt_disc], []
377
+
378
+ def get_last_layer(self):
379
+ return self.decoder.conv_out.weight
380
+
381
+ @torch.no_grad()
382
+ def log_images(self, batch, only_inputs=False, **kwargs):
383
+ log = dict()
384
+ x = self.get_input(batch, self.image_key)
385
+ x = x.to(self.device)
386
+ if not only_inputs:
387
+ xrec, posterior = self(x)
388
+ if x.shape[1] > 3:
389
+ # colorize with random projection
390
+ assert xrec.shape[1] > 3
391
+ x = self.to_rgb(x)
392
+ xrec = self.to_rgb(xrec)
393
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
394
+ log["reconstructions"] = xrec
395
+ log["inputs"] = x
396
+ return log
397
+
398
+ def to_rgb(self, x):
399
+ assert self.image_key == "segmentation"
400
+ if not hasattr(self, "colorize"):
401
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
402
+ x = F.conv2d(x, weight=self.colorize)
403
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
404
+ return x
405
+
406
+
407
+
408
+
ldm/models/diffusion/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (150 Bytes). View file
 
ldm/models/diffusion/__pycache__/control.cpython-38.pyc ADDED
Binary file (11.5 kB). View file
 
ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc ADDED
Binary file (7.76 kB). View file
 
ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc ADDED
Binary file (5.1 kB). View file
 
ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from omegaconf import OmegaConf
5
+ from torch.nn import functional as F
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from copy import deepcopy
9
+ from einops import rearrange
10
+ from glob import glob
11
+ from natsort import natsorted
12
+
13
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
+
16
+ __models__ = {
17
+ 'class_label': EncoderUNetModel,
18
+ 'segmentation': UNetModel
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ class NoisyLatentImageClassifier(pl.LightningModule):
29
+
30
+ def __init__(self,
31
+ diffusion_path,
32
+ num_classes,
33
+ ckpt_path=None,
34
+ pool='attention',
35
+ label_key=None,
36
+ diffusion_ckpt_path=None,
37
+ scheduler_config=None,
38
+ weight_decay=1.e-2,
39
+ log_steps=10,
40
+ monitor='val/loss',
41
+ *args,
42
+ **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.num_classes = num_classes
45
+ # get latest config of diffusion model
46
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
48
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
+ self.load_diffusion()
50
+
51
+ self.monitor = monitor
52
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
+ self.log_steps = log_steps
55
+
56
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
+ else self.diffusion_model.cond_stage_key
58
+
59
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
+
61
+ if self.label_key not in __models__:
62
+ raise NotImplementedError()
63
+
64
+ self.load_classifier(ckpt_path, pool)
65
+
66
+ self.scheduler_config = scheduler_config
67
+ self.use_scheduler = self.scheduler_config is not None
68
+ self.weight_decay = weight_decay
69
+
70
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
+ sd = torch.load(path, map_location="cpu")
72
+ if "state_dict" in list(sd.keys()):
73
+ sd = sd["state_dict"]
74
+ keys = list(sd.keys())
75
+ for k in keys:
76
+ for ik in ignore_keys:
77
+ if k.startswith(ik):
78
+ print("Deleting key {} from state_dict.".format(k))
79
+ del sd[k]
80
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
+ sd, strict=False)
82
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
+ if len(missing) > 0:
84
+ print(f"Missing Keys: {missing}")
85
+ if len(unexpected) > 0:
86
+ print(f"Unexpected Keys: {unexpected}")
87
+
88
+ def load_diffusion(self):
89
+ model = instantiate_from_config(self.diffusion_config)
90
+ self.diffusion_model = model.eval()
91
+ self.diffusion_model.train = disabled_train
92
+ for param in self.diffusion_model.parameters():
93
+ param.requires_grad = False
94
+
95
+ def load_classifier(self, ckpt_path, pool):
96
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
+ model_config.out_channels = self.num_classes
99
+ if self.label_key == 'class_label':
100
+ model_config.pool = pool
101
+
102
+ self.model = __models__[self.label_key](**model_config)
103
+ if ckpt_path is not None:
104
+ print('#####################################################################')
105
+ print(f'load from ckpt "{ckpt_path}"')
106
+ print('#####################################################################')
107
+ self.init_from_ckpt(ckpt_path)
108
+
109
+ @torch.no_grad()
110
+ def get_x_noisy(self, x, t, noise=None):
111
+ noise = default(noise, lambda: torch.randn_like(x))
112
+ continuous_sqrt_alpha_cumprod = None
113
+ if self.diffusion_model.use_continuous_noise:
114
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
+ # todo: make sure t+1 is correct here
116
+
117
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
+
120
+ def forward(self, x_noisy, t, *args, **kwargs):
121
+ return self.model(x_noisy, t)
122
+
123
+ @torch.no_grad()
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = rearrange(x, 'b h w c -> b c h w')
129
+ x = x.to(memory_format=torch.contiguous_format).float()
130
+ return x
131
+
132
+ @torch.no_grad()
133
+ def get_conditioning(self, batch, k=None):
134
+ if k is None:
135
+ k = self.label_key
136
+ assert k is not None, 'Needs to provide label key'
137
+
138
+ targets = batch[k].to(self.device)
139
+
140
+ if self.label_key == 'segmentation':
141
+ targets = rearrange(targets, 'b h w c -> b c h w')
142
+ for down in range(self.numd):
143
+ h, w = targets.shape[-2:]
144
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
+
146
+ # targets = rearrange(targets,'b c h w -> b h w c')
147
+
148
+ return targets
149
+
150
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
151
+ _, top_ks = torch.topk(logits, k, dim=1)
152
+ if reduction == "mean":
153
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
+ elif reduction == "none":
155
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
156
+
157
+ def on_train_epoch_start(self):
158
+ # save some memory
159
+ self.diffusion_model.model.to('cpu')
160
+
161
+ @torch.no_grad()
162
+ def write_logs(self, loss, logits, targets):
163
+ log_prefix = 'train' if self.training else 'val'
164
+ log = {}
165
+ log[f"{log_prefix}/loss"] = loss.mean()
166
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
+ logits, targets, k=1, reduction="mean"
168
+ )
169
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
+ logits, targets, k=5, reduction="mean"
171
+ )
172
+
173
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
+ lr = self.optimizers().param_groups[0]['lr']
177
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
+
179
+ def shared_step(self, batch, t=None):
180
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
+ targets = self.get_conditioning(batch)
182
+ if targets.dim() == 4:
183
+ targets = targets.argmax(dim=1)
184
+ if t is None:
185
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
+ else:
187
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
+ x_noisy = self.get_x_noisy(x, t)
189
+ logits = self(x_noisy, t)
190
+
191
+ loss = F.cross_entropy(logits, targets, reduction='none')
192
+
193
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
+
195
+ loss = loss.mean()
196
+ return loss, logits, x_noisy, targets
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ loss, *_ = self.shared_step(batch)
200
+ return loss
201
+
202
+ def reset_noise_accs(self):
203
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
+
206
+ def on_validation_start(self):
207
+ self.reset_noise_accs()
208
+
209
+ @torch.no_grad()
210
+ def validation_step(self, batch, batch_idx):
211
+ loss, *_ = self.shared_step(batch)
212
+
213
+ for t in self.noisy_acc:
214
+ _, logits, _, targets = self.shared_step(batch, t)
215
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
+
218
+ return loss
219
+
220
+ def configure_optimizers(self):
221
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
+
223
+ if self.use_scheduler:
224
+ scheduler = instantiate_from_config(self.scheduler_config)
225
+
226
+ print("Setting up LambdaLR scheduler...")
227
+ scheduler = [
228
+ {
229
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
+ 'interval': 'step',
231
+ 'frequency': 1
232
+ }]
233
+ return [optimizer], scheduler
234
+
235
+ return optimizer
236
+
237
+ @torch.no_grad()
238
+ def log_images(self, batch, N=8, *args, **kwargs):
239
+ log = dict()
240
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
+ log['inputs'] = x
242
+
243
+ y = self.get_conditioning(batch)
244
+
245
+ if self.label_key == 'class_label':
246
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
+ log['labels'] = y
248
+
249
+ if ismap(y):
250
+ log['labels'] = self.diffusion_model.to_rgb(y)
251
+
252
+ for step in range(self.log_steps):
253
+ current_time = step * self.log_time_interval
254
+
255
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
+
257
+ log[f'inputs@t{current_time}'] = x_noisy
258
+
259
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
+ pred = rearrange(pred, 'b h w c -> b c h w')
261
+
262
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
+
264
+ for key in log:
265
+ log[key] = log[key][:N]
266
+
267
+ return log
ldm/models/diffusion/control.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torchvision
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.optim.lr_scheduler import CosineAnnealingLR
7
+
8
+ from einops import rearrange
9
+ from ldm.modules.diffusionmodules.util import (
10
+ conv_nd,
11
+ linear,
12
+ zero_module,
13
+ timestep_embedding,
14
+ )
15
+ from ldm.models.diffusion.ddpm import DDPM
16
+ from ldm.modules.attention import SpatialTransformer
17
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
18
+ from ldm.util import instantiate_from_config, default
19
+ from ldm.models.diffusion.ddim import DDIMSampler
20
+
21
+ import torch
22
+ from torch.optim.optimizer import Optimizer
23
+ from torch.optim.lr_scheduler import LambdaLR
24
+
25
+
26
+ def disabled_train(self, mode=True):
27
+ return self
28
+
29
+
30
+ # =============================================================
31
+ # 可训练部分 ControlNet
32
+ # =============================================================
33
+ class ControlNet(nn.Module):
34
+ def __init__(
35
+ self,
36
+ in_channels, # 9
37
+ model_channels, # 320
38
+ hint_channels, # 20
39
+ attention_resolutions, # [4,2,1]
40
+ num_res_blocks, # 2
41
+ channel_mult=(1, 2, 4, 8), # [1,2,4,4]
42
+ num_head_channels=-1, # 64
43
+ transformer_depth=1, # 1
44
+ context_dim=None, # 768
45
+ use_checkpoint=False, # True
46
+ dropout=0,
47
+ conv_resample=True,
48
+ dims=2,
49
+ num_heads=-1,
50
+ use_scale_shift_norm=False):
51
+ super(ControlNet, self).__init__()
52
+ self.dims = dims
53
+ self.in_channels = in_channels
54
+ self.model_channels = model_channels
55
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
56
+ self.attention_resolutions = attention_resolutions
57
+ self.dropout = dropout
58
+ self.channel_mult = channel_mult
59
+ self.use_checkpoint = use_checkpoint
60
+ self.dtype = torch.float32
61
+ self.num_heads = num_heads
62
+ self.num_head_channels = num_head_channels
63
+
64
+ # time 编码器
65
+ time_embed_dim = model_channels * 4
66
+ self.time_embed = nn.Sequential(
67
+ linear(model_channels, time_embed_dim),
68
+ nn.SiLU(),
69
+ linear(time_embed_dim, time_embed_dim),
70
+ )
71
+
72
+ # input 编码器
73
+ self.input_blocks = nn.ModuleList(
74
+ [
75
+ TimestepEmbedSequential(
76
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
77
+ )
78
+ ]
79
+ )
80
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
81
+
82
+ # hint 编码器
83
+ self.input_hint_block = TimestepEmbedSequential(
84
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
85
+ nn.SiLU(),
86
+ conv_nd(dims, 16, 16, 3, padding=1),
87
+ nn.SiLU(),
88
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
89
+ nn.SiLU(),
90
+ conv_nd(dims, 32, 32, 3, padding=1),
91
+ nn.SiLU(),
92
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
93
+ nn.SiLU(),
94
+ conv_nd(dims, 96, 96, 3, padding=1),
95
+ nn.SiLU(),
96
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
97
+ nn.SiLU(),
98
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
99
+ )
100
+
101
+ # UNet
102
+ input_block_chans = [model_channels]
103
+ ch = model_channels
104
+ ds = 1
105
+ for level, mult in enumerate(channel_mult):
106
+ for nr in range(self.num_res_blocks[level]):
107
+ layers = [
108
+ ResBlock(
109
+ ch,
110
+ time_embed_dim,
111
+ dropout,
112
+ out_channels=mult * model_channels,
113
+ dims=dims,
114
+ use_checkpoint=use_checkpoint,
115
+ use_scale_shift_norm=use_scale_shift_norm,
116
+ )
117
+ ]
118
+ ch = mult * model_channels
119
+ if ds in attention_resolutions:
120
+ num_heads = ch // num_head_channels
121
+ dim_head = num_head_channels
122
+ disabled_sa = False
123
+
124
+ layers.append(
125
+ SpatialTransformer(
126
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim)
127
+ )
128
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
129
+ self.zero_convs.append(self.make_zero_conv(ch))
130
+ input_block_chans.append(ch)
131
+ if level != len(channel_mult) - 1:
132
+ out_ch = ch
133
+ self.input_blocks.append(
134
+ TimestepEmbedSequential(
135
+ Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
136
+ )
137
+ )
138
+ ch = out_ch
139
+ input_block_chans.append(ch)
140
+ self.zero_convs.append(self.make_zero_conv(ch))
141
+ ds *= 2
142
+ num_heads = ch // num_head_channels
143
+ dim_head = num_head_channels
144
+ self.middle_block = TimestepEmbedSequential(
145
+ ResBlock(
146
+ ch,
147
+ time_embed_dim,
148
+ dropout,
149
+ dims=dims,
150
+ use_checkpoint=use_checkpoint,
151
+ use_scale_shift_norm=use_scale_shift_norm,
152
+ ),
153
+ SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
154
+ ResBlock(
155
+ ch,
156
+ time_embed_dim,
157
+ dropout,
158
+ dims=dims,
159
+ use_checkpoint=use_checkpoint,
160
+ use_scale_shift_norm=use_scale_shift_norm,
161
+ ),
162
+ )
163
+ self.middle_block_out = self.make_zero_conv(ch)
164
+
165
+ def make_zero_conv(self, channels):
166
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
167
+
168
+ def forward(self, x, hint, timesteps, reference_dino):
169
+ # 处理输入
170
+ context = reference_dino
171
+ t_emb = timestep_embedding(timesteps, self.model_channels)
172
+ emb = self.time_embed(t_emb)
173
+ guided_hint = self.input_hint_block(hint, emb)
174
+
175
+ # 预测 control
176
+ outs = []
177
+ h = x.type(self.dtype)
178
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
179
+ if guided_hint is not None:
180
+ h = module(h, emb, context)
181
+ h += guided_hint
182
+ guided_hint = None
183
+ else:
184
+ h = module(h, emb, context)
185
+ outs.append(zero_conv(h, emb, context))
186
+ h = self.middle_block(h, emb, context)
187
+ outs.append(self.middle_block_out(h, emb, context))
188
+
189
+ return outs
190
+
191
+ # =============================================================
192
+ # 固定参数部分 ControlledUnetModel
193
+ # =============================================================
194
+ class ControlledUnetModel(UNetModel):
195
+ def forward(self, x, timesteps=None, context=None, control=None):
196
+ hs = []
197
+
198
+ # UNet 的上半部分
199
+ with torch.no_grad():
200
+ t_emb = timestep_embedding(timesteps, self.model_channels)
201
+ emb = self.time_embed(t_emb)
202
+ h = x.type(self.dtype)
203
+ for module in self.input_blocks:
204
+ h = module(h, emb, context)
205
+ hs.append(h)
206
+ h = self.middle_block(h, emb, context)
207
+
208
+ # 注入 control
209
+ if control is not None:
210
+ h += control.pop()
211
+
212
+ # UNet 的下半部分
213
+ for i, module in enumerate(self.output_blocks):
214
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
215
+ h = module(h, emb, context)
216
+
217
+ # 输出
218
+ h = h.type(x.dtype)
219
+ h = self.out(h)
220
+
221
+ return h
222
+
223
+ # =============================================================
224
+ # 主干网络 ControlLDM
225
+ # =============================================================
226
+ class ControlLDM(DDPM):
227
+ def __init__(self,
228
+ control_stage_config, # ControlNet
229
+ first_stage_config, # AutoencoderKL
230
+ cond_stage_config, # FrozenCLIPImageEmbedder
231
+ condi_stage_config, # FrozenCLIPTextEmbedder
232
+ scale_factor=1.0, # 0.18215
233
+ *args, **kwargs):
234
+ self.num_timesteps_cond = 1
235
+ super().__init__(*args, **kwargs) # self.model 和 self.register_buffer
236
+ self.control_model = instantiate_from_config(control_stage_config) # self.control_model
237
+ self.instantiate_first_stage(first_stage_config) # self.first_stage_model 调用 AutoencoderKL
238
+ self.instantiate_cond_stage(cond_stage_config) # self.cond_stage_model 调用 FrozenCLIPImageEmbedder
239
+ self.instantiate_condi_stage(condi_stage_config) # self.condi_stage_model FrozenCLIPTextEmbedder
240
+ self.proj_out=nn.Linear(1024, 768) # 全连接层
241
+ self.scale_factor = scale_factor # 0.18215
242
+ self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=False)
243
+ self.trainable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
244
+
245
+ self.dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
246
+ self.dinov2_vitl14.eval()
247
+ self.dinov2_vitl14.train = disabled_train
248
+ for param in self.dinov2_vitl14.parameters():
249
+ param.requires_grad = False
250
+ self.linear = nn.Linear(1024, 768)
251
+
252
+ # AutoencoderKL 不训练
253
+ def instantiate_first_stage(self, config):
254
+ model = instantiate_from_config(config)
255
+ self.first_stage_model = model.eval()
256
+ self.first_stage_model.train = disabled_train
257
+ for param in self.first_stage_model.parameters():
258
+ param.requires_grad = False
259
+
260
+ # FrozenCLIPImageEmbedder 不训练
261
+ def instantiate_cond_stage(self, config):
262
+ model = instantiate_from_config(config)
263
+ self.cond_stage_model = model.eval()
264
+ self.cond_stage_model.train = disabled_train
265
+ for param in self.cond_stage_model.parameters():
266
+ param.requires_grad = False
267
+
268
+ def instantiate_condi_stage(self, config):
269
+ model = instantiate_from_config(config)
270
+ self.condi_stage_model = model.eval()
271
+ self.condi_stage_model.train = disabled_train
272
+ for param in self.condi_stage_model.parameters():
273
+ param.required_grad = False
274
+
275
+ # 训练
276
+ def training_step(self, batch, batch_idx):
277
+ z_new, reference, hint, cloth_annotation= self.get_input(batch) # 加载数据
278
+ loss= self(z_new, reference, hint, cloth_annotation) # 计算损失
279
+ self.log("loss", # 记录损失
280
+ loss,
281
+ prog_bar=True,
282
+ logger=True,
283
+ on_step=True,
284
+ on_epoch=True)
285
+ self.log('lr_abs', # 记录学习率
286
+ self.optimizers().param_groups[0]['lr'],
287
+ prog_bar=True,
288
+ logger=True,
289
+ on_step=True,
290
+ on_epoch=False)
291
+ return loss
292
+
293
+ # 加载数据
294
+ @torch.no_grad()
295
+ def get_input(self, batch):
296
+
297
+ # 加载原始数据
298
+ x, inpaint, mask, reference, hint, cloth_annotation = super().get_input(batch)
299
+
300
+ # AutoencoderKL 处理真值
301
+ encoder_posterior = self.first_stage_model.encode(x)
302
+ z = self.scale_factor * (encoder_posterior.sample()).detach()
303
+
304
+ # AutoencoderKL 处理 inpaint
305
+ encoder_posterior_inpaint = self.first_stage_model.encode(inpaint)
306
+ z_inpaint = self.scale_factor * (encoder_posterior_inpaint.sample()).detach()
307
+
308
+ # Resize mask
309
+ mask_resize = torchvision.transforms.Resize([z.shape[-2],z.shape[-1]])(mask)
310
+
311
+ # 整理 z_new
312
+ z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
313
+ out = [z_new, reference, hint, cloth_annotation]
314
+
315
+ return out
316
+
317
+ # 计算损失
318
+ def forward(self, z_new, reference, hint, cloth_annotation):
319
+
320
+ # 随机时间 t
321
+ t = torch.randint(0, self.num_timesteps, (z_new.shape[0],), device=self.device).long()
322
+
323
+ # CLIP 处理 reference
324
+ reference_clip = self.cond_stage_model.encode(reference)
325
+ reference_clip = self.proj_out(reference_clip)
326
+
327
+ # CLIP text reference
328
+ reference_clip_text = self.condi_stage_model.encode(cloth_annotation)
329
+
330
+ #apply CrossAttention to combine features
331
+ cross_att = CrossAttention().to('cuda')
332
+ reference_clip = cross_att(reference_clip, reference_clip_text, reference_clip_text)
333
+
334
+ # DINO 处理 reference
335
+ dino = self.dinov2_vitl14(reference,is_training=True)
336
+ dino1 = dino["x_norm_clstoken"].unsqueeze(1)
337
+ dino2 = dino["x_norm_patchtokens"]
338
+ reference_dino = torch.cat((dino1, dino2), dim=1)
339
+ reference_dino = self.linear(reference_dino)
340
+
341
+ # 随机加噪
342
+ noise = torch.randn_like(z_new[:,:4,:,:])
343
+ x_noisy = self.q_sample(x_start=z_new[:,:4,:,:], t=t, noise=noise)
344
+ x_noisy = torch.cat((x_noisy, z_new[:,4:,:,:]),dim=1)
345
+
346
+ # 预测噪声
347
+ if random.uniform(0, 1)<0.2:
348
+ model_output = self.apply_model(x_noisy, hint, t, reference_clip, reference_dino)
349
+ else:
350
+ model_output = self.apply_model(x_noisy, hint, t, reference_clip, reference_dino)
351
+
352
+ # 计算损失
353
+ loss = self.get_loss(model_output, noise, mean=False).mean([1, 2, 3])
354
+ loss = loss.mean()
355
+
356
+ return loss
357
+
358
+ # 预测噪声
359
+ def apply_model(self, x_noisy, hint, t, reference_clip, reference_dino):
360
+
361
+ # 预测 control
362
+ control = self.control_model(x_noisy, hint, t, reference_dino)
363
+
364
+ # 调用 PBE
365
+ model_output = self.model(x_noisy, t, reference_clip, control)
366
+
367
+ return model_output
368
+
369
+ # 优化器
370
+ def configure_optimizers(self):
371
+ # 学习率设置
372
+ lr = self.learning_rate
373
+ params = list(self.control_model.parameters())+list(self.linear.parameters())
374
+ opt = torch.optim.AdamW(params, lr=lr)
375
+
376
+ return opt
377
+
378
+ # 采样
379
+ @torch.no_grad()
380
+ def sample_log(self, batch, ddim_steps=50, ddim_eta=0.):
381
+ z_new, reference, hint, cloth_annotation = self.get_input(batch)
382
+ x, _, mask, _, _, _ = super().get_input(batch)
383
+ log = dict()
384
+
385
+ # log["reference"] = reference
386
+ # reconstruction = 1. / self.scale_factor * z_new[:,:4,:,:]
387
+ # log["reconstruction"] = self.first_stage_model.decode(reconstruction)
388
+ log["mask"] = mask
389
+
390
+ test_model_kwargs = {}
391
+ test_model_kwargs['inpaint_image'] = z_new[:,4:8,:,:]
392
+ test_model_kwargs['inpaint_mask'] = z_new[:,8:,:,:]
393
+ ddim_sampler = DDIMSampler(self)
394
+ shape = (self.channels, self.image_size, self.image_size)
395
+ samples, _ = ddim_sampler.sample(ddim_steps,
396
+ reference.shape[0],
397
+ shape,
398
+ hint,
399
+ reference,
400
+ verbose=False,
401
+ eta=ddim_eta,
402
+ test_model_kwargs=test_model_kwargs)
403
+ samples = 1. / self.scale_factor * samples
404
+ x_samples = self.first_stage_model.decode(samples[:,:4,:,:])
405
+ # log["samples"] = x_samples
406
+
407
+ x = torchvision.transforms.Resize([512, 512])(x)
408
+ reference = torchvision.transforms.Resize([512, 512])(reference)
409
+ x_samples = torchvision.transforms.Resize([512, 512])(x_samples)
410
+ log["grid"] = torch.cat((x, reference, x_samples), dim=2)
411
+
412
+ return log
413
+
414
+
415
+
416
+ # CrossAttention class applies cross-attention between two embeddings: an image embedding and a text embedding.
417
+ class CrossAttention(nn.Module):
418
+ def __init__(
419
+ self,
420
+ embed_dim: int=768,
421
+ num_heads: int=8
422
+ ):
423
+ """
424
+ Initializes a CrossAttention layer using multi-head attention.
425
+
426
+ Args:
427
+ embed_dim (int): Dimensionality of the embeddings, which should match
428
+ the size of both reference_clip and reference_clip_text.
429
+ num_heads (int): Number of attention heads. Using multiple heads allows
430
+ the model to focus on different parts of the input embeddings.
431
+ """
432
+ super(CrossAttention, self).__init__()
433
+ self.cross_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
434
+
435
+ def forward(self, query, key, value):
436
+ """
437
+ Applies cross-attention to the query, key, and value inputs.
438
+
439
+ Args:
440
+ query (Tensor): The query tensor (in this case, reference_clip).
441
+ Shape should be [batch_size, seq_length, embed_dim].
442
+ key (Tensor): The key tensor (in this case, reference_clip_text).
443
+ Shape should be [batch_size, seq_length, embed_dim].
444
+ value (Tensor): The value tensor (in this case, reference_clip_text).
445
+ Shape should be [batch_size, seq_length, embed_dim].
446
+
447
+ Returns:
448
+ Tensor: The attention output after combining reference_clip and
449
+ reference_clip_text through cross-attention.
450
+ Shape is [batch_size, seq_length, embed_dim].
451
+ """
452
+
453
+ query = query.to('cuda')
454
+ key = key.to('cuda')
455
+ value = value.to('cuda')
456
+
457
+ # Apply cross-attention, where `query` attends to `key` and `value`.
458
+ # `attn_output` contains the resulting embeddings, and `attn_weights` contains the attention weights.
459
+ attn_output, attn_weights = self.cross_attn(query, key, value)
460
+ return attn_output
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
9
+ extract_into_tensor
10
+
11
+
12
+ class DDIMSampler(object):
13
+ def __init__(self, model, schedule="linear", **kwargs):
14
+ super().__init__()
15
+ self.model = model
16
+ self.ddpm_num_timesteps = model.num_timesteps
17
+ self.schedule = schedule
18
+
19
+ def register_buffer(self, name, attr):
20
+ if type(attr) == torch.Tensor:
21
+ if attr.device != torch.device("cuda"):
22
+ attr = attr.to(torch.device("cuda"))
23
+ setattr(self, name, attr)
24
+
25
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28
+ alphas_cumprod = self.model.alphas_cumprod
29
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31
+
32
+ self.register_buffer('betas', to_torch(self.model.betas))
33
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35
+
36
+ # calculations for diffusion q(x_t | x_{t-1}) and others
37
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42
+
43
+ # ddim sampling parameters
44
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45
+ ddim_timesteps=self.ddim_timesteps,
46
+ eta=ddim_eta,verbose=verbose)
47
+
48
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
49
+ self.register_buffer('ddim_alphas', ddim_alphas)
50
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
+
57
+ @torch.no_grad()
58
+ def sample(self,
59
+ S,
60
+ batch_size,
61
+ shape,
62
+ pose,
63
+ conditioning=None,
64
+ callback=None,
65
+ normals_sequence=None,
66
+ img_callback=None,
67
+ quantize_x0=False,
68
+ eta=0.,
69
+ mask=None,
70
+ x0=None,
71
+ temperature=1.,
72
+ noise_dropout=0.,
73
+ score_corrector=None,
74
+ corrector_kwargs=None,
75
+ verbose=True,
76
+ x_T=None,
77
+ log_every_t=100,
78
+ unconditional_guidance_scale=1.,
79
+ unconditional_conditioning=None,
80
+ **kwargs
81
+ ):
82
+ if conditioning is not None:
83
+ if isinstance(conditioning, dict):
84
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+ else:
88
+ if conditioning.shape[0] != batch_size:
89
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90
+
91
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92
+ # sampling
93
+ C, H, W = shape
94
+ size = (batch_size, C, H, W)
95
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
96
+
97
+ samples, intermediates = self.ddim_sampling(conditioning, size, pose,
98
+ callback=callback,
99
+ img_callback=img_callback,
100
+ quantize_denoised=quantize_x0,
101
+ mask=mask, x0=x0,
102
+ ddim_use_original_steps=False,
103
+ noise_dropout=noise_dropout,
104
+ temperature=temperature,
105
+ score_corrector=score_corrector,
106
+ corrector_kwargs=corrector_kwargs,
107
+ x_T=x_T,
108
+ log_every_t=log_every_t,
109
+ unconditional_guidance_scale=unconditional_guidance_scale,
110
+ unconditional_conditioning=unconditional_conditioning,
111
+ **kwargs
112
+ )
113
+ return samples, intermediates
114
+
115
+ @torch.no_grad()
116
+ def ddim_sampling(self, cond, shape, pose,
117
+ x_T=None, ddim_use_original_steps=False,
118
+ callback=None, timesteps=None, quantize_denoised=False,
119
+ mask=None, x0=None, img_callback=None, log_every_t=100,
120
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
121
+ unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs):
122
+ device = self.model.betas.device
123
+ b = shape[0]
124
+ if x_T is None:
125
+ img = torch.randn(shape, device=device)
126
+ else:
127
+ img = x_T
128
+
129
+ if timesteps is None:
130
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
131
+ elif timesteps is not None and not ddim_use_original_steps:
132
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
133
+ timesteps = self.ddim_timesteps[:subset_end]
134
+
135
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
136
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
137
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
138
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
139
+
140
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
141
+
142
+ for i, step in enumerate(iterator):
143
+ index = total_steps - i - 1
144
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
145
+
146
+ outs = self.p_sample_ddim(img, cond, ts, pose, index=index, use_original_steps=ddim_use_original_steps,
147
+ quantize_denoised=quantize_denoised, temperature=temperature,
148
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
149
+ corrector_kwargs=corrector_kwargs,
150
+ unconditional_guidance_scale=unconditional_guidance_scale,
151
+ unconditional_conditioning=unconditional_conditioning,**kwargs)
152
+ img, pred_x0 = outs
153
+ if callback: callback(i)
154
+ if img_callback: img_callback(pred_x0, i)
155
+
156
+ if index % log_every_t == 0 or index == total_steps - 1:
157
+ intermediates['x_inter'].append(img)
158
+ intermediates['pred_x0'].append(pred_x0)
159
+
160
+ return img, intermediates
161
+
162
+ @torch.no_grad()
163
+ def p_sample_ddim(self, x, c, t, pose, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
164
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
165
+ unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs):
166
+ b, *_, device = *x.shape, x.device
167
+
168
+ if 'test_model_kwargs' in kwargs:
169
+ kwargs=kwargs['test_model_kwargs']
170
+ x = torch.cat([x, kwargs['inpaint_image'], kwargs['inpaint_mask']],dim=1)
171
+ elif 'rest' in kwargs:
172
+ x = torch.cat((x, kwargs['rest']), dim=1)
173
+ else:
174
+ raise Exception("kwargs must contain either 'test_model_kwargs' or 'rest' key")
175
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
176
+ reference_clip = self.model.cond_stage_model.encode(c)
177
+ reference_clip= self.model.proj_out(reference_clip)
178
+ dino = self.model.dinov2_vitl14(c,is_training=True)
179
+ dino1 = dino["x_norm_clstoken"].unsqueeze(1)
180
+ dino2 = dino["x_norm_patchtokens"]
181
+ reference_dino = torch.cat((dino1, dino2), dim=1)
182
+ reference_dino = self.model.linear(reference_dino)
183
+ control = self.model.control_model(x, pose, t, reference_dino)
184
+ e_t = self.model.model(x, t, reference_clip, control)
185
+ else:
186
+ x_in = torch.cat([x] * 2)
187
+ t_in = torch.cat([t] * 2)
188
+ c_in = torch.cat([unconditional_conditioning, c])
189
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
190
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
191
+
192
+ if score_corrector is not None:
193
+ assert self.model.parameterization == "eps"
194
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
195
+
196
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
197
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
198
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
199
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
200
+ # select parameters corresponding to the currently considered timestep
201
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205
+
206
+ # current prediction for x_0
207
+ if x.shape[1]!=4:
208
+ pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt()
209
+ else:
210
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
211
+
212
+ if quantize_denoised:
213
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
214
+ # direction pointing to x_t
215
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
216
+ noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature
217
+ if noise_dropout > 0.:
218
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
219
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
220
+
221
+ # noise
222
+ # noise = torch.randn_like(kwargs['inpaint_image'])
223
+ # x_noisy = self.model.q_sample(x_start=kwargs['inpaint_image'], t=t, noise=noise)
224
+ # x_prev = x_prev*(1-kwargs['inpaint_mask'])+x_noisy*(kwargs['inpaint_mask'])
225
+
226
+
227
+ return x_prev, pred_x0
228
+
229
+ @torch.no_grad()
230
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
231
+ # fast, but does not allow for exact reconstruction
232
+ # t serves as an index to gather the correct alphas
233
+ if use_original_steps:
234
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
235
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
236
+ else:
237
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
238
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
239
+
240
+ if noise is None:
241
+ noise = torch.randn_like(x0)
242
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
243
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
244
+
245
+ @torch.no_grad()
246
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
247
+ use_original_steps=False):
248
+
249
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
250
+ timesteps = timesteps[:t_start]
251
+
252
+ time_range = np.flip(timesteps)
253
+ total_steps = timesteps.shape[0]
254
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
255
+
256
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
257
+ x_dec = x_latent
258
+ for i, step in enumerate(iterator):
259
+ index = total_steps - i - 1
260
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
261
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
262
+ unconditional_guidance_scale=unconditional_guidance_scale,
263
+ unconditional_conditioning=unconditional_conditioning)
264
+ return x_dec
265
+
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+ from einops import rearrange, repeat
8
+ from functools import partial
9
+ from torchvision.utils import make_grid
10
+ from ldm.util import default, count_params, instantiate_from_config
11
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
12
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor
13
+ from ldm.models.diffusion.ddim import DDIMSampler
14
+ from torchvision.transforms import Resize
15
+ import random
16
+
17
+
18
+
19
+ def disabled_train(self, mode=True):
20
+ return self
21
+
22
+
23
+ class DiffusionWrapper(pl.LightningModule):
24
+ def __init__(self, unet_config):
25
+ super().__init__()
26
+ self.diffusion_model = instantiate_from_config(unet_config)
27
+
28
+ def forward(self, x, timesteps=None, context=None, control=None):
29
+ out = self.diffusion_model(x, timesteps, context, control)
30
+ return out
31
+
32
+
33
+ class DDPM(pl.LightningModule):
34
+ def __init__(self,
35
+ unet_config,
36
+ linear_start=1e-4, # 0.00085
37
+ linear_end=2e-2, # 0.0120
38
+ log_every_t=100, # 200
39
+ timesteps=1000, # 1000
40
+ image_size=256, # 64
41
+ channels=3, # 4
42
+ u_cond_percent=0, # 0.2
43
+ use_ema=True, # False
44
+ beta_schedule="linear",
45
+ loss_type="l2",
46
+ clip_denoised=True,
47
+ cosine_s=8e-3,
48
+ original_elbo_weight=0.,
49
+ v_posterior=0.,
50
+ l_simple_weight=1.,
51
+ parameterization="eps",
52
+ use_positional_encodings=False,
53
+ learn_logvar=False,
54
+ logvar_init=0.):
55
+ super().__init__()
56
+ self.parameterization = parameterization
57
+ self.cond_stage_model = None
58
+ self.clip_denoised = clip_denoised
59
+ self.log_every_t = log_every_t
60
+ self.image_size = image_size
61
+ self.channels = channels
62
+ self.u_cond_percent=u_cond_percent
63
+ self.use_positional_encodings = use_positional_encodings
64
+ self.model = DiffusionWrapper(unet_config) # 调用 UNet 模型
65
+
66
+ self.use_ema = use_ema
67
+ self.use_scheduler = True
68
+ self.v_posterior = v_posterior
69
+ self.original_elbo_weight = original_elbo_weight
70
+ self.l_simple_weight = l_simple_weight
71
+ self.register_schedule(beta_schedule=beta_schedule, # "linear"
72
+ timesteps=timesteps, # 1000
73
+ linear_start=linear_start, # 0.00085
74
+ linear_end=linear_end, # 0.0120
75
+ cosine_s=cosine_s) # 8e-3
76
+ self.loss_type = loss_type
77
+ self.learn_logvar = learn_logvar
78
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
79
+
80
+ def register_schedule(self,
81
+ beta_schedule="linear",
82
+ timesteps=1000,
83
+ linear_start=0.00085,
84
+ linear_end=0.0120,
85
+ cosine_s=8e-3):
86
+ betas = make_beta_schedule(beta_schedule,
87
+ timesteps,
88
+ linear_start=linear_start,
89
+ linear_end=linear_end,
90
+ cosine_s=cosine_s)
91
+ alphas = 1. - betas
92
+ alphas_cumprod = np.cumprod(alphas, axis=0)
93
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
94
+
95
+ timesteps, = betas.shape
96
+ self.num_timesteps = int(timesteps)
97
+ self.linear_start = linear_start
98
+ self.linear_end = linear_end
99
+ to_torch = partial(torch.tensor, dtype=torch.float32)
100
+ self.register_buffer('betas', to_torch(betas))
101
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
102
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
103
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
104
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
105
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
106
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
107
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
108
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + self.v_posterior * betas
109
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
110
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
111
+ self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
112
+ self.register_buffer('posterior_mean_coef2', to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
113
+ lvlb_weights = self.betas ** 2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
114
+ lvlb_weights[0] = lvlb_weights[1]
115
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
116
+
117
+ def get_input(self, batch):
118
+ x = batch['GT']
119
+ mask = batch['inpaint_mask']
120
+ inpaint = batch['inpaint_image']
121
+ reference = batch['ref_imgs']
122
+ hint = batch['hint']
123
+ cloth_annotation = batch['caption_cloth']
124
+
125
+ x = x.to(memory_format=torch.contiguous_format).float()
126
+ mask = mask.to(memory_format=torch.contiguous_format).float()
127
+ inpaint = inpaint.to(memory_format=torch.contiguous_format).float()
128
+ reference = reference.to(memory_format=torch.contiguous_format).float()
129
+ hint = hint.to(memory_format=torch.contiguous_format).float()
130
+
131
+ return x, inpaint, mask, reference, hint, cloth_annotation
132
+
133
+ def q_sample(self, x_start, t, noise=None):
134
+ noise = default(noise, lambda: torch.randn_like(x_start))
135
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
136
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
137
+
138
+ def get_loss(self, pred, target, mean=True):
139
+ if mean:
140
+ loss = torch.nn.functional.mse_loss(target, pred)
141
+ else:
142
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
143
+ return loss
144
+
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class PLMSSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ if ddim_eta != 0:
26
+ raise ValueError('ddim_eta must be 0 for PLMS')
27
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29
+ alphas_cumprod = self.model.alphas_cumprod
30
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32
+
33
+ self.register_buffer('betas', to_torch(self.model.betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43
+
44
+ # ddim sampling parameters
45
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46
+ ddim_timesteps=self.ddim_timesteps,
47
+ eta=ddim_eta,verbose=verbose)
48
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
49
+ self.register_buffer('ddim_alphas', ddim_alphas)
50
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
+
57
+ @torch.no_grad()
58
+ def sample(self,
59
+ S,
60
+ batch_size,
61
+ shape,
62
+ conditioning=None,
63
+ callback=None,
64
+ normals_sequence=None,
65
+ img_callback=None,
66
+ quantize_x0=False,
67
+ eta=0.,
68
+ mask=None,
69
+ x0=None,
70
+ temperature=1.,
71
+ noise_dropout=0.,
72
+ score_corrector=None,
73
+ corrector_kwargs=None,
74
+ verbose=True,
75
+ x_T=None,
76
+ log_every_t=100,
77
+ unconditional_guidance_scale=1.,
78
+ unconditional_conditioning=None,
79
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80
+ **kwargs
81
+ ):
82
+ if conditioning is not None:
83
+ if isinstance(conditioning, dict):
84
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+ else:
88
+ if conditioning.shape[0] != batch_size:
89
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90
+
91
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92
+ # sampling
93
+ C, H, W = shape
94
+ size = (batch_size, C, H, W)
95
+ print(f'Data shape for PLMS sampling is {size}')
96
+
97
+ samples, intermediates = self.plms_sampling(conditioning, size,
98
+ callback=callback,
99
+ img_callback=img_callback,
100
+ quantize_denoised=quantize_x0,
101
+ mask=mask, x0=x0,
102
+ ddim_use_original_steps=False,
103
+ noise_dropout=noise_dropout,
104
+ temperature=temperature,
105
+ score_corrector=score_corrector,
106
+ corrector_kwargs=corrector_kwargs,
107
+ x_T=x_T,
108
+ log_every_t=log_every_t,
109
+ unconditional_guidance_scale=unconditional_guidance_scale,
110
+ unconditional_conditioning=unconditional_conditioning,
111
+ **kwargs
112
+ )
113
+ return samples, intermediates
114
+
115
+ @torch.no_grad()
116
+ def plms_sampling(self, cond, shape,
117
+ x_T=None, ddim_use_original_steps=False,
118
+ callback=None, timesteps=None, quantize_denoised=False,
119
+ mask=None, x0=None, img_callback=None, log_every_t=100,
120
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
121
+ unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs):
122
+ device = self.model.betas.device
123
+ b = shape[0]
124
+ if x_T is None:
125
+ img = torch.randn(shape, device=device)
126
+ else:
127
+ img = x_T
128
+
129
+ if timesteps is None:
130
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
131
+ elif timesteps is not None and not ddim_use_original_steps:
132
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
133
+ timesteps = self.ddim_timesteps[:subset_end]
134
+
135
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
136
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
137
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
138
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
139
+
140
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
141
+ old_eps = []
142
+
143
+ for i, step in enumerate(iterator):
144
+ index = total_steps - i - 1
145
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
146
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
147
+
148
+ if mask is not None:
149
+ assert x0 is not None
150
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
151
+ img = img_orig * mask + (1. - mask) * img
152
+
153
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
154
+ quantize_denoised=quantize_denoised, temperature=temperature,
155
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
156
+ corrector_kwargs=corrector_kwargs,
157
+ unconditional_guidance_scale=unconditional_guidance_scale,
158
+ unconditional_conditioning=unconditional_conditioning,
159
+ old_eps=old_eps, t_next=ts_next,**kwargs)
160
+ img, pred_x0, e_t = outs
161
+ old_eps.append(e_t)
162
+ if len(old_eps) >= 4:
163
+ old_eps.pop(0)
164
+ if callback: callback(i)
165
+ if img_callback: img_callback(pred_x0, i)
166
+
167
+ if index % log_every_t == 0 or index == total_steps - 1:
168
+ intermediates['x_inter'].append(img)
169
+ intermediates['pred_x0'].append(pred_x0)
170
+
171
+ return img, intermediates
172
+
173
+ @torch.no_grad()
174
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
175
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
176
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,**kwargs):
177
+ b, *_, device = *x.shape, x.device
178
+ def get_model_output(x, t):
179
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180
+ e_t = self.model.apply_model(x, t, c)
181
+ else:
182
+ x_in = torch.cat([x] * 2)
183
+ t_in = torch.cat([t] * 2)
184
+ c_in = torch.cat([unconditional_conditioning, c])
185
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187
+
188
+ if score_corrector is not None:
189
+ assert self.model.parameterization == "eps"
190
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191
+
192
+ return e_t
193
+
194
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198
+
199
+ def get_x_prev_and_pred_x0(e_t, index):
200
+ # select parameters corresponding to the currently considered timestep
201
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205
+
206
+ # current prediction for x_0
207
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208
+ if quantize_denoised:
209
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210
+ # direction pointing to x_t
211
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213
+ if noise_dropout > 0.:
214
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216
+ return x_prev, pred_x0
217
+ kwargs=kwargs['test_model_kwargs']
218
+ print(x.shape,kwargs['inpaint_image'].shape,kwargs['inpaint_mask'].shape)
219
+ x_new=torch.cat([x,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1)
220
+ e_t = get_model_output(x_new, t)
221
+ if len(old_eps) == 0:
222
+ # Pseudo Improved Euler (2nd order)
223
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
224
+ x_prev_new=torch.cat([x_prev,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1)
225
+ e_t_next = get_model_output(x_prev_new, t_next)
226
+ e_t_prime = (e_t + e_t_next) / 2
227
+ elif len(old_eps) == 1:
228
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
229
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
230
+ elif len(old_eps) == 2:
231
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
232
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
233
+ elif len(old_eps) >= 3:
234
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
235
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
236
+
237
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
238
+
239
+ return x_prev, pred_x0, e_t
ldm/modules/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ldm/modules/__pycache__/attention.cpython-38.pyc ADDED
Binary file (11.1 kB). View file
 
ldm/modules/__pycache__/x_transformer.cpython-38.pyc ADDED
Binary file (18.3 kB). View file
 
ldm/modules/attention.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from ldm.modules.diffusionmodules.util import checkpoint
9
+
10
+ try:
11
+ import xformers
12
+ import xformers.ops
13
+ XFORMERS_IS_AVAILBLE = True
14
+ except:
15
+ XFORMERS_IS_AVAILBLE = False
16
+
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+
22
+ def uniq(arr):
23
+ return{el: True for el in arr}.keys()
24
+
25
+
26
+ def default(val, d):
27
+ if exists(val):
28
+ return val
29
+ return d() if isfunction(d) else d
30
+
31
+
32
+ def max_neg_value(t):
33
+ return -torch.finfo(t.dtype).max
34
+
35
+
36
+ def init_(tensor):
37
+ dim = tensor.shape[-1]
38
+ std = 1 / math.sqrt(dim)
39
+ tensor.uniform_(-std, std)
40
+ return tensor
41
+
42
+
43
+ # feedforward
44
+ class GEGLU(nn.Module):
45
+ def __init__(self, dim_in, dim_out):
46
+ super().__init__()
47
+ self.proj = nn.Linear(dim_in, dim_out * 2)
48
+
49
+ def forward(self, x):
50
+ x, gate = self.proj(x).chunk(2, dim=-1)
51
+ return x * F.gelu(gate)
52
+
53
+
54
+ class FeedForward(nn.Module):
55
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
56
+ super().__init__()
57
+ inner_dim = int(dim * mult)
58
+ dim_out = default(dim_out, dim)
59
+ project_in = nn.Sequential(
60
+ nn.Linear(dim, inner_dim),
61
+ nn.GELU()
62
+ ) if not glu else GEGLU(dim, inner_dim)
63
+
64
+ self.net = nn.Sequential(
65
+ project_in,
66
+ nn.Dropout(dropout),
67
+ nn.Linear(inner_dim, dim_out)
68
+ )
69
+
70
+ def forward(self, x):
71
+ return self.net(x)
72
+
73
+
74
+ def zero_module(module):
75
+ """
76
+ Zero out the parameters of a module and return it.
77
+ """
78
+ for p in module.parameters():
79
+ p.detach().zero_()
80
+ return module
81
+
82
+
83
+ def Normalize(in_channels):
84
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
85
+
86
+
87
+ class LinearAttention(nn.Module):
88
+ def __init__(self, dim, heads=4, dim_head=32):
89
+ super().__init__()
90
+ self.heads = heads
91
+ hidden_dim = dim_head * heads
92
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
93
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
94
+
95
+ def forward(self, x):
96
+ b, c, h, w = x.shape
97
+ qkv = self.to_qkv(x)
98
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
99
+ k = k.softmax(dim=-1)
100
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
101
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
102
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
103
+ return self.to_out(out)
104
+
105
+
106
+ class SpatialSelfAttention(nn.Module):
107
+ def __init__(self, in_channels):
108
+ super().__init__()
109
+ self.in_channels = in_channels
110
+
111
+ self.norm = Normalize(in_channels)
112
+ self.q = torch.nn.Conv2d(in_channels,
113
+ in_channels,
114
+ kernel_size=1,
115
+ stride=1,
116
+ padding=0)
117
+ self.k = torch.nn.Conv2d(in_channels,
118
+ in_channels,
119
+ kernel_size=1,
120
+ stride=1,
121
+ padding=0)
122
+ self.v = torch.nn.Conv2d(in_channels,
123
+ in_channels,
124
+ kernel_size=1,
125
+ stride=1,
126
+ padding=0)
127
+ self.proj_out = torch.nn.Conv2d(in_channels,
128
+ in_channels,
129
+ kernel_size=1,
130
+ stride=1,
131
+ padding=0)
132
+
133
+ def forward(self, x):
134
+ h_ = x
135
+ h_ = self.norm(h_)
136
+ q = self.q(h_)
137
+ k = self.k(h_)
138
+ v = self.v(h_)
139
+
140
+ # compute attention
141
+ b,c,h,w = q.shape
142
+ q = rearrange(q, 'b c h w -> b (h w) c')
143
+ k = rearrange(k, 'b c h w -> b c (h w)')
144
+ w_ = torch.einsum('bij,bjk->bik', q, k)
145
+
146
+ w_ = w_ * (int(c)**(-0.5))
147
+ w_ = torch.nn.functional.softmax(w_, dim=2)
148
+
149
+ # attend to values
150
+ v = rearrange(v, 'b c h w -> b c (h w)')
151
+ w_ = rearrange(w_, 'b i j -> b j i')
152
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
153
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
154
+ h_ = self.proj_out(h_)
155
+
156
+ return x+h_
157
+
158
+
159
+ class CrossAttention(nn.Module):
160
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
161
+ super().__init__()
162
+ inner_dim = dim_head * heads
163
+ context_dim = default(context_dim, query_dim)
164
+
165
+ self.scale = dim_head ** -0.5
166
+ self.heads = heads
167
+
168
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
169
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
170
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
171
+
172
+ self.to_out = nn.Sequential(
173
+ nn.Linear(inner_dim, query_dim),
174
+ nn.Dropout(dropout)
175
+ )
176
+
177
+ def forward(self, x, context=None, mask=None):
178
+ h = self.heads
179
+
180
+ q = self.to_q(x)
181
+ context = default(context, x)
182
+ k = self.to_k(context)
183
+ v = self.to_v(context)
184
+
185
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
186
+
187
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
188
+
189
+ if exists(mask):
190
+ mask = rearrange(mask, 'b ... -> b (...)')
191
+ max_neg_value = -torch.finfo(sim.dtype).max
192
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
193
+ sim.masked_fill_(~mask, max_neg_value)
194
+
195
+ # attention, what we cannot get enough of
196
+ attn = sim.softmax(dim=-1)
197
+
198
+ out = einsum('b i j, b j d -> b i d', attn, v)
199
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
200
+ return self.to_out(out)
201
+
202
+
203
+ class MemoryEfficientCrossAttention(nn.Module):
204
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
205
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
206
+ super().__init__()
207
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
208
+ f"{heads} heads.")
209
+ inner_dim = dim_head * heads
210
+ context_dim = default(context_dim, query_dim)
211
+
212
+ self.heads = heads
213
+ self.dim_head = dim_head
214
+
215
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
216
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
217
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
218
+
219
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
220
+ self.attention_op: Optional[Any] = None
221
+
222
+ def forward(self, x, context=None, mask=None):
223
+ q = self.to_q(x)
224
+ context = default(context, x)
225
+ k = self.to_k(context)
226
+ v = self.to_v(context)
227
+
228
+ b, _, _ = q.shape
229
+ q, k, v = map(
230
+ lambda t: t.unsqueeze(3)
231
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
232
+ .permute(0, 2, 1, 3)
233
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
234
+ .contiguous(),
235
+ (q, k, v),
236
+ )
237
+
238
+ # actually compute the attention, what we cannot get enough of
239
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
240
+
241
+ if exists(mask):
242
+ raise NotImplementedError
243
+ out = (
244
+ out.unsqueeze(0)
245
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
246
+ .permute(0, 2, 1, 3)
247
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
248
+ )
249
+ return self.to_out(out)
250
+
251
+ class BasicTransformerBlock(nn.Module):
252
+ ATTENTION_MODES = {
253
+ "softmax": CrossAttention, # vanilla attention
254
+ "softmax-xformers": MemoryEfficientCrossAttention
255
+ }
256
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
257
+ disable_self_attn=False):
258
+ super().__init__()
259
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
260
+ assert attn_mode in self.ATTENTION_MODES
261
+ attn_cls = self.ATTENTION_MODES[attn_mode]
262
+ self.disable_self_attn = disable_self_attn
263
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
264
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
265
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
266
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
267
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
268
+ self.norm1 = nn.LayerNorm(dim)
269
+ self.norm2 = nn.LayerNorm(dim)
270
+ self.norm3 = nn.LayerNorm(dim)
271
+ self.checkpoint = checkpoint
272
+
273
+ def forward(self, x, context=None):
274
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
275
+
276
+ def _forward(self, x, context=None):
277
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
278
+ x = self.attn2(self.norm2(x), context=context) + x
279
+ x = self.ff(self.norm3(x)) + x
280
+ return x
281
+
282
+
283
+ class SpatialTransformer(nn.Module):
284
+ """
285
+ Transformer block for image-like data.
286
+ First, project the input (aka embedding)
287
+ and reshape to b, t, d.
288
+ Then apply standard transformer action.
289
+ Finally, reshape to image
290
+ NEW: use_linear for more efficiency instead of the 1x1 convs
291
+ """
292
+ def __init__(self, in_channels, n_heads, d_head,
293
+ depth=1, dropout=0., context_dim=None,
294
+ disable_self_attn=False, use_linear=False,
295
+ use_checkpoint=True):
296
+ super().__init__()
297
+ if exists(context_dim) and not isinstance(context_dim, list):
298
+ context_dim = [context_dim]
299
+ self.in_channels = in_channels
300
+ inner_dim = n_heads * d_head
301
+ self.norm = Normalize(in_channels)
302
+ if not use_linear:
303
+ self.proj_in = nn.Conv2d(in_channels,
304
+ inner_dim,
305
+ kernel_size=1,
306
+ stride=1,
307
+ padding=0)
308
+ else:
309
+ self.proj_in = nn.Linear(in_channels, inner_dim)
310
+
311
+ self.transformer_blocks = nn.ModuleList(
312
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
313
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
314
+ for d in range(depth)]
315
+ )
316
+ if not use_linear:
317
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
318
+ in_channels,
319
+ kernel_size=1,
320
+ stride=1,
321
+ padding=0))
322
+ else:
323
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
324
+ self.use_linear = use_linear
325
+
326
+ def forward(self, x, context=None):
327
+ # note: if no context is given, cross-attention defaults to self-attention
328
+ if not isinstance(context, list):
329
+ context = [context]
330
+ b, c, h, w = x.shape
331
+ x_in = x
332
+ x = self.norm(x)
333
+ if not self.use_linear:
334
+ x = self.proj_in(x)
335
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
336
+ if self.use_linear:
337
+ x = self.proj_in(x)
338
+ for i, block in enumerate(self.transformer_blocks):
339
+ x = block(x, context=context[i])
340
+ if self.use_linear:
341
+ x = self.proj_out(x)
342
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
343
+ if not self.use_linear:
344
+ x = self.proj_out(x)
345
+ return x + x_in
ldm/modules/diffusionmodules/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (158 Bytes). View file
 
ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc ADDED
Binary file (20.7 kB). View file
 
ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc ADDED
Binary file (16.8 kB). View file
 
ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc ADDED
Binary file (8.96 kB). View file
 
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ from ldm.util import instantiate_from_config
9
+ from ldm.modules.attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x*torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
40
+
41
+
42
+ class Upsample(nn.Module):
43
+ def __init__(self, in_channels, with_conv):
44
+ super().__init__()
45
+ self.with_conv = with_conv
46
+ if self.with_conv:
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=1,
51
+ padding=1)
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels,
67
+ in_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=0)
71
+
72
+ def forward(self, x):
73
+ if self.with_conv:
74
+ pad = (0,1,0,1)
75
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
+ x = self.conv(x)
77
+ else:
78
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
+ return x
80
+
81
+
82
+ class ResnetBlock(nn.Module):
83
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84
+ dropout, temb_channels=512):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Normalize(in_channels)
92
+ self.conv1 = torch.nn.Conv2d(in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1)
97
+ if temb_channels > 0:
98
+ self.temb_proj = torch.nn.Linear(temb_channels,
99
+ out_channels)
100
+ self.norm2 = Normalize(out_channels)
101
+ self.dropout = torch.nn.Dropout(dropout)
102
+ self.conv2 = torch.nn.Conv2d(out_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if self.in_channels != self.out_channels:
108
+ if self.use_conv_shortcut:
109
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1)
114
+ else:
115
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
116
+ out_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+
121
+ def forward(self, x, temb):
122
+ h = x
123
+ h = self.norm1(h)
124
+ h = nonlinearity(h)
125
+ h = self.conv1(h)
126
+
127
+ if temb is not None:
128
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
129
+
130
+ h = self.norm2(h)
131
+ h = nonlinearity(h)
132
+ h = self.dropout(h)
133
+ h = self.conv2(h)
134
+
135
+ if self.in_channels != self.out_channels:
136
+ if self.use_conv_shortcut:
137
+ x = self.conv_shortcut(x)
138
+ else:
139
+ x = self.nin_shortcut(x)
140
+
141
+ return x+h
142
+
143
+
144
+ class LinAttnBlock(LinearAttention):
145
+ """to match AttnBlock usage"""
146
+ def __init__(self, in_channels):
147
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
148
+
149
+
150
+ class AttnBlock(nn.Module):
151
+ def __init__(self, in_channels):
152
+ super().__init__()
153
+ self.in_channels = in_channels
154
+
155
+ self.norm = Normalize(in_channels)
156
+ self.q = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.k = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+ self.v = torch.nn.Conv2d(in_channels,
167
+ in_channels,
168
+ kernel_size=1,
169
+ stride=1,
170
+ padding=0)
171
+ self.proj_out = torch.nn.Conv2d(in_channels,
172
+ in_channels,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0)
176
+
177
+
178
+ def forward(self, x):
179
+ h_ = x
180
+ h_ = self.norm(h_)
181
+ q = self.q(h_)
182
+ k = self.k(h_)
183
+ v = self.v(h_)
184
+
185
+ # compute attention
186
+ b,c,h,w = q.shape
187
+ q = q.reshape(b,c,h*w)
188
+ q = q.permute(0,2,1) # b,hw,c
189
+ k = k.reshape(b,c,h*w) # b,c,hw
190
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
191
+ w_ = w_ * (int(c)**(-0.5))
192
+ w_ = torch.nn.functional.softmax(w_, dim=2)
193
+
194
+ # attend to values
195
+ v = v.reshape(b,c,h*w)
196
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
197
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
198
+ h_ = h_.reshape(b,c,h,w)
199
+
200
+ h_ = self.proj_out(h_)
201
+
202
+ return x+h_
203
+
204
+
205
+ def make_attn(in_channels, attn_type="vanilla"):
206
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
207
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
208
+ if attn_type == "vanilla":
209
+ return AttnBlock(in_channels)
210
+ elif attn_type == "none":
211
+ return nn.Identity(in_channels)
212
+ else:
213
+ return LinAttnBlock(in_channels)
214
+
215
+
216
+ class Model(nn.Module):
217
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
218
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
219
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
220
+ super().__init__()
221
+ if use_linear_attn: attn_type = "linear"
222
+ self.ch = ch
223
+ self.temb_ch = self.ch*4
224
+ self.num_resolutions = len(ch_mult)
225
+ self.num_res_blocks = num_res_blocks
226
+ self.resolution = resolution
227
+ self.in_channels = in_channels
228
+
229
+ self.use_timestep = use_timestep
230
+ if self.use_timestep:
231
+ # timestep embedding
232
+ self.temb = nn.Module()
233
+ self.temb.dense = nn.ModuleList([
234
+ torch.nn.Linear(self.ch,
235
+ self.temb_ch),
236
+ torch.nn.Linear(self.temb_ch,
237
+ self.temb_ch),
238
+ ])
239
+
240
+ # downsampling
241
+ self.conv_in = torch.nn.Conv2d(in_channels,
242
+ self.ch,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ curr_res = resolution
248
+ in_ch_mult = (1,)+tuple(ch_mult)
249
+ self.down = nn.ModuleList()
250
+ for i_level in range(self.num_resolutions):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_in = ch*in_ch_mult[i_level]
254
+ block_out = ch*ch_mult[i_level]
255
+ for i_block in range(self.num_res_blocks):
256
+ block.append(ResnetBlock(in_channels=block_in,
257
+ out_channels=block_out,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout))
260
+ block_in = block_out
261
+ if curr_res in attn_resolutions:
262
+ attn.append(make_attn(block_in, attn_type=attn_type))
263
+ down = nn.Module()
264
+ down.block = block
265
+ down.attn = attn
266
+ if i_level != self.num_resolutions-1:
267
+ down.downsample = Downsample(block_in, resamp_with_conv)
268
+ curr_res = curr_res // 2
269
+ self.down.append(down)
270
+
271
+ # middle
272
+ self.mid = nn.Module()
273
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
274
+ out_channels=block_in,
275
+ temb_channels=self.temb_ch,
276
+ dropout=dropout)
277
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
278
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
279
+ out_channels=block_in,
280
+ temb_channels=self.temb_ch,
281
+ dropout=dropout)
282
+
283
+ # upsampling
284
+ self.up = nn.ModuleList()
285
+ for i_level in reversed(range(self.num_resolutions)):
286
+ block = nn.ModuleList()
287
+ attn = nn.ModuleList()
288
+ block_out = ch*ch_mult[i_level]
289
+ skip_in = ch*ch_mult[i_level]
290
+ for i_block in range(self.num_res_blocks+1):
291
+ if i_block == self.num_res_blocks:
292
+ skip_in = ch*in_ch_mult[i_level]
293
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
294
+ out_channels=block_out,
295
+ temb_channels=self.temb_ch,
296
+ dropout=dropout))
297
+ block_in = block_out
298
+ if curr_res in attn_resolutions:
299
+ attn.append(make_attn(block_in, attn_type=attn_type))
300
+ up = nn.Module()
301
+ up.block = block
302
+ up.attn = attn
303
+ if i_level != 0:
304
+ up.upsample = Upsample(block_in, resamp_with_conv)
305
+ curr_res = curr_res * 2
306
+ self.up.insert(0, up) # prepend to get consistent order
307
+
308
+ # end
309
+ self.norm_out = Normalize(block_in)
310
+ self.conv_out = torch.nn.Conv2d(block_in,
311
+ out_ch,
312
+ kernel_size=3,
313
+ stride=1,
314
+ padding=1)
315
+
316
+ def forward(self, x, t=None, context=None):
317
+ #assert x.shape[2] == x.shape[3] == self.resolution
318
+ if context is not None:
319
+ # assume aligned context, cat along channel axis
320
+ x = torch.cat((x, context), dim=1)
321
+ if self.use_timestep:
322
+ # timestep embedding
323
+ assert t is not None
324
+ temb = get_timestep_embedding(t, self.ch)
325
+ temb = self.temb.dense[0](temb)
326
+ temb = nonlinearity(temb)
327
+ temb = self.temb.dense[1](temb)
328
+ else:
329
+ temb = None
330
+
331
+ # downsampling
332
+ hs = [self.conv_in(x)]
333
+ for i_level in range(self.num_resolutions):
334
+ for i_block in range(self.num_res_blocks):
335
+ h = self.down[i_level].block[i_block](hs[-1], temb)
336
+ if len(self.down[i_level].attn) > 0:
337
+ h = self.down[i_level].attn[i_block](h)
338
+ hs.append(h)
339
+ if i_level != self.num_resolutions-1:
340
+ hs.append(self.down[i_level].downsample(hs[-1]))
341
+
342
+ # middle
343
+ h = hs[-1]
344
+ h = self.mid.block_1(h, temb)
345
+ h = self.mid.attn_1(h)
346
+ h = self.mid.block_2(h, temb)
347
+
348
+ # upsampling
349
+ for i_level in reversed(range(self.num_resolutions)):
350
+ for i_block in range(self.num_res_blocks+1):
351
+ h = self.up[i_level].block[i_block](
352
+ torch.cat([h, hs.pop()], dim=1), temb)
353
+ if len(self.up[i_level].attn) > 0:
354
+ h = self.up[i_level].attn[i_block](h)
355
+ if i_level != 0:
356
+ h = self.up[i_level].upsample(h)
357
+
358
+ # end
359
+ h = self.norm_out(h)
360
+ h = nonlinearity(h)
361
+ h = self.conv_out(h)
362
+ return h
363
+
364
+ def get_last_layer(self):
365
+ return self.conv_out.weight
366
+
367
+
368
+ class Encoder(nn.Module):
369
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
370
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
371
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
372
+ **ignore_kwargs):
373
+ super().__init__()
374
+ if use_linear_attn: attn_type = "linear"
375
+ self.ch = ch
376
+ self.temb_ch = 0
377
+ self.num_resolutions = len(ch_mult)
378
+ self.num_res_blocks = num_res_blocks
379
+ self.resolution = resolution
380
+ self.in_channels = in_channels
381
+
382
+ # downsampling
383
+ self.conv_in = torch.nn.Conv2d(in_channels,
384
+ self.ch,
385
+ kernel_size=3,
386
+ stride=1,
387
+ padding=1)
388
+
389
+ curr_res = resolution
390
+ in_ch_mult = (1,)+tuple(ch_mult)
391
+ self.in_ch_mult = in_ch_mult
392
+ self.down = nn.ModuleList()
393
+ for i_level in range(self.num_resolutions):
394
+ block = nn.ModuleList()
395
+ attn = nn.ModuleList()
396
+ block_in = ch*in_ch_mult[i_level]
397
+ block_out = ch*ch_mult[i_level]
398
+ for i_block in range(self.num_res_blocks):
399
+ block.append(ResnetBlock(in_channels=block_in,
400
+ out_channels=block_out,
401
+ temb_channels=self.temb_ch,
402
+ dropout=dropout))
403
+ block_in = block_out
404
+ if curr_res in attn_resolutions:
405
+ attn.append(make_attn(block_in, attn_type=attn_type))
406
+ down = nn.Module()
407
+ down.block = block
408
+ down.attn = attn
409
+ if i_level != self.num_resolutions-1:
410
+ down.downsample = Downsample(block_in, resamp_with_conv)
411
+ curr_res = curr_res // 2
412
+ self.down.append(down)
413
+
414
+ # middle
415
+ self.mid = nn.Module()
416
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
417
+ out_channels=block_in,
418
+ temb_channels=self.temb_ch,
419
+ dropout=dropout)
420
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
421
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
422
+ out_channels=block_in,
423
+ temb_channels=self.temb_ch,
424
+ dropout=dropout)
425
+
426
+ # end
427
+ self.norm_out = Normalize(block_in)
428
+ self.conv_out = torch.nn.Conv2d(block_in,
429
+ 2*z_channels if double_z else z_channels,
430
+ kernel_size=3,
431
+ stride=1,
432
+ padding=1)
433
+
434
+ def forward(self, x):
435
+ # timestep embedding
436
+ temb = None
437
+
438
+ # downsampling
439
+ hs = [self.conv_in(x)]
440
+ for i_level in range(self.num_resolutions):
441
+ for i_block in range(self.num_res_blocks):
442
+ h = self.down[i_level].block[i_block](hs[-1], temb)
443
+ if len(self.down[i_level].attn) > 0:
444
+ h = self.down[i_level].attn[i_block](h)
445
+ hs.append(h)
446
+ if i_level != self.num_resolutions-1:
447
+ hs.append(self.down[i_level].downsample(hs[-1]))
448
+
449
+ # middle
450
+ h = hs[-1]
451
+ h = self.mid.block_1(h, temb)
452
+ h = self.mid.attn_1(h)
453
+ h = self.mid.block_2(h, temb)
454
+
455
+ # end
456
+ h = self.norm_out(h)
457
+ h = nonlinearity(h)
458
+ h = self.conv_out(h)
459
+ return h
460
+
461
+
462
+ class Decoder(nn.Module):
463
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
464
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
465
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
466
+ attn_type="vanilla", **ignorekwargs):
467
+ super().__init__()
468
+ if use_linear_attn: attn_type = "linear"
469
+ self.ch = ch
470
+ self.temb_ch = 0
471
+ self.num_resolutions = len(ch_mult)
472
+ self.num_res_blocks = num_res_blocks
473
+ self.resolution = resolution
474
+ self.in_channels = in_channels
475
+ self.give_pre_end = give_pre_end
476
+ self.tanh_out = tanh_out
477
+
478
+ # compute in_ch_mult, block_in and curr_res at lowest res
479
+ in_ch_mult = (1,)+tuple(ch_mult)
480
+ block_in = ch*ch_mult[self.num_resolutions-1]
481
+ curr_res = resolution // 2**(self.num_resolutions-1)
482
+ self.z_shape = (1,z_channels,curr_res,curr_res)
483
+ print("Working with z of shape {} = {} dimensions.".format(
484
+ self.z_shape, np.prod(self.z_shape)))
485
+
486
+ # z to block_in
487
+ self.conv_in = torch.nn.Conv2d(z_channels,
488
+ block_in,
489
+ kernel_size=3,
490
+ stride=1,
491
+ padding=1)
492
+
493
+ # middle
494
+ self.mid = nn.Module()
495
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
496
+ out_channels=block_in,
497
+ temb_channels=self.temb_ch,
498
+ dropout=dropout)
499
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
500
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout)
504
+
505
+ # upsampling
506
+ self.up = nn.ModuleList()
507
+ for i_level in reversed(range(self.num_resolutions)):
508
+ block = nn.ModuleList()
509
+ attn = nn.ModuleList()
510
+ block_out = ch*ch_mult[i_level]
511
+ for i_block in range(self.num_res_blocks+1):
512
+ block.append(ResnetBlock(in_channels=block_in,
513
+ out_channels=block_out,
514
+ temb_channels=self.temb_ch,
515
+ dropout=dropout))
516
+ block_in = block_out
517
+ if curr_res in attn_resolutions:
518
+ attn.append(make_attn(block_in, attn_type=attn_type))
519
+ up = nn.Module()
520
+ up.block = block
521
+ up.attn = attn
522
+ if i_level != 0:
523
+ up.upsample = Upsample(block_in, resamp_with_conv)
524
+ curr_res = curr_res * 2
525
+ self.up.insert(0, up) # prepend to get consistent order
526
+
527
+ # end
528
+ self.norm_out = Normalize(block_in)
529
+ self.conv_out = torch.nn.Conv2d(block_in,
530
+ out_ch,
531
+ kernel_size=3,
532
+ stride=1,
533
+ padding=1)
534
+
535
+ def forward(self, z):
536
+ #assert z.shape[1:] == self.z_shape[1:]
537
+ self.last_z_shape = z.shape
538
+
539
+ # timestep embedding
540
+ temb = None
541
+
542
+ # z to block_in
543
+ h = self.conv_in(z)
544
+
545
+ # middle
546
+ h = self.mid.block_1(h, temb)
547
+ h = self.mid.attn_1(h)
548
+ h = self.mid.block_2(h, temb)
549
+
550
+ # upsampling
551
+ for i_level in reversed(range(self.num_resolutions)):
552
+ for i_block in range(self.num_res_blocks+1):
553
+ h = self.up[i_level].block[i_block](h, temb)
554
+ if len(self.up[i_level].attn) > 0:
555
+ h = self.up[i_level].attn[i_block](h)
556
+ if i_level != 0:
557
+ h = self.up[i_level].upsample(h)
558
+
559
+ # end
560
+ if self.give_pre_end:
561
+ return h
562
+
563
+ h = self.norm_out(h)
564
+ h = nonlinearity(h)
565
+ h = self.conv_out(h)
566
+ if self.tanh_out:
567
+ h = torch.tanh(h)
568
+ return h
569
+
570
+
571
+ class SimpleDecoder(nn.Module):
572
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
573
+ super().__init__()
574
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
575
+ ResnetBlock(in_channels=in_channels,
576
+ out_channels=2 * in_channels,
577
+ temb_channels=0, dropout=0.0),
578
+ ResnetBlock(in_channels=2 * in_channels,
579
+ out_channels=4 * in_channels,
580
+ temb_channels=0, dropout=0.0),
581
+ ResnetBlock(in_channels=4 * in_channels,
582
+ out_channels=2 * in_channels,
583
+ temb_channels=0, dropout=0.0),
584
+ nn.Conv2d(2*in_channels, in_channels, 1),
585
+ Upsample(in_channels, with_conv=True)])
586
+ # end
587
+ self.norm_out = Normalize(in_channels)
588
+ self.conv_out = torch.nn.Conv2d(in_channels,
589
+ out_channels,
590
+ kernel_size=3,
591
+ stride=1,
592
+ padding=1)
593
+
594
+ def forward(self, x):
595
+ for i, layer in enumerate(self.model):
596
+ if i in [1,2,3]:
597
+ x = layer(x, None)
598
+ else:
599
+ x = layer(x)
600
+
601
+ h = self.norm_out(x)
602
+ h = nonlinearity(h)
603
+ x = self.conv_out(h)
604
+ return x
605
+
606
+
607
+ class UpsampleDecoder(nn.Module):
608
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
609
+ ch_mult=(2,2), dropout=0.0):
610
+ super().__init__()
611
+ # upsampling
612
+ self.temb_ch = 0
613
+ self.num_resolutions = len(ch_mult)
614
+ self.num_res_blocks = num_res_blocks
615
+ block_in = in_channels
616
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
617
+ self.res_blocks = nn.ModuleList()
618
+ self.upsample_blocks = nn.ModuleList()
619
+ for i_level in range(self.num_resolutions):
620
+ res_block = []
621
+ block_out = ch * ch_mult[i_level]
622
+ for i_block in range(self.num_res_blocks + 1):
623
+ res_block.append(ResnetBlock(in_channels=block_in,
624
+ out_channels=block_out,
625
+ temb_channels=self.temb_ch,
626
+ dropout=dropout))
627
+ block_in = block_out
628
+ self.res_blocks.append(nn.ModuleList(res_block))
629
+ if i_level != self.num_resolutions - 1:
630
+ self.upsample_blocks.append(Upsample(block_in, True))
631
+ curr_res = curr_res * 2
632
+
633
+ # end
634
+ self.norm_out = Normalize(block_in)
635
+ self.conv_out = torch.nn.Conv2d(block_in,
636
+ out_channels,
637
+ kernel_size=3,
638
+ stride=1,
639
+ padding=1)
640
+
641
+ def forward(self, x):
642
+ # upsampling
643
+ h = x
644
+ for k, i_level in enumerate(range(self.num_resolutions)):
645
+ for i_block in range(self.num_res_blocks + 1):
646
+ h = self.res_blocks[i_level][i_block](h, None)
647
+ if i_level != self.num_resolutions - 1:
648
+ h = self.upsample_blocks[k](h)
649
+ h = self.norm_out(h)
650
+ h = nonlinearity(h)
651
+ h = self.conv_out(h)
652
+ return h
653
+
654
+
655
+ class LatentRescaler(nn.Module):
656
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
657
+ super().__init__()
658
+ # residual block, interpolate, residual block
659
+ self.factor = factor
660
+ self.conv_in = nn.Conv2d(in_channels,
661
+ mid_channels,
662
+ kernel_size=3,
663
+ stride=1,
664
+ padding=1)
665
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
666
+ out_channels=mid_channels,
667
+ temb_channels=0,
668
+ dropout=0.0) for _ in range(depth)])
669
+ self.attn = AttnBlock(mid_channels)
670
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
671
+ out_channels=mid_channels,
672
+ temb_channels=0,
673
+ dropout=0.0) for _ in range(depth)])
674
+
675
+ self.conv_out = nn.Conv2d(mid_channels,
676
+ out_channels,
677
+ kernel_size=1,
678
+ )
679
+
680
+ def forward(self, x):
681
+ x = self.conv_in(x)
682
+ for block in self.res_block1:
683
+ x = block(x, None)
684
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
685
+ x = self.attn(x)
686
+ for block in self.res_block2:
687
+ x = block(x, None)
688
+ x = self.conv_out(x)
689
+ return x
690
+
691
+
692
+ class MergedRescaleEncoder(nn.Module):
693
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
694
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
695
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
696
+ super().__init__()
697
+ intermediate_chn = ch * ch_mult[-1]
698
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
699
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
700
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
701
+ out_ch=None)
702
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
703
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
704
+
705
+ def forward(self, x):
706
+ x = self.encoder(x)
707
+ x = self.rescaler(x)
708
+ return x
709
+
710
+
711
+ class MergedRescaleDecoder(nn.Module):
712
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
713
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
714
+ super().__init__()
715
+ tmp_chn = z_channels*ch_mult[-1]
716
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
717
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
718
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
719
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
720
+ out_channels=tmp_chn, depth=rescale_module_depth)
721
+
722
+ def forward(self, x):
723
+ x = self.rescaler(x)
724
+ x = self.decoder(x)
725
+ return x
726
+
727
+
728
+ class Upsampler(nn.Module):
729
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
730
+ super().__init__()
731
+ assert out_size >= in_size
732
+ num_blocks = int(np.log2(out_size//in_size))+1
733
+ factor_up = 1.+ (out_size % in_size)
734
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
735
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
736
+ out_channels=in_channels)
737
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
738
+ attn_resolutions=[], in_channels=None, ch=in_channels,
739
+ ch_mult=[ch_mult for _ in range(num_blocks)])
740
+
741
+ def forward(self, x):
742
+ x = self.rescaler(x)
743
+ x = self.decoder(x)
744
+ return x
745
+
746
+
747
+ class Resize(nn.Module):
748
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
749
+ super().__init__()
750
+ self.with_conv = learned
751
+ self.mode = mode
752
+ if self.with_conv:
753
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
754
+ raise NotImplementedError()
755
+ assert in_channels is not None
756
+ # no asymmetric padding in torch conv, must do it ourselves
757
+ self.conv = torch.nn.Conv2d(in_channels,
758
+ in_channels,
759
+ kernel_size=4,
760
+ stride=2,
761
+ padding=1)
762
+
763
+ def forward(self, x, scale_factor=1.0):
764
+ if scale_factor==1.0:
765
+ return x
766
+ else:
767
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
768
+ return x
769
+
770
+ class FirstStagePostProcessor(nn.Module):
771
+
772
+ def __init__(self, ch_mult:list, in_channels,
773
+ pretrained_model:nn.Module=None,
774
+ reshape=False,
775
+ n_channels=None,
776
+ dropout=0.,
777
+ pretrained_config=None):
778
+ super().__init__()
779
+ if pretrained_config is None:
780
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
781
+ self.pretrained_model = pretrained_model
782
+ else:
783
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
784
+ self.instantiate_pretrained(pretrained_config)
785
+
786
+ self.do_reshape = reshape
787
+
788
+ if n_channels is None:
789
+ n_channels = self.pretrained_model.encoder.ch
790
+
791
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
792
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
793
+ stride=1,padding=1)
794
+
795
+ blocks = []
796
+ downs = []
797
+ ch_in = n_channels
798
+ for m in ch_mult:
799
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
800
+ ch_in = m * n_channels
801
+ downs.append(Downsample(ch_in, with_conv=False))
802
+
803
+ self.model = nn.ModuleList(blocks)
804
+ self.downsampler = nn.ModuleList(downs)
805
+
806
+
807
+ def instantiate_pretrained(self, config):
808
+ model = instantiate_from_config(config)
809
+ self.pretrained_model = model.eval()
810
+ # self.pretrained_model.train = False
811
+ for param in self.pretrained_model.parameters():
812
+ param.requires_grad = False
813
+
814
+
815
+ @torch.no_grad()
816
+ def encode_with_pretrained(self,x):
817
+ c = self.pretrained_model.encode(x)
818
+ if isinstance(c, DiagonalGaussianDistribution):
819
+ c = c.mode()
820
+ return c
821
+
822
+ def forward(self,x):
823
+ z_fs = self.encode_with_pretrained(x)
824
+ z = self.proj_norm(z_fs)
825
+ z = self.proj(z)
826
+ z = nonlinearity(z)
827
+
828
+ for submodel, downmodel in zip(self.model,self.downsampler):
829
+ z = submodel(z,temb=None)
830
+ z = downmodel(z)
831
+
832
+ if self.do_reshape:
833
+ z = rearrange(z,'b c h w -> b (h w) c')
834
+ return z
835
+
ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ import math
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from ldm.modules.diffusionmodules.util import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+ from ldm.modules.attention import SpatialTransformer
21
+
22
+
23
+ # dummy replace
24
+ def convert_module_to_f16(x):
25
+ pass
26
+
27
+ def convert_module_to_f32(x):
28
+ pass
29
+
30
+
31
+ ## go
32
+ class AttentionPool2d(nn.Module):
33
+ """
34
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ spacial_dim: int,
40
+ embed_dim: int,
41
+ num_heads_channels: int,
42
+ output_dim: int = None,
43
+ ):
44
+ super().__init__()
45
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
+ self.num_heads = embed_dim // num_heads_channels
49
+ self.attention = QKVAttention(self.num_heads)
50
+
51
+ def forward(self, x):
52
+ b, c, *_spatial = x.shape
53
+ x = x.reshape(b, c, -1) # NC(HW)
54
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
+ x = self.qkv_proj(x)
57
+ x = self.attention(x)
58
+ x = self.c_proj(x)
59
+ return x[:, :, 0]
60
+
61
+
62
+ class TimestepBlock(nn.Module):
63
+ @abstractmethod
64
+ def forward(self, x, emb):
65
+ """
66
+ Apply the module to `x` given `emb` timestep embeddings.
67
+ """
68
+
69
+
70
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
71
+ def forward(self, x, emb, context=None):
72
+ for layer in self:
73
+ if isinstance(layer, TimestepBlock):
74
+ x = layer(x, emb)
75
+ elif isinstance(layer, SpatialTransformer):
76
+ x = layer(x, context)
77
+ else:
78
+ x = layer(x)
79
+ return x
80
+
81
+
82
+ class Upsample(nn.Module):
83
+ """
84
+ An upsampling layer with an optional convolution.
85
+ :param channels: channels in the inputs and outputs.
86
+ :param use_conv: a bool determining if a convolution is applied.
87
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
88
+ upsampling occurs in the inner-two dimensions.
89
+ """
90
+
91
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
92
+ super().__init__()
93
+ self.channels = channels
94
+ self.out_channels = out_channels or channels
95
+ self.use_conv = use_conv
96
+ self.dims = dims
97
+ if use_conv:
98
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
99
+
100
+ def forward(self, x):
101
+ assert x.shape[1] == self.channels
102
+ if self.dims == 3:
103
+ x = F.interpolate(
104
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
105
+ )
106
+ else:
107
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
108
+ if self.use_conv:
109
+ x = self.conv(x)
110
+ return x
111
+
112
+ class TransposedUpsample(nn.Module):
113
+ 'Learned 2x upsampling without padding'
114
+ def __init__(self, channels, out_channels=None, ks=5):
115
+ super().__init__()
116
+ self.channels = channels
117
+ self.out_channels = out_channels or channels
118
+
119
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
120
+
121
+ def forward(self,x):
122
+ return self.up(x)
123
+
124
+
125
+ class Downsample(nn.Module):
126
+ """
127
+ A downsampling layer with an optional convolution.
128
+ :param channels: channels in the inputs and outputs.
129
+ :param use_conv: a bool determining if a convolution is applied.
130
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
131
+ downsampling occurs in the inner-two dimensions.
132
+ """
133
+
134
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
135
+ super().__init__()
136
+ self.channels = channels
137
+ self.out_channels = out_channels or channels
138
+ self.use_conv = use_conv
139
+ self.dims = dims
140
+ stride = 2 if dims != 3 else (1, 2, 2)
141
+ if use_conv:
142
+ self.op = conv_nd(
143
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
144
+ )
145
+ else:
146
+ assert self.channels == self.out_channels
147
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
148
+
149
+ def forward(self, x):
150
+ assert x.shape[1] == self.channels
151
+ return self.op(x)
152
+
153
+
154
+ class ResBlock(TimestepBlock):
155
+ def __init__(
156
+ self,
157
+ channels,
158
+ emb_channels,
159
+ dropout,
160
+ out_channels=None,
161
+ use_conv=False,
162
+ use_scale_shift_norm=False,
163
+ dims=2,
164
+ use_checkpoint=False,
165
+ up=False,
166
+ down=False,
167
+ ):
168
+ super().__init__()
169
+ self.channels = channels
170
+ self.emb_channels = emb_channels
171
+ self.dropout = dropout
172
+ self.out_channels = out_channels or channels
173
+ self.use_conv = use_conv
174
+ self.use_checkpoint = use_checkpoint
175
+ self.use_scale_shift_norm = use_scale_shift_norm
176
+
177
+ self.in_layers = nn.Sequential(
178
+ normalization(channels),
179
+ nn.SiLU(),
180
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
181
+ )
182
+
183
+ self.updown = up or down
184
+
185
+ if up:
186
+ self.h_upd = Upsample(channels, False, dims)
187
+ self.x_upd = Upsample(channels, False, dims)
188
+ elif down:
189
+ self.h_upd = Downsample(channels, False, dims)
190
+ self.x_upd = Downsample(channels, False, dims)
191
+ else:
192
+ self.h_upd = self.x_upd = nn.Identity()
193
+
194
+ self.emb_layers = nn.Sequential(
195
+ nn.SiLU(),
196
+ linear(
197
+ emb_channels,
198
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
199
+ ),
200
+ )
201
+ self.out_layers = nn.Sequential(
202
+ normalization(self.out_channels),
203
+ nn.SiLU(),
204
+ nn.Dropout(p=dropout),
205
+ zero_module(
206
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
207
+ ),
208
+ )
209
+
210
+ if self.out_channels == channels:
211
+ self.skip_connection = nn.Identity()
212
+ elif use_conv:
213
+ self.skip_connection = conv_nd(
214
+ dims, channels, self.out_channels, 3, padding=1
215
+ )
216
+ else:
217
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
218
+
219
+ def forward(self, x, emb):
220
+ """
221
+ Apply the block to a Tensor, conditioned on a timestep embedding.
222
+ :param x: an [N x C x ...] Tensor of features.
223
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
224
+ :return: an [N x C x ...] Tensor of outputs.
225
+ """
226
+ return checkpoint(
227
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
228
+ )
229
+
230
+
231
+ def _forward(self, x, emb):
232
+ if self.updown:
233
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
234
+ h = in_rest(x)
235
+ h = self.h_upd(h)
236
+ x = self.x_upd(x)
237
+ h = in_conv(h)
238
+ else:
239
+ h = self.in_layers(x)
240
+ emb_out = self.emb_layers(emb).type(h.dtype)
241
+ while len(emb_out.shape) < len(h.shape):
242
+ emb_out = emb_out[..., None]
243
+ if self.use_scale_shift_norm:
244
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
245
+ scale, shift = th.chunk(emb_out, 2, dim=1)
246
+ h = out_norm(h) * (1 + scale) + shift
247
+ h = out_rest(h)
248
+ else:
249
+ h = h + emb_out
250
+ h = self.out_layers(h)
251
+ return self.skip_connection(x) + h
252
+
253
+
254
+ class My_ResBlock(TimestepBlock):
255
+ """
256
+ A residual block that can optionally change the number of channels.
257
+ :param channels: the number of input channels.
258
+ :param emb_channels: the number of timestep embedding channels.
259
+ :param dropout: the rate of dropout.
260
+ :param out_channels: if specified, the number of out channels.
261
+ :param use_conv: if True and out_channels is specified, use a spatial
262
+ convolution instead of a smaller 1x1 convolution to change the
263
+ channels in the skip connection.
264
+ :param dims: determines if the signal is 1D, 2D, or 3D.
265
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
266
+ :param up: if True, use this block for upsampling.
267
+ :param down: if True, use this block for downsampling.
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ channels,
273
+ emb_channels,
274
+ dropout,
275
+ out_channels=None,
276
+ use_conv=False,
277
+ use_scale_shift_norm=False,
278
+ dims=2,
279
+ use_checkpoint=False,
280
+ up=False,
281
+ down=False,
282
+ ):
283
+ super().__init__()
284
+ self.channels = channels
285
+ self.emb_channels = emb_channels
286
+ self.dropout = dropout
287
+ self.out_channels = out_channels or channels
288
+ self.use_conv = use_conv
289
+ self.use_checkpoint = use_checkpoint
290
+ self.use_scale_shift_norm = use_scale_shift_norm
291
+
292
+ self.in_layers = nn.Sequential(
293
+ normalization(channels),
294
+ nn.SiLU(),
295
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
296
+ )
297
+
298
+ self.updown = up or down
299
+
300
+ if up:
301
+ self.h_upd = Upsample(channels, False, dims)
302
+ self.x_upd = Upsample(channels, False, dims)
303
+ elif down:
304
+ self.h_upd = Downsample(channels, False, dims)
305
+ self.x_upd = Downsample(channels, False, dims)
306
+ else:
307
+ self.h_upd = self.x_upd = nn.Identity()
308
+
309
+ self.emb_layers = nn.Sequential(
310
+ nn.SiLU(),
311
+ linear(
312
+ emb_channels,
313
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
314
+ ),
315
+ )
316
+ self.out_layers = nn.Sequential(
317
+ normalization(self.out_channels),
318
+ nn.SiLU(),
319
+ nn.Dropout(p=dropout),
320
+ zero_module(
321
+ conv_nd(dims, self.out_channels, 4, 3, padding=1)
322
+ ),
323
+ )
324
+
325
+ if self.out_channels == channels:
326
+ self.skip_connection = nn.Identity()
327
+ elif use_conv:
328
+ self.skip_connection = conv_nd(
329
+ dims, channels, self.out_channels, 3, padding=1
330
+ )
331
+ else:
332
+ self.skip_connection = conv_nd(dims, channels, 4, 1)
333
+
334
+ def forward(self, x, emb):
335
+ """
336
+ Apply the block to a Tensor, conditioned on a timestep embedding.
337
+ :param x: an [N x C x ...] Tensor of features.
338
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
339
+ :return: an [N x C x ...] Tensor of outputs.
340
+ """
341
+ return checkpoint(
342
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
343
+ )
344
+
345
+
346
+ def _forward(self, x, emb):
347
+ if self.updown:
348
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
349
+ h = in_rest(x)
350
+ h = self.h_upd(h)
351
+ x = self.x_upd(x)
352
+ h = in_conv(h)
353
+ else:
354
+ h = self.in_layers(x)
355
+ emb_out = self.emb_layers(emb).type(h.dtype)
356
+ while len(emb_out.shape) < len(h.shape):
357
+ emb_out = emb_out[..., None]
358
+ if self.use_scale_shift_norm:
359
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
360
+ scale, shift = th.chunk(emb_out, 2, dim=1)
361
+ h = out_norm(h) * (1 + scale) + shift
362
+ h = out_rest(h)
363
+ else:
364
+ h = h + emb_out
365
+ h = self.out_layers(h)
366
+ return h
367
+
368
+
369
+ class AttentionBlock(nn.Module):
370
+ """
371
+ An attention block that allows spatial positions to attend to each other.
372
+ Originally ported from here, but adapted to the N-d case.
373
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ channels,
379
+ num_heads=1,
380
+ num_head_channels=-1,
381
+ use_checkpoint=False,
382
+ use_new_attention_order=False,
383
+ ):
384
+ super().__init__()
385
+ self.channels = channels
386
+ if num_head_channels == -1:
387
+ self.num_heads = num_heads
388
+ else:
389
+ assert (
390
+ channels % num_head_channels == 0
391
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
392
+ self.num_heads = channels // num_head_channels
393
+ self.use_checkpoint = use_checkpoint
394
+ self.norm = normalization(channels)
395
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
396
+ if use_new_attention_order:
397
+ # split qkv before split heads
398
+ self.attention = QKVAttention(self.num_heads)
399
+ else:
400
+ # split heads before split qkv
401
+ self.attention = QKVAttentionLegacy(self.num_heads)
402
+
403
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
404
+
405
+ def forward(self, x):
406
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
407
+ #return pt_checkpoint(self._forward, x) # pytorch
408
+
409
+ def _forward(self, x):
410
+ b, c, *spatial = x.shape
411
+ x = x.reshape(b, c, -1)
412
+ qkv = self.qkv(self.norm(x))
413
+ h = self.attention(qkv)
414
+ h = self.proj_out(h)
415
+ return (x + h).reshape(b, c, *spatial)
416
+
417
+
418
+ def count_flops_attn(model, _x, y):
419
+ """
420
+ A counter for the `thop` package to count the operations in an
421
+ attention operation.
422
+ Meant to be used like:
423
+ macs, params = thop.profile(
424
+ model,
425
+ inputs=(inputs, timestamps),
426
+ custom_ops={QKVAttention: QKVAttention.count_flops},
427
+ )
428
+ """
429
+ b, c, *spatial = y[0].shape
430
+ num_spatial = int(np.prod(spatial))
431
+ # We perform two matmuls with the same number of ops.
432
+ # The first computes the weight matrix, the second computes
433
+ # the combination of the value vectors.
434
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
435
+ model.total_ops += th.DoubleTensor([matmul_ops])
436
+
437
+
438
+ class QKVAttentionLegacy(nn.Module):
439
+ """
440
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
441
+ """
442
+
443
+ def __init__(self, n_heads):
444
+ super().__init__()
445
+ self.n_heads = n_heads
446
+
447
+ def forward(self, qkv):
448
+ """
449
+ Apply QKV attention.
450
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
451
+ :return: an [N x (H * C) x T] tensor after attention.
452
+ """
453
+ bs, width, length = qkv.shape
454
+ assert width % (3 * self.n_heads) == 0
455
+ ch = width // (3 * self.n_heads)
456
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
457
+ scale = 1 / math.sqrt(math.sqrt(ch))
458
+ weight = th.einsum(
459
+ "bct,bcs->bts", q * scale, k * scale
460
+ ) # More stable with f16 than dividing afterwards
461
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
462
+ a = th.einsum("bts,bcs->bct", weight, v)
463
+ return a.reshape(bs, -1, length)
464
+
465
+ @staticmethod
466
+ def count_flops(model, _x, y):
467
+ return count_flops_attn(model, _x, y)
468
+
469
+
470
+ class QKVAttention(nn.Module):
471
+ """
472
+ A module which performs QKV attention and splits in a different order.
473
+ """
474
+
475
+ def __init__(self, n_heads):
476
+ super().__init__()
477
+ self.n_heads = n_heads
478
+
479
+ def forward(self, qkv):
480
+ """
481
+ Apply QKV attention.
482
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
483
+ :return: an [N x (H * C) x T] tensor after attention.
484
+ """
485
+ bs, width, length = qkv.shape
486
+ assert width % (3 * self.n_heads) == 0
487
+ ch = width // (3 * self.n_heads)
488
+ q, k, v = qkv.chunk(3, dim=1)
489
+ scale = 1 / math.sqrt(math.sqrt(ch))
490
+ weight = th.einsum(
491
+ "bct,bcs->bts",
492
+ (q * scale).view(bs * self.n_heads, ch, length),
493
+ (k * scale).view(bs * self.n_heads, ch, length),
494
+ ) # More stable with f16 than dividing afterwards
495
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
496
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
497
+ return a.reshape(bs, -1, length)
498
+
499
+ @staticmethod
500
+ def count_flops(model, _x, y):
501
+ return count_flops_attn(model, _x, y)
502
+
503
+
504
+ class UNetModel(nn.Module):
505
+ def __init__(self,
506
+ image_size, # 32
507
+ in_channels, # 9
508
+ out_channels, # 4
509
+ model_channels, # 320
510
+ attention_resolutions, # [ 4, 2, 1 ]
511
+ num_res_blocks, # 2
512
+ channel_mult=(1, 2, 4, 8), # [ 1, 2, 4, 4 ]
513
+ num_heads=-1, # 8
514
+ use_spatial_transformer=False, # True
515
+ transformer_depth=1, # 1
516
+ context_dim=None, # 768
517
+ use_checkpoint=False, # True
518
+ legacy=True, # False
519
+ add_conv_in_front_of_unet=False, # False
520
+ dropout=0,
521
+ conv_resample=True,
522
+ dims=2,
523
+ num_classes=None,
524
+ num_head_channels=-1,
525
+ num_heads_upsample=-1,
526
+ use_scale_shift_norm=False):
527
+
528
+ super().__init__()
529
+
530
+ self.image_size = image_size # 32
531
+ self.in_channels = in_channels # 9
532
+ self.out_channels = out_channels # 4
533
+ self.model_channels = model_channels # 320
534
+ self.attention_resolutions = attention_resolutions # [4,2,1]
535
+ self.num_res_blocks = num_res_blocks # 2
536
+ self.channel_mult = channel_mult # [1,2,4,4]
537
+ num_heads_upsample = num_heads # 8
538
+ self.use_checkpoint = use_checkpoint # True
539
+ self.add_conv_in_front_of_unet=add_conv_in_front_of_unet # False
540
+ self.dropout = dropout # 0
541
+ self.conv_resample = conv_resample # True
542
+ self.num_classes = num_classes # None
543
+ self.num_heads = num_heads # 8
544
+ self.num_head_channels = num_head_channels # -1
545
+ self.num_heads_upsample = num_heads_upsample # -1
546
+ self.dtype = th.float32
547
+
548
+
549
+ # 时间编码器 320 -> 320*4 -> 320*4
550
+ time_embed_dim = model_channels * 4
551
+ self.time_embed = nn.Sequential(
552
+ linear(model_channels, time_embed_dim),
553
+ nn.SiLU(),
554
+ linear(time_embed_dim, time_embed_dim),
555
+ )
556
+
557
+
558
+ # 一阶段 self.input_blocks
559
+ self.input_blocks = nn.ModuleList(
560
+ [
561
+ TimestepEmbedSequential(
562
+ conv_nd(
563
+ dims, # 2
564
+ in_channels, # 9
565
+ model_channels, # 320
566
+ kernel_size=3, # 3
567
+ padding=1 # 1
568
+ )
569
+ )
570
+ ]
571
+ )
572
+
573
+ input_block_chans = [model_channels] # [320]
574
+ ch = model_channels # 320
575
+ ds = 1 # 1
576
+
577
+ for level, mult in enumerate(channel_mult): # [0,1,2,3], [1,2,4,4]
578
+ for _ in range(num_res_blocks): # 2
579
+ layers = [
580
+ ResBlock(
581
+ ch, # [1,1,1,2,2,4,4,4]*320
582
+ time_embed_dim, # 320*4
583
+ dropout, # 0
584
+ out_channels=mult * model_channels, # [1,1,2,2,4,4,4,4]*320
585
+ dims=dims, # 2
586
+ use_checkpoint=use_checkpoint, # True
587
+ use_scale_shift_norm=use_scale_shift_norm, # False
588
+ )
589
+ ]
590
+ ch = mult * model_channels # [1,1,2,2,4,4,4,4]*320
591
+ if ds in attention_resolutions: # 前6次小循环
592
+ dim_head = ch // num_heads # [1,1,2,2,4,4]*40
593
+ layers.append(
594
+ SpatialTransformer(
595
+ ch, # [1,1,2,2,4,4]*320
596
+ num_heads, # 8
597
+ dim_head, # [1,1,2,2,4,4]*40
598
+ depth=transformer_depth, # 1
599
+ context_dim=context_dim # 768
600
+ )
601
+ )
602
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
603
+ input_block_chans.append(ch) # [1,1,2,2,4,4,4,4]*320
604
+ if level != len(channel_mult) - 1: # 前3次大循环
605
+ out_ch = ch # [1,2,4]*320
606
+ self.input_blocks.append(
607
+ TimestepEmbedSequential(
608
+ Downsample(
609
+ ch, # [1,2,4]*320
610
+ conv_resample, # True
611
+ dims=dims, # 2
612
+ out_channels=out_ch # [1,2,4]*320
613
+ )
614
+ )
615
+ )
616
+ ch = out_ch # [1,2,4]*320
617
+ input_block_chans.append(ch) # [1,2,4]*320
618
+ ds *= 2 # 1 -> 2 -> 4 -> 8
619
+
620
+
621
+ # 二阶段 self.middle_block
622
+ dim_head = ch // num_heads # 1280 // 8
623
+
624
+ self.middle_block = TimestepEmbedSequential(
625
+ ResBlock(
626
+ ch, # 4*320
627
+ time_embed_dim, # 320
628
+ dropout, # 0
629
+ dims=dims, # 2
630
+ use_checkpoint=use_checkpoint, # True
631
+ use_scale_shift_norm=use_scale_shift_norm, # False
632
+ ),
633
+ SpatialTransformer(
634
+ ch, # 4*320
635
+ num_heads, # 8
636
+ dim_head, # 160
637
+ depth=transformer_depth, # 1
638
+ context_dim=context_dim # 768
639
+ ),
640
+ ResBlock(
641
+ ch, # 4*320
642
+ time_embed_dim, # 320
643
+ dropout, # 0
644
+ dims=dims, # 2
645
+ use_checkpoint=use_checkpoint, # True
646
+ use_scale_shift_norm=use_scale_shift_norm, # False
647
+ ),
648
+ )
649
+
650
+
651
+ # 三阶段 self.output_blocks
652
+ self.output_blocks = nn.ModuleList([])
653
+ for level, mult in list(enumerate(channel_mult))[::-1]: # [3,2,1,0], [4,4,2,1]
654
+ for i in range(num_res_blocks + 1): # 3
655
+ ich = input_block_chans.pop() # [4,4, 4,4,4, 2,2,2, 1,1,1, 1]*320
656
+ layers = [
657
+ ResBlock(
658
+ ch + ich, # [4,4,4,4,4,4,4,2,2,2,1,1]*320+ich
659
+ time_embed_dim, # 320
660
+ dropout, # 0
661
+ out_channels=model_channels*mult, # [4,4,4,4,4,4,2,2,2,1,1,1]*320
662
+ dims=dims, # 2
663
+ use_checkpoint=use_checkpoint, # True
664
+ use_scale_shift_norm=use_scale_shift_norm, # False
665
+ )
666
+ ]
667
+ ch = model_channels * mult # [4,4,4,4,4,4,2,2,2,1,1,1]*320
668
+ if ds in attention_resolutions: # 后三次大循环
669
+ dim_head = ch // num_heads
670
+ layers.append(
671
+ SpatialTransformer(
672
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
673
+ )
674
+ )
675
+ if level and i == num_res_blocks: # 前三次大循环中每次最后的小循环
676
+ out_ch = ch
677
+ layers.append(
678
+ Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
679
+ )
680
+ ds //= 2
681
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
682
+
683
+ # 四阶段 self.out
684
+ self.out = nn.Sequential(
685
+ normalization(ch),
686
+ nn.SiLU(),
687
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
688
+ )
689
+
690
+ def forward(self, x, timesteps=None, context=None):
691
+ hs = []
692
+
693
+ t_emb = timestep_embedding(timesteps, self.model_channels) # [N, 320]
694
+ emb = self.time_embed(t_emb) # [N, 320*4]
695
+ h = x.type(self.dtype) # 将 x 转换为 torch.float32
696
+
697
+ for module in self.input_blocks:
698
+ h = module(h, emb, context)
699
+ hs.append(h)
700
+ h = self.middle_block(h, emb, context)
701
+ for module in self.output_blocks:
702
+ h = th.cat([h, hs.pop()], dim=1)
703
+ h = module(h, emb, context)
704
+ h = h.type(x.dtype)
705
+
706
+ return self.out(h) # [N,4,H,W]
707
+
ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ steps_out = ddim_timesteps + 1
58
+ if verbose:
59
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
60
+ return steps_out
61
+
62
+
63
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64
+ # select alphas for computing the variance schedule
65
+ alphas = alphacums[ddim_timesteps]
66
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67
+
68
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
69
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70
+ if verbose:
71
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72
+ print(f'For the chosen value of eta, which is {eta}, '
73
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74
+ return sigmas, alphas, alphas_prev
75
+
76
+
77
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78
+ """
79
+ Create a beta schedule that discretizes the given alpha_t_bar function,
80
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
81
+ :param num_diffusion_timesteps: the number of betas to produce.
82
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83
+ produces the cumulative product of (1-beta) up to that
84
+ part of the diffusion process.
85
+ :param max_beta: the maximum beta to use; use values lower than 1 to
86
+ prevent singularities.
87
+ """
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93
+ return np.array(betas)
94
+
95
+
96
+ def extract_into_tensor(a, t, x_shape):
97
+ b, *_ = t.shape
98
+ out = a.gather(-1, t)
99
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100
+
101
+
102
+ def checkpoint(func, inputs, params, flag):
103
+ """
104
+ Evaluate a function without caching intermediate activations, allowing for
105
+ reduced memory at the expense of extra compute in the backward pass.
106
+ :param func: the function to evaluate.
107
+ :param inputs: the argument sequence to pass to `func`.
108
+ :param params: a sequence of parameters `func` depends on but does not
109
+ explicitly take as arguments.
110
+ :param flag: if False, disable gradient checkpointing.
111
+ """
112
+ if flag:
113
+ args = tuple(inputs) + tuple(params)
114
+ return CheckpointFunction.apply(func, len(inputs), *args)
115
+ else:
116
+ return func(*inputs)
117
+
118
+
119
+ class CheckpointFunction(torch.autograd.Function):
120
+ @staticmethod
121
+ def forward(ctx, run_function, length, *args):
122
+ ctx.run_function = run_function
123
+ ctx.input_tensors = list(args[:length])
124
+ ctx.input_params = list(args[length:])
125
+
126
+ with torch.no_grad():
127
+ output_tensors = ctx.run_function(*ctx.input_tensors)
128
+ return output_tensors
129
+
130
+ @staticmethod
131
+ def backward(ctx, *output_grads):
132
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133
+ with torch.enable_grad():
134
+ # Fixes a bug where the first op in run_function modifies the
135
+ # Tensor storage in place, which is not allowed for detach()'d
136
+ # Tensors.
137
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138
+ output_tensors = ctx.run_function(*shallow_copies)
139
+ input_grads = torch.autograd.grad(
140
+ output_tensors,
141
+ ctx.input_tensors + ctx.input_params,
142
+ output_grads,
143
+ allow_unused=True,
144
+ )
145
+ del ctx.input_tensors
146
+ del ctx.input_params
147
+ del output_tensors
148
+ return (None, None) + input_grads
149
+
150
+
151
+ def timestep_embedding(timesteps, dim, max_period=10000):
152
+ half = dim // 2 # 160
153
+ # 生成一个 shape 为 [160,] 的 torch,从 1 到 1/10000 指数衰减
154
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=timesteps.device)
155
+ # [1,2] * [1, 160] -> [2, 160]
156
+ args = timesteps[:, None].float() * freqs[None]
157
+ # [2, 320]
158
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
159
+ return embedding
160
+
161
+
162
+ def zero_module(module):
163
+ """
164
+ Zero out the parameters of a module and return it.
165
+ """
166
+ for p in module.parameters():
167
+ p.detach().zero_()
168
+ return module
169
+
170
+
171
+ def scale_module(module, scale):
172
+ """
173
+ Scale the parameters of a module and return it.
174
+ """
175
+ for p in module.parameters():
176
+ p.detach().mul_(scale)
177
+ return module
178
+
179
+
180
+ def mean_flat(tensor):
181
+ """
182
+ Take the mean over all non-batch dimensions.
183
+ """
184
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
185
+
186
+
187
+ def normalization(channels):
188
+ """
189
+ Make a standard normalization layer.
190
+ :param channels: number of input channels.
191
+ :return: an nn.Module for normalization.
192
+ """
193
+ return GroupNorm32(32, channels)
194
+
195
+
196
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
197
+ class SiLU(nn.Module):
198
+ def forward(self, x):
199
+ return x * torch.sigmoid(x)
200
+
201
+
202
+ class GroupNorm32(nn.GroupNorm):
203
+ def forward(self, x):
204
+ return super().forward(x.float()).type(x.dtype)
205
+
206
+ def conv_nd(dims, *args, **kwargs):
207
+ """
208
+ Create a 1D, 2D, or 3D convolution module.
209
+ """
210
+ if dims == 1:
211
+ return nn.Conv1d(*args, **kwargs)
212
+ elif dims == 2:
213
+ return nn.Conv2d(*args, **kwargs)
214
+ elif dims == 3:
215
+ return nn.Conv3d(*args, **kwargs)
216
+ raise ValueError(f"unsupported dimensions: {dims}")
217
+
218
+
219
+ def linear(*args, **kwargs):
220
+ """
221
+ Create a linear module.
222
+ """
223
+ return nn.Linear(*args, **kwargs)
224
+
225
+
226
+ def avg_pool_nd(dims, *args, **kwargs):
227
+ """
228
+ Create a 1D, 2D, or 3D average pooling module.
229
+ """
230
+ if dims == 1:
231
+ return nn.AvgPool1d(*args, **kwargs)
232
+ elif dims == 2:
233
+ return nn.AvgPool2d(*args, **kwargs)
234
+ elif dims == 3:
235
+ return nn.AvgPool3d(*args, **kwargs)
236
+ raise ValueError(f"unsupported dimensions: {dims}")
237
+
238
+
239
+ class HybridConditioner(nn.Module):
240
+
241
+ def __init__(self, c_concat_config, c_crossattn_config):
242
+ super().__init__()
243
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
244
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
245
+
246
+ def forward(self, c_concat, c_crossattn):
247
+ c_concat = self.concat_conditioner(c_concat)
248
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
249
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
250
+
251
+
252
+ def noise_like(shape, device, repeat=False):
253
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
254
+ noise = lambda: torch.randn(shape, device=device)
255
+ return repeat_noise() if repeat else noise()
ldm/modules/distributions/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ldm/modules/distributions/__init__.py ADDED
File without changes
ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (155 Bytes). View file
 
ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc ADDED
Binary file (3.8 kB). View file
 
ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
ldm/modules/ema.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1,dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ #remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.','')
20
+ self.m_name2s_name.update({name:s_name})
21
+ self.register_buffer(s_name,p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def forward(self,model):
26
+ decay = self.decay
27
+
28
+ if self.num_updates >= 0:
29
+ self.num_updates += 1
30
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31
+
32
+ one_minus_decay = 1.0 - decay
33
+
34
+ with torch.no_grad():
35
+ m_param = dict(model.named_parameters())
36
+ shadow_params = dict(self.named_buffers())
37
+
38
+ for key in m_param:
39
+ if m_param[key].requires_grad:
40
+ sname = self.m_name2s_name[key]
41
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43
+ else:
44
+ assert not key in self.m_name2s_name
45
+
46
+ def copy_to(self, model):
47
+ m_param = dict(model.named_parameters())
48
+ shadow_params = dict(self.named_buffers())
49
+ for key in m_param:
50
+ if m_param[key].requires_grad:
51
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52
+ else:
53
+ assert not key in self.m_name2s_name
54
+
55
+ def store(self, parameters):
56
+ """
57
+ Save the current parameters for restoring later.
58
+ Args:
59
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60
+ temporarily stored.
61
+ """
62
+ self.collected_params = [param.clone() for param in parameters]
63
+
64
+ def restore(self, parameters):
65
+ """
66
+ Restore the parameters stored with the `store` method.
67
+ Useful to validate the model with EMA parameters without affecting the
68
+ original optimization process. Store the parameters before the
69
+ `copy_to` method. After validation (or model saving), use this to
70
+ restore the former parameters.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ updated with the stored parameters.
74
+ """
75
+ for c_param, param in zip(self.collected_params, parameters):
76
+ param.data.copy_(c_param.data)
ldm/modules/encoders/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ldm/modules/encoders/__init__.py ADDED
File without changes
ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (150 Bytes). View file
 
ldm/modules/encoders/__pycache__/modules.cpython-38.pyc ADDED
Binary file (8.74 kB). View file