basso4 commited on
Commit
ee24757
·
verified ·
1 Parent(s): c38c888

Delete ldm

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ldm/__pycache__/util.cpython-38.pyc +0 -0
  2. ldm/data/.DS_Store +0 -0
  3. ldm/data/__init__.py +0 -0
  4. ldm/data/__pycache__/__init__.cpython-38.pyc +0 -0
  5. ldm/data/__pycache__/image_dresscode.cpython-38.pyc +0 -0
  6. ldm/data/__pycache__/image_vitonhd.cpython-38.pyc +0 -0
  7. ldm/data/__pycache__/viton-images.cpython-38.pyc +0 -0
  8. ldm/data/base.py +0 -23
  9. ldm/data/image_dresscode.py +0 -86
  10. ldm/data/image_vitonhd.py +0 -75
  11. ldm/data/imagenet.py +0 -394
  12. ldm/data/lsun.py +0 -92
  13. ldm/lr_scheduler.py +0 -81
  14. ldm/models/.DS_Store +0 -0
  15. ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  16. ldm/models/autoencoder.py +0 -408
  17. ldm/models/diffusion/.DS_Store +0 -0
  18. ldm/models/diffusion/__init__.py +0 -0
  19. ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  20. ldm/models/diffusion/__pycache__/control.cpython-38.pyc +0 -0
  21. ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc +0 -0
  22. ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc +0 -0
  23. ldm/models/diffusion/classifier.py +0 -267
  24. ldm/models/diffusion/control.py +0 -406
  25. ldm/models/diffusion/ddim.py +0 -265
  26. ldm/models/diffusion/ddpm.py +0 -143
  27. ldm/models/diffusion/plms.py +0 -239
  28. ldm/modules/.DS_Store +0 -0
  29. ldm/modules/__pycache__/attention.cpython-38.pyc +0 -0
  30. ldm/modules/__pycache__/x_transformer.cpython-38.pyc +0 -0
  31. ldm/modules/attention.py +0 -345
  32. ldm/modules/diffusionmodules/.DS_Store +0 -0
  33. ldm/modules/diffusionmodules/__init__.py +0 -0
  34. ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc +0 -0
  35. ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc +0 -0
  36. ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc +0 -0
  37. ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc +0 -0
  38. ldm/modules/diffusionmodules/model.py +0 -835
  39. ldm/modules/diffusionmodules/openaimodel.py +0 -707
  40. ldm/modules/diffusionmodules/util.py +0 -255
  41. ldm/modules/distributions/.DS_Store +0 -0
  42. ldm/modules/distributions/__init__.py +0 -0
  43. ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc +0 -0
  44. ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc +0 -0
  45. ldm/modules/distributions/distributions.py +0 -92
  46. ldm/modules/ema.py +0 -76
  47. ldm/modules/encoders/.DS_Store +0 -0
  48. ldm/modules/encoders/__init__.py +0 -0
  49. ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc +0 -0
  50. ldm/modules/encoders/__pycache__/modules.cpython-38.pyc +0 -0
ldm/__pycache__/util.cpython-38.pyc DELETED
Binary file (5.87 kB)
 
ldm/data/.DS_Store DELETED
Binary file (6.15 kB)
 
ldm/data/__init__.py DELETED
File without changes
ldm/data/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (138 Bytes)
 
ldm/data/__pycache__/image_dresscode.cpython-38.pyc DELETED
Binary file (2.52 kB)
 
ldm/data/__pycache__/image_vitonhd.cpython-38.pyc DELETED
Binary file (3.16 kB)
 
ldm/data/__pycache__/viton-images.cpython-38.pyc DELETED
Binary file (2.68 kB)
 
ldm/data/base.py DELETED
@@ -1,23 +0,0 @@
1
- from abc import abstractmethod
2
- from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3
-
4
-
5
- class Txt2ImgIterableBaseDataset(IterableDataset):
6
- '''
7
- Define an interface to make the IterableDatasets for text2img data chainable
8
- '''
9
- def __init__(self, num_records=0, valid_ids=None, size=256):
10
- super().__init__()
11
- self.num_records = num_records
12
- self.valid_ids = valid_ids
13
- self.sample_ids = valid_ids
14
- self.size = size
15
-
16
- print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17
-
18
- def __len__(self):
19
- return self.num_records
20
-
21
- @abstractmethod
22
- def __iter__(self):
23
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/data/image_dresscode.py DELETED
@@ -1,86 +0,0 @@
1
- import os
2
- import torch
3
- import torchvision
4
- import torch.utils.data as data
5
- import torchvision.transforms.functional as F
6
- import numpy as np
7
- from PIL import Image
8
-
9
- class OpenImageDataset(data.Dataset):
10
- def __init__(self, state, dataset_dir, type="paired"):
11
- self.state = state # train or test
12
- self.dataset_dir = dataset_dir # /home/sd/zjh/Dataset/DressCode
13
-
14
- # 确定状态
15
- if state == "train":
16
- self.dataset_file = os.path.join(dataset_dir, "train_pairs.txt")
17
- if state == "test":
18
- if type == "unpaired":
19
- self.dataset_file = os.path.join(dataset_dir, "test_pairs_unpaired.txt")
20
- if type == "paired":
21
- self.dataset_file = os.path.join(dataset_dir, "test_pairs_paired.txt")
22
-
23
- # 加载数据集
24
- self.people_list = []
25
- self.clothes_list = []
26
- with open(self.dataset_file, 'r') as f:
27
- for line in f.readlines():
28
- people, clothes, category = line.strip().split()
29
- if category == "0":
30
- category = "upper_body"
31
- elif category == "1":
32
- category = "lower_body"
33
- elif category == "2":
34
- category = "dresses"
35
- people_path = os.path.join(self.dataset_dir, category, "images", people)
36
- clothes_path = os.path.join(self.dataset_dir, category, "images", clothes)
37
- self.people_list.append(people_path)
38
- self.clothes_list.append(clothes_path)
39
-
40
-
41
- def __len__(self):
42
- return len(self.people_list)
43
-
44
- def __getitem__(self, index):
45
- people_path = self.people_list[index]
46
- # /home/sd/zjh/Dataset/DressCode/upper_body/images/000000_0.jpg
47
- clothes_path = self.clothes_list[index]
48
- # /home/sd/zjh/Dataset/DressCode/upper_body/images/000000_1.jpg
49
- dense_path = people_path.replace("images", "dense")[:-5] + "5_uv.npz"
50
- # /home/sd/zjh/Dataset/DressCode/upper_body/dense/000000_5_uv.npz
51
- mask_path = people_path.replace("images", "mask")[:-3] + "png"
52
- # /home/sd/Harddisk/zjh/DressCode/upper_body/mask/000000_0.png
53
-
54
- # 加载图像
55
- img = Image.open(people_path).convert("RGB").resize((512, 512))
56
- img = torchvision.transforms.ToTensor()(img)
57
- refernce = Image.open(clothes_path).convert("RGB").resize((224, 224))
58
- refernce = torchvision.transforms.ToTensor()(refernce)
59
- mask = Image.open(mask_path).convert("L").resize((512, 512))
60
- mask = torchvision.transforms.ToTensor()(mask)
61
- mask = 1-mask
62
- densepose = np.load(dense_path)
63
- densepose = torch.from_numpy(densepose['uv'])
64
- densepose = torch.nn.functional.interpolate(densepose.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=True).squeeze(0)
65
-
66
- # 正则化
67
- img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
68
- refernce = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
69
- (0.26862954, 0.26130258, 0.27577711))(refernce)
70
- # densepose = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(densepose)
71
-
72
- # 生成 inpaint 和 hint
73
- inpaint = img * mask
74
- hint = torchvision.transforms.Resize((512, 512))(refernce)
75
- hint = torch.cat((hint,densepose),dim = 0)
76
-
77
-
78
- return {"GT": img, # [3, 512, 512]
79
- "inpaint_image": inpaint, # [3, 512, 512]
80
- "inpaint_mask": mask, # [1, 512, 512]
81
- "ref_imgs": refernce, # [3, 224, 224]
82
- "hint": hint # [5, 512, 512]
83
- }
84
-
85
-
86
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/data/image_vitonhd.py DELETED
@@ -1,75 +0,0 @@
1
- import os
2
- import torch
3
- import torchvision
4
- import torch.utils.data as data
5
- import torchvision.transforms.functional as F
6
- from PIL import Image
7
-
8
- class OpenImageDataset(data.Dataset):
9
- def __init__(self, state, dataset_dir, type="paired"):
10
- self.state=state
11
- self.dataset_dir = dataset_dir
12
- self.dataset_list = []
13
-
14
- if state == "train":
15
- self.dataset_file = os.path.join(dataset_dir, "train_pairs.txt")
16
- with open(self.dataset_file, 'r') as f:
17
- for line in f.readlines():
18
- person, garment = line.strip().split()
19
- self.dataset_list.append([person, person])
20
-
21
- if state == "test":
22
- self.dataset_file = os.path.join(dataset_dir, "test_pairs.txt")
23
- if type == "unpaired":
24
- with open(self.dataset_file, 'r') as f:
25
- for line in f.readlines():
26
- person, garment = line.strip().split()
27
- self.dataset_list.append([person, garment])
28
-
29
- if type == "paired":
30
- with open(self.dataset_file, 'r') as f:
31
- for line in f.readlines():
32
- person, garment = line.strip().split()
33
- self.dataset_list.append([person, person])
34
-
35
- def __len__(self):
36
- return len(self.dataset_list)
37
-
38
- def __getitem__(self, index):
39
-
40
- person, garment = self.dataset_list[index]
41
-
42
- # 确定路径
43
- img_path = os.path.join(self.dataset_dir, self.state, "image", person)
44
- reference_path = os.path.join(self.dataset_dir, self.state, "cloth", garment)
45
- mask_path = os.path.join(self.dataset_dir, self.state, "mask", person[:-4]+".png")
46
- densepose_path = os.path.join(self.dataset_dir, self.state, "image-densepose", person)
47
-
48
- # 加载图像
49
- img = Image.open(img_path).convert("RGB").resize((512, 512))
50
- img = torchvision.transforms.ToTensor()(img)
51
- refernce = Image.open(reference_path).convert("RGB").resize((224, 224))
52
- refernce = torchvision.transforms.ToTensor()(refernce)
53
- mask = Image.open(mask_path).convert("L").resize((512, 512))
54
- mask = torchvision.transforms.ToTensor()(mask)
55
- mask = 1-mask
56
- densepose = Image.open(densepose_path).convert("RGB").resize((512, 512))
57
- densepose = torchvision.transforms.ToTensor()(densepose)
58
-
59
- # 正则化
60
- img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img)
61
- refernce = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
62
- (0.26862954, 0.26130258, 0.27577711))(refernce)
63
- densepose = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(densepose)
64
-
65
- # 生成 inpaint 和 hint
66
- inpaint = img * mask
67
- hint = torchvision.transforms.Resize((512, 512))(refernce)
68
- hint = torch.cat((hint,densepose),dim = 0)
69
-
70
- return {"GT": img, # [3, 512, 512]
71
- "inpaint_image": inpaint, # [3, 512, 512]
72
- "inpaint_mask": mask, # [1, 512, 512]
73
- "ref_imgs": refernce, # [3, 224, 224]
74
- "hint": hint, # [6, 512, 512]
75
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/data/imagenet.py DELETED
@@ -1,394 +0,0 @@
1
- import os, yaml, pickle, shutil, tarfile, glob
2
- import cv2
3
- import albumentations
4
- import PIL
5
- import numpy as np
6
- import torchvision.transforms.functional as TF
7
- from omegaconf import OmegaConf
8
- from functools import partial
9
- from PIL import Image
10
- from tqdm import tqdm
11
- from torch.utils.data import Dataset, Subset
12
-
13
- import taming.data.utils as tdu
14
- from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
- from taming.data.imagenet import ImagePaths
16
-
17
- from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
-
19
-
20
- def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
- with open(path_to_yaml) as f:
22
- di2s = yaml.load(f)
23
- return dict((v,k) for k,v in di2s.items())
24
-
25
-
26
- class ImageNetBase(Dataset):
27
- def __init__(self, config=None):
28
- self.config = config or OmegaConf.create()
29
- if not type(self.config)==dict:
30
- self.config = OmegaConf.to_container(self.config)
31
- self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
- self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
- self._prepare()
34
- self._prepare_synset_to_human()
35
- self._prepare_idx_to_synset()
36
- self._prepare_human_to_integer_label()
37
- self._load()
38
-
39
- def __len__(self):
40
- return len(self.data)
41
-
42
- def __getitem__(self, i):
43
- return self.data[i]
44
-
45
- def _prepare(self):
46
- raise NotImplementedError()
47
-
48
- def _filter_relpaths(self, relpaths):
49
- ignore = set([
50
- "n06596364_9591.JPEG",
51
- ])
52
- relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
- if "sub_indices" in self.config:
54
- indices = str_to_indices(self.config["sub_indices"])
55
- synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
- self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
- files = []
58
- for rpath in relpaths:
59
- syn = rpath.split("/")[0]
60
- if syn in synsets:
61
- files.append(rpath)
62
- return files
63
- else:
64
- return relpaths
65
-
66
- def _prepare_synset_to_human(self):
67
- SIZE = 2655750
68
- URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
- self.human_dict = os.path.join(self.root, "synset_human.txt")
70
- if (not os.path.exists(self.human_dict) or
71
- not os.path.getsize(self.human_dict)==SIZE):
72
- download(URL, self.human_dict)
73
-
74
- def _prepare_idx_to_synset(self):
75
- URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
- self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
- if (not os.path.exists(self.idx2syn)):
78
- download(URL, self.idx2syn)
79
-
80
- def _prepare_human_to_integer_label(self):
81
- URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
- self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
- if (not os.path.exists(self.human2integer)):
84
- download(URL, self.human2integer)
85
- with open(self.human2integer, "r") as f:
86
- lines = f.read().splitlines()
87
- assert len(lines) == 1000
88
- self.human2integer_dict = dict()
89
- for line in lines:
90
- value, key = line.split(":")
91
- self.human2integer_dict[key] = int(value)
92
-
93
- def _load(self):
94
- with open(self.txt_filelist, "r") as f:
95
- self.relpaths = f.read().splitlines()
96
- l1 = len(self.relpaths)
97
- self.relpaths = self._filter_relpaths(self.relpaths)
98
- print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
-
100
- self.synsets = [p.split("/")[0] for p in self.relpaths]
101
- self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
-
103
- unique_synsets = np.unique(self.synsets)
104
- class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
- if not self.keep_orig_class_label:
106
- self.class_labels = [class_dict[s] for s in self.synsets]
107
- else:
108
- self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
-
110
- with open(self.human_dict, "r") as f:
111
- human_dict = f.read().splitlines()
112
- human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
-
114
- self.human_labels = [human_dict[s] for s in self.synsets]
115
-
116
- labels = {
117
- "relpath": np.array(self.relpaths),
118
- "synsets": np.array(self.synsets),
119
- "class_label": np.array(self.class_labels),
120
- "human_label": np.array(self.human_labels),
121
- }
122
-
123
- if self.process_images:
124
- self.size = retrieve(self.config, "size", default=256)
125
- self.data = ImagePaths(self.abspaths,
126
- labels=labels,
127
- size=self.size,
128
- random_crop=self.random_crop,
129
- )
130
- else:
131
- self.data = self.abspaths
132
-
133
-
134
- class ImageNetTrain(ImageNetBase):
135
- NAME = "ILSVRC2012_train"
136
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
- AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
- FILES = [
139
- "ILSVRC2012_img_train.tar",
140
- ]
141
- SIZES = [
142
- 147897477120,
143
- ]
144
-
145
- def __init__(self, process_images=True, data_root=None, **kwargs):
146
- self.process_images = process_images
147
- self.data_root = data_root
148
- super().__init__(**kwargs)
149
-
150
- def _prepare(self):
151
- if self.data_root:
152
- self.root = os.path.join(self.data_root, self.NAME)
153
- else:
154
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
-
157
- self.datadir = os.path.join(self.root, "data")
158
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
- self.expected_length = 1281167
160
- self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
- default=True)
162
- if not tdu.is_prepared(self.root):
163
- # prep
164
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
-
166
- datadir = self.datadir
167
- if not os.path.exists(datadir):
168
- path = os.path.join(self.root, self.FILES[0])
169
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
- import academictorrents as at
171
- atpath = at.get(self.AT_HASH, datastore=self.root)
172
- assert atpath == path
173
-
174
- print("Extracting {} to {}".format(path, datadir))
175
- os.makedirs(datadir, exist_ok=True)
176
- with tarfile.open(path, "r:") as tar:
177
- tar.extractall(path=datadir)
178
-
179
- print("Extracting sub-tars.")
180
- subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
- for subpath in tqdm(subpaths):
182
- subdir = subpath[:-len(".tar")]
183
- os.makedirs(subdir, exist_ok=True)
184
- with tarfile.open(subpath, "r:") as tar:
185
- tar.extractall(path=subdir)
186
-
187
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
- filelist = sorted(filelist)
190
- filelist = "\n".join(filelist)+"\n"
191
- with open(self.txt_filelist, "w") as f:
192
- f.write(filelist)
193
-
194
- tdu.mark_prepared(self.root)
195
-
196
-
197
- class ImageNetValidation(ImageNetBase):
198
- NAME = "ILSVRC2012_validation"
199
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
- AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
- VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
- FILES = [
203
- "ILSVRC2012_img_val.tar",
204
- "validation_synset.txt",
205
- ]
206
- SIZES = [
207
- 6744924160,
208
- 1950000,
209
- ]
210
-
211
- def __init__(self, process_images=True, data_root=None, **kwargs):
212
- self.data_root = data_root
213
- self.process_images = process_images
214
- super().__init__(**kwargs)
215
-
216
- def _prepare(self):
217
- if self.data_root:
218
- self.root = os.path.join(self.data_root, self.NAME)
219
- else:
220
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
- self.datadir = os.path.join(self.root, "data")
223
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
- self.expected_length = 50000
225
- self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
- default=False)
227
- if not tdu.is_prepared(self.root):
228
- # prep
229
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
-
231
- datadir = self.datadir
232
- if not os.path.exists(datadir):
233
- path = os.path.join(self.root, self.FILES[0])
234
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
- import academictorrents as at
236
- atpath = at.get(self.AT_HASH, datastore=self.root)
237
- assert atpath == path
238
-
239
- print("Extracting {} to {}".format(path, datadir))
240
- os.makedirs(datadir, exist_ok=True)
241
- with tarfile.open(path, "r:") as tar:
242
- tar.extractall(path=datadir)
243
-
244
- vspath = os.path.join(self.root, self.FILES[1])
245
- if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
- download(self.VS_URL, vspath)
247
-
248
- with open(vspath, "r") as f:
249
- synset_dict = f.read().splitlines()
250
- synset_dict = dict(line.split() for line in synset_dict)
251
-
252
- print("Reorganizing into synset folders")
253
- synsets = np.unique(list(synset_dict.values()))
254
- for s in synsets:
255
- os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
- for k, v in synset_dict.items():
257
- src = os.path.join(datadir, k)
258
- dst = os.path.join(datadir, v)
259
- shutil.move(src, dst)
260
-
261
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
- filelist = sorted(filelist)
264
- filelist = "\n".join(filelist)+"\n"
265
- with open(self.txt_filelist, "w") as f:
266
- f.write(filelist)
267
-
268
- tdu.mark_prepared(self.root)
269
-
270
-
271
-
272
- class ImageNetSR(Dataset):
273
- def __init__(self, size=None,
274
- degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
- random_crop=True):
276
- """
277
- Imagenet Superresolution Dataloader
278
- Performs following ops in order:
279
- 1. crops a crop of size s from image either as random or center crop
280
- 2. resizes crop to size with cv2.area_interpolation
281
- 3. degrades resized crop with degradation_fn
282
-
283
- :param size: resizing to size after cropping
284
- :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
- :param downscale_f: Low Resolution Downsample factor
286
- :param min_crop_f: determines crop size s,
287
- where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
- :param max_crop_f: ""
289
- :param data_root:
290
- :param random_crop:
291
- """
292
- self.base = self.get_base()
293
- assert size
294
- assert (size / downscale_f).is_integer()
295
- self.size = size
296
- self.LR_size = int(size / downscale_f)
297
- self.min_crop_f = min_crop_f
298
- self.max_crop_f = max_crop_f
299
- assert(max_crop_f <= 1.)
300
- self.center_crop = not random_crop
301
-
302
- self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
-
304
- self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
-
306
- if degradation == "bsrgan":
307
- self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
-
309
- elif degradation == "bsrgan_light":
310
- self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
-
312
- else:
313
- interpolation_fn = {
314
- "cv_nearest": cv2.INTER_NEAREST,
315
- "cv_bilinear": cv2.INTER_LINEAR,
316
- "cv_bicubic": cv2.INTER_CUBIC,
317
- "cv_area": cv2.INTER_AREA,
318
- "cv_lanczos": cv2.INTER_LANCZOS4,
319
- "pil_nearest": PIL.Image.NEAREST,
320
- "pil_bilinear": PIL.Image.BILINEAR,
321
- "pil_bicubic": PIL.Image.BICUBIC,
322
- "pil_box": PIL.Image.BOX,
323
- "pil_hamming": PIL.Image.HAMMING,
324
- "pil_lanczos": PIL.Image.LANCZOS,
325
- }[degradation]
326
-
327
- self.pil_interpolation = degradation.startswith("pil_")
328
-
329
- if self.pil_interpolation:
330
- self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
-
332
- else:
333
- self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
- interpolation=interpolation_fn)
335
-
336
- def __len__(self):
337
- return len(self.base)
338
-
339
- def __getitem__(self, i):
340
- example = self.base[i]
341
- image = Image.open(example["file_path_"])
342
-
343
- if not image.mode == "RGB":
344
- image = image.convert("RGB")
345
-
346
- image = np.array(image).astype(np.uint8)
347
-
348
- min_side_len = min(image.shape[:2])
349
- crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
- crop_side_len = int(crop_side_len)
351
-
352
- if self.center_crop:
353
- self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
-
355
- else:
356
- self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
-
358
- image = self.cropper(image=image)["image"]
359
- image = self.image_rescaler(image=image)["image"]
360
-
361
- if self.pil_interpolation:
362
- image_pil = PIL.Image.fromarray(image)
363
- LR_image = self.degradation_process(image_pil)
364
- LR_image = np.array(LR_image).astype(np.uint8)
365
-
366
- else:
367
- LR_image = self.degradation_process(image=image)["image"]
368
-
369
- example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
- example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
-
372
- return example
373
-
374
-
375
- class ImageNetSRTrain(ImageNetSR):
376
- def __init__(self, **kwargs):
377
- super().__init__(**kwargs)
378
-
379
- def get_base(self):
380
- with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
- indices = pickle.load(f)
382
- dset = ImageNetTrain(process_images=False,)
383
- return Subset(dset, indices)
384
-
385
-
386
- class ImageNetSRValidation(ImageNetSR):
387
- def __init__(self, **kwargs):
388
- super().__init__(**kwargs)
389
-
390
- def get_base(self):
391
- with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
- indices = pickle.load(f)
393
- dset = ImageNetValidation(process_images=False,)
394
- return Subset(dset, indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/data/lsun.py DELETED
@@ -1,92 +0,0 @@
1
- import os
2
- import numpy as np
3
- import PIL
4
- from PIL import Image
5
- from torch.utils.data import Dataset
6
- from torchvision import transforms
7
-
8
-
9
- class LSUNBase(Dataset):
10
- def __init__(self,
11
- txt_file,
12
- data_root,
13
- size=None,
14
- interpolation="bicubic",
15
- flip_p=0.5
16
- ):
17
- self.data_paths = txt_file
18
- self.data_root = data_root
19
- with open(self.data_paths, "r") as f:
20
- self.image_paths = f.read().splitlines()
21
- self._length = len(self.image_paths)
22
- self.labels = {
23
- "relative_file_path_": [l for l in self.image_paths],
24
- "file_path_": [os.path.join(self.data_root, l)
25
- for l in self.image_paths],
26
- }
27
-
28
- self.size = size
29
- self.interpolation = {"linear": PIL.Image.LINEAR,
30
- "bilinear": PIL.Image.BILINEAR,
31
- "bicubic": PIL.Image.BICUBIC,
32
- "lanczos": PIL.Image.LANCZOS,
33
- }[interpolation]
34
- self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35
-
36
- def __len__(self):
37
- return self._length
38
-
39
- def __getitem__(self, i):
40
- example = dict((k, self.labels[k][i]) for k in self.labels)
41
- image = Image.open(example["file_path_"])
42
- if not image.mode == "RGB":
43
- image = image.convert("RGB")
44
-
45
- # default to score-sde preprocessing
46
- img = np.array(image).astype(np.uint8)
47
- crop = min(img.shape[0], img.shape[1])
48
- h, w, = img.shape[0], img.shape[1]
49
- img = img[(h - crop) // 2:(h + crop) // 2,
50
- (w - crop) // 2:(w + crop) // 2]
51
-
52
- image = Image.fromarray(img)
53
- if self.size is not None:
54
- image = image.resize((self.size, self.size), resample=self.interpolation)
55
-
56
- image = self.flip(image)
57
- image = np.array(image).astype(np.uint8)
58
- example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59
- return example
60
-
61
-
62
- class LSUNChurchesTrain(LSUNBase):
63
- def __init__(self, **kwargs):
64
- super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65
-
66
-
67
- class LSUNChurchesValidation(LSUNBase):
68
- def __init__(self, flip_p=0., **kwargs):
69
- super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70
- flip_p=flip_p, **kwargs)
71
-
72
-
73
- class LSUNBedroomsTrain(LSUNBase):
74
- def __init__(self, **kwargs):
75
- super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76
-
77
-
78
- class LSUNBedroomsValidation(LSUNBase):
79
- def __init__(self, flip_p=0.0, **kwargs):
80
- super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81
- flip_p=flip_p, **kwargs)
82
-
83
-
84
- class LSUNCatsTrain(LSUNBase):
85
- def __init__(self, **kwargs):
86
- super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87
-
88
-
89
- class LSUNCatsValidation(LSUNBase):
90
- def __init__(self, flip_p=0., **kwargs):
91
- super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92
- flip_p=flip_p, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/lr_scheduler.py DELETED
@@ -1,81 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- class LambdaWarmUpCosineScheduler:
5
- """
6
- note: use with a base_lr of 1.0
7
- """
8
- def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
- self.lr_warm_up_steps = warm_up_steps
10
- self.lr_start = lr_start
11
- self.lr_min = lr_min
12
- self.lr_max = lr_max
13
- self.lr_max_decay_steps = max_decay_steps
14
- self.last_lr = 0.
15
- self.verbosity_interval = verbosity_interval
16
-
17
- def schedule(self, n, **kwargs):
18
- if self.verbosity_interval > 0:
19
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
- if n < self.lr_warm_up_steps:
21
- lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
- self.last_lr = lr
23
- return lr
24
- else:
25
- t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
- t = min(t, 1.0)
27
- lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
- 1 + np.cos(t * np.pi))
29
- self.last_lr = lr
30
- return lr
31
-
32
- def __call__(self, n, **kwargs):
33
- return self.schedule(n,**kwargs)
34
-
35
-
36
- class LambdaWarmUpCosineScheduler2:
37
- """
38
- supports repeated iterations, configurable via lists
39
- note: use with a base_lr of 1.0.
40
- """
41
- def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
- assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
- self.lr_warm_up_steps = warm_up_steps # [10000]
44
- self.f_start = f_start # 1.e-6
45
- self.f_min = f_min # 1.
46
- self.f_max = f_max # 1.
47
- self.cycle_lengths = cycle_lengths # [10000000000000]
48
- self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
- self.last_f = 0.
50
- self.verbosity_interval = verbosity_interval # 0
51
-
52
- def find_in_interval(self, n):
53
- interval = 0
54
- for cl in self.cum_cycles[1:]:
55
- if n <= cl:
56
- return interval
57
- interval += 1
58
-
59
-
60
- def __call__(self, n, **kwargs):
61
- return self.schedule(n, **kwargs)
62
-
63
-
64
- class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
65
-
66
- def schedule(self, n, **kwargs):
67
- cycle = self.find_in_interval(n)
68
- n = n - self.cum_cycles[cycle]
69
- if self.verbosity_interval > 0:
70
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
71
- f"current cycle {cycle}")
72
-
73
- if n < self.lr_warm_up_steps[cycle]:
74
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
75
- self.last_f = f
76
- return f
77
- else:
78
- f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
79
- self.last_f = f
80
- return f
81
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/.DS_Store DELETED
Binary file (6.15 kB)
 
ldm/models/__pycache__/autoencoder.cpython-38.pyc DELETED
Binary file (12.1 kB)
 
ldm/models/autoencoder.py DELETED
@@ -1,408 +0,0 @@
1
- import torch
2
- import pytorch_lightning as pl
3
- import torch.nn.functional as F
4
- from contextlib import contextmanager
5
-
6
- from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
-
8
- from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
- from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
-
11
- from ldm.util import instantiate_from_config
12
-
13
-
14
- class VQModel(pl.LightningModule):
15
- def __init__(self,
16
- ddconfig,
17
- lossconfig,
18
- n_embed,
19
- embed_dim,
20
- ckpt_path=None,
21
- ignore_keys=[],
22
- image_key="image",
23
- colorize_nlabels=None,
24
- monitor=None,
25
- batch_resize_range=None,
26
- scheduler_config=None,
27
- lr_g_factor=1.0,
28
- remap=None,
29
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
- use_ema=False
31
- ):
32
- super().__init__()
33
- self.embed_dim = embed_dim
34
- self.n_embed = n_embed
35
- self.image_key = image_key
36
- self.encoder = Encoder(**ddconfig)
37
- self.decoder = Decoder(**ddconfig)
38
- self.loss = instantiate_from_config(lossconfig)
39
- self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
- remap=remap,
41
- sane_index_shape=sane_index_shape)
42
- self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
- if colorize_nlabels is not None:
45
- assert type(colorize_nlabels)==int
46
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
- if monitor is not None:
48
- self.monitor = monitor
49
- self.batch_resize_range = batch_resize_range
50
- if self.batch_resize_range is not None:
51
- print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
-
53
- self.use_ema = use_ema
54
- if self.use_ema:
55
- self.model_ema = LitEma(self)
56
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
-
58
- if ckpt_path is not None:
59
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
- self.scheduler_config = scheduler_config
61
- self.lr_g_factor = lr_g_factor
62
-
63
- @contextmanager
64
- def ema_scope(self, context=None):
65
- if self.use_ema:
66
- self.model_ema.store(self.parameters())
67
- self.model_ema.copy_to(self)
68
- if context is not None:
69
- print(f"{context}: Switched to EMA weights")
70
- try:
71
- yield None
72
- finally:
73
- if self.use_ema:
74
- self.model_ema.restore(self.parameters())
75
- if context is not None:
76
- print(f"{context}: Restored training weights")
77
-
78
- def init_from_ckpt(self, path, ignore_keys=list()):
79
- sd = torch.load(path, map_location="cpu")["state_dict"]
80
- keys = list(sd.keys())
81
- for k in keys:
82
- for ik in ignore_keys:
83
- if k.startswith(ik):
84
- print("Deleting key {} from state_dict.".format(k))
85
- del sd[k]
86
- missing, unexpected = self.load_state_dict(sd, strict=False)
87
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
- if len(missing) > 0:
89
- print(f"Missing Keys: {missing}")
90
- print(f"Unexpected Keys: {unexpected}")
91
-
92
- def on_train_batch_end(self, *args, **kwargs):
93
- if self.use_ema:
94
- self.model_ema(self)
95
-
96
- def encode(self, x):
97
- h = self.encoder(x)
98
- h = self.quant_conv(h)
99
- quant, emb_loss, info = self.quantize(h)
100
- return quant, emb_loss, info
101
-
102
- def encode_to_prequant(self, x):
103
- h = self.encoder(x)
104
- h = self.quant_conv(h)
105
- return h
106
-
107
- def decode(self, quant):
108
- quant = self.post_quant_conv(quant)
109
- dec = self.decoder(quant)
110
- return dec
111
-
112
- def decode_code(self, code_b):
113
- quant_b = self.quantize.embed_code(code_b)
114
- dec = self.decode(quant_b)
115
- return dec
116
-
117
- def forward(self, input, return_pred_indices=False):
118
- quant, diff, (_,_,ind) = self.encode(input)
119
- dec = self.decode(quant)
120
- if return_pred_indices:
121
- return dec, diff, ind
122
- return dec, diff
123
-
124
- def get_input(self, batch, k):
125
- x = batch[k]
126
- if len(x.shape) == 3:
127
- x = x[..., None]
128
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
- if self.batch_resize_range is not None:
130
- lower_size = self.batch_resize_range[0]
131
- upper_size = self.batch_resize_range[1]
132
- if self.global_step <= 4:
133
- # do the first few batches with max size to avoid later oom
134
- new_resize = upper_size
135
- else:
136
- new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
- if new_resize != x.shape[2]:
138
- x = F.interpolate(x, size=new_resize, mode="bicubic")
139
- x = x.detach()
140
- return x
141
-
142
- def training_step(self, batch, batch_idx, optimizer_idx):
143
- # https://github.com/pytorch/pytorch/issues/37142
144
- # try not to fool the heuristics
145
- x = self.get_input(batch, self.image_key)
146
- xrec, qloss, ind = self(x, return_pred_indices=True)
147
-
148
- if optimizer_idx == 0:
149
- # autoencode
150
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
- last_layer=self.get_last_layer(), split="train",
152
- predicted_indices=ind)
153
-
154
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
- return aeloss
156
-
157
- if optimizer_idx == 1:
158
- # discriminator
159
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
- last_layer=self.get_last_layer(), split="train")
161
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
- return discloss
163
-
164
- def validation_step(self, batch, batch_idx):
165
- log_dict = self._validation_step(batch, batch_idx)
166
- with self.ema_scope():
167
- log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
- return log_dict
169
-
170
- def _validation_step(self, batch, batch_idx, suffix=""):
171
- x = self.get_input(batch, self.image_key)
172
- xrec, qloss, ind = self(x, return_pred_indices=True)
173
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
- self.global_step,
175
- last_layer=self.get_last_layer(),
176
- split="val"+suffix,
177
- predicted_indices=ind
178
- )
179
-
180
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
- self.global_step,
182
- last_layer=self.get_last_layer(),
183
- split="val"+suffix,
184
- predicted_indices=ind
185
- )
186
- rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
- self.log(f"val{suffix}/rec_loss", rec_loss,
188
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
- self.log(f"val{suffix}/aeloss", aeloss,
190
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
- if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
- del log_dict_ae[f"val{suffix}/rec_loss"]
193
- self.log_dict(log_dict_ae)
194
- self.log_dict(log_dict_disc)
195
- return self.log_dict
196
-
197
- def configure_optimizers(self):
198
- lr_d = self.learning_rate
199
- lr_g = self.lr_g_factor*self.learning_rate
200
- print("lr_d", lr_d)
201
- print("lr_g", lr_g)
202
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203
- list(self.decoder.parameters())+
204
- list(self.quantize.parameters())+
205
- list(self.quant_conv.parameters())+
206
- list(self.post_quant_conv.parameters()),
207
- lr=lr_g, betas=(0.5, 0.9))
208
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209
- lr=lr_d, betas=(0.5, 0.9))
210
-
211
- if self.scheduler_config is not None:
212
- scheduler = instantiate_from_config(self.scheduler_config)
213
-
214
- print("Setting up LambdaLR scheduler...")
215
- scheduler = [
216
- {
217
- 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218
- 'interval': 'step',
219
- 'frequency': 1
220
- },
221
- {
222
- 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223
- 'interval': 'step',
224
- 'frequency': 1
225
- },
226
- ]
227
- return [opt_ae, opt_disc], scheduler
228
- return [opt_ae, opt_disc], []
229
-
230
- def get_last_layer(self):
231
- return self.decoder.conv_out.weight
232
-
233
- def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234
- log = dict()
235
- x = self.get_input(batch, self.image_key)
236
- x = x.to(self.device)
237
- if only_inputs:
238
- log["inputs"] = x
239
- return log
240
- xrec, _ = self(x)
241
- if x.shape[1] > 3:
242
- # colorize with random projection
243
- assert xrec.shape[1] > 3
244
- x = self.to_rgb(x)
245
- xrec = self.to_rgb(xrec)
246
- log["inputs"] = x
247
- log["reconstructions"] = xrec
248
- if plot_ema:
249
- with self.ema_scope():
250
- xrec_ema, _ = self(x)
251
- if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252
- log["reconstructions_ema"] = xrec_ema
253
- return log
254
-
255
- def to_rgb(self, x):
256
- assert self.image_key == "segmentation"
257
- if not hasattr(self, "colorize"):
258
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259
- x = F.conv2d(x, weight=self.colorize)
260
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261
- return x
262
-
263
-
264
- class VQModelInterface(VQModel):
265
- def __init__(self, embed_dim, *args, **kwargs):
266
- super().__init__(embed_dim=embed_dim, *args, **kwargs)
267
- self.embed_dim = embed_dim
268
-
269
- def encode(self, x):
270
- h = self.encoder(x)
271
- h = self.quant_conv(h)
272
- return h
273
-
274
- def decode(self, h, force_not_quantize=False):
275
- # also go through quantization layer
276
- if not force_not_quantize:
277
- quant, emb_loss, info = self.quantize(h)
278
- else:
279
- quant = h
280
- quant = self.post_quant_conv(quant)
281
- dec = self.decoder(quant)
282
- return dec
283
-
284
-
285
- class AutoencoderKL(pl.LightningModule):
286
- def __init__(self,
287
- embed_dim, # 4
288
- monitor, # "val/rec_loss"
289
- ddconfig, # {...}
290
- lossconfig, # {target: torch.nn.Identity}
291
- ckpt_path=None,
292
- ignore_keys=[],
293
- image_key="image",
294
- colorize_nlabels=None):
295
- super().__init__()
296
- self.image_key = image_key # "image"
297
- self.encoder = Encoder(**ddconfig)
298
- self.decoder = Decoder(**ddconfig)
299
- self.loss = instantiate_from_config(lossconfig)
300
- self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) # 8 -> 8
301
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) # 4 -> 4
302
- self.embed_dim = embed_dim
303
- self.monitor = monitor
304
-
305
- def encode(self, x):
306
- h = self.encoder(x)
307
- moments = self.quant_conv(h)
308
- posterior = DiagonalGaussianDistribution(moments)
309
- return posterior
310
-
311
- def decode(self, z):
312
- z = self.post_quant_conv(z)
313
- dec = self.decoder(z)
314
- return dec
315
-
316
- def forward(self, input, sample_posterior=True):
317
- posterior = self.encode(input)
318
- if sample_posterior:
319
- z = posterior.sample()
320
- else:
321
- z = posterior.mode()
322
- dec = self.decode(z)
323
- return dec, posterior
324
-
325
- def get_input(self, batch, k):
326
- x = batch[k]
327
- if len(x.shape) == 3:
328
- x = x[..., None]
329
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
330
- return x
331
-
332
- def training_step(self, batch, batch_idx, optimizer_idx):
333
- inputs = self.get_input(batch, self.image_key)
334
- reconstructions, posterior = self(inputs)
335
-
336
- if optimizer_idx == 0:
337
- # train encoder+decoder+logvar
338
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
339
- last_layer=self.get_last_layer(), split="train")
340
- self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
341
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
342
- return aeloss
343
-
344
- if optimizer_idx == 1:
345
- # train the discriminator
346
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
347
- last_layer=self.get_last_layer(), split="train")
348
-
349
- self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
350
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
351
- return discloss
352
-
353
- def validation_step(self, batch, batch_idx):
354
- inputs = self.get_input(batch, self.image_key)
355
- reconstructions, posterior = self(inputs)
356
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
357
- last_layer=self.get_last_layer(), split="val")
358
-
359
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
360
- last_layer=self.get_last_layer(), split="val")
361
-
362
- self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
363
- self.log_dict(log_dict_ae)
364
- self.log_dict(log_dict_disc)
365
- return self.log_dict
366
-
367
- def configure_optimizers(self):
368
- lr = self.learning_rate
369
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
370
- list(self.decoder.parameters())+
371
- list(self.quant_conv.parameters())+
372
- list(self.post_quant_conv.parameters()),
373
- lr=lr, betas=(0.5, 0.9))
374
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
375
- lr=lr, betas=(0.5, 0.9))
376
- return [opt_ae, opt_disc], []
377
-
378
- def get_last_layer(self):
379
- return self.decoder.conv_out.weight
380
-
381
- @torch.no_grad()
382
- def log_images(self, batch, only_inputs=False, **kwargs):
383
- log = dict()
384
- x = self.get_input(batch, self.image_key)
385
- x = x.to(self.device)
386
- if not only_inputs:
387
- xrec, posterior = self(x)
388
- if x.shape[1] > 3:
389
- # colorize with random projection
390
- assert xrec.shape[1] > 3
391
- x = self.to_rgb(x)
392
- xrec = self.to_rgb(xrec)
393
- log["samples"] = self.decode(torch.randn_like(posterior.sample()))
394
- log["reconstructions"] = xrec
395
- log["inputs"] = x
396
- return log
397
-
398
- def to_rgb(self, x):
399
- assert self.image_key == "segmentation"
400
- if not hasattr(self, "colorize"):
401
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
402
- x = F.conv2d(x, weight=self.colorize)
403
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
404
- return x
405
-
406
-
407
-
408
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/.DS_Store DELETED
Binary file (6.15 kB)
 
ldm/models/diffusion/__init__.py DELETED
File without changes
ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (150 Bytes)
 
ldm/models/diffusion/__pycache__/control.cpython-38.pyc DELETED
Binary file (9.15 kB)
 
ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc DELETED
Binary file (7.76 kB)
 
ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc DELETED
Binary file (5.1 kB)
 
ldm/models/diffusion/classifier.py DELETED
@@ -1,267 +0,0 @@
1
- import os
2
- import torch
3
- import pytorch_lightning as pl
4
- from omegaconf import OmegaConf
5
- from torch.nn import functional as F
6
- from torch.optim import AdamW
7
- from torch.optim.lr_scheduler import LambdaLR
8
- from copy import deepcopy
9
- from einops import rearrange
10
- from glob import glob
11
- from natsort import natsorted
12
-
13
- from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
- from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
-
16
- __models__ = {
17
- 'class_label': EncoderUNetModel,
18
- 'segmentation': UNetModel
19
- }
20
-
21
-
22
- def disabled_train(self, mode=True):
23
- """Overwrite model.train with this function to make sure train/eval mode
24
- does not change anymore."""
25
- return self
26
-
27
-
28
- class NoisyLatentImageClassifier(pl.LightningModule):
29
-
30
- def __init__(self,
31
- diffusion_path,
32
- num_classes,
33
- ckpt_path=None,
34
- pool='attention',
35
- label_key=None,
36
- diffusion_ckpt_path=None,
37
- scheduler_config=None,
38
- weight_decay=1.e-2,
39
- log_steps=10,
40
- monitor='val/loss',
41
- *args,
42
- **kwargs):
43
- super().__init__(*args, **kwargs)
44
- self.num_classes = num_classes
45
- # get latest config of diffusion model
46
- diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
- self.diffusion_config = OmegaConf.load(diffusion_config).model
48
- self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
- self.load_diffusion()
50
-
51
- self.monitor = monitor
52
- self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
- self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
- self.log_steps = log_steps
55
-
56
- self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
- else self.diffusion_model.cond_stage_key
58
-
59
- assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
-
61
- if self.label_key not in __models__:
62
- raise NotImplementedError()
63
-
64
- self.load_classifier(ckpt_path, pool)
65
-
66
- self.scheduler_config = scheduler_config
67
- self.use_scheduler = self.scheduler_config is not None
68
- self.weight_decay = weight_decay
69
-
70
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
- sd = torch.load(path, map_location="cpu")
72
- if "state_dict" in list(sd.keys()):
73
- sd = sd["state_dict"]
74
- keys = list(sd.keys())
75
- for k in keys:
76
- for ik in ignore_keys:
77
- if k.startswith(ik):
78
- print("Deleting key {} from state_dict.".format(k))
79
- del sd[k]
80
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
- sd, strict=False)
82
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
- if len(missing) > 0:
84
- print(f"Missing Keys: {missing}")
85
- if len(unexpected) > 0:
86
- print(f"Unexpected Keys: {unexpected}")
87
-
88
- def load_diffusion(self):
89
- model = instantiate_from_config(self.diffusion_config)
90
- self.diffusion_model = model.eval()
91
- self.diffusion_model.train = disabled_train
92
- for param in self.diffusion_model.parameters():
93
- param.requires_grad = False
94
-
95
- def load_classifier(self, ckpt_path, pool):
96
- model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
- model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
- model_config.out_channels = self.num_classes
99
- if self.label_key == 'class_label':
100
- model_config.pool = pool
101
-
102
- self.model = __models__[self.label_key](**model_config)
103
- if ckpt_path is not None:
104
- print('#####################################################################')
105
- print(f'load from ckpt "{ckpt_path}"')
106
- print('#####################################################################')
107
- self.init_from_ckpt(ckpt_path)
108
-
109
- @torch.no_grad()
110
- def get_x_noisy(self, x, t, noise=None):
111
- noise = default(noise, lambda: torch.randn_like(x))
112
- continuous_sqrt_alpha_cumprod = None
113
- if self.diffusion_model.use_continuous_noise:
114
- continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
- # todo: make sure t+1 is correct here
116
-
117
- return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
- continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
-
120
- def forward(self, x_noisy, t, *args, **kwargs):
121
- return self.model(x_noisy, t)
122
-
123
- @torch.no_grad()
124
- def get_input(self, batch, k):
125
- x = batch[k]
126
- if len(x.shape) == 3:
127
- x = x[..., None]
128
- x = rearrange(x, 'b h w c -> b c h w')
129
- x = x.to(memory_format=torch.contiguous_format).float()
130
- return x
131
-
132
- @torch.no_grad()
133
- def get_conditioning(self, batch, k=None):
134
- if k is None:
135
- k = self.label_key
136
- assert k is not None, 'Needs to provide label key'
137
-
138
- targets = batch[k].to(self.device)
139
-
140
- if self.label_key == 'segmentation':
141
- targets = rearrange(targets, 'b h w c -> b c h w')
142
- for down in range(self.numd):
143
- h, w = targets.shape[-2:]
144
- targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
-
146
- # targets = rearrange(targets,'b c h w -> b h w c')
147
-
148
- return targets
149
-
150
- def compute_top_k(self, logits, labels, k, reduction="mean"):
151
- _, top_ks = torch.topk(logits, k, dim=1)
152
- if reduction == "mean":
153
- return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
- elif reduction == "none":
155
- return (top_ks == labels[:, None]).float().sum(dim=-1)
156
-
157
- def on_train_epoch_start(self):
158
- # save some memory
159
- self.diffusion_model.model.to('cpu')
160
-
161
- @torch.no_grad()
162
- def write_logs(self, loss, logits, targets):
163
- log_prefix = 'train' if self.training else 'val'
164
- log = {}
165
- log[f"{log_prefix}/loss"] = loss.mean()
166
- log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
- logits, targets, k=1, reduction="mean"
168
- )
169
- log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
- logits, targets, k=5, reduction="mean"
171
- )
172
-
173
- self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
- self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
- self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
- lr = self.optimizers().param_groups[0]['lr']
177
- self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
-
179
- def shared_step(self, batch, t=None):
180
- x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
- targets = self.get_conditioning(batch)
182
- if targets.dim() == 4:
183
- targets = targets.argmax(dim=1)
184
- if t is None:
185
- t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
- else:
187
- t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
- x_noisy = self.get_x_noisy(x, t)
189
- logits = self(x_noisy, t)
190
-
191
- loss = F.cross_entropy(logits, targets, reduction='none')
192
-
193
- self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
-
195
- loss = loss.mean()
196
- return loss, logits, x_noisy, targets
197
-
198
- def training_step(self, batch, batch_idx):
199
- loss, *_ = self.shared_step(batch)
200
- return loss
201
-
202
- def reset_noise_accs(self):
203
- self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
- range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
-
206
- def on_validation_start(self):
207
- self.reset_noise_accs()
208
-
209
- @torch.no_grad()
210
- def validation_step(self, batch, batch_idx):
211
- loss, *_ = self.shared_step(batch)
212
-
213
- for t in self.noisy_acc:
214
- _, logits, _, targets = self.shared_step(batch, t)
215
- self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
- self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
-
218
- return loss
219
-
220
- def configure_optimizers(self):
221
- optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
-
223
- if self.use_scheduler:
224
- scheduler = instantiate_from_config(self.scheduler_config)
225
-
226
- print("Setting up LambdaLR scheduler...")
227
- scheduler = [
228
- {
229
- 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
- 'interval': 'step',
231
- 'frequency': 1
232
- }]
233
- return [optimizer], scheduler
234
-
235
- return optimizer
236
-
237
- @torch.no_grad()
238
- def log_images(self, batch, N=8, *args, **kwargs):
239
- log = dict()
240
- x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
- log['inputs'] = x
242
-
243
- y = self.get_conditioning(batch)
244
-
245
- if self.label_key == 'class_label':
246
- y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
- log['labels'] = y
248
-
249
- if ismap(y):
250
- log['labels'] = self.diffusion_model.to_rgb(y)
251
-
252
- for step in range(self.log_steps):
253
- current_time = step * self.log_time_interval
254
-
255
- _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
-
257
- log[f'inputs@t{current_time}'] = x_noisy
258
-
259
- pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
- pred = rearrange(pred, 'b h w c -> b c h w')
261
-
262
- log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
-
264
- for key in log:
265
- log[key] = log[key][:N]
266
-
267
- return log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/control.py DELETED
@@ -1,406 +0,0 @@
1
- import random
2
- import torch
3
- import torchvision
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from torch.optim.lr_scheduler import CosineAnnealingLR
7
-
8
- from einops import rearrange
9
- from ldm.modules.diffusionmodules.util import (
10
- conv_nd,
11
- linear,
12
- zero_module,
13
- timestep_embedding,
14
- )
15
- from ldm.models.diffusion.ddpm import DDPM
16
- from ldm.modules.attention import SpatialTransformer
17
- from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
18
- from ldm.util import instantiate_from_config, default
19
- from ldm.models.diffusion.ddim import DDIMSampler
20
-
21
- import torch
22
- from torch.optim.optimizer import Optimizer
23
- from torch.optim.lr_scheduler import LambdaLR
24
-
25
-
26
- def disabled_train(self, mode=True):
27
- return self
28
-
29
-
30
- # =============================================================
31
- # 可训练部分 ControlNet
32
- # =============================================================
33
- class ControlNet(nn.Module):
34
- def __init__(
35
- self,
36
- in_channels, # 9
37
- model_channels, # 320
38
- hint_channels, # 20
39
- attention_resolutions, # [4,2,1]
40
- num_res_blocks, # 2
41
- channel_mult=(1, 2, 4, 8), # [1,2,4,4]
42
- num_head_channels=-1, # 64
43
- transformer_depth=1, # 1
44
- context_dim=None, # 768
45
- use_checkpoint=False, # True
46
- dropout=0,
47
- conv_resample=True,
48
- dims=2,
49
- num_heads=-1,
50
- use_scale_shift_norm=False):
51
- super(ControlNet, self).__init__()
52
- self.dims = dims
53
- self.in_channels = in_channels
54
- self.model_channels = model_channels
55
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
56
- self.attention_resolutions = attention_resolutions
57
- self.dropout = dropout
58
- self.channel_mult = channel_mult
59
- self.use_checkpoint = use_checkpoint
60
- self.dtype = torch.float32
61
- self.num_heads = num_heads
62
- self.num_head_channels = num_head_channels
63
-
64
- # time 编码器
65
- time_embed_dim = model_channels * 4
66
- self.time_embed = nn.Sequential(
67
- linear(model_channels, time_embed_dim),
68
- nn.SiLU(),
69
- linear(time_embed_dim, time_embed_dim),
70
- )
71
-
72
- # input 编码器
73
- self.input_blocks = nn.ModuleList(
74
- [
75
- TimestepEmbedSequential(
76
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
77
- )
78
- ]
79
- )
80
- self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
81
-
82
- # hint 编码器
83
- self.input_hint_block = TimestepEmbedSequential(
84
- conv_nd(dims, hint_channels, 16, 3, padding=1),
85
- nn.SiLU(),
86
- conv_nd(dims, 16, 16, 3, padding=1),
87
- nn.SiLU(),
88
- conv_nd(dims, 16, 32, 3, padding=1, stride=2),
89
- nn.SiLU(),
90
- conv_nd(dims, 32, 32, 3, padding=1),
91
- nn.SiLU(),
92
- conv_nd(dims, 32, 96, 3, padding=1, stride=2),
93
- nn.SiLU(),
94
- conv_nd(dims, 96, 96, 3, padding=1),
95
- nn.SiLU(),
96
- conv_nd(dims, 96, 256, 3, padding=1, stride=2),
97
- nn.SiLU(),
98
- zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
99
- )
100
-
101
- # UNet
102
- input_block_chans = [model_channels]
103
- ch = model_channels
104
- ds = 1
105
- for level, mult in enumerate(channel_mult):
106
- for nr in range(self.num_res_blocks[level]):
107
- layers = [
108
- ResBlock(
109
- ch,
110
- time_embed_dim,
111
- dropout,
112
- out_channels=mult * model_channels,
113
- dims=dims,
114
- use_checkpoint=use_checkpoint,
115
- use_scale_shift_norm=use_scale_shift_norm,
116
- )
117
- ]
118
- ch = mult * model_channels
119
- if ds in attention_resolutions:
120
- num_heads = ch // num_head_channels
121
- dim_head = num_head_channels
122
- disabled_sa = False
123
-
124
- layers.append(
125
- SpatialTransformer(
126
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim)
127
- )
128
- self.input_blocks.append(TimestepEmbedSequential(*layers))
129
- self.zero_convs.append(self.make_zero_conv(ch))
130
- input_block_chans.append(ch)
131
- if level != len(channel_mult) - 1:
132
- out_ch = ch
133
- self.input_blocks.append(
134
- TimestepEmbedSequential(
135
- Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
136
- )
137
- )
138
- ch = out_ch
139
- input_block_chans.append(ch)
140
- self.zero_convs.append(self.make_zero_conv(ch))
141
- ds *= 2
142
- num_heads = ch // num_head_channels
143
- dim_head = num_head_channels
144
- self.middle_block = TimestepEmbedSequential(
145
- ResBlock(
146
- ch,
147
- time_embed_dim,
148
- dropout,
149
- dims=dims,
150
- use_checkpoint=use_checkpoint,
151
- use_scale_shift_norm=use_scale_shift_norm,
152
- ),
153
- SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
154
- ResBlock(
155
- ch,
156
- time_embed_dim,
157
- dropout,
158
- dims=dims,
159
- use_checkpoint=use_checkpoint,
160
- use_scale_shift_norm=use_scale_shift_norm,
161
- ),
162
- )
163
- self.middle_block_out = self.make_zero_conv(ch)
164
-
165
- def make_zero_conv(self, channels):
166
- return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
167
-
168
- def forward(self, x, hint, timesteps, reference_dino):
169
- # 处理输入
170
- context = reference_dino
171
- t_emb = timestep_embedding(timesteps, self.model_channels)
172
- emb = self.time_embed(t_emb)
173
- guided_hint = self.input_hint_block(hint, emb)
174
-
175
- # 预测 control
176
- outs = []
177
- h = x.type(self.dtype)
178
- for module, zero_conv in zip(self.input_blocks, self.zero_convs):
179
- if guided_hint is not None:
180
- h = module(h, emb, context)
181
- h += guided_hint
182
- guided_hint = None
183
- else:
184
- h = module(h, emb, context)
185
- outs.append(zero_conv(h, emb, context))
186
- h = self.middle_block(h, emb, context)
187
- outs.append(self.middle_block_out(h, emb, context))
188
-
189
- return outs
190
-
191
- # =============================================================
192
- # 固定参数部分 ControlledUnetModel
193
- # =============================================================
194
- class ControlledUnetModel(UNetModel):
195
- def forward(self, x, timesteps=None, context=None, control=None):
196
- hs = []
197
-
198
- # UNet 的上半部分
199
- with torch.no_grad():
200
- t_emb = timestep_embedding(timesteps, self.model_channels)
201
- emb = self.time_embed(t_emb)
202
- h = x.type(self.dtype)
203
- for module in self.input_blocks:
204
- h = module(h, emb, context)
205
- hs.append(h)
206
- h = self.middle_block(h, emb, context)
207
-
208
- # 注入 control
209
- if control is not None:
210
- h += control.pop()
211
-
212
- # UNet 的下半部分
213
- for i, module in enumerate(self.output_blocks):
214
- h = torch.cat([h, hs.pop() + control.pop()], dim=1)
215
- h = module(h, emb, context)
216
-
217
- # 输出
218
- h = h.type(x.dtype)
219
- h = self.out(h)
220
-
221
- return h
222
-
223
- # =============================================================
224
- # 主干网络 ControlLDM
225
- # =============================================================
226
- class ControlLDM(DDPM):
227
- def __init__(self,
228
- control_stage_config, # ControlNet
229
- first_stage_config, # AutoencoderKL
230
- cond_stage_config, # FrozenCLIPImageEmbedder
231
- scale_factor=1.0, # 0.18215
232
- *args, **kwargs):
233
- self.num_timesteps_cond = 1
234
- super().__init__(*args, **kwargs) # self.model 和 self.register_buffer
235
- self.control_model = instantiate_from_config(control_stage_config) # self.control_model
236
- self.instantiate_first_stage(first_stage_config) # self.first_stage_model 调用 AutoencoderKL
237
- self.instantiate_cond_stage(cond_stage_config) # self.cond_stage_model 调用 FrozenCLIPImageEmbedder
238
- self.proj_out=nn.Linear(1024, 768) # 全连接层
239
- self.scale_factor = scale_factor # 0.18215
240
- self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=False)
241
- self.trainable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
242
-
243
- self.dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
244
- self.dinov2_vitl14.eval()
245
- self.dinov2_vitl14.train = disabled_train
246
- for param in self.dinov2_vitl14.parameters():
247
- param.requires_grad = False
248
- self.linear = nn.Linear(1024, 768)
249
-
250
- # self.dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
251
- # self.dinov2_vitg14.eval()
252
- # self.dinov2_vitg14.train = disabled_train
253
- # for param in self.dinov2_vitg14.parameters():
254
- # param.requires_grad = False
255
- # self.linear = nn.Linear(1536, 768)
256
-
257
- # AutoencoderKL 不训练
258
- def instantiate_first_stage(self, config):
259
- model = instantiate_from_config(config)
260
- self.first_stage_model = model.eval()
261
- self.first_stage_model.train = disabled_train
262
- for param in self.first_stage_model.parameters():
263
- param.requires_grad = False
264
-
265
- # FrozenCLIPImageEmbedder 不训练
266
- def instantiate_cond_stage(self, config):
267
- model = instantiate_from_config(config)
268
- self.cond_stage_model = model.eval()
269
- self.cond_stage_model.train = disabled_train
270
- for param in self.cond_stage_model.parameters():
271
- param.requires_grad = False
272
-
273
- # 训练
274
- def training_step(self, batch, batch_idx):
275
- z_new, reference, hint= self.get_input(batch) # 加载数据
276
- loss= self(z_new, reference, hint) # 计算损失
277
- self.log("loss", # 记录损失
278
- loss,
279
- prog_bar=True,
280
- logger=True,
281
- on_step=True,
282
- on_epoch=True)
283
- self.log('lr_abs', # 记录学习率
284
- self.optimizers().param_groups[0]['lr'],
285
- prog_bar=True,
286
- logger=True,
287
- on_step=True,
288
- on_epoch=False)
289
- return loss
290
-
291
- # 加载数据
292
- @torch.no_grad()
293
- def get_input(self, batch):
294
-
295
- # 加载原始数据
296
- x, inpaint, mask, reference, hint = super().get_input(batch)
297
-
298
- # AutoencoderKL 处理真值
299
- encoder_posterior = self.first_stage_model.encode(x)
300
- z = self.scale_factor * (encoder_posterior.sample()).detach()
301
-
302
- # AutoencoderKL 处理 inpaint
303
- encoder_posterior_inpaint = self.first_stage_model.encode(inpaint)
304
- z_inpaint = self.scale_factor * (encoder_posterior_inpaint.sample()).detach()
305
-
306
- # Resize mask
307
- mask_resize = torchvision.transforms.Resize([z.shape[-2],z.shape[-1]])(mask)
308
-
309
- # 整理 z_new
310
- z_new = torch.cat((z,z_inpaint,mask_resize),dim=1)
311
- out = [z_new, reference, hint]
312
-
313
- return out
314
-
315
- # 计算损失
316
- def forward(self, z_new, reference, hint):
317
-
318
- # 随机时间 t
319
- t = torch.randint(0, self.num_timesteps, (z_new.shape[0],), device=self.device).long()
320
-
321
- # CLIP 处理 reference
322
- reference_clip = self.cond_stage_model.encode(reference)
323
- reference_clip = self.proj_out(reference_clip)
324
-
325
- # DINO 处理 reference
326
- dino = self.dinov2_vitl14(reference,is_training=True)
327
- dino1 = dino["x_norm_clstoken"].unsqueeze(1)
328
- dino2 = dino["x_norm_patchtokens"]
329
- reference_dino = torch.cat((dino1, dino2), dim=1)
330
- reference_dino = self.linear(reference_dino)
331
-
332
- # 随机加噪
333
- noise = torch.randn_like(z_new[:,:4,:,:])
334
- x_noisy = self.q_sample(x_start=z_new[:,:4,:,:], t=t, noise=noise)
335
- x_noisy = torch.cat((x_noisy, z_new[:,4:,:,:]),dim=1)
336
-
337
- # 预测噪声
338
- if random.uniform(0, 1)<0.2:
339
- model_output = self.apply_model(x_noisy, hint, t, reference_clip, reference_dino)
340
- else:
341
- model_output = self.apply_model(x_noisy, hint, t, reference_clip, reference_dino)
342
-
343
- # 计算损失
344
- loss = self.get_loss(model_output, noise, mean=False).mean([1, 2, 3])
345
- loss = loss.mean()
346
-
347
- return loss
348
-
349
- # 预测噪声
350
- def apply_model(self, x_noisy, hint, t, reference_clip, reference_dino):
351
-
352
- # 预测 control
353
- control = self.control_model(x_noisy, hint, t, reference_dino)
354
-
355
- # 调用 PBE
356
- model_output = self.model(x_noisy, t, reference_clip, control)
357
-
358
- return model_output
359
-
360
- # 优化器
361
- def configure_optimizers(self):
362
- # 学习率设置
363
- lr = self.learning_rate
364
- params = list(self.control_model.parameters())+list(self.linear.parameters())
365
- opt = torch.optim.AdamW(params, lr=lr)
366
-
367
- return opt
368
-
369
- # 采样
370
- @torch.no_grad()
371
- def sample_log(self, batch, ddim_steps=50, ddim_eta=0.):
372
- z_new, reference, hint = self.get_input(batch)
373
- x, _, mask, _, _ = super().get_input(batch)
374
- log = dict()
375
-
376
- # log["reference"] = reference
377
- # reconstruction = 1. / self.scale_factor * z_new[:,:4,:,:]
378
- # log["reconstruction"] = self.first_stage_model.decode(reconstruction)
379
- log["mask"] = mask
380
-
381
- test_model_kwargs = {}
382
- test_model_kwargs['inpaint_image'] = z_new[:,4:8,:,:]
383
- test_model_kwargs['inpaint_mask'] = z_new[:,8:,:,:]
384
- ddim_sampler = DDIMSampler(self)
385
- shape = (self.channels, self.image_size, self.image_size)
386
- samples, _ = ddim_sampler.sample(ddim_steps,
387
- reference.shape[0],
388
- shape,
389
- hint,
390
- reference,
391
- verbose=False,
392
- eta=ddim_eta,
393
- test_model_kwargs=test_model_kwargs)
394
- samples = 1. / self.scale_factor * samples
395
- x_samples = self.first_stage_model.decode(samples[:,:4,:,:])
396
- # log["samples"] = x_samples
397
-
398
- x = torchvision.transforms.Resize([512, 512])(x)
399
- reference = torchvision.transforms.Resize([512, 512])(reference)
400
- x_samples = torchvision.transforms.Resize([512, 512])(x_samples)
401
- log["grid"] = torch.cat((x, reference, x_samples), dim=2)
402
-
403
- return log
404
-
405
-
406
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/ddim.py DELETED
@@ -1,265 +0,0 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
- import numpy as np
5
- from tqdm import tqdm
6
- from functools import partial
7
-
8
- from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
9
- extract_into_tensor
10
-
11
-
12
- class DDIMSampler(object):
13
- def __init__(self, model, schedule="linear", **kwargs):
14
- super().__init__()
15
- self.model = model
16
- self.ddpm_num_timesteps = model.num_timesteps
17
- self.schedule = schedule
18
-
19
- def register_buffer(self, name, attr):
20
- if type(attr) == torch.Tensor:
21
- if attr.device != torch.device("cuda"):
22
- attr = attr.to(torch.device("cuda"))
23
- setattr(self, name, attr)
24
-
25
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28
- alphas_cumprod = self.model.alphas_cumprod
29
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31
-
32
- self.register_buffer('betas', to_torch(self.model.betas))
33
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35
-
36
- # calculations for diffusion q(x_t | x_{t-1}) and others
37
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42
-
43
- # ddim sampling parameters
44
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45
- ddim_timesteps=self.ddim_timesteps,
46
- eta=ddim_eta,verbose=verbose)
47
-
48
- self.register_buffer('ddim_sigmas', ddim_sigmas)
49
- self.register_buffer('ddim_alphas', ddim_alphas)
50
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
-
57
- @torch.no_grad()
58
- def sample(self,
59
- S,
60
- batch_size,
61
- shape,
62
- pose,
63
- conditioning=None,
64
- callback=None,
65
- normals_sequence=None,
66
- img_callback=None,
67
- quantize_x0=False,
68
- eta=0.,
69
- mask=None,
70
- x0=None,
71
- temperature=1.,
72
- noise_dropout=0.,
73
- score_corrector=None,
74
- corrector_kwargs=None,
75
- verbose=True,
76
- x_T=None,
77
- log_every_t=100,
78
- unconditional_guidance_scale=1.,
79
- unconditional_conditioning=None,
80
- **kwargs
81
- ):
82
- if conditioning is not None:
83
- if isinstance(conditioning, dict):
84
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85
- if cbs != batch_size:
86
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
- else:
88
- if conditioning.shape[0] != batch_size:
89
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90
-
91
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92
- # sampling
93
- C, H, W = shape
94
- size = (batch_size, C, H, W)
95
- print(f'Data shape for DDIM sampling is {size}, eta {eta}')
96
-
97
- samples, intermediates = self.ddim_sampling(conditioning, size, pose,
98
- callback=callback,
99
- img_callback=img_callback,
100
- quantize_denoised=quantize_x0,
101
- mask=mask, x0=x0,
102
- ddim_use_original_steps=False,
103
- noise_dropout=noise_dropout,
104
- temperature=temperature,
105
- score_corrector=score_corrector,
106
- corrector_kwargs=corrector_kwargs,
107
- x_T=x_T,
108
- log_every_t=log_every_t,
109
- unconditional_guidance_scale=unconditional_guidance_scale,
110
- unconditional_conditioning=unconditional_conditioning,
111
- **kwargs
112
- )
113
- return samples, intermediates
114
-
115
- @torch.no_grad()
116
- def ddim_sampling(self, cond, shape, pose,
117
- x_T=None, ddim_use_original_steps=False,
118
- callback=None, timesteps=None, quantize_denoised=False,
119
- mask=None, x0=None, img_callback=None, log_every_t=100,
120
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
121
- unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs):
122
- device = self.model.betas.device
123
- b = shape[0]
124
- if x_T is None:
125
- img = torch.randn(shape, device=device)
126
- else:
127
- img = x_T
128
-
129
- if timesteps is None:
130
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
131
- elif timesteps is not None and not ddim_use_original_steps:
132
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
133
- timesteps = self.ddim_timesteps[:subset_end]
134
-
135
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
136
- time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
137
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
138
- print(f"Running DDIM Sampling with {total_steps} timesteps")
139
-
140
- iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
141
-
142
- for i, step in enumerate(iterator):
143
- index = total_steps - i - 1
144
- ts = torch.full((b,), step, device=device, dtype=torch.long)
145
-
146
- outs = self.p_sample_ddim(img, cond, ts, pose, index=index, use_original_steps=ddim_use_original_steps,
147
- quantize_denoised=quantize_denoised, temperature=temperature,
148
- noise_dropout=noise_dropout, score_corrector=score_corrector,
149
- corrector_kwargs=corrector_kwargs,
150
- unconditional_guidance_scale=unconditional_guidance_scale,
151
- unconditional_conditioning=unconditional_conditioning,**kwargs)
152
- img, pred_x0 = outs
153
- if callback: callback(i)
154
- if img_callback: img_callback(pred_x0, i)
155
-
156
- if index % log_every_t == 0 or index == total_steps - 1:
157
- intermediates['x_inter'].append(img)
158
- intermediates['pred_x0'].append(pred_x0)
159
-
160
- return img, intermediates
161
-
162
- @torch.no_grad()
163
- def p_sample_ddim(self, x, c, t, pose, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
164
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
165
- unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs):
166
- b, *_, device = *x.shape, x.device
167
-
168
- if 'test_model_kwargs' in kwargs:
169
- kwargs=kwargs['test_model_kwargs']
170
- x = torch.cat([x, kwargs['inpaint_image'], kwargs['inpaint_mask']],dim=1)
171
- elif 'rest' in kwargs:
172
- x = torch.cat((x, kwargs['rest']), dim=1)
173
- else:
174
- raise Exception("kwargs must contain either 'test_model_kwargs' or 'rest' key")
175
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
176
- reference_clip = self.model.cond_stage_model.encode(c)
177
- reference_clip= self.model.proj_out(reference_clip)
178
- dino = self.model.dinov2_vitl14(c,is_training=True)
179
- dino1 = dino["x_norm_clstoken"].unsqueeze(1)
180
- dino2 = dino["x_norm_patchtokens"]
181
- reference_dino = torch.cat((dino1, dino2), dim=1)
182
- reference_dino = self.model.linear(reference_dino)
183
- control = self.model.control_model(x, pose, t, reference_dino)
184
- e_t = self.model.model(x, t, reference_clip, control)
185
- else:
186
- x_in = torch.cat([x] * 2)
187
- t_in = torch.cat([t] * 2)
188
- c_in = torch.cat([unconditional_conditioning, c])
189
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
190
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
191
-
192
- if score_corrector is not None:
193
- assert self.model.parameterization == "eps"
194
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
195
-
196
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
197
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
198
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
199
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
200
- # select parameters corresponding to the currently considered timestep
201
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205
-
206
- # current prediction for x_0
207
- if x.shape[1]!=4:
208
- pred_x0 = (x[:,:4,:,:] - sqrt_one_minus_at * e_t) / a_t.sqrt()
209
- else:
210
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
211
-
212
- if quantize_denoised:
213
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
214
- # direction pointing to x_t
215
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
216
- noise = sigma_t * noise_like(dir_xt.shape, device, repeat_noise) * temperature
217
- if noise_dropout > 0.:
218
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
219
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
220
-
221
- # noise
222
- # noise = torch.randn_like(kwargs['inpaint_image'])
223
- # x_noisy = self.model.q_sample(x_start=kwargs['inpaint_image'], t=t, noise=noise)
224
- # x_prev = x_prev*(1-kwargs['inpaint_mask'])+x_noisy*(kwargs['inpaint_mask'])
225
-
226
-
227
- return x_prev, pred_x0
228
-
229
- @torch.no_grad()
230
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
231
- # fast, but does not allow for exact reconstruction
232
- # t serves as an index to gather the correct alphas
233
- if use_original_steps:
234
- sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
235
- sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
236
- else:
237
- sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
238
- sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
239
-
240
- if noise is None:
241
- noise = torch.randn_like(x0)
242
- return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
243
- extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
244
-
245
- @torch.no_grad()
246
- def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
247
- use_original_steps=False):
248
-
249
- timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
250
- timesteps = timesteps[:t_start]
251
-
252
- time_range = np.flip(timesteps)
253
- total_steps = timesteps.shape[0]
254
- print(f"Running DDIM Sampling with {total_steps} timesteps")
255
-
256
- iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
257
- x_dec = x_latent
258
- for i, step in enumerate(iterator):
259
- index = total_steps - i - 1
260
- ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
261
- x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
262
- unconditional_guidance_scale=unconditional_guidance_scale,
263
- unconditional_conditioning=unconditional_conditioning)
264
- return x_dec
265
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/ddpm.py DELETED
@@ -1,143 +0,0 @@
1
- import torch
2
- import einops
3
- import torch.nn as nn
4
- import numpy as np
5
- import pytorch_lightning as pl
6
- from torch.optim.lr_scheduler import LambdaLR
7
- from einops import rearrange, repeat
8
- from functools import partial
9
- from torchvision.utils import make_grid
10
- from ldm.util import default, count_params, instantiate_from_config
11
- from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
12
- from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor
13
- from ldm.models.diffusion.ddim import DDIMSampler
14
- from torchvision.transforms import Resize
15
- import random
16
-
17
-
18
-
19
- def disabled_train(self, mode=True):
20
- return self
21
-
22
-
23
- class DiffusionWrapper(pl.LightningModule):
24
- def __init__(self, unet_config):
25
- super().__init__()
26
- self.diffusion_model = instantiate_from_config(unet_config)
27
-
28
- def forward(self, x, timesteps=None, context=None, control=None):
29
- out = self.diffusion_model(x, timesteps, context, control)
30
- return out
31
-
32
-
33
- class DDPM(pl.LightningModule):
34
- def __init__(self,
35
- unet_config,
36
- linear_start=1e-4, # 0.00085
37
- linear_end=2e-2, # 0.0120
38
- log_every_t=100, # 200
39
- timesteps=1000, # 1000
40
- image_size=256, # 32
41
- channels=3, # 4
42
- u_cond_percent=0, # 0.2
43
- use_ema=True, # False
44
- beta_schedule="linear",
45
- loss_type="l2",
46
- clip_denoised=True,
47
- cosine_s=8e-3,
48
- original_elbo_weight=0.,
49
- v_posterior=0.,
50
- l_simple_weight=1.,
51
- parameterization="eps",
52
- use_positional_encodings=False,
53
- learn_logvar=False,
54
- logvar_init=0.):
55
- super().__init__()
56
- self.parameterization = parameterization
57
- self.cond_stage_model = None
58
- self.clip_denoised = clip_denoised
59
- self.log_every_t = log_every_t
60
- self.image_size = image_size
61
- self.channels = channels
62
- self.u_cond_percent=u_cond_percent
63
- self.use_positional_encodings = use_positional_encodings
64
- self.model = DiffusionWrapper(unet_config) # 调用 UNet 模型
65
-
66
- self.use_ema = use_ema
67
- self.use_scheduler = True
68
- self.v_posterior = v_posterior
69
- self.original_elbo_weight = original_elbo_weight
70
- self.l_simple_weight = l_simple_weight
71
- self.register_schedule(beta_schedule=beta_schedule, # "linear"
72
- timesteps=timesteps, # 1000
73
- linear_start=linear_start, # 0.00085
74
- linear_end=linear_end, # 0.0120
75
- cosine_s=cosine_s) # 8e-3
76
- self.loss_type = loss_type
77
- self.learn_logvar = learn_logvar
78
- self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
79
-
80
- def register_schedule(self,
81
- beta_schedule="linear",
82
- timesteps=1000,
83
- linear_start=0.00085,
84
- linear_end=0.0120,
85
- cosine_s=8e-3):
86
- betas = make_beta_schedule(beta_schedule,
87
- timesteps,
88
- linear_start=linear_start,
89
- linear_end=linear_end,
90
- cosine_s=cosine_s)
91
- alphas = 1. - betas
92
- alphas_cumprod = np.cumprod(alphas, axis=0)
93
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
94
-
95
- timesteps, = betas.shape
96
- self.num_timesteps = int(timesteps)
97
- self.linear_start = linear_start
98
- self.linear_end = linear_end
99
- to_torch = partial(torch.tensor, dtype=torch.float32)
100
- self.register_buffer('betas', to_torch(betas))
101
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
102
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
103
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
104
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
105
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
106
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
107
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
108
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + self.v_posterior * betas
109
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
110
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
111
- self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
112
- self.register_buffer('posterior_mean_coef2', to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
113
- lvlb_weights = self.betas ** 2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
114
- lvlb_weights[0] = lvlb_weights[1]
115
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
116
-
117
- def get_input(self, batch):
118
- x = batch['GT']
119
- mask = batch['inpaint_mask']
120
- inpaint = batch['inpaint_image']
121
- reference = batch['ref_imgs']
122
- hint = batch['hint']
123
-
124
- x = x.to(memory_format=torch.contiguous_format).float()
125
- mask = mask.to(memory_format=torch.contiguous_format).float()
126
- inpaint = inpaint.to(memory_format=torch.contiguous_format).float()
127
- reference = reference.to(memory_format=torch.contiguous_format).float()
128
- hint = hint.to(memory_format=torch.contiguous_format).float()
129
-
130
- return x, inpaint, mask, reference, hint
131
-
132
- def q_sample(self, x_start, t, noise=None):
133
- noise = default(noise, lambda: torch.randn_like(x_start))
134
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
135
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
136
-
137
- def get_loss(self, pred, target, mean=True):
138
- if mean:
139
- loss = torch.nn.functional.mse_loss(target, pred)
140
- else:
141
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
142
- return loss
143
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/plms.py DELETED
@@ -1,239 +0,0 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
- import numpy as np
5
- from tqdm import tqdm
6
- from functools import partial
7
-
8
- from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
-
10
-
11
- class PLMSSampler(object):
12
- def __init__(self, model, schedule="linear", **kwargs):
13
- super().__init__()
14
- self.model = model
15
- self.ddpm_num_timesteps = model.num_timesteps
16
- self.schedule = schedule
17
-
18
- def register_buffer(self, name, attr):
19
- if type(attr) == torch.Tensor:
20
- if attr.device != torch.device("cuda"):
21
- attr = attr.to(torch.device("cuda"))
22
- setattr(self, name, attr)
23
-
24
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
- if ddim_eta != 0:
26
- raise ValueError('ddim_eta must be 0 for PLMS')
27
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29
- alphas_cumprod = self.model.alphas_cumprod
30
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32
-
33
- self.register_buffer('betas', to_torch(self.model.betas))
34
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36
-
37
- # calculations for diffusion q(x_t | x_{t-1}) and others
38
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43
-
44
- # ddim sampling parameters
45
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46
- ddim_timesteps=self.ddim_timesteps,
47
- eta=ddim_eta,verbose=verbose)
48
- self.register_buffer('ddim_sigmas', ddim_sigmas)
49
- self.register_buffer('ddim_alphas', ddim_alphas)
50
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
-
57
- @torch.no_grad()
58
- def sample(self,
59
- S,
60
- batch_size,
61
- shape,
62
- conditioning=None,
63
- callback=None,
64
- normals_sequence=None,
65
- img_callback=None,
66
- quantize_x0=False,
67
- eta=0.,
68
- mask=None,
69
- x0=None,
70
- temperature=1.,
71
- noise_dropout=0.,
72
- score_corrector=None,
73
- corrector_kwargs=None,
74
- verbose=True,
75
- x_T=None,
76
- log_every_t=100,
77
- unconditional_guidance_scale=1.,
78
- unconditional_conditioning=None,
79
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80
- **kwargs
81
- ):
82
- if conditioning is not None:
83
- if isinstance(conditioning, dict):
84
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85
- if cbs != batch_size:
86
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
- else:
88
- if conditioning.shape[0] != batch_size:
89
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90
-
91
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92
- # sampling
93
- C, H, W = shape
94
- size = (batch_size, C, H, W)
95
- print(f'Data shape for PLMS sampling is {size}')
96
-
97
- samples, intermediates = self.plms_sampling(conditioning, size,
98
- callback=callback,
99
- img_callback=img_callback,
100
- quantize_denoised=quantize_x0,
101
- mask=mask, x0=x0,
102
- ddim_use_original_steps=False,
103
- noise_dropout=noise_dropout,
104
- temperature=temperature,
105
- score_corrector=score_corrector,
106
- corrector_kwargs=corrector_kwargs,
107
- x_T=x_T,
108
- log_every_t=log_every_t,
109
- unconditional_guidance_scale=unconditional_guidance_scale,
110
- unconditional_conditioning=unconditional_conditioning,
111
- **kwargs
112
- )
113
- return samples, intermediates
114
-
115
- @torch.no_grad()
116
- def plms_sampling(self, cond, shape,
117
- x_T=None, ddim_use_original_steps=False,
118
- callback=None, timesteps=None, quantize_denoised=False,
119
- mask=None, x0=None, img_callback=None, log_every_t=100,
120
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
121
- unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs):
122
- device = self.model.betas.device
123
- b = shape[0]
124
- if x_T is None:
125
- img = torch.randn(shape, device=device)
126
- else:
127
- img = x_T
128
-
129
- if timesteps is None:
130
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
131
- elif timesteps is not None and not ddim_use_original_steps:
132
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
133
- timesteps = self.ddim_timesteps[:subset_end]
134
-
135
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
136
- time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
137
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
138
- print(f"Running PLMS Sampling with {total_steps} timesteps")
139
-
140
- iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
141
- old_eps = []
142
-
143
- for i, step in enumerate(iterator):
144
- index = total_steps - i - 1
145
- ts = torch.full((b,), step, device=device, dtype=torch.long)
146
- ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
147
-
148
- if mask is not None:
149
- assert x0 is not None
150
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
151
- img = img_orig * mask + (1. - mask) * img
152
-
153
- outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
154
- quantize_denoised=quantize_denoised, temperature=temperature,
155
- noise_dropout=noise_dropout, score_corrector=score_corrector,
156
- corrector_kwargs=corrector_kwargs,
157
- unconditional_guidance_scale=unconditional_guidance_scale,
158
- unconditional_conditioning=unconditional_conditioning,
159
- old_eps=old_eps, t_next=ts_next,**kwargs)
160
- img, pred_x0, e_t = outs
161
- old_eps.append(e_t)
162
- if len(old_eps) >= 4:
163
- old_eps.pop(0)
164
- if callback: callback(i)
165
- if img_callback: img_callback(pred_x0, i)
166
-
167
- if index % log_every_t == 0 or index == total_steps - 1:
168
- intermediates['x_inter'].append(img)
169
- intermediates['pred_x0'].append(pred_x0)
170
-
171
- return img, intermediates
172
-
173
- @torch.no_grad()
174
- def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
175
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
176
- unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,**kwargs):
177
- b, *_, device = *x.shape, x.device
178
- def get_model_output(x, t):
179
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180
- e_t = self.model.apply_model(x, t, c)
181
- else:
182
- x_in = torch.cat([x] * 2)
183
- t_in = torch.cat([t] * 2)
184
- c_in = torch.cat([unconditional_conditioning, c])
185
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187
-
188
- if score_corrector is not None:
189
- assert self.model.parameterization == "eps"
190
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191
-
192
- return e_t
193
-
194
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198
-
199
- def get_x_prev_and_pred_x0(e_t, index):
200
- # select parameters corresponding to the currently considered timestep
201
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205
-
206
- # current prediction for x_0
207
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208
- if quantize_denoised:
209
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210
- # direction pointing to x_t
211
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213
- if noise_dropout > 0.:
214
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216
- return x_prev, pred_x0
217
- kwargs=kwargs['test_model_kwargs']
218
- print(x.shape,kwargs['inpaint_image'].shape,kwargs['inpaint_mask'].shape)
219
- x_new=torch.cat([x,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1)
220
- e_t = get_model_output(x_new, t)
221
- if len(old_eps) == 0:
222
- # Pseudo Improved Euler (2nd order)
223
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
224
- x_prev_new=torch.cat([x_prev,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1)
225
- e_t_next = get_model_output(x_prev_new, t_next)
226
- e_t_prime = (e_t + e_t_next) / 2
227
- elif len(old_eps) == 1:
228
- # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
229
- e_t_prime = (3 * e_t - old_eps[-1]) / 2
230
- elif len(old_eps) == 2:
231
- # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
232
- e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
233
- elif len(old_eps) >= 3:
234
- # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
235
- e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
236
-
237
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
238
-
239
- return x_prev, pred_x0, e_t
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/.DS_Store DELETED
Binary file (6.15 kB)
 
ldm/modules/__pycache__/attention.cpython-38.pyc DELETED
Binary file (11.1 kB)
 
ldm/modules/__pycache__/x_transformer.cpython-38.pyc DELETED
Binary file (18.3 kB)
 
ldm/modules/attention.py DELETED
@@ -1,345 +0,0 @@
1
- from inspect import isfunction
2
- import math
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn, einsum
6
- from einops import rearrange, repeat
7
-
8
- from ldm.modules.diffusionmodules.util import checkpoint
9
-
10
- try:
11
- import xformers
12
- import xformers.ops
13
- XFORMERS_IS_AVAILBLE = True
14
- except:
15
- XFORMERS_IS_AVAILBLE = False
16
-
17
-
18
- def exists(val):
19
- return val is not None
20
-
21
-
22
- def uniq(arr):
23
- return{el: True for el in arr}.keys()
24
-
25
-
26
- def default(val, d):
27
- if exists(val):
28
- return val
29
- return d() if isfunction(d) else d
30
-
31
-
32
- def max_neg_value(t):
33
- return -torch.finfo(t.dtype).max
34
-
35
-
36
- def init_(tensor):
37
- dim = tensor.shape[-1]
38
- std = 1 / math.sqrt(dim)
39
- tensor.uniform_(-std, std)
40
- return tensor
41
-
42
-
43
- # feedforward
44
- class GEGLU(nn.Module):
45
- def __init__(self, dim_in, dim_out):
46
- super().__init__()
47
- self.proj = nn.Linear(dim_in, dim_out * 2)
48
-
49
- def forward(self, x):
50
- x, gate = self.proj(x).chunk(2, dim=-1)
51
- return x * F.gelu(gate)
52
-
53
-
54
- class FeedForward(nn.Module):
55
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
56
- super().__init__()
57
- inner_dim = int(dim * mult)
58
- dim_out = default(dim_out, dim)
59
- project_in = nn.Sequential(
60
- nn.Linear(dim, inner_dim),
61
- nn.GELU()
62
- ) if not glu else GEGLU(dim, inner_dim)
63
-
64
- self.net = nn.Sequential(
65
- project_in,
66
- nn.Dropout(dropout),
67
- nn.Linear(inner_dim, dim_out)
68
- )
69
-
70
- def forward(self, x):
71
- return self.net(x)
72
-
73
-
74
- def zero_module(module):
75
- """
76
- Zero out the parameters of a module and return it.
77
- """
78
- for p in module.parameters():
79
- p.detach().zero_()
80
- return module
81
-
82
-
83
- def Normalize(in_channels):
84
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
85
-
86
-
87
- class LinearAttention(nn.Module):
88
- def __init__(self, dim, heads=4, dim_head=32):
89
- super().__init__()
90
- self.heads = heads
91
- hidden_dim = dim_head * heads
92
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
93
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
94
-
95
- def forward(self, x):
96
- b, c, h, w = x.shape
97
- qkv = self.to_qkv(x)
98
- q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
99
- k = k.softmax(dim=-1)
100
- context = torch.einsum('bhdn,bhen->bhde', k, v)
101
- out = torch.einsum('bhde,bhdn->bhen', context, q)
102
- out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
103
- return self.to_out(out)
104
-
105
-
106
- class SpatialSelfAttention(nn.Module):
107
- def __init__(self, in_channels):
108
- super().__init__()
109
- self.in_channels = in_channels
110
-
111
- self.norm = Normalize(in_channels)
112
- self.q = torch.nn.Conv2d(in_channels,
113
- in_channels,
114
- kernel_size=1,
115
- stride=1,
116
- padding=0)
117
- self.k = torch.nn.Conv2d(in_channels,
118
- in_channels,
119
- kernel_size=1,
120
- stride=1,
121
- padding=0)
122
- self.v = torch.nn.Conv2d(in_channels,
123
- in_channels,
124
- kernel_size=1,
125
- stride=1,
126
- padding=0)
127
- self.proj_out = torch.nn.Conv2d(in_channels,
128
- in_channels,
129
- kernel_size=1,
130
- stride=1,
131
- padding=0)
132
-
133
- def forward(self, x):
134
- h_ = x
135
- h_ = self.norm(h_)
136
- q = self.q(h_)
137
- k = self.k(h_)
138
- v = self.v(h_)
139
-
140
- # compute attention
141
- b,c,h,w = q.shape
142
- q = rearrange(q, 'b c h w -> b (h w) c')
143
- k = rearrange(k, 'b c h w -> b c (h w)')
144
- w_ = torch.einsum('bij,bjk->bik', q, k)
145
-
146
- w_ = w_ * (int(c)**(-0.5))
147
- w_ = torch.nn.functional.softmax(w_, dim=2)
148
-
149
- # attend to values
150
- v = rearrange(v, 'b c h w -> b c (h w)')
151
- w_ = rearrange(w_, 'b i j -> b j i')
152
- h_ = torch.einsum('bij,bjk->bik', v, w_)
153
- h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
154
- h_ = self.proj_out(h_)
155
-
156
- return x+h_
157
-
158
-
159
- class CrossAttention(nn.Module):
160
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
161
- super().__init__()
162
- inner_dim = dim_head * heads
163
- context_dim = default(context_dim, query_dim)
164
-
165
- self.scale = dim_head ** -0.5
166
- self.heads = heads
167
-
168
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
169
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
170
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
171
-
172
- self.to_out = nn.Sequential(
173
- nn.Linear(inner_dim, query_dim),
174
- nn.Dropout(dropout)
175
- )
176
-
177
- def forward(self, x, context=None, mask=None):
178
- h = self.heads
179
-
180
- q = self.to_q(x)
181
- context = default(context, x)
182
- k = self.to_k(context)
183
- v = self.to_v(context)
184
-
185
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
186
-
187
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
188
-
189
- if exists(mask):
190
- mask = rearrange(mask, 'b ... -> b (...)')
191
- max_neg_value = -torch.finfo(sim.dtype).max
192
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
193
- sim.masked_fill_(~mask, max_neg_value)
194
-
195
- # attention, what we cannot get enough of
196
- attn = sim.softmax(dim=-1)
197
-
198
- out = einsum('b i j, b j d -> b i d', attn, v)
199
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
200
- return self.to_out(out)
201
-
202
-
203
- class MemoryEfficientCrossAttention(nn.Module):
204
- # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
205
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
206
- super().__init__()
207
- print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
208
- f"{heads} heads.")
209
- inner_dim = dim_head * heads
210
- context_dim = default(context_dim, query_dim)
211
-
212
- self.heads = heads
213
- self.dim_head = dim_head
214
-
215
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
216
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
217
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
218
-
219
- self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
220
- self.attention_op: Optional[Any] = None
221
-
222
- def forward(self, x, context=None, mask=None):
223
- q = self.to_q(x)
224
- context = default(context, x)
225
- k = self.to_k(context)
226
- v = self.to_v(context)
227
-
228
- b, _, _ = q.shape
229
- q, k, v = map(
230
- lambda t: t.unsqueeze(3)
231
- .reshape(b, t.shape[1], self.heads, self.dim_head)
232
- .permute(0, 2, 1, 3)
233
- .reshape(b * self.heads, t.shape[1], self.dim_head)
234
- .contiguous(),
235
- (q, k, v),
236
- )
237
-
238
- # actually compute the attention, what we cannot get enough of
239
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
240
-
241
- if exists(mask):
242
- raise NotImplementedError
243
- out = (
244
- out.unsqueeze(0)
245
- .reshape(b, self.heads, out.shape[1], self.dim_head)
246
- .permute(0, 2, 1, 3)
247
- .reshape(b, out.shape[1], self.heads * self.dim_head)
248
- )
249
- return self.to_out(out)
250
-
251
- class BasicTransformerBlock(nn.Module):
252
- ATTENTION_MODES = {
253
- "softmax": CrossAttention, # vanilla attention
254
- "softmax-xformers": MemoryEfficientCrossAttention
255
- }
256
- def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
257
- disable_self_attn=False):
258
- super().__init__()
259
- attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
260
- assert attn_mode in self.ATTENTION_MODES
261
- attn_cls = self.ATTENTION_MODES[attn_mode]
262
- self.disable_self_attn = disable_self_attn
263
- self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
264
- context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
265
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
266
- self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
267
- heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
268
- self.norm1 = nn.LayerNorm(dim)
269
- self.norm2 = nn.LayerNorm(dim)
270
- self.norm3 = nn.LayerNorm(dim)
271
- self.checkpoint = checkpoint
272
-
273
- def forward(self, x, context=None):
274
- return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
275
-
276
- def _forward(self, x, context=None):
277
- x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
278
- x = self.attn2(self.norm2(x), context=context) + x
279
- x = self.ff(self.norm3(x)) + x
280
- return x
281
-
282
-
283
- class SpatialTransformer(nn.Module):
284
- """
285
- Transformer block for image-like data.
286
- First, project the input (aka embedding)
287
- and reshape to b, t, d.
288
- Then apply standard transformer action.
289
- Finally, reshape to image
290
- NEW: use_linear for more efficiency instead of the 1x1 convs
291
- """
292
- def __init__(self, in_channels, n_heads, d_head,
293
- depth=1, dropout=0., context_dim=None,
294
- disable_self_attn=False, use_linear=False,
295
- use_checkpoint=True):
296
- super().__init__()
297
- if exists(context_dim) and not isinstance(context_dim, list):
298
- context_dim = [context_dim]
299
- self.in_channels = in_channels
300
- inner_dim = n_heads * d_head
301
- self.norm = Normalize(in_channels)
302
- if not use_linear:
303
- self.proj_in = nn.Conv2d(in_channels,
304
- inner_dim,
305
- kernel_size=1,
306
- stride=1,
307
- padding=0)
308
- else:
309
- self.proj_in = nn.Linear(in_channels, inner_dim)
310
-
311
- self.transformer_blocks = nn.ModuleList(
312
- [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
313
- disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
314
- for d in range(depth)]
315
- )
316
- if not use_linear:
317
- self.proj_out = zero_module(nn.Conv2d(inner_dim,
318
- in_channels,
319
- kernel_size=1,
320
- stride=1,
321
- padding=0))
322
- else:
323
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
324
- self.use_linear = use_linear
325
-
326
- def forward(self, x, context=None):
327
- # note: if no context is given, cross-attention defaults to self-attention
328
- if not isinstance(context, list):
329
- context = [context]
330
- b, c, h, w = x.shape
331
- x_in = x
332
- x = self.norm(x)
333
- if not self.use_linear:
334
- x = self.proj_in(x)
335
- x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
336
- if self.use_linear:
337
- x = self.proj_in(x)
338
- for i, block in enumerate(self.transformer_blocks):
339
- x = block(x, context=context[i])
340
- if self.use_linear:
341
- x = self.proj_out(x)
342
- x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
343
- if not self.use_linear:
344
- x = self.proj_out(x)
345
- return x + x_in
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/diffusionmodules/.DS_Store DELETED
Binary file (6.15 kB)
 
ldm/modules/diffusionmodules/__init__.py DELETED
File without changes
ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (158 Bytes)
 
ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc DELETED
Binary file (20.7 kB)
 
ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc DELETED
Binary file (16.8 kB)
 
ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc DELETED
Binary file (8.96 kB)
 
ldm/modules/diffusionmodules/model.py DELETED
@@ -1,835 +0,0 @@
1
- # pytorch_diffusion + derived encoder decoder
2
- import math
3
- import torch
4
- import torch.nn as nn
5
- import numpy as np
6
- from einops import rearrange
7
-
8
- from ldm.util import instantiate_from_config
9
- from ldm.modules.attention import LinearAttention
10
-
11
-
12
- def get_timestep_embedding(timesteps, embedding_dim):
13
- """
14
- This matches the implementation in Denoising Diffusion Probabilistic Models:
15
- From Fairseq.
16
- Build sinusoidal embeddings.
17
- This matches the implementation in tensor2tensor, but differs slightly
18
- from the description in Section 3.5 of "Attention Is All You Need".
19
- """
20
- assert len(timesteps.shape) == 1
21
-
22
- half_dim = embedding_dim // 2
23
- emb = math.log(10000) / (half_dim - 1)
24
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
- emb = emb.to(device=timesteps.device)
26
- emb = timesteps.float()[:, None] * emb[None, :]
27
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
- if embedding_dim % 2 == 1: # zero pad
29
- emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
- return emb
31
-
32
-
33
- def nonlinearity(x):
34
- # swish
35
- return x*torch.sigmoid(x)
36
-
37
-
38
- def Normalize(in_channels, num_groups=32):
39
- return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
40
-
41
-
42
- class Upsample(nn.Module):
43
- def __init__(self, in_channels, with_conv):
44
- super().__init__()
45
- self.with_conv = with_conv
46
- if self.with_conv:
47
- self.conv = torch.nn.Conv2d(in_channels,
48
- in_channels,
49
- kernel_size=3,
50
- stride=1,
51
- padding=1)
52
-
53
- def forward(self, x):
54
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
- if self.with_conv:
56
- x = self.conv(x)
57
- return x
58
-
59
-
60
- class Downsample(nn.Module):
61
- def __init__(self, in_channels, with_conv):
62
- super().__init__()
63
- self.with_conv = with_conv
64
- if self.with_conv:
65
- # no asymmetric padding in torch conv, must do it ourselves
66
- self.conv = torch.nn.Conv2d(in_channels,
67
- in_channels,
68
- kernel_size=3,
69
- stride=2,
70
- padding=0)
71
-
72
- def forward(self, x):
73
- if self.with_conv:
74
- pad = (0,1,0,1)
75
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
- x = self.conv(x)
77
- else:
78
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
- return x
80
-
81
-
82
- class ResnetBlock(nn.Module):
83
- def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84
- dropout, temb_channels=512):
85
- super().__init__()
86
- self.in_channels = in_channels
87
- out_channels = in_channels if out_channels is None else out_channels
88
- self.out_channels = out_channels
89
- self.use_conv_shortcut = conv_shortcut
90
-
91
- self.norm1 = Normalize(in_channels)
92
- self.conv1 = torch.nn.Conv2d(in_channels,
93
- out_channels,
94
- kernel_size=3,
95
- stride=1,
96
- padding=1)
97
- if temb_channels > 0:
98
- self.temb_proj = torch.nn.Linear(temb_channels,
99
- out_channels)
100
- self.norm2 = Normalize(out_channels)
101
- self.dropout = torch.nn.Dropout(dropout)
102
- self.conv2 = torch.nn.Conv2d(out_channels,
103
- out_channels,
104
- kernel_size=3,
105
- stride=1,
106
- padding=1)
107
- if self.in_channels != self.out_channels:
108
- if self.use_conv_shortcut:
109
- self.conv_shortcut = torch.nn.Conv2d(in_channels,
110
- out_channels,
111
- kernel_size=3,
112
- stride=1,
113
- padding=1)
114
- else:
115
- self.nin_shortcut = torch.nn.Conv2d(in_channels,
116
- out_channels,
117
- kernel_size=1,
118
- stride=1,
119
- padding=0)
120
-
121
- def forward(self, x, temb):
122
- h = x
123
- h = self.norm1(h)
124
- h = nonlinearity(h)
125
- h = self.conv1(h)
126
-
127
- if temb is not None:
128
- h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
129
-
130
- h = self.norm2(h)
131
- h = nonlinearity(h)
132
- h = self.dropout(h)
133
- h = self.conv2(h)
134
-
135
- if self.in_channels != self.out_channels:
136
- if self.use_conv_shortcut:
137
- x = self.conv_shortcut(x)
138
- else:
139
- x = self.nin_shortcut(x)
140
-
141
- return x+h
142
-
143
-
144
- class LinAttnBlock(LinearAttention):
145
- """to match AttnBlock usage"""
146
- def __init__(self, in_channels):
147
- super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
148
-
149
-
150
- class AttnBlock(nn.Module):
151
- def __init__(self, in_channels):
152
- super().__init__()
153
- self.in_channels = in_channels
154
-
155
- self.norm = Normalize(in_channels)
156
- self.q = torch.nn.Conv2d(in_channels,
157
- in_channels,
158
- kernel_size=1,
159
- stride=1,
160
- padding=0)
161
- self.k = torch.nn.Conv2d(in_channels,
162
- in_channels,
163
- kernel_size=1,
164
- stride=1,
165
- padding=0)
166
- self.v = torch.nn.Conv2d(in_channels,
167
- in_channels,
168
- kernel_size=1,
169
- stride=1,
170
- padding=0)
171
- self.proj_out = torch.nn.Conv2d(in_channels,
172
- in_channels,
173
- kernel_size=1,
174
- stride=1,
175
- padding=0)
176
-
177
-
178
- def forward(self, x):
179
- h_ = x
180
- h_ = self.norm(h_)
181
- q = self.q(h_)
182
- k = self.k(h_)
183
- v = self.v(h_)
184
-
185
- # compute attention
186
- b,c,h,w = q.shape
187
- q = q.reshape(b,c,h*w)
188
- q = q.permute(0,2,1) # b,hw,c
189
- k = k.reshape(b,c,h*w) # b,c,hw
190
- w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
191
- w_ = w_ * (int(c)**(-0.5))
192
- w_ = torch.nn.functional.softmax(w_, dim=2)
193
-
194
- # attend to values
195
- v = v.reshape(b,c,h*w)
196
- w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
197
- h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
198
- h_ = h_.reshape(b,c,h,w)
199
-
200
- h_ = self.proj_out(h_)
201
-
202
- return x+h_
203
-
204
-
205
- def make_attn(in_channels, attn_type="vanilla"):
206
- assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
207
- print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
208
- if attn_type == "vanilla":
209
- return AttnBlock(in_channels)
210
- elif attn_type == "none":
211
- return nn.Identity(in_channels)
212
- else:
213
- return LinAttnBlock(in_channels)
214
-
215
-
216
- class Model(nn.Module):
217
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
218
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
219
- resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
220
- super().__init__()
221
- if use_linear_attn: attn_type = "linear"
222
- self.ch = ch
223
- self.temb_ch = self.ch*4
224
- self.num_resolutions = len(ch_mult)
225
- self.num_res_blocks = num_res_blocks
226
- self.resolution = resolution
227
- self.in_channels = in_channels
228
-
229
- self.use_timestep = use_timestep
230
- if self.use_timestep:
231
- # timestep embedding
232
- self.temb = nn.Module()
233
- self.temb.dense = nn.ModuleList([
234
- torch.nn.Linear(self.ch,
235
- self.temb_ch),
236
- torch.nn.Linear(self.temb_ch,
237
- self.temb_ch),
238
- ])
239
-
240
- # downsampling
241
- self.conv_in = torch.nn.Conv2d(in_channels,
242
- self.ch,
243
- kernel_size=3,
244
- stride=1,
245
- padding=1)
246
-
247
- curr_res = resolution
248
- in_ch_mult = (1,)+tuple(ch_mult)
249
- self.down = nn.ModuleList()
250
- for i_level in range(self.num_resolutions):
251
- block = nn.ModuleList()
252
- attn = nn.ModuleList()
253
- block_in = ch*in_ch_mult[i_level]
254
- block_out = ch*ch_mult[i_level]
255
- for i_block in range(self.num_res_blocks):
256
- block.append(ResnetBlock(in_channels=block_in,
257
- out_channels=block_out,
258
- temb_channels=self.temb_ch,
259
- dropout=dropout))
260
- block_in = block_out
261
- if curr_res in attn_resolutions:
262
- attn.append(make_attn(block_in, attn_type=attn_type))
263
- down = nn.Module()
264
- down.block = block
265
- down.attn = attn
266
- if i_level != self.num_resolutions-1:
267
- down.downsample = Downsample(block_in, resamp_with_conv)
268
- curr_res = curr_res // 2
269
- self.down.append(down)
270
-
271
- # middle
272
- self.mid = nn.Module()
273
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
274
- out_channels=block_in,
275
- temb_channels=self.temb_ch,
276
- dropout=dropout)
277
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
278
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
279
- out_channels=block_in,
280
- temb_channels=self.temb_ch,
281
- dropout=dropout)
282
-
283
- # upsampling
284
- self.up = nn.ModuleList()
285
- for i_level in reversed(range(self.num_resolutions)):
286
- block = nn.ModuleList()
287
- attn = nn.ModuleList()
288
- block_out = ch*ch_mult[i_level]
289
- skip_in = ch*ch_mult[i_level]
290
- for i_block in range(self.num_res_blocks+1):
291
- if i_block == self.num_res_blocks:
292
- skip_in = ch*in_ch_mult[i_level]
293
- block.append(ResnetBlock(in_channels=block_in+skip_in,
294
- out_channels=block_out,
295
- temb_channels=self.temb_ch,
296
- dropout=dropout))
297
- block_in = block_out
298
- if curr_res in attn_resolutions:
299
- attn.append(make_attn(block_in, attn_type=attn_type))
300
- up = nn.Module()
301
- up.block = block
302
- up.attn = attn
303
- if i_level != 0:
304
- up.upsample = Upsample(block_in, resamp_with_conv)
305
- curr_res = curr_res * 2
306
- self.up.insert(0, up) # prepend to get consistent order
307
-
308
- # end
309
- self.norm_out = Normalize(block_in)
310
- self.conv_out = torch.nn.Conv2d(block_in,
311
- out_ch,
312
- kernel_size=3,
313
- stride=1,
314
- padding=1)
315
-
316
- def forward(self, x, t=None, context=None):
317
- #assert x.shape[2] == x.shape[3] == self.resolution
318
- if context is not None:
319
- # assume aligned context, cat along channel axis
320
- x = torch.cat((x, context), dim=1)
321
- if self.use_timestep:
322
- # timestep embedding
323
- assert t is not None
324
- temb = get_timestep_embedding(t, self.ch)
325
- temb = self.temb.dense[0](temb)
326
- temb = nonlinearity(temb)
327
- temb = self.temb.dense[1](temb)
328
- else:
329
- temb = None
330
-
331
- # downsampling
332
- hs = [self.conv_in(x)]
333
- for i_level in range(self.num_resolutions):
334
- for i_block in range(self.num_res_blocks):
335
- h = self.down[i_level].block[i_block](hs[-1], temb)
336
- if len(self.down[i_level].attn) > 0:
337
- h = self.down[i_level].attn[i_block](h)
338
- hs.append(h)
339
- if i_level != self.num_resolutions-1:
340
- hs.append(self.down[i_level].downsample(hs[-1]))
341
-
342
- # middle
343
- h = hs[-1]
344
- h = self.mid.block_1(h, temb)
345
- h = self.mid.attn_1(h)
346
- h = self.mid.block_2(h, temb)
347
-
348
- # upsampling
349
- for i_level in reversed(range(self.num_resolutions)):
350
- for i_block in range(self.num_res_blocks+1):
351
- h = self.up[i_level].block[i_block](
352
- torch.cat([h, hs.pop()], dim=1), temb)
353
- if len(self.up[i_level].attn) > 0:
354
- h = self.up[i_level].attn[i_block](h)
355
- if i_level != 0:
356
- h = self.up[i_level].upsample(h)
357
-
358
- # end
359
- h = self.norm_out(h)
360
- h = nonlinearity(h)
361
- h = self.conv_out(h)
362
- return h
363
-
364
- def get_last_layer(self):
365
- return self.conv_out.weight
366
-
367
-
368
- class Encoder(nn.Module):
369
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
370
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
371
- resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
372
- **ignore_kwargs):
373
- super().__init__()
374
- if use_linear_attn: attn_type = "linear"
375
- self.ch = ch
376
- self.temb_ch = 0
377
- self.num_resolutions = len(ch_mult)
378
- self.num_res_blocks = num_res_blocks
379
- self.resolution = resolution
380
- self.in_channels = in_channels
381
-
382
- # downsampling
383
- self.conv_in = torch.nn.Conv2d(in_channels,
384
- self.ch,
385
- kernel_size=3,
386
- stride=1,
387
- padding=1)
388
-
389
- curr_res = resolution
390
- in_ch_mult = (1,)+tuple(ch_mult)
391
- self.in_ch_mult = in_ch_mult
392
- self.down = nn.ModuleList()
393
- for i_level in range(self.num_resolutions):
394
- block = nn.ModuleList()
395
- attn = nn.ModuleList()
396
- block_in = ch*in_ch_mult[i_level]
397
- block_out = ch*ch_mult[i_level]
398
- for i_block in range(self.num_res_blocks):
399
- block.append(ResnetBlock(in_channels=block_in,
400
- out_channels=block_out,
401
- temb_channels=self.temb_ch,
402
- dropout=dropout))
403
- block_in = block_out
404
- if curr_res in attn_resolutions:
405
- attn.append(make_attn(block_in, attn_type=attn_type))
406
- down = nn.Module()
407
- down.block = block
408
- down.attn = attn
409
- if i_level != self.num_resolutions-1:
410
- down.downsample = Downsample(block_in, resamp_with_conv)
411
- curr_res = curr_res // 2
412
- self.down.append(down)
413
-
414
- # middle
415
- self.mid = nn.Module()
416
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
417
- out_channels=block_in,
418
- temb_channels=self.temb_ch,
419
- dropout=dropout)
420
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
421
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
422
- out_channels=block_in,
423
- temb_channels=self.temb_ch,
424
- dropout=dropout)
425
-
426
- # end
427
- self.norm_out = Normalize(block_in)
428
- self.conv_out = torch.nn.Conv2d(block_in,
429
- 2*z_channels if double_z else z_channels,
430
- kernel_size=3,
431
- stride=1,
432
- padding=1)
433
-
434
- def forward(self, x):
435
- # timestep embedding
436
- temb = None
437
-
438
- # downsampling
439
- hs = [self.conv_in(x)]
440
- for i_level in range(self.num_resolutions):
441
- for i_block in range(self.num_res_blocks):
442
- h = self.down[i_level].block[i_block](hs[-1], temb)
443
- if len(self.down[i_level].attn) > 0:
444
- h = self.down[i_level].attn[i_block](h)
445
- hs.append(h)
446
- if i_level != self.num_resolutions-1:
447
- hs.append(self.down[i_level].downsample(hs[-1]))
448
-
449
- # middle
450
- h = hs[-1]
451
- h = self.mid.block_1(h, temb)
452
- h = self.mid.attn_1(h)
453
- h = self.mid.block_2(h, temb)
454
-
455
- # end
456
- h = self.norm_out(h)
457
- h = nonlinearity(h)
458
- h = self.conv_out(h)
459
- return h
460
-
461
-
462
- class Decoder(nn.Module):
463
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
464
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
465
- resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
466
- attn_type="vanilla", **ignorekwargs):
467
- super().__init__()
468
- if use_linear_attn: attn_type = "linear"
469
- self.ch = ch
470
- self.temb_ch = 0
471
- self.num_resolutions = len(ch_mult)
472
- self.num_res_blocks = num_res_blocks
473
- self.resolution = resolution
474
- self.in_channels = in_channels
475
- self.give_pre_end = give_pre_end
476
- self.tanh_out = tanh_out
477
-
478
- # compute in_ch_mult, block_in and curr_res at lowest res
479
- in_ch_mult = (1,)+tuple(ch_mult)
480
- block_in = ch*ch_mult[self.num_resolutions-1]
481
- curr_res = resolution // 2**(self.num_resolutions-1)
482
- self.z_shape = (1,z_channels,curr_res,curr_res)
483
- print("Working with z of shape {} = {} dimensions.".format(
484
- self.z_shape, np.prod(self.z_shape)))
485
-
486
- # z to block_in
487
- self.conv_in = torch.nn.Conv2d(z_channels,
488
- block_in,
489
- kernel_size=3,
490
- stride=1,
491
- padding=1)
492
-
493
- # middle
494
- self.mid = nn.Module()
495
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
496
- out_channels=block_in,
497
- temb_channels=self.temb_ch,
498
- dropout=dropout)
499
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
500
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
501
- out_channels=block_in,
502
- temb_channels=self.temb_ch,
503
- dropout=dropout)
504
-
505
- # upsampling
506
- self.up = nn.ModuleList()
507
- for i_level in reversed(range(self.num_resolutions)):
508
- block = nn.ModuleList()
509
- attn = nn.ModuleList()
510
- block_out = ch*ch_mult[i_level]
511
- for i_block in range(self.num_res_blocks+1):
512
- block.append(ResnetBlock(in_channels=block_in,
513
- out_channels=block_out,
514
- temb_channels=self.temb_ch,
515
- dropout=dropout))
516
- block_in = block_out
517
- if curr_res in attn_resolutions:
518
- attn.append(make_attn(block_in, attn_type=attn_type))
519
- up = nn.Module()
520
- up.block = block
521
- up.attn = attn
522
- if i_level != 0:
523
- up.upsample = Upsample(block_in, resamp_with_conv)
524
- curr_res = curr_res * 2
525
- self.up.insert(0, up) # prepend to get consistent order
526
-
527
- # end
528
- self.norm_out = Normalize(block_in)
529
- self.conv_out = torch.nn.Conv2d(block_in,
530
- out_ch,
531
- kernel_size=3,
532
- stride=1,
533
- padding=1)
534
-
535
- def forward(self, z):
536
- #assert z.shape[1:] == self.z_shape[1:]
537
- self.last_z_shape = z.shape
538
-
539
- # timestep embedding
540
- temb = None
541
-
542
- # z to block_in
543
- h = self.conv_in(z)
544
-
545
- # middle
546
- h = self.mid.block_1(h, temb)
547
- h = self.mid.attn_1(h)
548
- h = self.mid.block_2(h, temb)
549
-
550
- # upsampling
551
- for i_level in reversed(range(self.num_resolutions)):
552
- for i_block in range(self.num_res_blocks+1):
553
- h = self.up[i_level].block[i_block](h, temb)
554
- if len(self.up[i_level].attn) > 0:
555
- h = self.up[i_level].attn[i_block](h)
556
- if i_level != 0:
557
- h = self.up[i_level].upsample(h)
558
-
559
- # end
560
- if self.give_pre_end:
561
- return h
562
-
563
- h = self.norm_out(h)
564
- h = nonlinearity(h)
565
- h = self.conv_out(h)
566
- if self.tanh_out:
567
- h = torch.tanh(h)
568
- return h
569
-
570
-
571
- class SimpleDecoder(nn.Module):
572
- def __init__(self, in_channels, out_channels, *args, **kwargs):
573
- super().__init__()
574
- self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
575
- ResnetBlock(in_channels=in_channels,
576
- out_channels=2 * in_channels,
577
- temb_channels=0, dropout=0.0),
578
- ResnetBlock(in_channels=2 * in_channels,
579
- out_channels=4 * in_channels,
580
- temb_channels=0, dropout=0.0),
581
- ResnetBlock(in_channels=4 * in_channels,
582
- out_channels=2 * in_channels,
583
- temb_channels=0, dropout=0.0),
584
- nn.Conv2d(2*in_channels, in_channels, 1),
585
- Upsample(in_channels, with_conv=True)])
586
- # end
587
- self.norm_out = Normalize(in_channels)
588
- self.conv_out = torch.nn.Conv2d(in_channels,
589
- out_channels,
590
- kernel_size=3,
591
- stride=1,
592
- padding=1)
593
-
594
- def forward(self, x):
595
- for i, layer in enumerate(self.model):
596
- if i in [1,2,3]:
597
- x = layer(x, None)
598
- else:
599
- x = layer(x)
600
-
601
- h = self.norm_out(x)
602
- h = nonlinearity(h)
603
- x = self.conv_out(h)
604
- return x
605
-
606
-
607
- class UpsampleDecoder(nn.Module):
608
- def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
609
- ch_mult=(2,2), dropout=0.0):
610
- super().__init__()
611
- # upsampling
612
- self.temb_ch = 0
613
- self.num_resolutions = len(ch_mult)
614
- self.num_res_blocks = num_res_blocks
615
- block_in = in_channels
616
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
617
- self.res_blocks = nn.ModuleList()
618
- self.upsample_blocks = nn.ModuleList()
619
- for i_level in range(self.num_resolutions):
620
- res_block = []
621
- block_out = ch * ch_mult[i_level]
622
- for i_block in range(self.num_res_blocks + 1):
623
- res_block.append(ResnetBlock(in_channels=block_in,
624
- out_channels=block_out,
625
- temb_channels=self.temb_ch,
626
- dropout=dropout))
627
- block_in = block_out
628
- self.res_blocks.append(nn.ModuleList(res_block))
629
- if i_level != self.num_resolutions - 1:
630
- self.upsample_blocks.append(Upsample(block_in, True))
631
- curr_res = curr_res * 2
632
-
633
- # end
634
- self.norm_out = Normalize(block_in)
635
- self.conv_out = torch.nn.Conv2d(block_in,
636
- out_channels,
637
- kernel_size=3,
638
- stride=1,
639
- padding=1)
640
-
641
- def forward(self, x):
642
- # upsampling
643
- h = x
644
- for k, i_level in enumerate(range(self.num_resolutions)):
645
- for i_block in range(self.num_res_blocks + 1):
646
- h = self.res_blocks[i_level][i_block](h, None)
647
- if i_level != self.num_resolutions - 1:
648
- h = self.upsample_blocks[k](h)
649
- h = self.norm_out(h)
650
- h = nonlinearity(h)
651
- h = self.conv_out(h)
652
- return h
653
-
654
-
655
- class LatentRescaler(nn.Module):
656
- def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
657
- super().__init__()
658
- # residual block, interpolate, residual block
659
- self.factor = factor
660
- self.conv_in = nn.Conv2d(in_channels,
661
- mid_channels,
662
- kernel_size=3,
663
- stride=1,
664
- padding=1)
665
- self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
666
- out_channels=mid_channels,
667
- temb_channels=0,
668
- dropout=0.0) for _ in range(depth)])
669
- self.attn = AttnBlock(mid_channels)
670
- self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
671
- out_channels=mid_channels,
672
- temb_channels=0,
673
- dropout=0.0) for _ in range(depth)])
674
-
675
- self.conv_out = nn.Conv2d(mid_channels,
676
- out_channels,
677
- kernel_size=1,
678
- )
679
-
680
- def forward(self, x):
681
- x = self.conv_in(x)
682
- for block in self.res_block1:
683
- x = block(x, None)
684
- x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
685
- x = self.attn(x)
686
- for block in self.res_block2:
687
- x = block(x, None)
688
- x = self.conv_out(x)
689
- return x
690
-
691
-
692
- class MergedRescaleEncoder(nn.Module):
693
- def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
694
- attn_resolutions, dropout=0.0, resamp_with_conv=True,
695
- ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
696
- super().__init__()
697
- intermediate_chn = ch * ch_mult[-1]
698
- self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
699
- z_channels=intermediate_chn, double_z=False, resolution=resolution,
700
- attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
701
- out_ch=None)
702
- self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
703
- mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
704
-
705
- def forward(self, x):
706
- x = self.encoder(x)
707
- x = self.rescaler(x)
708
- return x
709
-
710
-
711
- class MergedRescaleDecoder(nn.Module):
712
- def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
713
- dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
714
- super().__init__()
715
- tmp_chn = z_channels*ch_mult[-1]
716
- self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
717
- resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
718
- ch_mult=ch_mult, resolution=resolution, ch=ch)
719
- self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
720
- out_channels=tmp_chn, depth=rescale_module_depth)
721
-
722
- def forward(self, x):
723
- x = self.rescaler(x)
724
- x = self.decoder(x)
725
- return x
726
-
727
-
728
- class Upsampler(nn.Module):
729
- def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
730
- super().__init__()
731
- assert out_size >= in_size
732
- num_blocks = int(np.log2(out_size//in_size))+1
733
- factor_up = 1.+ (out_size % in_size)
734
- print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
735
- self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
736
- out_channels=in_channels)
737
- self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
738
- attn_resolutions=[], in_channels=None, ch=in_channels,
739
- ch_mult=[ch_mult for _ in range(num_blocks)])
740
-
741
- def forward(self, x):
742
- x = self.rescaler(x)
743
- x = self.decoder(x)
744
- return x
745
-
746
-
747
- class Resize(nn.Module):
748
- def __init__(self, in_channels=None, learned=False, mode="bilinear"):
749
- super().__init__()
750
- self.with_conv = learned
751
- self.mode = mode
752
- if self.with_conv:
753
- print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
754
- raise NotImplementedError()
755
- assert in_channels is not None
756
- # no asymmetric padding in torch conv, must do it ourselves
757
- self.conv = torch.nn.Conv2d(in_channels,
758
- in_channels,
759
- kernel_size=4,
760
- stride=2,
761
- padding=1)
762
-
763
- def forward(self, x, scale_factor=1.0):
764
- if scale_factor==1.0:
765
- return x
766
- else:
767
- x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
768
- return x
769
-
770
- class FirstStagePostProcessor(nn.Module):
771
-
772
- def __init__(self, ch_mult:list, in_channels,
773
- pretrained_model:nn.Module=None,
774
- reshape=False,
775
- n_channels=None,
776
- dropout=0.,
777
- pretrained_config=None):
778
- super().__init__()
779
- if pretrained_config is None:
780
- assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
781
- self.pretrained_model = pretrained_model
782
- else:
783
- assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
784
- self.instantiate_pretrained(pretrained_config)
785
-
786
- self.do_reshape = reshape
787
-
788
- if n_channels is None:
789
- n_channels = self.pretrained_model.encoder.ch
790
-
791
- self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
792
- self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
793
- stride=1,padding=1)
794
-
795
- blocks = []
796
- downs = []
797
- ch_in = n_channels
798
- for m in ch_mult:
799
- blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
800
- ch_in = m * n_channels
801
- downs.append(Downsample(ch_in, with_conv=False))
802
-
803
- self.model = nn.ModuleList(blocks)
804
- self.downsampler = nn.ModuleList(downs)
805
-
806
-
807
- def instantiate_pretrained(self, config):
808
- model = instantiate_from_config(config)
809
- self.pretrained_model = model.eval()
810
- # self.pretrained_model.train = False
811
- for param in self.pretrained_model.parameters():
812
- param.requires_grad = False
813
-
814
-
815
- @torch.no_grad()
816
- def encode_with_pretrained(self,x):
817
- c = self.pretrained_model.encode(x)
818
- if isinstance(c, DiagonalGaussianDistribution):
819
- c = c.mode()
820
- return c
821
-
822
- def forward(self,x):
823
- z_fs = self.encode_with_pretrained(x)
824
- z = self.proj_norm(z_fs)
825
- z = self.proj(z)
826
- z = nonlinearity(z)
827
-
828
- for submodel, downmodel in zip(self.model,self.downsampler):
829
- z = submodel(z,temb=None)
830
- z = downmodel(z)
831
-
832
- if self.do_reshape:
833
- z = rearrange(z,'b c h w -> b (h w) c')
834
- return z
835
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/diffusionmodules/openaimodel.py DELETED
@@ -1,707 +0,0 @@
1
- from abc import abstractmethod
2
- from functools import partial
3
- import math
4
- from typing import Iterable
5
-
6
- import numpy as np
7
- import torch as th
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
- from ldm.modules.diffusionmodules.util import (
12
- checkpoint,
13
- conv_nd,
14
- linear,
15
- avg_pool_nd,
16
- zero_module,
17
- normalization,
18
- timestep_embedding,
19
- )
20
- from ldm.modules.attention import SpatialTransformer
21
-
22
-
23
- # dummy replace
24
- def convert_module_to_f16(x):
25
- pass
26
-
27
- def convert_module_to_f32(x):
28
- pass
29
-
30
-
31
- ## go
32
- class AttentionPool2d(nn.Module):
33
- """
34
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
- """
36
-
37
- def __init__(
38
- self,
39
- spacial_dim: int,
40
- embed_dim: int,
41
- num_heads_channels: int,
42
- output_dim: int = None,
43
- ):
44
- super().__init__()
45
- self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
- self.num_heads = embed_dim // num_heads_channels
49
- self.attention = QKVAttention(self.num_heads)
50
-
51
- def forward(self, x):
52
- b, c, *_spatial = x.shape
53
- x = x.reshape(b, c, -1) # NC(HW)
54
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
- x = self.qkv_proj(x)
57
- x = self.attention(x)
58
- x = self.c_proj(x)
59
- return x[:, :, 0]
60
-
61
-
62
- class TimestepBlock(nn.Module):
63
- @abstractmethod
64
- def forward(self, x, emb):
65
- """
66
- Apply the module to `x` given `emb` timestep embeddings.
67
- """
68
-
69
-
70
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
71
- def forward(self, x, emb, context=None):
72
- for layer in self:
73
- if isinstance(layer, TimestepBlock):
74
- x = layer(x, emb)
75
- elif isinstance(layer, SpatialTransformer):
76
- x = layer(x, context)
77
- else:
78
- x = layer(x)
79
- return x
80
-
81
-
82
- class Upsample(nn.Module):
83
- """
84
- An upsampling layer with an optional convolution.
85
- :param channels: channels in the inputs and outputs.
86
- :param use_conv: a bool determining if a convolution is applied.
87
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
88
- upsampling occurs in the inner-two dimensions.
89
- """
90
-
91
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
92
- super().__init__()
93
- self.channels = channels
94
- self.out_channels = out_channels or channels
95
- self.use_conv = use_conv
96
- self.dims = dims
97
- if use_conv:
98
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
99
-
100
- def forward(self, x):
101
- assert x.shape[1] == self.channels
102
- if self.dims == 3:
103
- x = F.interpolate(
104
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
105
- )
106
- else:
107
- x = F.interpolate(x, scale_factor=2, mode="nearest")
108
- if self.use_conv:
109
- x = self.conv(x)
110
- return x
111
-
112
- class TransposedUpsample(nn.Module):
113
- 'Learned 2x upsampling without padding'
114
- def __init__(self, channels, out_channels=None, ks=5):
115
- super().__init__()
116
- self.channels = channels
117
- self.out_channels = out_channels or channels
118
-
119
- self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
120
-
121
- def forward(self,x):
122
- return self.up(x)
123
-
124
-
125
- class Downsample(nn.Module):
126
- """
127
- A downsampling layer with an optional convolution.
128
- :param channels: channels in the inputs and outputs.
129
- :param use_conv: a bool determining if a convolution is applied.
130
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
131
- downsampling occurs in the inner-two dimensions.
132
- """
133
-
134
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
135
- super().__init__()
136
- self.channels = channels
137
- self.out_channels = out_channels or channels
138
- self.use_conv = use_conv
139
- self.dims = dims
140
- stride = 2 if dims != 3 else (1, 2, 2)
141
- if use_conv:
142
- self.op = conv_nd(
143
- dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
144
- )
145
- else:
146
- assert self.channels == self.out_channels
147
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
148
-
149
- def forward(self, x):
150
- assert x.shape[1] == self.channels
151
- return self.op(x)
152
-
153
-
154
- class ResBlock(TimestepBlock):
155
- def __init__(
156
- self,
157
- channels,
158
- emb_channels,
159
- dropout,
160
- out_channels=None,
161
- use_conv=False,
162
- use_scale_shift_norm=False,
163
- dims=2,
164
- use_checkpoint=False,
165
- up=False,
166
- down=False,
167
- ):
168
- super().__init__()
169
- self.channels = channels
170
- self.emb_channels = emb_channels
171
- self.dropout = dropout
172
- self.out_channels = out_channels or channels
173
- self.use_conv = use_conv
174
- self.use_checkpoint = use_checkpoint
175
- self.use_scale_shift_norm = use_scale_shift_norm
176
-
177
- self.in_layers = nn.Sequential(
178
- normalization(channels),
179
- nn.SiLU(),
180
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
181
- )
182
-
183
- self.updown = up or down
184
-
185
- if up:
186
- self.h_upd = Upsample(channels, False, dims)
187
- self.x_upd = Upsample(channels, False, dims)
188
- elif down:
189
- self.h_upd = Downsample(channels, False, dims)
190
- self.x_upd = Downsample(channels, False, dims)
191
- else:
192
- self.h_upd = self.x_upd = nn.Identity()
193
-
194
- self.emb_layers = nn.Sequential(
195
- nn.SiLU(),
196
- linear(
197
- emb_channels,
198
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
199
- ),
200
- )
201
- self.out_layers = nn.Sequential(
202
- normalization(self.out_channels),
203
- nn.SiLU(),
204
- nn.Dropout(p=dropout),
205
- zero_module(
206
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
207
- ),
208
- )
209
-
210
- if self.out_channels == channels:
211
- self.skip_connection = nn.Identity()
212
- elif use_conv:
213
- self.skip_connection = conv_nd(
214
- dims, channels, self.out_channels, 3, padding=1
215
- )
216
- else:
217
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
218
-
219
- def forward(self, x, emb):
220
- """
221
- Apply the block to a Tensor, conditioned on a timestep embedding.
222
- :param x: an [N x C x ...] Tensor of features.
223
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
224
- :return: an [N x C x ...] Tensor of outputs.
225
- """
226
- return checkpoint(
227
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
228
- )
229
-
230
-
231
- def _forward(self, x, emb):
232
- if self.updown:
233
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
234
- h = in_rest(x)
235
- h = self.h_upd(h)
236
- x = self.x_upd(x)
237
- h = in_conv(h)
238
- else:
239
- h = self.in_layers(x)
240
- emb_out = self.emb_layers(emb).type(h.dtype)
241
- while len(emb_out.shape) < len(h.shape):
242
- emb_out = emb_out[..., None]
243
- if self.use_scale_shift_norm:
244
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
245
- scale, shift = th.chunk(emb_out, 2, dim=1)
246
- h = out_norm(h) * (1 + scale) + shift
247
- h = out_rest(h)
248
- else:
249
- h = h + emb_out
250
- h = self.out_layers(h)
251
- return self.skip_connection(x) + h
252
-
253
-
254
- class My_ResBlock(TimestepBlock):
255
- """
256
- A residual block that can optionally change the number of channels.
257
- :param channels: the number of input channels.
258
- :param emb_channels: the number of timestep embedding channels.
259
- :param dropout: the rate of dropout.
260
- :param out_channels: if specified, the number of out channels.
261
- :param use_conv: if True and out_channels is specified, use a spatial
262
- convolution instead of a smaller 1x1 convolution to change the
263
- channels in the skip connection.
264
- :param dims: determines if the signal is 1D, 2D, or 3D.
265
- :param use_checkpoint: if True, use gradient checkpointing on this module.
266
- :param up: if True, use this block for upsampling.
267
- :param down: if True, use this block for downsampling.
268
- """
269
-
270
- def __init__(
271
- self,
272
- channels,
273
- emb_channels,
274
- dropout,
275
- out_channels=None,
276
- use_conv=False,
277
- use_scale_shift_norm=False,
278
- dims=2,
279
- use_checkpoint=False,
280
- up=False,
281
- down=False,
282
- ):
283
- super().__init__()
284
- self.channels = channels
285
- self.emb_channels = emb_channels
286
- self.dropout = dropout
287
- self.out_channels = out_channels or channels
288
- self.use_conv = use_conv
289
- self.use_checkpoint = use_checkpoint
290
- self.use_scale_shift_norm = use_scale_shift_norm
291
-
292
- self.in_layers = nn.Sequential(
293
- normalization(channels),
294
- nn.SiLU(),
295
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
296
- )
297
-
298
- self.updown = up or down
299
-
300
- if up:
301
- self.h_upd = Upsample(channels, False, dims)
302
- self.x_upd = Upsample(channels, False, dims)
303
- elif down:
304
- self.h_upd = Downsample(channels, False, dims)
305
- self.x_upd = Downsample(channels, False, dims)
306
- else:
307
- self.h_upd = self.x_upd = nn.Identity()
308
-
309
- self.emb_layers = nn.Sequential(
310
- nn.SiLU(),
311
- linear(
312
- emb_channels,
313
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
314
- ),
315
- )
316
- self.out_layers = nn.Sequential(
317
- normalization(self.out_channels),
318
- nn.SiLU(),
319
- nn.Dropout(p=dropout),
320
- zero_module(
321
- conv_nd(dims, self.out_channels, 4, 3, padding=1)
322
- ),
323
- )
324
-
325
- if self.out_channels == channels:
326
- self.skip_connection = nn.Identity()
327
- elif use_conv:
328
- self.skip_connection = conv_nd(
329
- dims, channels, self.out_channels, 3, padding=1
330
- )
331
- else:
332
- self.skip_connection = conv_nd(dims, channels, 4, 1)
333
-
334
- def forward(self, x, emb):
335
- """
336
- Apply the block to a Tensor, conditioned on a timestep embedding.
337
- :param x: an [N x C x ...] Tensor of features.
338
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
339
- :return: an [N x C x ...] Tensor of outputs.
340
- """
341
- return checkpoint(
342
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
343
- )
344
-
345
-
346
- def _forward(self, x, emb):
347
- if self.updown:
348
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
349
- h = in_rest(x)
350
- h = self.h_upd(h)
351
- x = self.x_upd(x)
352
- h = in_conv(h)
353
- else:
354
- h = self.in_layers(x)
355
- emb_out = self.emb_layers(emb).type(h.dtype)
356
- while len(emb_out.shape) < len(h.shape):
357
- emb_out = emb_out[..., None]
358
- if self.use_scale_shift_norm:
359
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
360
- scale, shift = th.chunk(emb_out, 2, dim=1)
361
- h = out_norm(h) * (1 + scale) + shift
362
- h = out_rest(h)
363
- else:
364
- h = h + emb_out
365
- h = self.out_layers(h)
366
- return h
367
-
368
-
369
- class AttentionBlock(nn.Module):
370
- """
371
- An attention block that allows spatial positions to attend to each other.
372
- Originally ported from here, but adapted to the N-d case.
373
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
374
- """
375
-
376
- def __init__(
377
- self,
378
- channels,
379
- num_heads=1,
380
- num_head_channels=-1,
381
- use_checkpoint=False,
382
- use_new_attention_order=False,
383
- ):
384
- super().__init__()
385
- self.channels = channels
386
- if num_head_channels == -1:
387
- self.num_heads = num_heads
388
- else:
389
- assert (
390
- channels % num_head_channels == 0
391
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
392
- self.num_heads = channels // num_head_channels
393
- self.use_checkpoint = use_checkpoint
394
- self.norm = normalization(channels)
395
- self.qkv = conv_nd(1, channels, channels * 3, 1)
396
- if use_new_attention_order:
397
- # split qkv before split heads
398
- self.attention = QKVAttention(self.num_heads)
399
- else:
400
- # split heads before split qkv
401
- self.attention = QKVAttentionLegacy(self.num_heads)
402
-
403
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
404
-
405
- def forward(self, x):
406
- return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
407
- #return pt_checkpoint(self._forward, x) # pytorch
408
-
409
- def _forward(self, x):
410
- b, c, *spatial = x.shape
411
- x = x.reshape(b, c, -1)
412
- qkv = self.qkv(self.norm(x))
413
- h = self.attention(qkv)
414
- h = self.proj_out(h)
415
- return (x + h).reshape(b, c, *spatial)
416
-
417
-
418
- def count_flops_attn(model, _x, y):
419
- """
420
- A counter for the `thop` package to count the operations in an
421
- attention operation.
422
- Meant to be used like:
423
- macs, params = thop.profile(
424
- model,
425
- inputs=(inputs, timestamps),
426
- custom_ops={QKVAttention: QKVAttention.count_flops},
427
- )
428
- """
429
- b, c, *spatial = y[0].shape
430
- num_spatial = int(np.prod(spatial))
431
- # We perform two matmuls with the same number of ops.
432
- # The first computes the weight matrix, the second computes
433
- # the combination of the value vectors.
434
- matmul_ops = 2 * b * (num_spatial ** 2) * c
435
- model.total_ops += th.DoubleTensor([matmul_ops])
436
-
437
-
438
- class QKVAttentionLegacy(nn.Module):
439
- """
440
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
441
- """
442
-
443
- def __init__(self, n_heads):
444
- super().__init__()
445
- self.n_heads = n_heads
446
-
447
- def forward(self, qkv):
448
- """
449
- Apply QKV attention.
450
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
451
- :return: an [N x (H * C) x T] tensor after attention.
452
- """
453
- bs, width, length = qkv.shape
454
- assert width % (3 * self.n_heads) == 0
455
- ch = width // (3 * self.n_heads)
456
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
457
- scale = 1 / math.sqrt(math.sqrt(ch))
458
- weight = th.einsum(
459
- "bct,bcs->bts", q * scale, k * scale
460
- ) # More stable with f16 than dividing afterwards
461
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
462
- a = th.einsum("bts,bcs->bct", weight, v)
463
- return a.reshape(bs, -1, length)
464
-
465
- @staticmethod
466
- def count_flops(model, _x, y):
467
- return count_flops_attn(model, _x, y)
468
-
469
-
470
- class QKVAttention(nn.Module):
471
- """
472
- A module which performs QKV attention and splits in a different order.
473
- """
474
-
475
- def __init__(self, n_heads):
476
- super().__init__()
477
- self.n_heads = n_heads
478
-
479
- def forward(self, qkv):
480
- """
481
- Apply QKV attention.
482
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
483
- :return: an [N x (H * C) x T] tensor after attention.
484
- """
485
- bs, width, length = qkv.shape
486
- assert width % (3 * self.n_heads) == 0
487
- ch = width // (3 * self.n_heads)
488
- q, k, v = qkv.chunk(3, dim=1)
489
- scale = 1 / math.sqrt(math.sqrt(ch))
490
- weight = th.einsum(
491
- "bct,bcs->bts",
492
- (q * scale).view(bs * self.n_heads, ch, length),
493
- (k * scale).view(bs * self.n_heads, ch, length),
494
- ) # More stable with f16 than dividing afterwards
495
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
496
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
497
- return a.reshape(bs, -1, length)
498
-
499
- @staticmethod
500
- def count_flops(model, _x, y):
501
- return count_flops_attn(model, _x, y)
502
-
503
-
504
- class UNetModel(nn.Module):
505
- def __init__(self,
506
- image_size, # 32
507
- in_channels, # 9
508
- out_channels, # 4
509
- model_channels, # 320
510
- attention_resolutions, # [ 4, 2, 1 ]
511
- num_res_blocks, # 2
512
- channel_mult=(1, 2, 4, 8), # [ 1, 2, 4, 4 ]
513
- num_heads=-1, # 8
514
- use_spatial_transformer=False, # True
515
- transformer_depth=1, # 1
516
- context_dim=None, # 768
517
- use_checkpoint=False, # True
518
- legacy=True, # False
519
- add_conv_in_front_of_unet=False, # False
520
- dropout=0,
521
- conv_resample=True,
522
- dims=2,
523
- num_classes=None,
524
- num_head_channels=-1,
525
- num_heads_upsample=-1,
526
- use_scale_shift_norm=False):
527
-
528
- super().__init__()
529
-
530
- self.image_size = image_size # 32
531
- self.in_channels = in_channels # 9
532
- self.out_channels = out_channels # 4
533
- self.model_channels = model_channels # 320
534
- self.attention_resolutions = attention_resolutions # [4,2,1]
535
- self.num_res_blocks = num_res_blocks # 2
536
- self.channel_mult = channel_mult # [1,2,4,4]
537
- num_heads_upsample = num_heads # 8
538
- self.use_checkpoint = use_checkpoint # True
539
- self.add_conv_in_front_of_unet=add_conv_in_front_of_unet # False
540
- self.dropout = dropout # 0
541
- self.conv_resample = conv_resample # True
542
- self.num_classes = num_classes # None
543
- self.num_heads = num_heads # 8
544
- self.num_head_channels = num_head_channels # -1
545
- self.num_heads_upsample = num_heads_upsample # -1
546
- self.dtype = th.float32
547
-
548
-
549
- # 时间编码器 320 -> 320*4 -> 320*4
550
- time_embed_dim = model_channels * 4
551
- self.time_embed = nn.Sequential(
552
- linear(model_channels, time_embed_dim),
553
- nn.SiLU(),
554
- linear(time_embed_dim, time_embed_dim),
555
- )
556
-
557
-
558
- # 一阶段 self.input_blocks
559
- self.input_blocks = nn.ModuleList(
560
- [
561
- TimestepEmbedSequential(
562
- conv_nd(
563
- dims, # 2
564
- in_channels, # 9
565
- model_channels, # 320
566
- kernel_size=3, # 3
567
- padding=1 # 1
568
- )
569
- )
570
- ]
571
- )
572
-
573
- input_block_chans = [model_channels] # [320]
574
- ch = model_channels # 320
575
- ds = 1 # 1
576
-
577
- for level, mult in enumerate(channel_mult): # [0,1,2,3], [1,2,4,4]
578
- for _ in range(num_res_blocks): # 2
579
- layers = [
580
- ResBlock(
581
- ch, # [1,1,1,2,2,4,4,4]*320
582
- time_embed_dim, # 320*4
583
- dropout, # 0
584
- out_channels=mult * model_channels, # [1,1,2,2,4,4,4,4]*320
585
- dims=dims, # 2
586
- use_checkpoint=use_checkpoint, # True
587
- use_scale_shift_norm=use_scale_shift_norm, # False
588
- )
589
- ]
590
- ch = mult * model_channels # [1,1,2,2,4,4,4,4]*320
591
- if ds in attention_resolutions: # 前6次小循环
592
- dim_head = ch // num_heads # [1,1,2,2,4,4]*40
593
- layers.append(
594
- SpatialTransformer(
595
- ch, # [1,1,2,2,4,4]*320
596
- num_heads, # 8
597
- dim_head, # [1,1,2,2,4,4]*40
598
- depth=transformer_depth, # 1
599
- context_dim=context_dim # 768
600
- )
601
- )
602
- self.input_blocks.append(TimestepEmbedSequential(*layers))
603
- input_block_chans.append(ch) # [1,1,2,2,4,4,4,4]*320
604
- if level != len(channel_mult) - 1: # 前3次大循环
605
- out_ch = ch # [1,2,4]*320
606
- self.input_blocks.append(
607
- TimestepEmbedSequential(
608
- Downsample(
609
- ch, # [1,2,4]*320
610
- conv_resample, # True
611
- dims=dims, # 2
612
- out_channels=out_ch # [1,2,4]*320
613
- )
614
- )
615
- )
616
- ch = out_ch # [1,2,4]*320
617
- input_block_chans.append(ch) # [1,2,4]*320
618
- ds *= 2 # 1 -> 2 -> 4 -> 8
619
-
620
-
621
- # 二阶段 self.middle_block
622
- dim_head = ch // num_heads # 1280 // 8
623
-
624
- self.middle_block = TimestepEmbedSequential(
625
- ResBlock(
626
- ch, # 4*320
627
- time_embed_dim, # 320
628
- dropout, # 0
629
- dims=dims, # 2
630
- use_checkpoint=use_checkpoint, # True
631
- use_scale_shift_norm=use_scale_shift_norm, # False
632
- ),
633
- SpatialTransformer(
634
- ch, # 4*320
635
- num_heads, # 8
636
- dim_head, # 160
637
- depth=transformer_depth, # 1
638
- context_dim=context_dim # 768
639
- ),
640
- ResBlock(
641
- ch, # 4*320
642
- time_embed_dim, # 320
643
- dropout, # 0
644
- dims=dims, # 2
645
- use_checkpoint=use_checkpoint, # True
646
- use_scale_shift_norm=use_scale_shift_norm, # False
647
- ),
648
- )
649
-
650
-
651
- # 三阶段 self.output_blocks
652
- self.output_blocks = nn.ModuleList([])
653
- for level, mult in list(enumerate(channel_mult))[::-1]: # [3,2,1,0], [4,4,2,1]
654
- for i in range(num_res_blocks + 1): # 3
655
- ich = input_block_chans.pop() # [4,4, 4,4,4, 2,2,2, 1,1,1, 1]*320
656
- layers = [
657
- ResBlock(
658
- ch + ich, # [4,4,4,4,4,4,4,2,2,2,1,1]*320+ich
659
- time_embed_dim, # 320
660
- dropout, # 0
661
- out_channels=model_channels*mult, # [4,4,4,4,4,4,2,2,2,1,1,1]*320
662
- dims=dims, # 2
663
- use_checkpoint=use_checkpoint, # True
664
- use_scale_shift_norm=use_scale_shift_norm, # False
665
- )
666
- ]
667
- ch = model_channels * mult # [4,4,4,4,4,4,2,2,2,1,1,1]*320
668
- if ds in attention_resolutions: # 后三次大循环
669
- dim_head = ch // num_heads
670
- layers.append(
671
- SpatialTransformer(
672
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
673
- )
674
- )
675
- if level and i == num_res_blocks: # 前三次大循环中每次最后的小循环
676
- out_ch = ch
677
- layers.append(
678
- Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
679
- )
680
- ds //= 2
681
- self.output_blocks.append(TimestepEmbedSequential(*layers))
682
-
683
- # 四阶段 self.out
684
- self.out = nn.Sequential(
685
- normalization(ch),
686
- nn.SiLU(),
687
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
688
- )
689
-
690
- def forward(self, x, timesteps=None, context=None):
691
- hs = []
692
-
693
- t_emb = timestep_embedding(timesteps, self.model_channels) # [N, 320]
694
- emb = self.time_embed(t_emb) # [N, 320*4]
695
- h = x.type(self.dtype) # 将 x 转换为 torch.float32
696
-
697
- for module in self.input_blocks:
698
- h = module(h, emb, context)
699
- hs.append(h)
700
- h = self.middle_block(h, emb, context)
701
- for module in self.output_blocks:
702
- h = th.cat([h, hs.pop()], dim=1)
703
- h = module(h, emb, context)
704
- h = h.type(x.dtype)
705
-
706
- return self.out(h) # [N,4,H,W]
707
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/diffusionmodules/util.py DELETED
@@ -1,255 +0,0 @@
1
- # adopted from
2
- # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
- # and
4
- # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
- # and
6
- # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
- #
8
- # thanks!
9
-
10
-
11
- import os
12
- import math
13
- import torch
14
- import torch.nn as nn
15
- import numpy as np
16
- from einops import repeat
17
-
18
- from ldm.util import instantiate_from_config
19
-
20
-
21
- def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
- if schedule == "linear":
23
- betas = (
24
- torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
- )
26
-
27
- elif schedule == "cosine":
28
- timesteps = (
29
- torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
- )
31
- alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
- alphas = torch.cos(alphas).pow(2)
33
- alphas = alphas / alphas[0]
34
- betas = 1 - alphas[1:] / alphas[:-1]
35
- betas = np.clip(betas, a_min=0, a_max=0.999)
36
-
37
- elif schedule == "sqrt_linear":
38
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
- elif schedule == "sqrt":
40
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
- else:
42
- raise ValueError(f"schedule '{schedule}' unknown.")
43
- return betas.numpy()
44
-
45
-
46
- def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
- if ddim_discr_method == 'uniform':
48
- c = num_ddpm_timesteps // num_ddim_timesteps
49
- ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
- elif ddim_discr_method == 'quad':
51
- ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
- else:
53
- raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
-
55
- # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
- # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
- steps_out = ddim_timesteps + 1
58
- if verbose:
59
- print(f'Selected timesteps for ddim sampler: {steps_out}')
60
- return steps_out
61
-
62
-
63
- def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64
- # select alphas for computing the variance schedule
65
- alphas = alphacums[ddim_timesteps]
66
- alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67
-
68
- # according the the formula provided in https://arxiv.org/abs/2010.02502
69
- sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70
- if verbose:
71
- print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72
- print(f'For the chosen value of eta, which is {eta}, '
73
- f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74
- return sigmas, alphas, alphas_prev
75
-
76
-
77
- def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78
- """
79
- Create a beta schedule that discretizes the given alpha_t_bar function,
80
- which defines the cumulative product of (1-beta) over time from t = [0,1].
81
- :param num_diffusion_timesteps: the number of betas to produce.
82
- :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83
- produces the cumulative product of (1-beta) up to that
84
- part of the diffusion process.
85
- :param max_beta: the maximum beta to use; use values lower than 1 to
86
- prevent singularities.
87
- """
88
- betas = []
89
- for i in range(num_diffusion_timesteps):
90
- t1 = i / num_diffusion_timesteps
91
- t2 = (i + 1) / num_diffusion_timesteps
92
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93
- return np.array(betas)
94
-
95
-
96
- def extract_into_tensor(a, t, x_shape):
97
- b, *_ = t.shape
98
- out = a.gather(-1, t)
99
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100
-
101
-
102
- def checkpoint(func, inputs, params, flag):
103
- """
104
- Evaluate a function without caching intermediate activations, allowing for
105
- reduced memory at the expense of extra compute in the backward pass.
106
- :param func: the function to evaluate.
107
- :param inputs: the argument sequence to pass to `func`.
108
- :param params: a sequence of parameters `func` depends on but does not
109
- explicitly take as arguments.
110
- :param flag: if False, disable gradient checkpointing.
111
- """
112
- if flag:
113
- args = tuple(inputs) + tuple(params)
114
- return CheckpointFunction.apply(func, len(inputs), *args)
115
- else:
116
- return func(*inputs)
117
-
118
-
119
- class CheckpointFunction(torch.autograd.Function):
120
- @staticmethod
121
- def forward(ctx, run_function, length, *args):
122
- ctx.run_function = run_function
123
- ctx.input_tensors = list(args[:length])
124
- ctx.input_params = list(args[length:])
125
-
126
- with torch.no_grad():
127
- output_tensors = ctx.run_function(*ctx.input_tensors)
128
- return output_tensors
129
-
130
- @staticmethod
131
- def backward(ctx, *output_grads):
132
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133
- with torch.enable_grad():
134
- # Fixes a bug where the first op in run_function modifies the
135
- # Tensor storage in place, which is not allowed for detach()'d
136
- # Tensors.
137
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138
- output_tensors = ctx.run_function(*shallow_copies)
139
- input_grads = torch.autograd.grad(
140
- output_tensors,
141
- ctx.input_tensors + ctx.input_params,
142
- output_grads,
143
- allow_unused=True,
144
- )
145
- del ctx.input_tensors
146
- del ctx.input_params
147
- del output_tensors
148
- return (None, None) + input_grads
149
-
150
-
151
- def timestep_embedding(timesteps, dim, max_period=10000):
152
- half = dim // 2 # 160
153
- # 生成一个 shape 为 [160,] 的 torch,从 1 到 1/10000 指数衰减
154
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=timesteps.device)
155
- # [1,2] * [1, 160] -> [2, 160]
156
- args = timesteps[:, None].float() * freqs[None]
157
- # [2, 320]
158
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
159
- return embedding
160
-
161
-
162
- def zero_module(module):
163
- """
164
- Zero out the parameters of a module and return it.
165
- """
166
- for p in module.parameters():
167
- p.detach().zero_()
168
- return module
169
-
170
-
171
- def scale_module(module, scale):
172
- """
173
- Scale the parameters of a module and return it.
174
- """
175
- for p in module.parameters():
176
- p.detach().mul_(scale)
177
- return module
178
-
179
-
180
- def mean_flat(tensor):
181
- """
182
- Take the mean over all non-batch dimensions.
183
- """
184
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
185
-
186
-
187
- def normalization(channels):
188
- """
189
- Make a standard normalization layer.
190
- :param channels: number of input channels.
191
- :return: an nn.Module for normalization.
192
- """
193
- return GroupNorm32(32, channels)
194
-
195
-
196
- # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
197
- class SiLU(nn.Module):
198
- def forward(self, x):
199
- return x * torch.sigmoid(x)
200
-
201
-
202
- class GroupNorm32(nn.GroupNorm):
203
- def forward(self, x):
204
- return super().forward(x.float()).type(x.dtype)
205
-
206
- def conv_nd(dims, *args, **kwargs):
207
- """
208
- Create a 1D, 2D, or 3D convolution module.
209
- """
210
- if dims == 1:
211
- return nn.Conv1d(*args, **kwargs)
212
- elif dims == 2:
213
- return nn.Conv2d(*args, **kwargs)
214
- elif dims == 3:
215
- return nn.Conv3d(*args, **kwargs)
216
- raise ValueError(f"unsupported dimensions: {dims}")
217
-
218
-
219
- def linear(*args, **kwargs):
220
- """
221
- Create a linear module.
222
- """
223
- return nn.Linear(*args, **kwargs)
224
-
225
-
226
- def avg_pool_nd(dims, *args, **kwargs):
227
- """
228
- Create a 1D, 2D, or 3D average pooling module.
229
- """
230
- if dims == 1:
231
- return nn.AvgPool1d(*args, **kwargs)
232
- elif dims == 2:
233
- return nn.AvgPool2d(*args, **kwargs)
234
- elif dims == 3:
235
- return nn.AvgPool3d(*args, **kwargs)
236
- raise ValueError(f"unsupported dimensions: {dims}")
237
-
238
-
239
- class HybridConditioner(nn.Module):
240
-
241
- def __init__(self, c_concat_config, c_crossattn_config):
242
- super().__init__()
243
- self.concat_conditioner = instantiate_from_config(c_concat_config)
244
- self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
245
-
246
- def forward(self, c_concat, c_crossattn):
247
- c_concat = self.concat_conditioner(c_concat)
248
- c_crossattn = self.crossattn_conditioner(c_crossattn)
249
- return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
250
-
251
-
252
- def noise_like(shape, device, repeat=False):
253
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
254
- noise = lambda: torch.randn(shape, device=device)
255
- return repeat_noise() if repeat else noise()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/distributions/.DS_Store DELETED
Binary file (6.15 kB)
 
ldm/modules/distributions/__init__.py DELETED
File without changes
ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (155 Bytes)
 
ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc DELETED
Binary file (3.8 kB)
 
ldm/modules/distributions/distributions.py DELETED
@@ -1,92 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
-
5
- class AbstractDistribution:
6
- def sample(self):
7
- raise NotImplementedError()
8
-
9
- def mode(self):
10
- raise NotImplementedError()
11
-
12
-
13
- class DiracDistribution(AbstractDistribution):
14
- def __init__(self, value):
15
- self.value = value
16
-
17
- def sample(self):
18
- return self.value
19
-
20
- def mode(self):
21
- return self.value
22
-
23
-
24
- class DiagonalGaussianDistribution(object):
25
- def __init__(self, parameters, deterministic=False):
26
- self.parameters = parameters
27
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
- self.deterministic = deterministic
30
- self.std = torch.exp(0.5 * self.logvar)
31
- self.var = torch.exp(self.logvar)
32
- if self.deterministic:
33
- self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
-
35
- def sample(self):
36
- x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
- return x
38
-
39
- def kl(self, other=None):
40
- if self.deterministic:
41
- return torch.Tensor([0.])
42
- else:
43
- if other is None:
44
- return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
- + self.var - 1.0 - self.logvar,
46
- dim=[1, 2, 3])
47
- else:
48
- return 0.5 * torch.sum(
49
- torch.pow(self.mean - other.mean, 2) / other.var
50
- + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
- dim=[1, 2, 3])
52
-
53
- def nll(self, sample, dims=[1,2,3]):
54
- if self.deterministic:
55
- return torch.Tensor([0.])
56
- logtwopi = np.log(2.0 * np.pi)
57
- return 0.5 * torch.sum(
58
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
- dim=dims)
60
-
61
- def mode(self):
62
- return self.mean
63
-
64
-
65
- def normal_kl(mean1, logvar1, mean2, logvar2):
66
- """
67
- source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
- Compute the KL divergence between two gaussians.
69
- Shapes are automatically broadcasted, so batches can be compared to
70
- scalars, among other use cases.
71
- """
72
- tensor = None
73
- for obj in (mean1, logvar1, mean2, logvar2):
74
- if isinstance(obj, torch.Tensor):
75
- tensor = obj
76
- break
77
- assert tensor is not None, "at least one argument must be a Tensor"
78
-
79
- # Force variances to be Tensors. Broadcasting helps convert scalars to
80
- # Tensors, but it does not work for torch.exp().
81
- logvar1, logvar2 = [
82
- x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
- for x in (logvar1, logvar2)
84
- ]
85
-
86
- return 0.5 * (
87
- -1.0
88
- + logvar2
89
- - logvar1
90
- + torch.exp(logvar1 - logvar2)
91
- + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/ema.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class LitEma(nn.Module):
6
- def __init__(self, model, decay=0.9999, use_num_upates=True):
7
- super().__init__()
8
- if decay < 0.0 or decay > 1.0:
9
- raise ValueError('Decay must be between 0 and 1')
10
-
11
- self.m_name2s_name = {}
12
- self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
- self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14
- else torch.tensor(-1,dtype=torch.int))
15
-
16
- for name, p in model.named_parameters():
17
- if p.requires_grad:
18
- #remove as '.'-character is not allowed in buffers
19
- s_name = name.replace('.','')
20
- self.m_name2s_name.update({name:s_name})
21
- self.register_buffer(s_name,p.clone().detach().data)
22
-
23
- self.collected_params = []
24
-
25
- def forward(self,model):
26
- decay = self.decay
27
-
28
- if self.num_updates >= 0:
29
- self.num_updates += 1
30
- decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31
-
32
- one_minus_decay = 1.0 - decay
33
-
34
- with torch.no_grad():
35
- m_param = dict(model.named_parameters())
36
- shadow_params = dict(self.named_buffers())
37
-
38
- for key in m_param:
39
- if m_param[key].requires_grad:
40
- sname = self.m_name2s_name[key]
41
- shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42
- shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43
- else:
44
- assert not key in self.m_name2s_name
45
-
46
- def copy_to(self, model):
47
- m_param = dict(model.named_parameters())
48
- shadow_params = dict(self.named_buffers())
49
- for key in m_param:
50
- if m_param[key].requires_grad:
51
- m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52
- else:
53
- assert not key in self.m_name2s_name
54
-
55
- def store(self, parameters):
56
- """
57
- Save the current parameters for restoring later.
58
- Args:
59
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60
- temporarily stored.
61
- """
62
- self.collected_params = [param.clone() for param in parameters]
63
-
64
- def restore(self, parameters):
65
- """
66
- Restore the parameters stored with the `store` method.
67
- Useful to validate the model with EMA parameters without affecting the
68
- original optimization process. Store the parameters before the
69
- `copy_to` method. After validation (or model saving), use this to
70
- restore the former parameters.
71
- Args:
72
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
- updated with the stored parameters.
74
- """
75
- for c_param, param in zip(self.collected_params, parameters):
76
- param.data.copy_(c_param.data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/encoders/.DS_Store DELETED
Binary file (6.15 kB)
 
ldm/modules/encoders/__init__.py DELETED
File without changes
ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (150 Bytes)
 
ldm/modules/encoders/__pycache__/modules.cpython-38.pyc DELETED
Binary file (7.15 kB)