Jannat24 commited on
Commit
9101b75
·
1 Parent(s): 202f3ae

taming_transformer

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. =1.0.8 +0 -0
  3. =2.0.0 +0 -0
  4. License.txt +19 -0
  5. __pycache__/main.cpython-312.pyc +0 -0
  6. environment.yaml +25 -0
  7. main.py +585 -0
  8. scripts/extract_depth.py +112 -0
  9. scripts/extract_segmentation.py +130 -0
  10. scripts/extract_submodel.py +17 -0
  11. scripts/make_samples.py +292 -0
  12. scripts/make_scene_samples.py +198 -0
  13. scripts/sample_conditional.py +355 -0
  14. scripts/sample_fast.py +260 -0
  15. scripts/taming-transformers.ipynb +0 -0
  16. setup.py +13 -0
  17. taming/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
  18. taming/__pycache__/util.cpython-312.pyc +0 -0
  19. taming/data/.ipynb_checkpoints/utils-checkpoint.py +171 -0
  20. taming/data/__pycache__/helper_types.cpython-312.pyc +0 -0
  21. taming/data/__pycache__/utils.cpython-312.pyc +0 -0
  22. taming/data/ade20k.py +124 -0
  23. taming/data/annotated_objects_coco.py +139 -0
  24. taming/data/annotated_objects_dataset.py +218 -0
  25. taming/data/annotated_objects_open_images.py +137 -0
  26. taming/data/base.py +70 -0
  27. taming/data/coco.py +176 -0
  28. taming/data/conditional_builder/objects_bbox.py +60 -0
  29. taming/data/conditional_builder/objects_center_points.py +168 -0
  30. taming/data/conditional_builder/utils.py +105 -0
  31. taming/data/custom.py +38 -0
  32. taming/data/faceshq.py +134 -0
  33. taming/data/helper_types.py +49 -0
  34. taming/data/image_transforms.py +132 -0
  35. taming/data/imagenet.py +558 -0
  36. taming/data/open_images_helper.py +379 -0
  37. taming/data/sflckr.py +91 -0
  38. taming/data/utils.py +171 -0
  39. taming/lr_scheduler.py +34 -0
  40. taming/models/__pycache__/vqgan.cpython-312.pyc +0 -0
  41. taming/models/cond_transformer.py +352 -0
  42. taming/models/dummy_cond_stage.py +22 -0
  43. taming/models/vqgan.py +404 -0
  44. taming/modules/__pycache__/util.cpython-312.pyc +0 -0
  45. taming/modules/diffusionmodules/__pycache__/model.cpython-312.pyc +0 -0
  46. taming/modules/diffusionmodules/model.py +776 -0
  47. taming/modules/discriminator/__pycache__/model.cpython-312.pyc +0 -0
  48. taming/modules/discriminator/model.py +67 -0
  49. taming/modules/losses/__init__.py +2 -0
  50. taming/modules/losses/__pycache__/__init__.cpython-312.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ scripts/reconstruction_usage.ipynb filter=lfs diff=lfs merge=lfs -text
=1.0.8 ADDED
The diff for this file is too large to render. See raw diff
 
=2.0.0 ADDED
File without changes
License.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
15
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
17
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
18
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
19
+ OR OTHER DEALINGS IN THE SOFTWARE./
__pycache__/main.cpython-312.pyc ADDED
Binary file (27.2 kB). View file
 
environment.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: taming
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=10.2
9
+ - pytorch=1.7.0
10
+ - torchvision=0.8.1
11
+ - numpy=1.19.2
12
+ - pip:
13
+ - albumentations==0.4.3
14
+ - opencv-python==4.1.2.30
15
+ - pudb==2019.2
16
+ - imageio==2.9.0
17
+ - imageio-ffmpeg==0.4.2
18
+ - pytorch-lightning==1.0.8
19
+ - omegaconf==2.0.0
20
+ - test-tube>=0.7.5
21
+ - streamlit>=0.73.1
22
+ - einops==0.3.0
23
+ - more-itertools>=8.0.0
24
+ - transformers==4.3.1
25
+ - -e .
main.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, datetime, glob, importlib
2
+ from omegaconf import OmegaConf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision
7
+ from torch.utils.data import random_split, DataLoader, Dataset
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning import seed_everything
10
+ from pytorch_lightning.trainer import Trainer
11
+ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
12
+ from pytorch_lightning.utilities import rank_zero_only
13
+
14
+ from taming.data.utils import custom_collate
15
+
16
+
17
+ def get_obj_from_str(string, reload=False):
18
+ module, cls = string.rsplit(".", 1)
19
+ if reload:
20
+ module_imp = importlib.import_module(module)
21
+ importlib.reload(module_imp)
22
+ return getattr(importlib.import_module(module, package=None), cls)
23
+
24
+
25
+ def get_parser(**parser_kwargs):
26
+ def str2bool(v):
27
+ if isinstance(v, bool):
28
+ return v
29
+ if v.lower() in ("yes", "true", "t", "y", "1"):
30
+ return True
31
+ elif v.lower() in ("no", "false", "f", "n", "0"):
32
+ return False
33
+ else:
34
+ raise argparse.ArgumentTypeError("Boolean value expected.")
35
+
36
+ parser = argparse.ArgumentParser(**parser_kwargs)
37
+ parser.add_argument(
38
+ "-n",
39
+ "--name",
40
+ type=str,
41
+ const=True,
42
+ default="",
43
+ nargs="?",
44
+ help="postfix for logdir",
45
+ )
46
+ parser.add_argument(
47
+ "-r",
48
+ "--resume",
49
+ type=str,
50
+ const=True,
51
+ default="",
52
+ nargs="?",
53
+ help="resume from logdir or checkpoint in logdir",
54
+ )
55
+ parser.add_argument(
56
+ "-b",
57
+ "--base",
58
+ nargs="*",
59
+ metavar="base_config.yaml",
60
+ help="paths to base configs. Loaded from left-to-right. "
61
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
62
+ default=list(),
63
+ )
64
+ parser.add_argument(
65
+ "-t",
66
+ "--train",
67
+ type=str2bool,
68
+ const=True,
69
+ default=False,
70
+ nargs="?",
71
+ help="train",
72
+ )
73
+ parser.add_argument(
74
+ "--no-test",
75
+ type=str2bool,
76
+ const=True,
77
+ default=False,
78
+ nargs="?",
79
+ help="disable test",
80
+ )
81
+ parser.add_argument("-p", "--project", help="name of new or path to existing project")
82
+ parser.add_argument(
83
+ "-d",
84
+ "--debug",
85
+ type=str2bool,
86
+ nargs="?",
87
+ const=True,
88
+ default=False,
89
+ help="enable post-mortem debugging",
90
+ )
91
+ parser.add_argument(
92
+ "-s",
93
+ "--seed",
94
+ type=int,
95
+ default=23,
96
+ help="seed for seed_everything",
97
+ )
98
+ parser.add_argument(
99
+ "-f",
100
+ "--postfix",
101
+ type=str,
102
+ default="",
103
+ help="post-postfix for default name",
104
+ )
105
+
106
+ return parser
107
+
108
+
109
+ def nondefault_trainer_args(opt):
110
+ parser = argparse.ArgumentParser()
111
+ parser = Trainer.add_argparse_args(parser)
112
+ args = parser.parse_args([])
113
+ return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
114
+
115
+
116
+ def instantiate_from_config(config):
117
+ if not "target" in config:
118
+ raise KeyError("Expected key `target` to instantiate.")
119
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
120
+
121
+
122
+ class WrappedDataset(Dataset):
123
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
124
+ def __init__(self, dataset):
125
+ self.data = dataset
126
+
127
+ def __len__(self):
128
+ return len(self.data)
129
+
130
+ def __getitem__(self, idx):
131
+ return self.data[idx]
132
+
133
+
134
+ class DataModuleFromConfig(pl.LightningDataModule):
135
+ def __init__(self, batch_size, train=None, validation=None, test=None,
136
+ wrap=False, num_workers=None):
137
+ super().__init__()
138
+ self.batch_size = batch_size
139
+ self.dataset_configs = dict()
140
+ self.num_workers = num_workers if num_workers is not None else batch_size*2
141
+ if train is not None:
142
+ self.dataset_configs["train"] = train
143
+ self.train_dataloader = self._train_dataloader
144
+ if validation is not None:
145
+ self.dataset_configs["validation"] = validation
146
+ self.val_dataloader = self._val_dataloader
147
+ if test is not None:
148
+ self.dataset_configs["test"] = test
149
+ self.test_dataloader = self._test_dataloader
150
+ self.wrap = wrap
151
+
152
+ def prepare_data(self):
153
+ for data_cfg in self.dataset_configs.values():
154
+ instantiate_from_config(data_cfg)
155
+
156
+ def setup(self, stage=None):
157
+ self.datasets = dict(
158
+ (k, instantiate_from_config(self.dataset_configs[k]))
159
+ for k in self.dataset_configs)
160
+ if self.wrap:
161
+ for k in self.datasets:
162
+ self.datasets[k] = WrappedDataset(self.datasets[k])
163
+
164
+ def _train_dataloader(self):
165
+ return DataLoader(self.datasets["train"], batch_size=self.batch_size,
166
+ num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate)
167
+
168
+ def _val_dataloader(self):
169
+ return DataLoader(self.datasets["validation"],
170
+ batch_size=self.batch_size,
171
+ num_workers=self.num_workers, collate_fn=custom_collate)
172
+
173
+ def _test_dataloader(self):
174
+ return DataLoader(self.datasets["test"], batch_size=self.batch_size,
175
+ num_workers=self.num_workers, collate_fn=custom_collate)
176
+
177
+
178
+ class SetupCallback(Callback):
179
+ def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
180
+ super().__init__()
181
+ self.resume = resume
182
+ self.now = now
183
+ self.logdir = logdir
184
+ self.ckptdir = ckptdir
185
+ self.cfgdir = cfgdir
186
+ self.config = config
187
+ self.lightning_config = lightning_config
188
+
189
+ def on_pretrain_routine_start(self, trainer, pl_module):
190
+ if trainer.global_rank == 0:
191
+ # Create logdirs and save configs
192
+ os.makedirs(self.logdir, exist_ok=True)
193
+ os.makedirs(self.ckptdir, exist_ok=True)
194
+ os.makedirs(self.cfgdir, exist_ok=True)
195
+
196
+ print("Project config")
197
+ print(self.config.pretty())
198
+ OmegaConf.save(self.config,
199
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
200
+
201
+ print("Lightning config")
202
+ print(self.lightning_config.pretty())
203
+ OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
204
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
205
+
206
+ else:
207
+ # ModelCheckpoint callback created log directory --- remove it
208
+ if not self.resume and os.path.exists(self.logdir):
209
+ dst, name = os.path.split(self.logdir)
210
+ dst = os.path.join(dst, "child_runs", name)
211
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
212
+ try:
213
+ os.rename(self.logdir, dst)
214
+ except FileNotFoundError:
215
+ pass
216
+
217
+
218
+ class ImageLogger(Callback):
219
+ def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):
220
+ super().__init__()
221
+ self.batch_freq = batch_frequency
222
+ self.max_images = max_images
223
+ self.logger_log_images = {
224
+ pl.loggers.WandbLogger: self._wandb,
225
+ pl.loggers.TestTubeLogger: self._testtube,
226
+ }
227
+ self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
228
+ if not increase_log_steps:
229
+ self.log_steps = [self.batch_freq]
230
+ self.clamp = clamp
231
+
232
+ @rank_zero_only
233
+ def _wandb(self, pl_module, images, batch_idx, split):
234
+ raise ValueError("No way wandb")
235
+ grids = dict()
236
+ for k in images:
237
+ grid = torchvision.utils.make_grid(images[k])
238
+ grids[f"{split}/{k}"] = wandb.Image(grid)
239
+ pl_module.logger.experiment.log(grids)
240
+
241
+ @rank_zero_only
242
+ def _testtube(self, pl_module, images, batch_idx, split):
243
+ for k in images:
244
+ grid = torchvision.utils.make_grid(images[k])
245
+ grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
246
+
247
+ tag = f"{split}/{k}"
248
+ pl_module.logger.experiment.add_image(
249
+ tag, grid,
250
+ global_step=pl_module.global_step)
251
+
252
+ @rank_zero_only
253
+ def log_local(self, save_dir, split, images,
254
+ global_step, current_epoch, batch_idx):
255
+ root = os.path.join(save_dir, "images", split)
256
+ for k in images:
257
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
258
+
259
+ grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
260
+ grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
261
+ grid = grid.numpy()
262
+ grid = (grid*255).astype(np.uint8)
263
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
264
+ k,
265
+ global_step,
266
+ current_epoch,
267
+ batch_idx)
268
+ path = os.path.join(root, filename)
269
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
270
+ Image.fromarray(grid).save(path)
271
+
272
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
273
+ if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
274
+ hasattr(pl_module, "log_images") and
275
+ callable(pl_module.log_images) and
276
+ self.max_images > 0):
277
+ logger = type(pl_module.logger)
278
+
279
+ is_train = pl_module.training
280
+ if is_train:
281
+ pl_module.eval()
282
+
283
+ with torch.no_grad():
284
+ images = pl_module.log_images(batch, split=split, pl_module=pl_module)
285
+
286
+ for k in images:
287
+ N = min(images[k].shape[0], self.max_images)
288
+ images[k] = images[k][:N]
289
+ if isinstance(images[k], torch.Tensor):
290
+ images[k] = images[k].detach().cpu()
291
+ if self.clamp:
292
+ images[k] = torch.clamp(images[k], -1., 1.)
293
+
294
+ self.log_local(pl_module.logger.save_dir, split, images,
295
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
296
+
297
+ logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
298
+ logger_log_images(pl_module, images, pl_module.global_step, split)
299
+
300
+ if is_train:
301
+ pl_module.train()
302
+
303
+ def check_frequency(self, batch_idx):
304
+ if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
305
+ try:
306
+ self.log_steps.pop(0)
307
+ except IndexError:
308
+ pass
309
+ return True
310
+ return False
311
+
312
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
313
+ self.log_img(pl_module, batch, batch_idx, split="train")
314
+
315
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
316
+ self.log_img(pl_module, batch, batch_idx, split="val")
317
+
318
+
319
+
320
+ if __name__ == "__main__":
321
+ # custom parser to specify config files, train, test and debug mode,
322
+ # postfix, resume.
323
+ # `--key value` arguments are interpreted as arguments to the trainer.
324
+ # `nested.key=value` arguments are interpreted as config parameters.
325
+ # configs are merged from left-to-right followed by command line parameters.
326
+
327
+ # model:
328
+ # base_learning_rate: float
329
+ # target: path to lightning module
330
+ # params:
331
+ # key: value
332
+ # data:
333
+ # target: main.DataModuleFromConfig
334
+ # params:
335
+ # batch_size: int
336
+ # wrap: bool
337
+ # train:
338
+ # target: path to train dataset
339
+ # params:
340
+ # key: value
341
+ # validation:
342
+ # target: path to validation dataset
343
+ # params:
344
+ # key: value
345
+ # test:
346
+ # target: path to test dataset
347
+ # params:
348
+ # key: value
349
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
350
+ # trainer:
351
+ # additional arguments to trainer
352
+ # logger:
353
+ # logger to instantiate
354
+ # modelcheckpoint:
355
+ # modelcheckpoint to instantiate
356
+ # callbacks:
357
+ # callback1:
358
+ # target: importpath
359
+ # params:
360
+ # key: value
361
+
362
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
363
+
364
+ # add cwd for convenience and to make classes in this file available when
365
+ # running as `python main.py`
366
+ # (in particular `main.DataModuleFromConfig`)
367
+ sys.path.append(os.getcwd())
368
+
369
+ parser = get_parser()
370
+ parser = Trainer.add_argparse_args(parser)
371
+
372
+ opt, unknown = parser.parse_known_args()
373
+ if opt.name and opt.resume:
374
+ raise ValueError(
375
+ "-n/--name and -r/--resume cannot be specified both."
376
+ "If you want to resume training in a new log folder, "
377
+ "use -n/--name in combination with --resume_from_checkpoint"
378
+ )
379
+ if opt.resume:
380
+ if not os.path.exists(opt.resume):
381
+ raise ValueError("Cannot find {}".format(opt.resume))
382
+ if os.path.isfile(opt.resume):
383
+ paths = opt.resume.split("/")
384
+ idx = len(paths)-paths[::-1].index("logs")+1
385
+ logdir = "/".join(paths[:idx])
386
+ ckpt = opt.resume
387
+ else:
388
+ assert os.path.isdir(opt.resume), opt.resume
389
+ logdir = opt.resume.rstrip("/")
390
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
391
+
392
+ opt.resume_from_checkpoint = ckpt
393
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
394
+ opt.base = base_configs+opt.base
395
+ _tmp = logdir.split("/")
396
+ nowname = _tmp[_tmp.index("logs")+1]
397
+ else:
398
+ if opt.name:
399
+ name = "_"+opt.name
400
+ elif opt.base:
401
+ cfg_fname = os.path.split(opt.base[0])[-1]
402
+ cfg_name = os.path.splitext(cfg_fname)[0]
403
+ name = "_"+cfg_name
404
+ else:
405
+ name = ""
406
+ nowname = now+name+opt.postfix
407
+ logdir = os.path.join("logs", nowname)
408
+
409
+ ckptdir = os.path.join(logdir, "checkpoints")
410
+ cfgdir = os.path.join(logdir, "configs")
411
+ seed_everything(opt.seed)
412
+
413
+ try:
414
+ # init and save configs
415
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
416
+ cli = OmegaConf.from_dotlist(unknown)
417
+ config = OmegaConf.merge(*configs, cli)
418
+ lightning_config = config.pop("lightning", OmegaConf.create())
419
+ # merge trainer cli with config
420
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
421
+ # default to ddp
422
+ trainer_config["distributed_backend"] = "ddp"
423
+ for k in nondefault_trainer_args(opt):
424
+ trainer_config[k] = getattr(opt, k)
425
+ if not "gpus" in trainer_config:
426
+ del trainer_config["distributed_backend"]
427
+ cpu = True
428
+ else:
429
+ gpuinfo = trainer_config["gpus"]
430
+ print(f"Running on GPUs {gpuinfo}")
431
+ cpu = False
432
+ trainer_opt = argparse.Namespace(**trainer_config)
433
+ lightning_config.trainer = trainer_config
434
+
435
+ # model
436
+ model = instantiate_from_config(config.model)
437
+
438
+ # trainer and callbacks
439
+ trainer_kwargs = dict()
440
+
441
+ # default logger configs
442
+ # NOTE wandb < 0.10.0 interferes with shutdown
443
+ # wandb >= 0.10.0 seems to fix it but still interferes with pudb
444
+ # debugging (wrongly sized pudb ui)
445
+ # thus prefer testtube for now
446
+ default_logger_cfgs = {
447
+ "wandb": {
448
+ "target": "pytorch_lightning.loggers.WandbLogger",
449
+ "params": {
450
+ "name": nowname,
451
+ "save_dir": logdir,
452
+ "offline": opt.debug,
453
+ "id": nowname,
454
+ }
455
+ },
456
+ "testtube": {
457
+ "target": "pytorch_lightning.loggers.TestTubeLogger",
458
+ "params": {
459
+ "name": "testtube",
460
+ "save_dir": logdir,
461
+ }
462
+ },
463
+ }
464
+ default_logger_cfg = default_logger_cfgs["testtube"]
465
+ logger_cfg = lightning_config.logger or OmegaConf.create()
466
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
467
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
468
+
469
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
470
+ # specify which metric is used to determine best models
471
+ default_modelckpt_cfg = {
472
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
473
+ "params": {
474
+ "dirpath": ckptdir,
475
+ "filename": "{epoch:06}",
476
+ "verbose": True,
477
+ "save_last": True,
478
+ }
479
+ }
480
+ if hasattr(model, "monitor"):
481
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
482
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
483
+ default_modelckpt_cfg["params"]["save_top_k"] = 3
484
+
485
+ modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()
486
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
487
+ trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
488
+
489
+ # add callback which sets up log directory
490
+ default_callbacks_cfg = {
491
+ "setup_callback": {
492
+ "target": "main.SetupCallback",
493
+ "params": {
494
+ "resume": opt.resume,
495
+ "now": now,
496
+ "logdir": logdir,
497
+ "ckptdir": ckptdir,
498
+ "cfgdir": cfgdir,
499
+ "config": config,
500
+ "lightning_config": lightning_config,
501
+ }
502
+ },
503
+ "image_logger": {
504
+ "target": "main.ImageLogger",
505
+ "params": {
506
+ "batch_frequency": 750,
507
+ "max_images": 4,
508
+ "clamp": True
509
+ }
510
+ },
511
+ "learning_rate_logger": {
512
+ "target": "main.LearningRateMonitor",
513
+ "params": {
514
+ "logging_interval": "step",
515
+ #"log_momentum": True
516
+ }
517
+ },
518
+ }
519
+ callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
520
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
521
+ trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
522
+
523
+ trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
524
+
525
+ # data
526
+ data = instantiate_from_config(config.data)
527
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
528
+ # calling these ourselves should not be necessary but it is.
529
+ # lightning still takes care of proper multiprocessing though
530
+ data.prepare_data()
531
+ data.setup()
532
+
533
+ # configure learning rate
534
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
535
+ if not cpu:
536
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
537
+ else:
538
+ ngpu = 1
539
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
540
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
541
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
542
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
543
+ print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
544
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
545
+
546
+ # allow checkpointing via USR1
547
+ def melk(*args, **kwargs):
548
+ # run all checkpoint hooks
549
+ if trainer.global_rank == 0:
550
+ print("Summoning checkpoint.")
551
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
552
+ trainer.save_checkpoint(ckpt_path)
553
+
554
+ def divein(*args, **kwargs):
555
+ if trainer.global_rank == 0:
556
+ import pudb; pudb.set_trace()
557
+
558
+ import signal
559
+ signal.signal(signal.SIGUSR1, melk)
560
+ signal.signal(signal.SIGUSR2, divein)
561
+
562
+ # run
563
+ if opt.train:
564
+ try:
565
+ trainer.fit(model, data)
566
+ except Exception:
567
+ melk()
568
+ raise
569
+ if not opt.no_test and not trainer.interrupted:
570
+ trainer.test(model, data)
571
+ except Exception:
572
+ if opt.debug and trainer.global_rank==0:
573
+ try:
574
+ import pudb as debugger
575
+ except ImportError:
576
+ import pdb as debugger
577
+ debugger.post_mortem()
578
+ raise
579
+ finally:
580
+ # move newly created debug project to debug_runs
581
+ if opt.debug and not opt.resume and trainer.global_rank==0:
582
+ dst, name = os.path.split(logdir)
583
+ dst = os.path.join(dst, "debug_runs", name)
584
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
585
+ os.rename(logdir, dst)
scripts/extract_depth.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import trange
5
+ from PIL import Image
6
+
7
+
8
+ def get_state(gpu):
9
+ import torch
10
+ midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
11
+ if gpu:
12
+ midas.cuda()
13
+ midas.eval()
14
+
15
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
16
+ transform = midas_transforms.default_transform
17
+
18
+ state = {"model": midas,
19
+ "transform": transform}
20
+ return state
21
+
22
+
23
+ def depth_to_rgba(x):
24
+ assert x.dtype == np.float32
25
+ assert len(x.shape) == 2
26
+ y = x.copy()
27
+ y.dtype = np.uint8
28
+ y = y.reshape(x.shape+(4,))
29
+ return np.ascontiguousarray(y)
30
+
31
+
32
+ def rgba_to_depth(x):
33
+ assert x.dtype == np.uint8
34
+ assert len(x.shape) == 3 and x.shape[2] == 4
35
+ y = x.copy()
36
+ y.dtype = np.float32
37
+ y = y.reshape(x.shape[:2])
38
+ return np.ascontiguousarray(y)
39
+
40
+
41
+ def run(x, state):
42
+ model = state["model"]
43
+ transform = state["transform"]
44
+ hw = x.shape[:2]
45
+ with torch.no_grad():
46
+ prediction = model(transform((x + 1.0) * 127.5).cuda())
47
+ prediction = torch.nn.functional.interpolate(
48
+ prediction.unsqueeze(1),
49
+ size=hw,
50
+ mode="bicubic",
51
+ align_corners=False,
52
+ ).squeeze()
53
+ output = prediction.cpu().numpy()
54
+ return output
55
+
56
+
57
+ def get_filename(relpath, level=-2):
58
+ # save class folder structure and filename:
59
+ fn = relpath.split(os.sep)[level:]
60
+ folder = fn[-2]
61
+ file = fn[-1].split('.')[0]
62
+ return folder, file
63
+
64
+
65
+ def save_depth(dataset, path, debug=False):
66
+ os.makedirs(path)
67
+ N = len(dset)
68
+ if debug:
69
+ N = 10
70
+ state = get_state(gpu=True)
71
+ for idx in trange(N, desc="Data"):
72
+ ex = dataset[idx]
73
+ image, relpath = ex["image"], ex["relpath"]
74
+ folder, filename = get_filename(relpath)
75
+ # prepare
76
+ folderabspath = os.path.join(path, folder)
77
+ os.makedirs(folderabspath, exist_ok=True)
78
+ savepath = os.path.join(folderabspath, filename)
79
+ # run model
80
+ xout = run(image, state)
81
+ I = depth_to_rgba(xout)
82
+ Image.fromarray(I).save("{}.png".format(savepath))
83
+
84
+
85
+ if __name__ == "__main__":
86
+ from taming.data.imagenet import ImageNetTrain, ImageNetValidation
87
+ out = "data/imagenet_depth"
88
+ if not os.path.exists(out):
89
+ print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
90
+ "(be prepared that the output size will be larger than ImageNet itself).")
91
+ exit(1)
92
+
93
+ # go
94
+ dset = ImageNetValidation()
95
+ abspath = os.path.join(out, "val")
96
+ if os.path.exists(abspath):
97
+ print("{} exists - not doing anything.".format(abspath))
98
+ else:
99
+ print("preparing {}".format(abspath))
100
+ save_depth(dset, abspath)
101
+ print("done with validation split")
102
+
103
+ dset = ImageNetTrain()
104
+ abspath = os.path.join(out, "train")
105
+ if os.path.exists(abspath):
106
+ print("{} exists - not doing anything.".format(abspath))
107
+ else:
108
+ print("preparing {}".format(abspath))
109
+ save_depth(dset, abspath)
110
+ print("done with train split")
111
+
112
+ print("done done.")
scripts/extract_segmentation.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ import numpy as np
3
+ import scipy
4
+ import torch
5
+ import torch.nn as nn
6
+ from scipy import ndimage
7
+ from tqdm import tqdm, trange
8
+ from PIL import Image
9
+ import torch.hub
10
+ import torchvision
11
+ import torch.nn.functional as F
12
+
13
+ # download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
14
+ # https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
15
+ # and put the path here
16
+ CKPT_PATH = "TODO"
17
+
18
+ rescale = lambda x: (x + 1.) / 2.
19
+
20
+ def rescale_bgr(x):
21
+ x = (x+1)*127.5
22
+ x = torch.flip(x, dims=[0])
23
+ return x
24
+
25
+
26
+ class COCOStuffSegmenter(nn.Module):
27
+ def __init__(self, config):
28
+ super().__init__()
29
+ self.config = config
30
+ self.n_labels = 182
31
+ model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
32
+ ckpt_path = CKPT_PATH
33
+ model.load_state_dict(torch.load(ckpt_path))
34
+ self.model = model
35
+
36
+ normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
37
+ self.image_transform = torchvision.transforms.Compose([
38
+ torchvision.transforms.Lambda(lambda image: torch.stack(
39
+ [normalize(rescale_bgr(x)) for x in image]))
40
+ ])
41
+
42
+ def forward(self, x, upsample=None):
43
+ x = self._pre_process(x)
44
+ x = self.model(x)
45
+ if upsample is not None:
46
+ x = torch.nn.functional.upsample_bilinear(x, size=upsample)
47
+ return x
48
+
49
+ def _pre_process(self, x):
50
+ x = self.image_transform(x)
51
+ return x
52
+
53
+ @property
54
+ def mean(self):
55
+ # bgr
56
+ return [104.008, 116.669, 122.675]
57
+
58
+ @property
59
+ def std(self):
60
+ return [1.0, 1.0, 1.0]
61
+
62
+ @property
63
+ def input_size(self):
64
+ return [3, 224, 224]
65
+
66
+
67
+ def run_model(img, model):
68
+ model = model.eval()
69
+ with torch.no_grad():
70
+ segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
71
+ segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
72
+ return segmentation.detach().cpu()
73
+
74
+
75
+ def get_input(batch, k):
76
+ x = batch[k]
77
+ if len(x.shape) == 3:
78
+ x = x[..., None]
79
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
80
+ return x.float()
81
+
82
+
83
+ def save_segmentation(segmentation, path):
84
+ # --> class label to uint8, save as png
85
+ os.makedirs(os.path.dirname(path), exist_ok=True)
86
+ assert len(segmentation.shape)==4
87
+ assert segmentation.shape[0]==1
88
+ for seg in segmentation:
89
+ seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
90
+ seg = Image.fromarray(seg)
91
+ seg.save(path)
92
+
93
+
94
+ def iterate_dataset(dataloader, destpath, model):
95
+ os.makedirs(destpath, exist_ok=True)
96
+ num_processed = 0
97
+ for i, batch in tqdm(enumerate(dataloader), desc="Data"):
98
+ try:
99
+ img = get_input(batch, "image")
100
+ img = img.cuda()
101
+ seg = run_model(img, model)
102
+
103
+ path = batch["relative_file_path_"][0]
104
+ path = os.path.splitext(path)[0]
105
+
106
+ path = os.path.join(destpath, path + ".png")
107
+ save_segmentation(seg, path)
108
+ num_processed += 1
109
+ except Exception as e:
110
+ print(e)
111
+ print("but anyhow..")
112
+
113
+ print("Processed {} files. Bye.".format(num_processed))
114
+
115
+
116
+ from taming.data.sflckr import Examples
117
+ from torch.utils.data import DataLoader
118
+
119
+ if __name__ == "__main__":
120
+ dest = sys.argv[1]
121
+ batchsize = 1
122
+ print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
123
+
124
+ model = COCOStuffSegmenter({}).cuda()
125
+ print("Instantiated model.")
126
+
127
+ dataset = Examples()
128
+ dloader = DataLoader(dataset, batch_size=batchsize)
129
+ iterate_dataset(dataloader=dloader, destpath=dest, model=model)
130
+ print("done.")
scripts/extract_submodel.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+
4
+ if __name__ == "__main__":
5
+ inpath = sys.argv[1]
6
+ outpath = sys.argv[2]
7
+ submodel = "cond_stage_model"
8
+ if len(sys.argv) > 3:
9
+ submodel = sys.argv[3]
10
+
11
+ print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
12
+
13
+ sd = torch.load(inpath, map_location="cpu")
14
+ new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
15
+ for k,v in sd["state_dict"].items()
16
+ if k.startswith("cond_stage_model"))}
17
+ torch.save(new_sd, outpath)
scripts/make_samples.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob, math, time
2
+ import torch
3
+ import numpy as np
4
+ from omegaconf import OmegaConf
5
+ from PIL import Image
6
+ from main import instantiate_from_config, DataModuleFromConfig
7
+ from torch.utils.data import DataLoader
8
+ from torch.utils.data.dataloader import default_collate
9
+ from tqdm import trange
10
+
11
+
12
+ def save_image(x, path):
13
+ c,h,w = x.shape
14
+ assert c==3
15
+ x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
16
+ Image.fromarray(x).save(path)
17
+
18
+
19
+ @torch.no_grad()
20
+ def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
21
+ if len(dsets.datasets) > 1:
22
+ split = sorted(dsets.datasets.keys())[0]
23
+ dset = dsets.datasets[split]
24
+ else:
25
+ dset = next(iter(dsets.datasets.values()))
26
+ print("Dataset: ", dset.__class__.__name__)
27
+ for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
28
+ indices = list(range(start_idx, start_idx+batch_size))
29
+ example = default_collate([dset[i] for i in indices])
30
+
31
+ x = model.get_input("image", example).to(model.device)
32
+ for i in range(x.shape[0]):
33
+ save_image(x[i], os.path.join(outdir, "originals",
34
+ "{:06}.png".format(indices[i])))
35
+
36
+ cond_key = model.cond_stage_key
37
+ c = model.get_input(cond_key, example).to(model.device)
38
+
39
+ scale_factor = 1.0
40
+ quant_z, z_indices = model.encode_to_z(x)
41
+ quant_c, c_indices = model.encode_to_c(c)
42
+
43
+ cshape = quant_z.shape
44
+
45
+ xrec = model.first_stage_model.decode(quant_z)
46
+ for i in range(xrec.shape[0]):
47
+ save_image(xrec[i], os.path.join(outdir, "reconstructions",
48
+ "{:06}.png".format(indices[i])))
49
+
50
+ if cond_key == "segmentation":
51
+ # get image from segmentation mask
52
+ num_classes = c.shape[1]
53
+ c = torch.argmax(c, dim=1, keepdim=True)
54
+ c = torch.nn.functional.one_hot(c, num_classes=num_classes)
55
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
56
+ c = model.cond_stage_model.to_rgb(c)
57
+
58
+ idx = z_indices
59
+
60
+ half_sample = False
61
+ if half_sample:
62
+ start = idx.shape[1]//2
63
+ else:
64
+ start = 0
65
+
66
+ idx[:,start:] = 0
67
+ idx = idx.reshape(cshape[0],cshape[2],cshape[3])
68
+ start_i = start//cshape[3]
69
+ start_j = start %cshape[3]
70
+
71
+ cidx = c_indices
72
+ cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
73
+
74
+ sample = True
75
+
76
+ for i in range(start_i,cshape[2]-0):
77
+ if i <= 8:
78
+ local_i = i
79
+ elif cshape[2]-i < 8:
80
+ local_i = 16-(cshape[2]-i)
81
+ else:
82
+ local_i = 8
83
+ for j in range(start_j,cshape[3]-0):
84
+ if j <= 8:
85
+ local_j = j
86
+ elif cshape[3]-j < 8:
87
+ local_j = 16-(cshape[3]-j)
88
+ else:
89
+ local_j = 8
90
+
91
+ i_start = i-local_i
92
+ i_end = i_start+16
93
+ j_start = j-local_j
94
+ j_end = j_start+16
95
+ patch = idx[:,i_start:i_end,j_start:j_end]
96
+ patch = patch.reshape(patch.shape[0],-1)
97
+ cpatch = cidx[:, i_start:i_end, j_start:j_end]
98
+ cpatch = cpatch.reshape(cpatch.shape[0], -1)
99
+ patch = torch.cat((cpatch, patch), dim=1)
100
+ logits,_ = model.transformer(patch[:,:-1])
101
+ logits = logits[:, -256:, :]
102
+ logits = logits.reshape(cshape[0],16,16,-1)
103
+ logits = logits[:,local_i,local_j,:]
104
+
105
+ logits = logits/temperature
106
+
107
+ if top_k is not None:
108
+ logits = model.top_k_logits(logits, top_k)
109
+ # apply softmax to convert to probabilities
110
+ probs = torch.nn.functional.softmax(logits, dim=-1)
111
+ # sample from the distribution or take the most likely
112
+ if sample:
113
+ ix = torch.multinomial(probs, num_samples=1)
114
+ else:
115
+ _, ix = torch.topk(probs, k=1, dim=-1)
116
+ idx[:,i,j] = ix
117
+
118
+ xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
119
+ for i in range(xsample.shape[0]):
120
+ save_image(xsample[i], os.path.join(outdir, "samples",
121
+ "{:06}.png".format(indices[i])))
122
+
123
+
124
+ def get_parser():
125
+ parser = argparse.ArgumentParser()
126
+ parser.add_argument(
127
+ "-r",
128
+ "--resume",
129
+ type=str,
130
+ nargs="?",
131
+ help="load from logdir or checkpoint in logdir",
132
+ )
133
+ parser.add_argument(
134
+ "-b",
135
+ "--base",
136
+ nargs="*",
137
+ metavar="base_config.yaml",
138
+ help="paths to base configs. Loaded from left-to-right. "
139
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
140
+ default=list(),
141
+ )
142
+ parser.add_argument(
143
+ "-c",
144
+ "--config",
145
+ nargs="?",
146
+ metavar="single_config.yaml",
147
+ help="path to single config. If specified, base configs will be ignored "
148
+ "(except for the last one if left unspecified).",
149
+ const=True,
150
+ default="",
151
+ )
152
+ parser.add_argument(
153
+ "--ignore_base_data",
154
+ action="store_true",
155
+ help="Ignore data specification from base configs. Useful if you want "
156
+ "to specify a custom datasets on the command line.",
157
+ )
158
+ parser.add_argument(
159
+ "--outdir",
160
+ required=True,
161
+ type=str,
162
+ help="Where to write outputs to.",
163
+ )
164
+ parser.add_argument(
165
+ "--top_k",
166
+ type=int,
167
+ default=100,
168
+ help="Sample from among top-k predictions.",
169
+ )
170
+ parser.add_argument(
171
+ "--temperature",
172
+ type=float,
173
+ default=1.0,
174
+ help="Sampling temperature.",
175
+ )
176
+ return parser
177
+
178
+
179
+ def load_model_from_config(config, sd, gpu=True, eval_mode=True):
180
+ if "ckpt_path" in config.params:
181
+ print("Deleting the restore-ckpt path from the config...")
182
+ config.params.ckpt_path = None
183
+ if "downsample_cond_size" in config.params:
184
+ print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
185
+ config.params.downsample_cond_size = -1
186
+ config.params["downsample_cond_factor"] = 0.5
187
+ try:
188
+ if "ckpt_path" in config.params.first_stage_config.params:
189
+ config.params.first_stage_config.params.ckpt_path = None
190
+ print("Deleting the first-stage restore-ckpt path from the config...")
191
+ if "ckpt_path" in config.params.cond_stage_config.params:
192
+ config.params.cond_stage_config.params.ckpt_path = None
193
+ print("Deleting the cond-stage restore-ckpt path from the config...")
194
+ except:
195
+ pass
196
+
197
+ model = instantiate_from_config(config)
198
+ if sd is not None:
199
+ missing, unexpected = model.load_state_dict(sd, strict=False)
200
+ print(f"Missing Keys in State Dict: {missing}")
201
+ print(f"Unexpected Keys in State Dict: {unexpected}")
202
+ if gpu:
203
+ model.cuda()
204
+ if eval_mode:
205
+ model.eval()
206
+ return {"model": model}
207
+
208
+
209
+ def get_data(config):
210
+ # get data
211
+ data = instantiate_from_config(config.data)
212
+ data.prepare_data()
213
+ data.setup()
214
+ return data
215
+
216
+
217
+ def load_model_and_dset(config, ckpt, gpu, eval_mode):
218
+ # get data
219
+ dsets = get_data(config) # calls data.config ...
220
+
221
+ # now load the specified checkpoint
222
+ if ckpt:
223
+ pl_sd = torch.load(ckpt, map_location="cpu")
224
+ global_step = pl_sd["global_step"]
225
+ else:
226
+ pl_sd = {"state_dict": None}
227
+ global_step = None
228
+ model = load_model_from_config(config.model,
229
+ pl_sd["state_dict"],
230
+ gpu=gpu,
231
+ eval_mode=eval_mode)["model"]
232
+ return dsets, model, global_step
233
+
234
+
235
+ if __name__ == "__main__":
236
+ sys.path.append(os.getcwd())
237
+
238
+ parser = get_parser()
239
+
240
+ opt, unknown = parser.parse_known_args()
241
+
242
+ ckpt = None
243
+ if opt.resume:
244
+ if not os.path.exists(opt.resume):
245
+ raise ValueError("Cannot find {}".format(opt.resume))
246
+ if os.path.isfile(opt.resume):
247
+ paths = opt.resume.split("/")
248
+ try:
249
+ idx = len(paths)-paths[::-1].index("logs")+1
250
+ except ValueError:
251
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
252
+ logdir = "/".join(paths[:idx])
253
+ ckpt = opt.resume
254
+ else:
255
+ assert os.path.isdir(opt.resume), opt.resume
256
+ logdir = opt.resume.rstrip("/")
257
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
258
+ print(f"logdir:{logdir}")
259
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
260
+ opt.base = base_configs+opt.base
261
+
262
+ if opt.config:
263
+ if type(opt.config) == str:
264
+ opt.base = [opt.config]
265
+ else:
266
+ opt.base = [opt.base[-1]]
267
+
268
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
269
+ cli = OmegaConf.from_dotlist(unknown)
270
+ if opt.ignore_base_data:
271
+ for config in configs:
272
+ if hasattr(config, "data"): del config["data"]
273
+ config = OmegaConf.merge(*configs, cli)
274
+
275
+ print(ckpt)
276
+ gpu = True
277
+ eval_mode = True
278
+ show_config = False
279
+ if show_config:
280
+ print(OmegaConf.to_container(config))
281
+
282
+ dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
283
+ print(f"Global step: {global_step}")
284
+
285
+ outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
286
+ opt.top_k,
287
+ opt.temperature))
288
+ os.makedirs(outdir, exist_ok=True)
289
+ print("Writing samples to ", outdir)
290
+ for k in ["originals", "reconstructions", "samples"]:
291
+ os.makedirs(os.path.join(outdir, k), exist_ok=True)
292
+ run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)
scripts/make_scene_samples.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import sys
4
+ from itertools import product
5
+ from pathlib import Path
6
+ from typing import Literal, List, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+ from pytorch_lightning import seed_everything
12
+ from torch import Tensor
13
+ from torchvision.utils import save_image
14
+ from tqdm import tqdm
15
+
16
+ from scripts.make_samples import get_parser, load_model_and_dset
17
+ from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
18
+ from taming.data.helper_types import BoundingBox, Annotation
19
+ from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
20
+ from taming.models.cond_transformer import Net2NetTransformer
21
+
22
+ seed_everything(42424242)
23
+ device: Literal['cuda', 'cpu'] = 'cuda'
24
+ first_stage_factor = 16
25
+ trained_on_res = 256
26
+
27
+
28
+ def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
29
+ assert 0 <= coord < coord_max
30
+ coord_desired_center = (coord_window - 1) // 2
31
+ return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
32
+
33
+
34
+ def get_crop_coordinates(x: int, y: int) -> BoundingBox:
35
+ WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
36
+ x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
37
+ y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
38
+ w = first_stage_factor / WIDTH
39
+ h = first_stage_factor / HEIGHT
40
+ return x0, y0, w, h
41
+
42
+
43
+ def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
44
+ WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
45
+ x0 = _helper(predict_x, WIDTH, first_stage_factor)
46
+ y0 = _helper(predict_y, HEIGHT, first_stage_factor)
47
+ no_images = z_indices.shape[0]
48
+ cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
49
+ cut_out_2 = z_indices[:, predict_y, x0:predict_x]
50
+ return torch.cat((cut_out_1, cut_out_2), dim=1)
51
+
52
+
53
+ @torch.no_grad()
54
+ def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
55
+ conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int,
56
+ temperature: float, top_k: int) -> Tensor:
57
+ x_max, y_max = desired_z_shape[1], desired_z_shape[0]
58
+
59
+ annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
60
+
61
+ recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
62
+ if not recompute_conditional:
63
+ crop_coordinates = get_crop_coordinates(0, 0)
64
+ conditional_indices = conditional_builder.build(annotations, crop_coordinates)
65
+ c_indices = conditional_indices.to(device).repeat(no_samples, 1)
66
+ z_indices = torch.zeros((no_samples, 0), device=device).long()
67
+ output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
68
+ sample=True, top_k=top_k)
69
+ else:
70
+ output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
71
+ for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
72
+ crop_coordinates = get_crop_coordinates(predict_x, predict_y)
73
+ z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
74
+ conditional_indices = conditional_builder.build(annotations, crop_coordinates)
75
+ c_indices = conditional_indices.to(device).repeat(no_samples, 1)
76
+ new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
77
+ output_indices[:, predict_y, predict_x] = new_index[:, -1]
78
+ z_shape = (
79
+ no_samples,
80
+ model.first_stage_model.quantize.e_dim, # codebook embed_dim
81
+ desired_z_shape[0], # z_height
82
+ desired_z_shape[1] # z_width
83
+ )
84
+ x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
85
+ x_sample = x_sample.to('cpu')
86
+
87
+ plotter = conditional_builder.plot
88
+ figure_size = (x_sample.shape[2], x_sample.shape[3])
89
+ scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
90
+ plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
91
+ return torch.cat((x_sample, plot.unsqueeze(0)))
92
+
93
+
94
+ def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
95
+ if not resolution_str.count(',') == 1:
96
+ raise ValueError("Give resolution as in 'height,width'")
97
+ res_h, res_w = resolution_str.split(',')
98
+ res_h = max(int(res_h), trained_on_res)
99
+ res_w = max(int(res_w), trained_on_res)
100
+ z_h = int(round(res_h/first_stage_factor))
101
+ z_w = int(round(res_w/first_stage_factor))
102
+ return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
103
+
104
+
105
+ def add_arg_to_parser(parser):
106
+ parser.add_argument(
107
+ "-R",
108
+ "--resolution",
109
+ type=str,
110
+ default='256,256',
111
+ help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
112
+ )
113
+ parser.add_argument(
114
+ "-C",
115
+ "--conditional",
116
+ type=str,
117
+ default='objects_bbox',
118
+ help=f"objects_bbox or objects_center_points",
119
+ )
120
+ parser.add_argument(
121
+ "-N",
122
+ "--n_samples_per_layout",
123
+ type=int,
124
+ default=4,
125
+ help=f"how many samples to generate per layout",
126
+ )
127
+ return parser
128
+
129
+
130
+ if __name__ == "__main__":
131
+ sys.path.append(os.getcwd())
132
+
133
+ parser = get_parser()
134
+ parser = add_arg_to_parser(parser)
135
+
136
+ opt, unknown = parser.parse_known_args()
137
+
138
+ ckpt = None
139
+ if opt.resume:
140
+ if not os.path.exists(opt.resume):
141
+ raise ValueError("Cannot find {}".format(opt.resume))
142
+ if os.path.isfile(opt.resume):
143
+ paths = opt.resume.split("/")
144
+ try:
145
+ idx = len(paths)-paths[::-1].index("logs")+1
146
+ except ValueError:
147
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
148
+ logdir = "/".join(paths[:idx])
149
+ ckpt = opt.resume
150
+ else:
151
+ assert os.path.isdir(opt.resume), opt.resume
152
+ logdir = opt.resume.rstrip("/")
153
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
154
+ print(f"logdir:{logdir}")
155
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
156
+ opt.base = base_configs+opt.base
157
+
158
+ if opt.config:
159
+ if type(opt.config) == str:
160
+ opt.base = [opt.config]
161
+ else:
162
+ opt.base = [opt.base[-1]]
163
+
164
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
165
+ cli = OmegaConf.from_dotlist(unknown)
166
+ if opt.ignore_base_data:
167
+ for config in configs:
168
+ if hasattr(config, "data"):
169
+ del config["data"]
170
+ config = OmegaConf.merge(*configs, cli)
171
+ desired_z_shape, desired_resolution = get_resolution(opt.resolution)
172
+ conditional = opt.conditional
173
+
174
+ print(ckpt)
175
+ gpu = True
176
+ eval_mode = True
177
+ show_config = False
178
+ if show_config:
179
+ print(OmegaConf.to_container(config))
180
+
181
+ dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
182
+ print(f"Global step: {global_step}")
183
+
184
+ data_loader = dsets.val_dataloader()
185
+ print(dsets.datasets["validation"].conditional_builders)
186
+ conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
187
+
188
+ outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
189
+ outdir.mkdir(exist_ok=True, parents=True)
190
+ print("Writing samples to ", outdir)
191
+
192
+ p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
193
+ for batch_no, batch in p_bar_1:
194
+ save_img: Optional[Tensor] = None
195
+ for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
196
+ imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
197
+ opt.n_samples_per_layout, opt.temperature, opt.top_k)
198
+ save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)
scripts/sample_conditional.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob, math, time
2
+ import torch
3
+ import numpy as np
4
+ from omegaconf import OmegaConf
5
+ import streamlit as st
6
+ from streamlit import caching
7
+ from PIL import Image
8
+ from main import instantiate_from_config, DataModuleFromConfig
9
+ from torch.utils.data import DataLoader
10
+ from torch.utils.data.dataloader import default_collate
11
+
12
+
13
+ rescale = lambda x: (x + 1.) / 2.
14
+
15
+
16
+ def bchw_to_st(x):
17
+ return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
18
+
19
+ def save_img(xstart, fname):
20
+ I = (xstart.clip(0,1)[0]*255).astype(np.uint8)
21
+ Image.fromarray(I).save(fname)
22
+
23
+
24
+
25
+ def get_interactive_image(resize=False):
26
+ image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
27
+ if image is not None:
28
+ image = Image.open(image)
29
+ if not image.mode == "RGB":
30
+ image = image.convert("RGB")
31
+ image = np.array(image).astype(np.uint8)
32
+ print("upload image shape: {}".format(image.shape))
33
+ img = Image.fromarray(image)
34
+ if resize:
35
+ img = img.resize((256, 256))
36
+ image = np.array(img)
37
+ return image
38
+
39
+
40
+ def single_image_to_torch(x, permute=True):
41
+ assert x is not None, "Please provide an image through the upload function"
42
+ x = np.array(x)
43
+ x = torch.FloatTensor(x/255.*2. - 1.)[None,...]
44
+ if permute:
45
+ x = x.permute(0, 3, 1, 2)
46
+ return x
47
+
48
+
49
+ def pad_to_M(x, M):
50
+ hp = math.ceil(x.shape[2]/M)*M-x.shape[2]
51
+ wp = math.ceil(x.shape[3]/M)*M-x.shape[3]
52
+ x = torch.nn.functional.pad(x, (0,wp,0,hp,0,0,0,0))
53
+ return x
54
+
55
+ @torch.no_grad()
56
+ def run_conditional(model, dsets):
57
+ if len(dsets.datasets) > 1:
58
+ split = st.sidebar.radio("Split", sorted(dsets.datasets.keys()))
59
+ dset = dsets.datasets[split]
60
+ else:
61
+ dset = next(iter(dsets.datasets.values()))
62
+ batch_size = 1
63
+ start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0,
64
+ min_value=0,
65
+ max_value=len(dset)-batch_size)
66
+ indices = list(range(start_index, start_index+batch_size))
67
+
68
+ example = default_collate([dset[i] for i in indices])
69
+
70
+ x = model.get_input("image", example).to(model.device)
71
+
72
+ cond_key = model.cond_stage_key
73
+ c = model.get_input(cond_key, example).to(model.device)
74
+
75
+ scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00)
76
+ if scale_factor != 1.0:
77
+ x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic")
78
+ c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic")
79
+
80
+ quant_z, z_indices = model.encode_to_z(x)
81
+ quant_c, c_indices = model.encode_to_c(c)
82
+
83
+ cshape = quant_z.shape
84
+
85
+ xrec = model.first_stage_model.decode(quant_z)
86
+ st.write("image: {}".format(x.shape))
87
+ st.image(bchw_to_st(x), clamp=True, output_format="PNG")
88
+ st.write("image reconstruction: {}".format(xrec.shape))
89
+ st.image(bchw_to_st(xrec), clamp=True, output_format="PNG")
90
+
91
+ if cond_key == "segmentation":
92
+ # get image from segmentation mask
93
+ num_classes = c.shape[1]
94
+ c = torch.argmax(c, dim=1, keepdim=True)
95
+ c = torch.nn.functional.one_hot(c, num_classes=num_classes)
96
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
97
+ c = model.cond_stage_model.to_rgb(c)
98
+
99
+ st.write(f"{cond_key}: {tuple(c.shape)}")
100
+ st.image(bchw_to_st(c), clamp=True, output_format="PNG")
101
+
102
+ idx = z_indices
103
+
104
+ half_sample = st.sidebar.checkbox("Image Completion", value=False)
105
+ if half_sample:
106
+ start = idx.shape[1]//2
107
+ else:
108
+ start = 0
109
+
110
+ idx[:,start:] = 0
111
+ idx = idx.reshape(cshape[0],cshape[2],cshape[3])
112
+ start_i = start//cshape[3]
113
+ start_j = start %cshape[3]
114
+
115
+ if not half_sample and quant_z.shape == quant_c.shape:
116
+ st.info("Setting idx to c_indices")
117
+ idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3])
118
+
119
+ cidx = c_indices
120
+ cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
121
+
122
+ xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
123
+ st.image(bchw_to_st(xstart), clamp=True, output_format="PNG")
124
+
125
+ temperature = st.number_input("Temperature", value=1.0)
126
+ top_k = st.number_input("Top k", value=100)
127
+ sample = st.checkbox("Sample", value=True)
128
+ update_every = st.number_input("Update every", value=75)
129
+
130
+ st.text(f"Sampling shape ({cshape[2]},{cshape[3]})")
131
+
132
+ animate = st.checkbox("animate")
133
+ if animate:
134
+ import imageio
135
+ outvid = "sampling.mp4"
136
+ writer = imageio.get_writer(outvid, fps=25)
137
+ elapsed_t = st.empty()
138
+ info = st.empty()
139
+ st.text("Sampled")
140
+ if st.button("Sample"):
141
+ output = st.empty()
142
+ start_t = time.time()
143
+ for i in range(start_i,cshape[2]-0):
144
+ if i <= 8:
145
+ local_i = i
146
+ elif cshape[2]-i < 8:
147
+ local_i = 16-(cshape[2]-i)
148
+ else:
149
+ local_i = 8
150
+ for j in range(start_j,cshape[3]-0):
151
+ if j <= 8:
152
+ local_j = j
153
+ elif cshape[3]-j < 8:
154
+ local_j = 16-(cshape[3]-j)
155
+ else:
156
+ local_j = 8
157
+
158
+ i_start = i-local_i
159
+ i_end = i_start+16
160
+ j_start = j-local_j
161
+ j_end = j_start+16
162
+ elapsed_t.text(f"Time: {time.time() - start_t} seconds")
163
+ info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})")
164
+ patch = idx[:,i_start:i_end,j_start:j_end]
165
+ patch = patch.reshape(patch.shape[0],-1)
166
+ cpatch = cidx[:, i_start:i_end, j_start:j_end]
167
+ cpatch = cpatch.reshape(cpatch.shape[0], -1)
168
+ patch = torch.cat((cpatch, patch), dim=1)
169
+ logits,_ = model.transformer(patch[:,:-1])
170
+ logits = logits[:, -256:, :]
171
+ logits = logits.reshape(cshape[0],16,16,-1)
172
+ logits = logits[:,local_i,local_j,:]
173
+
174
+ logits = logits/temperature
175
+
176
+ if top_k is not None:
177
+ logits = model.top_k_logits(logits, top_k)
178
+ # apply softmax to convert to probabilities
179
+ probs = torch.nn.functional.softmax(logits, dim=-1)
180
+ # sample from the distribution or take the most likely
181
+ if sample:
182
+ ix = torch.multinomial(probs, num_samples=1)
183
+ else:
184
+ _, ix = torch.topk(probs, k=1, dim=-1)
185
+ idx[:,i,j] = ix
186
+
187
+ if (i*cshape[3]+j)%update_every==0:
188
+ xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,)
189
+
190
+ xstart = bchw_to_st(xstart)
191
+ output.image(xstart, clamp=True, output_format="PNG")
192
+
193
+ if animate:
194
+ writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8))
195
+
196
+ xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
197
+ xstart = bchw_to_st(xstart)
198
+ output.image(xstart, clamp=True, output_format="PNG")
199
+ #save_img(xstart, "full_res_sample.png")
200
+ if animate:
201
+ writer.close()
202
+ st.video(outvid)
203
+
204
+
205
+ def get_parser():
206
+ parser = argparse.ArgumentParser()
207
+ parser.add_argument(
208
+ "-r",
209
+ "--resume",
210
+ type=str,
211
+ nargs="?",
212
+ help="load from logdir or checkpoint in logdir",
213
+ )
214
+ parser.add_argument(
215
+ "-b",
216
+ "--base",
217
+ nargs="*",
218
+ metavar="base_config.yaml",
219
+ help="paths to base configs. Loaded from left-to-right. "
220
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
221
+ default=list(),
222
+ )
223
+ parser.add_argument(
224
+ "-c",
225
+ "--config",
226
+ nargs="?",
227
+ metavar="single_config.yaml",
228
+ help="path to single config. If specified, base configs will be ignored "
229
+ "(except for the last one if left unspecified).",
230
+ const=True,
231
+ default="",
232
+ )
233
+ parser.add_argument(
234
+ "--ignore_base_data",
235
+ action="store_true",
236
+ help="Ignore data specification from base configs. Useful if you want "
237
+ "to specify a custom datasets on the command line.",
238
+ )
239
+ return parser
240
+
241
+
242
+ def load_model_from_config(config, sd, gpu=True, eval_mode=True):
243
+ if "ckpt_path" in config.params:
244
+ st.warning("Deleting the restore-ckpt path from the config...")
245
+ config.params.ckpt_path = None
246
+ if "downsample_cond_size" in config.params:
247
+ st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
248
+ config.params.downsample_cond_size = -1
249
+ config.params["downsample_cond_factor"] = 0.5
250
+ try:
251
+ if "ckpt_path" in config.params.first_stage_config.params:
252
+ config.params.first_stage_config.params.ckpt_path = None
253
+ st.warning("Deleting the first-stage restore-ckpt path from the config...")
254
+ if "ckpt_path" in config.params.cond_stage_config.params:
255
+ config.params.cond_stage_config.params.ckpt_path = None
256
+ st.warning("Deleting the cond-stage restore-ckpt path from the config...")
257
+ except:
258
+ pass
259
+
260
+ model = instantiate_from_config(config)
261
+ if sd is not None:
262
+ missing, unexpected = model.load_state_dict(sd, strict=False)
263
+ st.info(f"Missing Keys in State Dict: {missing}")
264
+ st.info(f"Unexpected Keys in State Dict: {unexpected}")
265
+ if gpu:
266
+ model.cuda()
267
+ if eval_mode:
268
+ model.eval()
269
+ return {"model": model}
270
+
271
+
272
+ def get_data(config):
273
+ # get data
274
+ data = instantiate_from_config(config.data)
275
+ data.prepare_data()
276
+ data.setup()
277
+ return data
278
+
279
+
280
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
281
+ def load_model_and_dset(config, ckpt, gpu, eval_mode):
282
+ # get data
283
+ dsets = get_data(config) # calls data.config ...
284
+
285
+ # now load the specified checkpoint
286
+ if ckpt:
287
+ pl_sd = torch.load(ckpt, map_location="cpu")
288
+ global_step = pl_sd["global_step"]
289
+ else:
290
+ pl_sd = {"state_dict": None}
291
+ global_step = None
292
+ model = load_model_from_config(config.model,
293
+ pl_sd["state_dict"],
294
+ gpu=gpu,
295
+ eval_mode=eval_mode)["model"]
296
+ return dsets, model, global_step
297
+
298
+
299
+ if __name__ == "__main__":
300
+ sys.path.append(os.getcwd())
301
+
302
+ parser = get_parser()
303
+
304
+ opt, unknown = parser.parse_known_args()
305
+
306
+ ckpt = None
307
+ if opt.resume:
308
+ if not os.path.exists(opt.resume):
309
+ raise ValueError("Cannot find {}".format(opt.resume))
310
+ if os.path.isfile(opt.resume):
311
+ paths = opt.resume.split("/")
312
+ try:
313
+ idx = len(paths)-paths[::-1].index("logs")+1
314
+ except ValueError:
315
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
316
+ logdir = "/".join(paths[:idx])
317
+ ckpt = opt.resume
318
+ else:
319
+ assert os.path.isdir(opt.resume), opt.resume
320
+ logdir = opt.resume.rstrip("/")
321
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
322
+ print(f"logdir:{logdir}")
323
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
324
+ opt.base = base_configs+opt.base
325
+
326
+ if opt.config:
327
+ if type(opt.config) == str:
328
+ opt.base = [opt.config]
329
+ else:
330
+ opt.base = [opt.base[-1]]
331
+
332
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
333
+ cli = OmegaConf.from_dotlist(unknown)
334
+ if opt.ignore_base_data:
335
+ for config in configs:
336
+ if hasattr(config, "data"): del config["data"]
337
+ config = OmegaConf.merge(*configs, cli)
338
+
339
+ st.sidebar.text(ckpt)
340
+ gs = st.sidebar.empty()
341
+ gs.text(f"Global step: ?")
342
+ st.sidebar.text("Options")
343
+ #gpu = st.sidebar.checkbox("GPU", value=True)
344
+ gpu = True
345
+ #eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
346
+ eval_mode = True
347
+ #show_config = st.sidebar.checkbox("Show Config", value=False)
348
+ show_config = False
349
+ if show_config:
350
+ st.info("Checkpoint: {}".format(ckpt))
351
+ st.json(OmegaConf.to_container(config))
352
+
353
+ dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
354
+ gs.text(f"Global step: {global_step}")
355
+ run_conditional(model, dsets)
scripts/sample_fast.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import torch
3
+ import time
4
+ import numpy as np
5
+ from omegaconf import OmegaConf
6
+ from PIL import Image
7
+ from tqdm import tqdm, trange
8
+ from einops import repeat
9
+
10
+ from main import instantiate_from_config
11
+ from taming.modules.transformer.mingpt import sample_with_past
12
+
13
+
14
+ rescale = lambda x: (x + 1.) / 2.
15
+
16
+
17
+ def chw_to_pillow(x):
18
+ return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8))
19
+
20
+
21
+ @torch.no_grad()
22
+ def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None,
23
+ dim_z=256, h=16, w=16, verbose_time=False, top_p=None):
24
+ log = dict()
25
+ assert type(class_label) == int, f'expecting type int but type is {type(class_label)}'
26
+ qzshape = [batch_size, dim_z, h, w]
27
+ assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.'
28
+ c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device) # class token
29
+ t1 = time.time()
30
+ index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
31
+ sample_logits=True, top_k=top_k, callback=callback,
32
+ temperature=temperature, top_p=top_p)
33
+ if verbose_time:
34
+ sampling_time = time.time() - t1
35
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
36
+ x_sample = model.decode_to_img(index_sample, qzshape)
37
+ log["samples"] = x_sample
38
+ log["class_label"] = c_indices
39
+ return log
40
+
41
+
42
+ @torch.no_grad()
43
+ def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None,
44
+ dim_z=256, h=16, w=16, verbose_time=False):
45
+ log = dict()
46
+ qzshape = [batch_size, dim_z, h, w]
47
+ assert model.be_unconditional, 'Expecting an unconditional model.'
48
+ c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device) # sos token
49
+ t1 = time.time()
50
+ index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
51
+ sample_logits=True, top_k=top_k, callback=callback,
52
+ temperature=temperature, top_p=top_p)
53
+ if verbose_time:
54
+ sampling_time = time.time() - t1
55
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
56
+ x_sample = model.decode_to_img(index_sample, qzshape)
57
+ log["samples"] = x_sample
58
+ return log
59
+
60
+
61
+ @torch.no_grad()
62
+ def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000,
63
+ given_classes=None, top_p=None):
64
+ batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size]
65
+ if not unconditional:
66
+ assert given_classes is not None
67
+ print("Running in pure class-conditional sampling mode. I will produce "
68
+ f"{num_samples} samples for each of the {len(given_classes)} classes, "
69
+ f"i.e. {num_samples*len(given_classes)} in total.")
70
+ for class_label in tqdm(given_classes, desc="Classes"):
71
+ for n, bs in tqdm(enumerate(batches), desc="Sampling Class"):
72
+ if bs == 0: break
73
+ logs = sample_classconditional(model, batch_size=bs, class_label=class_label,
74
+ temperature=temperature, top_k=top_k, top_p=top_p)
75
+ save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"])
76
+ else:
77
+ print(f"Running in unconditional sampling mode, producing {num_samples} samples.")
78
+ for n, bs in tqdm(enumerate(batches), desc="Sampling"):
79
+ if bs == 0: break
80
+ logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p)
81
+ save_from_logs(logs, logdir, base_count=n * batch_size)
82
+
83
+
84
+ def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None):
85
+ xx = logs[key]
86
+ for i, x in enumerate(xx):
87
+ x = chw_to_pillow(x)
88
+ count = base_count + i
89
+ if cond_key is None:
90
+ x.save(os.path.join(logdir, f"{count:06}.png"))
91
+ else:
92
+ condlabel = cond_key[i]
93
+ if type(condlabel) == torch.Tensor: condlabel = condlabel.item()
94
+ os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True)
95
+ x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png"))
96
+
97
+
98
+ def get_parser():
99
+ def str2bool(v):
100
+ if isinstance(v, bool):
101
+ return v
102
+ if v.lower() in ("yes", "true", "t", "y", "1"):
103
+ return True
104
+ elif v.lower() in ("no", "false", "f", "n", "0"):
105
+ return False
106
+ else:
107
+ raise argparse.ArgumentTypeError("Boolean value expected.")
108
+
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument(
111
+ "-r",
112
+ "--resume",
113
+ type=str,
114
+ nargs="?",
115
+ help="load from logdir or checkpoint in logdir",
116
+ )
117
+ parser.add_argument(
118
+ "-o",
119
+ "--outdir",
120
+ type=str,
121
+ nargs="?",
122
+ help="path where the samples will be logged to.",
123
+ default=""
124
+ )
125
+ parser.add_argument(
126
+ "-b",
127
+ "--base",
128
+ nargs="*",
129
+ metavar="base_config.yaml",
130
+ help="paths to base configs. Loaded from left-to-right. "
131
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
132
+ default=list(),
133
+ )
134
+ parser.add_argument(
135
+ "-n",
136
+ "--num_samples",
137
+ type=int,
138
+ nargs="?",
139
+ help="num_samples to draw",
140
+ default=50000
141
+ )
142
+ parser.add_argument(
143
+ "--batch_size",
144
+ type=int,
145
+ nargs="?",
146
+ help="the batch size",
147
+ default=25
148
+ )
149
+ parser.add_argument(
150
+ "-k",
151
+ "--top_k",
152
+ type=int,
153
+ nargs="?",
154
+ help="top-k value to sample with",
155
+ default=250,
156
+ )
157
+ parser.add_argument(
158
+ "-t",
159
+ "--temperature",
160
+ type=float,
161
+ nargs="?",
162
+ help="temperature value to sample with",
163
+ default=1.0
164
+ )
165
+ parser.add_argument(
166
+ "-p",
167
+ "--top_p",
168
+ type=float,
169
+ nargs="?",
170
+ help="top-p value to sample with",
171
+ default=1.0
172
+ )
173
+ parser.add_argument(
174
+ "--classes",
175
+ type=str,
176
+ nargs="?",
177
+ help="specify comma-separated classes to sample from. Uses 1000 classes per default.",
178
+ default="imagenet"
179
+ )
180
+ return parser
181
+
182
+
183
+ def load_model_from_config(config, sd, gpu=True, eval_mode=True):
184
+ model = instantiate_from_config(config)
185
+ if sd is not None:
186
+ model.load_state_dict(sd)
187
+ if gpu:
188
+ model.cuda()
189
+ if eval_mode:
190
+ model.eval()
191
+ return {"model": model}
192
+
193
+
194
+ def load_model(config, ckpt, gpu, eval_mode):
195
+ # load the specified checkpoint
196
+ if ckpt:
197
+ pl_sd = torch.load(ckpt, map_location="cpu")
198
+ global_step = pl_sd["global_step"]
199
+ print(f"loaded model from global step {global_step}.")
200
+ else:
201
+ pl_sd = {"state_dict": None}
202
+ global_step = None
203
+ model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
204
+ return model, global_step
205
+
206
+
207
+ if __name__ == "__main__":
208
+ sys.path.append(os.getcwd())
209
+ parser = get_parser()
210
+
211
+ opt, unknown = parser.parse_known_args()
212
+ assert opt.resume
213
+
214
+ ckpt = None
215
+
216
+ if not os.path.exists(opt.resume):
217
+ raise ValueError("Cannot find {}".format(opt.resume))
218
+ if os.path.isfile(opt.resume):
219
+ paths = opt.resume.split("/")
220
+ try:
221
+ idx = len(paths)-paths[::-1].index("logs")+1
222
+ except ValueError:
223
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
224
+ logdir = "/".join(paths[:idx])
225
+ ckpt = opt.resume
226
+ else:
227
+ assert os.path.isdir(opt.resume), opt.resume
228
+ logdir = opt.resume.rstrip("/")
229
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
230
+
231
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
232
+ opt.base = base_configs+opt.base
233
+
234
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
235
+ cli = OmegaConf.from_dotlist(unknown)
236
+ config = OmegaConf.merge(*configs, cli)
237
+
238
+ model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True)
239
+
240
+ if opt.outdir:
241
+ print(f"Switching logdir from '{logdir}' to '{opt.outdir}'")
242
+ logdir = opt.outdir
243
+
244
+ if opt.classes == "imagenet":
245
+ given_classes = [i for i in range(1000)]
246
+ else:
247
+ cls_str = opt.classes
248
+ assert not cls_str.endswith(","), 'class string should not end with a ","'
249
+ given_classes = [int(c) for c in cls_str.split(",")]
250
+
251
+ logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}",
252
+ f"{global_step}")
253
+
254
+ print(f"Logging to {logdir}")
255
+ os.makedirs(logdir, exist_ok=True)
256
+
257
+ run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional,
258
+ given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p)
259
+
260
+ print("done.")
scripts/taming-transformers.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
setup.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='taming-transformers',
5
+ version='0.0.1',
6
+ description='Taming Transformers for High-Resolution Image Synthesis',
7
+ packages=find_packages(),
8
+ install_requires=[
9
+ 'torch',
10
+ 'numpy',
11
+ 'tqdm',
12
+ ],
13
+ )
taming/__pycache__/lr_scheduler.cpython-312.pyc ADDED
Binary file (2.19 kB). View file
 
taming/__pycache__/util.cpython-312.pyc ADDED
Binary file (6.33 kB). View file
 
taming/data/.ipynb_checkpoints/utils-checkpoint.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import tarfile
4
+ import urllib
5
+ import zipfile
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ from taming.data.helper_types import Annotation
11
+ #from torch._six import string_classes
12
+ from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
13
+ from tqdm import tqdm
14
+
15
+ string_classes = (str,bytes)
16
+
17
+
18
+ def unpack(path):
19
+ if path.endswith("tar.gz"):
20
+ with tarfile.open(path, "r:gz") as tar:
21
+ tar.extractall(path=os.path.split(path)[0])
22
+ elif path.endswith("tar"):
23
+ with tarfile.open(path, "r:") as tar:
24
+ tar.extractall(path=os.path.split(path)[0])
25
+ elif path.endswith("zip"):
26
+ with zipfile.ZipFile(path, "r") as f:
27
+ f.extractall(path=os.path.split(path)[0])
28
+ else:
29
+ raise NotImplementedError(
30
+ "Unknown file extension: {}".format(os.path.splitext(path)[1])
31
+ )
32
+
33
+
34
+ def reporthook(bar):
35
+ """tqdm progress bar for downloads."""
36
+
37
+ def hook(b=1, bsize=1, tsize=None):
38
+ if tsize is not None:
39
+ bar.total = tsize
40
+ bar.update(b * bsize - bar.n)
41
+
42
+ return hook
43
+
44
+
45
+ def get_root(name):
46
+ base = "data/"
47
+ root = os.path.join(base, name)
48
+ os.makedirs(root, exist_ok=True)
49
+ return root
50
+
51
+
52
+ def is_prepared(root):
53
+ return Path(root).joinpath(".ready").exists()
54
+
55
+
56
+ def mark_prepared(root):
57
+ Path(root).joinpath(".ready").touch()
58
+
59
+
60
+ def prompt_download(file_, source, target_dir, content_dir=None):
61
+ targetpath = os.path.join(target_dir, file_)
62
+ while not os.path.exists(targetpath):
63
+ if content_dir is not None and os.path.exists(
64
+ os.path.join(target_dir, content_dir)
65
+ ):
66
+ break
67
+ print(
68
+ "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
69
+ )
70
+ if content_dir is not None:
71
+ print(
72
+ "Or place its content into '{}'.".format(
73
+ os.path.join(target_dir, content_dir)
74
+ )
75
+ )
76
+ input("Press Enter when done...")
77
+ return targetpath
78
+
79
+
80
+ def download_url(file_, url, target_dir):
81
+ targetpath = os.path.join(target_dir, file_)
82
+ os.makedirs(target_dir, exist_ok=True)
83
+ with tqdm(
84
+ unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
85
+ ) as bar:
86
+ urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
87
+ return targetpath
88
+
89
+
90
+ def download_urls(urls, target_dir):
91
+ paths = dict()
92
+ for fname, url in urls.items():
93
+ outpath = download_url(fname, url, target_dir)
94
+ paths[fname] = outpath
95
+ return paths
96
+
97
+
98
+ def quadratic_crop(x, bbox, alpha=1.0):
99
+ """bbox is xmin, ymin, xmax, ymax"""
100
+ im_h, im_w = x.shape[:2]
101
+ bbox = np.array(bbox, dtype=np.float32)
102
+ bbox = np.clip(bbox, 0, max(im_h, im_w))
103
+ center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
104
+ w = bbox[2] - bbox[0]
105
+ h = bbox[3] - bbox[1]
106
+ l = int(alpha * max(w, h))
107
+ l = max(l, 2)
108
+
109
+ required_padding = -1 * min(
110
+ center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
111
+ )
112
+ required_padding = int(np.ceil(required_padding))
113
+ if required_padding > 0:
114
+ padding = [
115
+ [required_padding, required_padding],
116
+ [required_padding, required_padding],
117
+ ]
118
+ padding += [[0, 0]] * (len(x.shape) - 2)
119
+ x = np.pad(x, padding, "reflect")
120
+ center = center[0] + required_padding, center[1] + required_padding
121
+ xmin = int(center[0] - l / 2)
122
+ ymin = int(center[1] - l / 2)
123
+ return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
124
+
125
+
126
+ def custom_collate(batch):
127
+ r"""source: pytorch 1.9.0, only one modification to original code """
128
+
129
+ elem = batch[0]
130
+ elem_type = type(elem)
131
+ if isinstance(elem, torch.Tensor):
132
+ out = None
133
+ if torch.utils.data.get_worker_info() is not None:
134
+ # If we're in a background process, concatenate directly into a
135
+ # shared memory tensor to avoid an extra copy
136
+ numel = sum([x.numel() for x in batch])
137
+ storage = elem.storage()._new_shared(numel)
138
+ out = elem.new(storage)
139
+ return torch.stack(batch, 0, out=out)
140
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
141
+ and elem_type.__name__ != 'string_':
142
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
143
+ # array of string classes and object
144
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
145
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
146
+
147
+ return custom_collate([torch.as_tensor(b) for b in batch])
148
+ elif elem.shape == (): # scalars
149
+ return torch.as_tensor(batch)
150
+ elif isinstance(elem, float):
151
+ return torch.tensor(batch, dtype=torch.float64)
152
+ elif isinstance(elem, int):
153
+ return torch.tensor(batch)
154
+ elif isinstance(elem, string_classes):
155
+ return batch
156
+ elif isinstance(elem, collections.abc.Mapping):
157
+ return {key: custom_collate([d[key] for d in batch]) for key in elem}
158
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
159
+ return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
160
+ if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
161
+ return batch # added
162
+ elif isinstance(elem, collections.abc.Sequence):
163
+ # check to make sure that the elements in batch have consistent size
164
+ it = iter(batch)
165
+ elem_size = len(next(it))
166
+ if not all(len(elem) == elem_size for elem in it):
167
+ raise RuntimeError('each element in list of batch should be of equal size')
168
+ transposed = zip(*batch)
169
+ return [custom_collate(samples) for samples in transposed]
170
+
171
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
taming/data/__pycache__/helper_types.cpython-312.pyc ADDED
Binary file (2.43 kB). View file
 
taming/data/__pycache__/utils.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
taming/data/ade20k.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import albumentations
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+
8
+ from taming.data.sflckr import SegmentationBase # for examples included in repo
9
+
10
+
11
+ class Examples(SegmentationBase):
12
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
13
+ super().__init__(data_csv="data/ade20k_examples.txt",
14
+ data_root="data/ade20k_images",
15
+ segmentation_root="data/ade20k_segmentations",
16
+ size=size, random_crop=random_crop,
17
+ interpolation=interpolation,
18
+ n_labels=151, shift_segmentation=False)
19
+
20
+
21
+ # With semantic map and scene label
22
+ class ADE20kBase(Dataset):
23
+ def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
24
+ self.split = self.get_split()
25
+ self.n_labels = 151 # unknown + 150
26
+ self.data_csv = {"train": "data/ade20k_train.txt",
27
+ "validation": "data/ade20k_test.txt"}[self.split]
28
+ self.data_root = "data/ade20k_root"
29
+ with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
30
+ self.scene_categories = f.read().splitlines()
31
+ self.scene_categories = dict(line.split() for line in self.scene_categories)
32
+ with open(self.data_csv, "r") as f:
33
+ self.image_paths = f.read().splitlines()
34
+ self._length = len(self.image_paths)
35
+ self.labels = {
36
+ "relative_file_path_": [l for l in self.image_paths],
37
+ "file_path_": [os.path.join(self.data_root, "images", l)
38
+ for l in self.image_paths],
39
+ "relative_segmentation_path_": [l.replace(".jpg", ".png")
40
+ for l in self.image_paths],
41
+ "segmentation_path_": [os.path.join(self.data_root, "annotations",
42
+ l.replace(".jpg", ".png"))
43
+ for l in self.image_paths],
44
+ "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
45
+ for l in self.image_paths],
46
+ }
47
+
48
+ size = None if size is not None and size<=0 else size
49
+ self.size = size
50
+ if crop_size is None:
51
+ self.crop_size = size if size is not None else None
52
+ else:
53
+ self.crop_size = crop_size
54
+ if self.size is not None:
55
+ self.interpolation = interpolation
56
+ self.interpolation = {
57
+ "nearest": cv2.INTER_NEAREST,
58
+ "bilinear": cv2.INTER_LINEAR,
59
+ "bicubic": cv2.INTER_CUBIC,
60
+ "area": cv2.INTER_AREA,
61
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
62
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
63
+ interpolation=self.interpolation)
64
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
65
+ interpolation=cv2.INTER_NEAREST)
66
+
67
+ if crop_size is not None:
68
+ self.center_crop = not random_crop
69
+ if self.center_crop:
70
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
71
+ else:
72
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
73
+ self.preprocessor = self.cropper
74
+
75
+ def __len__(self):
76
+ return self._length
77
+
78
+ def __getitem__(self, i):
79
+ example = dict((k, self.labels[k][i]) for k in self.labels)
80
+ image = Image.open(example["file_path_"])
81
+ if not image.mode == "RGB":
82
+ image = image.convert("RGB")
83
+ image = np.array(image).astype(np.uint8)
84
+ if self.size is not None:
85
+ image = self.image_rescaler(image=image)["image"]
86
+ segmentation = Image.open(example["segmentation_path_"])
87
+ segmentation = np.array(segmentation).astype(np.uint8)
88
+ if self.size is not None:
89
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
90
+ if self.size is not None:
91
+ processed = self.preprocessor(image=image, mask=segmentation)
92
+ else:
93
+ processed = {"image": image, "mask": segmentation}
94
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
95
+ segmentation = processed["mask"]
96
+ onehot = np.eye(self.n_labels)[segmentation]
97
+ example["segmentation"] = onehot
98
+ return example
99
+
100
+
101
+ class ADE20kTrain(ADE20kBase):
102
+ # default to random_crop=True
103
+ def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
104
+ super().__init__(config=config, size=size, random_crop=random_crop,
105
+ interpolation=interpolation, crop_size=crop_size)
106
+
107
+ def get_split(self):
108
+ return "train"
109
+
110
+
111
+ class ADE20kValidation(ADE20kBase):
112
+ def get_split(self):
113
+ return "validation"
114
+
115
+
116
+ if __name__ == "__main__":
117
+ dset = ADE20kValidation()
118
+ ex = dset[0]
119
+ for k in ["image", "scene_category", "segmentation"]:
120
+ print(type(ex[k]))
121
+ try:
122
+ print(ex[k].shape)
123
+ except:
124
+ print(ex[k])
taming/data/annotated_objects_coco.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from itertools import chain
3
+ from pathlib import Path
4
+ from typing import Iterable, Dict, List, Callable, Any
5
+ from collections import defaultdict
6
+
7
+ from tqdm import tqdm
8
+
9
+ from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
10
+ from taming.data.helper_types import Annotation, ImageDescription, Category
11
+
12
+ COCO_PATH_STRUCTURE = {
13
+ 'train': {
14
+ 'top_level': '',
15
+ 'instances_annotations': 'annotations/instances_train2017.json',
16
+ 'stuff_annotations': 'annotations/stuff_train2017.json',
17
+ 'files': 'train2017'
18
+ },
19
+ 'validation': {
20
+ 'top_level': '',
21
+ 'instances_annotations': 'annotations/instances_val2017.json',
22
+ 'stuff_annotations': 'annotations/stuff_val2017.json',
23
+ 'files': 'val2017'
24
+ }
25
+ }
26
+
27
+
28
+ def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
29
+ return {
30
+ str(img['id']): ImageDescription(
31
+ id=img['id'],
32
+ license=img.get('license'),
33
+ file_name=img['file_name'],
34
+ coco_url=img['coco_url'],
35
+ original_size=(img['width'], img['height']),
36
+ date_captured=img.get('date_captured'),
37
+ flickr_url=img.get('flickr_url')
38
+ )
39
+ for img in description_json
40
+ }
41
+
42
+
43
+ def load_categories(category_json: Iterable) -> Dict[str, Category]:
44
+ return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
45
+ for cat in category_json if cat['name'] != 'other'}
46
+
47
+
48
+ def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
49
+ category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
50
+ annotations = defaultdict(list)
51
+ total = sum(len(a) for a in annotations_json)
52
+ for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
53
+ image_id = str(ann['image_id'])
54
+ if image_id not in image_descriptions:
55
+ raise ValueError(f'image_id [{image_id}] has no image description.')
56
+ category_id = ann['category_id']
57
+ try:
58
+ category_no = category_no_for_id(str(category_id))
59
+ except KeyError:
60
+ continue
61
+
62
+ width, height = image_descriptions[image_id].original_size
63
+ bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
64
+
65
+ annotations[image_id].append(
66
+ Annotation(
67
+ id=ann['id'],
68
+ area=bbox[2]*bbox[3], # use bbox area
69
+ is_group_of=ann['iscrowd'],
70
+ image_id=ann['image_id'],
71
+ bbox=bbox,
72
+ category_id=str(category_id),
73
+ category_no=category_no
74
+ )
75
+ )
76
+ return dict(annotations)
77
+
78
+
79
+ class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
80
+ def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
81
+ """
82
+ @param data_path: is the path to the following folder structure:
83
+ coco/
84
+ ├── annotations
85
+ │ ├── instances_train2017.json
86
+ │ ├── instances_val2017.json
87
+ │ ├── stuff_train2017.json
88
+ │ └── stuff_val2017.json
89
+ ├── train2017
90
+ │ ├── 000000000009.jpg
91
+ │ ├── 000000000025.jpg
92
+ │ └── ...
93
+ ├── val2017
94
+ │ ├── 000000000139.jpg
95
+ │ ├── 000000000285.jpg
96
+ │ └── ...
97
+ @param: split: one of 'train' or 'validation'
98
+ @param: desired image size (give square images)
99
+ """
100
+ super().__init__(**kwargs)
101
+ self.use_things = use_things
102
+ self.use_stuff = use_stuff
103
+
104
+ with open(self.paths['instances_annotations']) as f:
105
+ inst_data_json = json.load(f)
106
+ with open(self.paths['stuff_annotations']) as f:
107
+ stuff_data_json = json.load(f)
108
+
109
+ category_jsons = []
110
+ annotation_jsons = []
111
+ if self.use_things:
112
+ category_jsons.append(inst_data_json['categories'])
113
+ annotation_jsons.append(inst_data_json['annotations'])
114
+ if self.use_stuff:
115
+ category_jsons.append(stuff_data_json['categories'])
116
+ annotation_jsons.append(stuff_data_json['annotations'])
117
+
118
+ self.categories = load_categories(chain(*category_jsons))
119
+ self.filter_categories()
120
+ self.setup_category_id_and_number()
121
+
122
+ self.image_descriptions = load_image_descriptions(inst_data_json['images'])
123
+ annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
124
+ self.annotations = self.filter_object_number(annotations, self.min_object_area,
125
+ self.min_objects_per_image, self.max_objects_per_image)
126
+ self.image_ids = list(self.annotations.keys())
127
+ self.clean_up_annotations_and_image_descriptions()
128
+
129
+ def get_path_structure(self) -> Dict[str, str]:
130
+ if self.split not in COCO_PATH_STRUCTURE:
131
+ raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
132
+ return COCO_PATH_STRUCTURE[self.split]
133
+
134
+ def get_image_path(self, image_id: str) -> Path:
135
+ return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
136
+
137
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
138
+ # noinspection PyProtectedMember
139
+ return self.image_descriptions[image_id]._asdict()
taming/data/annotated_objects_dataset.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional, List, Callable, Dict, Any, Union
3
+ import warnings
4
+
5
+ import PIL.Image as pil_image
6
+ from torch import Tensor
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+
10
+ from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
11
+ from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
12
+ from taming.data.conditional_builder.utils import load_object_from_string
13
+ from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
14
+ from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
15
+ Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
16
+
17
+
18
+ class AnnotatedObjectsDataset(Dataset):
19
+ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
20
+ min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
21
+ crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
22
+ encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
23
+ no_object_classes: Optional[int] = None):
24
+ self.data_path = data_path
25
+ self.split = split
26
+ self.keys = keys
27
+ self.target_image_size = target_image_size
28
+ self.min_object_area = min_object_area
29
+ self.min_objects_per_image = min_objects_per_image
30
+ self.max_objects_per_image = max_objects_per_image
31
+ self.crop_method = crop_method
32
+ self.random_flip = random_flip
33
+ self.no_tokens = no_tokens
34
+ self.use_group_parameter = use_group_parameter
35
+ self.encode_crop = encode_crop
36
+
37
+ self.annotations = None
38
+ self.image_descriptions = None
39
+ self.categories = None
40
+ self.category_ids = None
41
+ self.category_number = None
42
+ self.image_ids = None
43
+ self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
44
+ self.paths = self.build_paths(self.data_path)
45
+ self._conditional_builders = None
46
+ self.category_allow_list = None
47
+ if category_allow_list_target:
48
+ allow_list = load_object_from_string(category_allow_list_target)
49
+ self.category_allow_list = {name for name, _ in allow_list}
50
+ self.category_mapping = {}
51
+ if category_mapping_target:
52
+ self.category_mapping = load_object_from_string(category_mapping_target)
53
+ self.no_object_classes = no_object_classes
54
+
55
+ def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
56
+ top_level = Path(top_level)
57
+ sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
58
+ for path in sub_paths.values():
59
+ if not path.exists():
60
+ raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
61
+ return sub_paths
62
+
63
+ @staticmethod
64
+ def load_image_from_disk(path: Path) -> Image:
65
+ return pil_image.open(path).convert('RGB')
66
+
67
+ @staticmethod
68
+ def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
69
+ transform_functions = []
70
+ if crop_method == 'none':
71
+ transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
72
+ elif crop_method == 'center':
73
+ transform_functions.extend([
74
+ transforms.Resize(target_image_size),
75
+ CenterCropReturnCoordinates(target_image_size)
76
+ ])
77
+ elif crop_method == 'random-1d':
78
+ transform_functions.extend([
79
+ transforms.Resize(target_image_size),
80
+ RandomCrop1dReturnCoordinates(target_image_size)
81
+ ])
82
+ elif crop_method == 'random-2d':
83
+ transform_functions.extend([
84
+ Random2dCropReturnCoordinates(target_image_size),
85
+ transforms.Resize(target_image_size)
86
+ ])
87
+ elif crop_method is None:
88
+ return None
89
+ else:
90
+ raise ValueError(f'Received invalid crop method [{crop_method}].')
91
+ if random_flip:
92
+ transform_functions.append(RandomHorizontalFlipReturn())
93
+ transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
94
+ return transform_functions
95
+
96
+ def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
97
+ crop_bbox = None
98
+ flipped = None
99
+ for t in self.transform_functions:
100
+ if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
101
+ crop_bbox, x = t(x)
102
+ elif isinstance(t, RandomHorizontalFlipReturn):
103
+ flipped, x = t(x)
104
+ else:
105
+ x = t(x)
106
+ return crop_bbox, flipped, x
107
+
108
+ @property
109
+ def no_classes(self) -> int:
110
+ return self.no_object_classes if self.no_object_classes else len(self.categories)
111
+
112
+ @property
113
+ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
114
+ # cannot set this up in init because no_classes is only known after loading data in init of superclass
115
+ if self._conditional_builders is None:
116
+ self._conditional_builders = {
117
+ 'objects_center_points': ObjectsCenterPointsConditionalBuilder(
118
+ self.no_classes,
119
+ self.max_objects_per_image,
120
+ self.no_tokens,
121
+ self.encode_crop,
122
+ self.use_group_parameter,
123
+ getattr(self, 'use_additional_parameters', False)
124
+ ),
125
+ 'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
126
+ self.no_classes,
127
+ self.max_objects_per_image,
128
+ self.no_tokens,
129
+ self.encode_crop,
130
+ self.use_group_parameter,
131
+ getattr(self, 'use_additional_parameters', False)
132
+ )
133
+ }
134
+ return self._conditional_builders
135
+
136
+ def filter_categories(self) -> None:
137
+ if self.category_allow_list:
138
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
139
+ if self.category_mapping:
140
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
141
+
142
+ def setup_category_id_and_number(self) -> None:
143
+ self.category_ids = list(self.categories.keys())
144
+ self.category_ids.sort()
145
+ if '/m/01s55n' in self.category_ids:
146
+ self.category_ids.remove('/m/01s55n')
147
+ self.category_ids.append('/m/01s55n')
148
+ self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
149
+ if self.category_allow_list is not None and self.category_mapping is None \
150
+ and len(self.category_ids) != len(self.category_allow_list):
151
+ warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
152
+ 'Make sure all names in category_allow_list exist.')
153
+
154
+ def clean_up_annotations_and_image_descriptions(self) -> None:
155
+ image_id_set = set(self.image_ids)
156
+ self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
157
+ self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
158
+
159
+ @staticmethod
160
+ def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
161
+ min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
162
+ filtered = {}
163
+ for image_id, annotations in all_annotations.items():
164
+ annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
165
+ if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
166
+ filtered[image_id] = annotations_with_min_area
167
+ return filtered
168
+
169
+ def __len__(self):
170
+ return len(self.image_ids)
171
+
172
+ def __getitem__(self, n: int) -> Dict[str, Any]:
173
+ image_id = self.get_image_id(n)
174
+ sample = self.get_image_description(image_id)
175
+ sample['annotations'] = self.get_annotation(image_id)
176
+
177
+ if 'image' in self.keys:
178
+ sample['image_path'] = str(self.get_image_path(image_id))
179
+ sample['image'] = self.load_image_from_disk(sample['image_path'])
180
+ sample['image'] = convert_pil_to_tensor(sample['image'])
181
+ sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
182
+ sample['image'] = sample['image'].permute(1, 2, 0)
183
+
184
+ for conditional, builder in self.conditional_builders.items():
185
+ if conditional in self.keys:
186
+ sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
187
+
188
+ if self.keys:
189
+ # only return specified keys
190
+ sample = {key: sample[key] for key in self.keys}
191
+ return sample
192
+
193
+ def get_image_id(self, no: int) -> str:
194
+ return self.image_ids[no]
195
+
196
+ def get_annotation(self, image_id: str) -> str:
197
+ return self.annotations[image_id]
198
+
199
+ def get_textual_label_for_category_id(self, category_id: str) -> str:
200
+ return self.categories[category_id].name
201
+
202
+ def get_textual_label_for_category_no(self, category_no: int) -> str:
203
+ return self.categories[self.get_category_id(category_no)].name
204
+
205
+ def get_category_number(self, category_id: str) -> int:
206
+ return self.category_number[category_id]
207
+
208
+ def get_category_id(self, category_no: int) -> str:
209
+ return self.category_ids[category_no]
210
+
211
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
212
+ raise NotImplementedError()
213
+
214
+ def get_path_structure(self):
215
+ raise NotImplementedError
216
+
217
+ def get_image_path(self, image_id: str) -> Path:
218
+ raise NotImplementedError
taming/data/annotated_objects_open_images.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from csv import DictReader, reader as TupleReader
3
+ from pathlib import Path
4
+ from typing import Dict, List, Any
5
+ import warnings
6
+
7
+ from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
8
+ from taming.data.helper_types import Annotation, Category
9
+ from tqdm import tqdm
10
+
11
+ OPEN_IMAGES_STRUCTURE = {
12
+ 'train': {
13
+ 'top_level': '',
14
+ 'class_descriptions': 'class-descriptions-boxable.csv',
15
+ 'annotations': 'oidv6-train-annotations-bbox.csv',
16
+ 'file_list': 'train-images-boxable.csv',
17
+ 'files': 'train'
18
+ },
19
+ 'validation': {
20
+ 'top_level': '',
21
+ 'class_descriptions': 'class-descriptions-boxable.csv',
22
+ 'annotations': 'validation-annotations-bbox.csv',
23
+ 'file_list': 'validation-images.csv',
24
+ 'files': 'validation'
25
+ },
26
+ 'test': {
27
+ 'top_level': '',
28
+ 'class_descriptions': 'class-descriptions-boxable.csv',
29
+ 'annotations': 'test-annotations-bbox.csv',
30
+ 'file_list': 'test-images.csv',
31
+ 'files': 'test'
32
+ }
33
+ }
34
+
35
+
36
+ def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
37
+ category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
38
+ annotations: Dict[str, List[Annotation]] = defaultdict(list)
39
+ with open(descriptor_path) as file:
40
+ reader = DictReader(file)
41
+ for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
42
+ width = float(row['XMax']) - float(row['XMin'])
43
+ height = float(row['YMax']) - float(row['YMin'])
44
+ area = width * height
45
+ category_id = row['LabelName']
46
+ if category_id in category_mapping:
47
+ category_id = category_mapping[category_id]
48
+ if area >= min_object_area and category_id in category_no_for_id:
49
+ annotations[row['ImageID']].append(
50
+ Annotation(
51
+ id=i,
52
+ image_id=row['ImageID'],
53
+ source=row['Source'],
54
+ category_id=category_id,
55
+ category_no=category_no_for_id[category_id],
56
+ confidence=float(row['Confidence']),
57
+ bbox=(float(row['XMin']), float(row['YMin']), width, height),
58
+ area=area,
59
+ is_occluded=bool(int(row['IsOccluded'])),
60
+ is_truncated=bool(int(row['IsTruncated'])),
61
+ is_group_of=bool(int(row['IsGroupOf'])),
62
+ is_depiction=bool(int(row['IsDepiction'])),
63
+ is_inside=bool(int(row['IsInside']))
64
+ )
65
+ )
66
+ if 'train' in str(descriptor_path) and i < 14000000:
67
+ warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
68
+ return dict(annotations)
69
+
70
+
71
+ def load_image_ids(csv_path: Path) -> List[str]:
72
+ with open(csv_path) as file:
73
+ reader = DictReader(file)
74
+ return [row['image_name'] for row in reader]
75
+
76
+
77
+ def load_categories(csv_path: Path) -> Dict[str, Category]:
78
+ with open(csv_path) as file:
79
+ reader = TupleReader(file)
80
+ return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
81
+
82
+
83
+ class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
84
+ def __init__(self, use_additional_parameters: bool, **kwargs):
85
+ """
86
+ @param data_path: is the path to the following folder structure:
87
+ open_images/
88
+ │ oidv6-train-annotations-bbox.csv
89
+ ├── class-descriptions-boxable.csv
90
+ ├── oidv6-train-annotations-bbox.csv
91
+ ├── test
92
+ │ ├── 000026e7ee790996.jpg
93
+ │ ├── 000062a39995e348.jpg
94
+ │ └── ...
95
+ ├── test-annotations-bbox.csv
96
+ ├── test-images.csv
97
+ ├── train
98
+ │ ├── 000002b66c9c498e.jpg
99
+ │ ├── 000002b97e5471a0.jpg
100
+ │ └── ...
101
+ ├── train-images-boxable.csv
102
+ ├── validation
103
+ │ ├── 0001eeaf4aed83f9.jpg
104
+ │ ├── 0004886b7d043cfd.jpg
105
+ │ └── ...
106
+ ├── validation-annotations-bbox.csv
107
+ └── validation-images.csv
108
+ @param: split: one of 'train', 'validation' or 'test'
109
+ @param: desired image size (returns square images)
110
+ """
111
+
112
+ super().__init__(**kwargs)
113
+ self.use_additional_parameters = use_additional_parameters
114
+
115
+ self.categories = load_categories(self.paths['class_descriptions'])
116
+ self.filter_categories()
117
+ self.setup_category_id_and_number()
118
+
119
+ self.image_descriptions = {}
120
+ annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
121
+ self.category_number)
122
+ self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
123
+ self.max_objects_per_image)
124
+ self.image_ids = list(self.annotations.keys())
125
+ self.clean_up_annotations_and_image_descriptions()
126
+
127
+ def get_path_structure(self) -> Dict[str, str]:
128
+ if self.split not in OPEN_IMAGES_STRUCTURE:
129
+ raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
130
+ return OPEN_IMAGES_STRUCTURE[self.split]
131
+
132
+ def get_image_path(self, image_id: str) -> Path:
133
+ return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
134
+
135
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
136
+ image_path = self.get_image_path(image_id)
137
+ return {'file_path': str(image_path), 'file_name': image_path.name}
taming/data/base.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import numpy as np
3
+ import albumentations
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset, ConcatDataset
6
+
7
+
8
+ class ConcatDatasetWithIndex(ConcatDataset):
9
+ """Modified from original pytorch code to return dataset idx"""
10
+ def __getitem__(self, idx):
11
+ if idx < 0:
12
+ if -idx > len(self):
13
+ raise ValueError("absolute value of index should not exceed dataset length")
14
+ idx = len(self) + idx
15
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
16
+ if dataset_idx == 0:
17
+ sample_idx = idx
18
+ else:
19
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
20
+ return self.datasets[dataset_idx][sample_idx], dataset_idx
21
+
22
+
23
+ class ImagePaths(Dataset):
24
+ def __init__(self, paths, size=None, random_crop=False, labels=None):
25
+ self.size = size
26
+ self.random_crop = random_crop
27
+
28
+ self.labels = dict() if labels is None else labels
29
+ self.labels["file_path_"] = paths
30
+ self._length = len(paths)
31
+
32
+ if self.size is not None and self.size > 0:
33
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
34
+ if not self.random_crop:
35
+ self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
36
+ else:
37
+ self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
38
+ self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
39
+ else:
40
+ self.preprocessor = lambda **kwargs: kwargs
41
+
42
+ def __len__(self):
43
+ return self._length
44
+
45
+ def preprocess_image(self, image_path):
46
+ image = Image.open(image_path)
47
+ if not image.mode == "RGB":
48
+ image = image.convert("RGB")
49
+ image = np.array(image).astype(np.uint8)
50
+ image = self.preprocessor(image=image)["image"]
51
+ image = (image/127.5 - 1.0).astype(np.float32)
52
+ return image
53
+
54
+ def __getitem__(self, i):
55
+ example = dict()
56
+ example["image"] = self.preprocess_image(self.labels["file_path_"][i])
57
+ for k in self.labels:
58
+ example[k] = self.labels[k][i]
59
+ return example
60
+
61
+
62
+ class NumpyPaths(ImagePaths):
63
+ def preprocess_image(self, image_path):
64
+ image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
65
+ image = np.transpose(image, (1,2,0))
66
+ image = Image.fromarray(image, mode="RGB")
67
+ image = np.array(image).astype(np.uint8)
68
+ image = self.preprocessor(image=image)["image"]
69
+ image = (image/127.5 - 1.0).astype(np.float32)
70
+ return image
taming/data/coco.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import albumentations
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from torch.utils.data import Dataset
8
+
9
+ from taming.data.sflckr import SegmentationBase # for examples included in repo
10
+
11
+
12
+ class Examples(SegmentationBase):
13
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
14
+ super().__init__(data_csv="data/coco_examples.txt",
15
+ data_root="data/coco_images",
16
+ segmentation_root="data/coco_segmentations",
17
+ size=size, random_crop=random_crop,
18
+ interpolation=interpolation,
19
+ n_labels=183, shift_segmentation=True)
20
+
21
+
22
+ class CocoBase(Dataset):
23
+ """needed for (image, caption, segmentation) pairs"""
24
+ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
25
+ crop_size=None, force_no_crop=False, given_files=None):
26
+ self.split = self.get_split()
27
+ self.size = size
28
+ if crop_size is None:
29
+ self.crop_size = size
30
+ else:
31
+ self.crop_size = crop_size
32
+
33
+ self.onehot = onehot_segmentation # return segmentation as rgb or one hot
34
+ self.stuffthing = use_stuffthing # include thing in segmentation
35
+ if self.onehot and not self.stuffthing:
36
+ raise NotImplemented("One hot mode is only supported for the "
37
+ "stuffthings version because labels are stored "
38
+ "a bit different.")
39
+
40
+ data_json = datajson
41
+ with open(data_json) as json_file:
42
+ self.json_data = json.load(json_file)
43
+ self.img_id_to_captions = dict()
44
+ self.img_id_to_filepath = dict()
45
+ self.img_id_to_segmentation_filepath = dict()
46
+
47
+ assert data_json.split("/")[-1] in ["captions_train2017.json",
48
+ "captions_val2017.json"]
49
+ if self.stuffthing:
50
+ self.segmentation_prefix = (
51
+ "data/cocostuffthings/val2017" if
52
+ data_json.endswith("captions_val2017.json") else
53
+ "data/cocostuffthings/train2017")
54
+ else:
55
+ self.segmentation_prefix = (
56
+ "data/coco/annotations/stuff_val2017_pixelmaps" if
57
+ data_json.endswith("captions_val2017.json") else
58
+ "data/coco/annotations/stuff_train2017_pixelmaps")
59
+
60
+ imagedirs = self.json_data["images"]
61
+ self.labels = {"image_ids": list()}
62
+ for imgdir in tqdm(imagedirs, desc="ImgToPath"):
63
+ self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
64
+ self.img_id_to_captions[imgdir["id"]] = list()
65
+ pngfilename = imgdir["file_name"].replace("jpg", "png")
66
+ self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
67
+ self.segmentation_prefix, pngfilename)
68
+ if given_files is not None:
69
+ if pngfilename in given_files:
70
+ self.labels["image_ids"].append(imgdir["id"])
71
+ else:
72
+ self.labels["image_ids"].append(imgdir["id"])
73
+
74
+ capdirs = self.json_data["annotations"]
75
+ for capdir in tqdm(capdirs, desc="ImgToCaptions"):
76
+ # there are in average 5 captions per image
77
+ self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
78
+
79
+ self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
80
+ if self.split=="validation":
81
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
82
+ else:
83
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
84
+ self.preprocessor = albumentations.Compose(
85
+ [self.rescaler, self.cropper],
86
+ additional_targets={"segmentation": "image"})
87
+ if force_no_crop:
88
+ self.rescaler = albumentations.Resize(height=self.size, width=self.size)
89
+ self.preprocessor = albumentations.Compose(
90
+ [self.rescaler],
91
+ additional_targets={"segmentation": "image"})
92
+
93
+ def __len__(self):
94
+ return len(self.labels["image_ids"])
95
+
96
+ def preprocess_image(self, image_path, segmentation_path):
97
+ image = Image.open(image_path)
98
+ if not image.mode == "RGB":
99
+ image = image.convert("RGB")
100
+ image = np.array(image).astype(np.uint8)
101
+
102
+ segmentation = Image.open(segmentation_path)
103
+ if not self.onehot and not segmentation.mode == "RGB":
104
+ segmentation = segmentation.convert("RGB")
105
+ segmentation = np.array(segmentation).astype(np.uint8)
106
+ if self.onehot:
107
+ assert self.stuffthing
108
+ # stored in caffe format: unlabeled==255. stuff and thing from
109
+ # 0-181. to be compatible with the labels in
110
+ # https://github.com/nightrome/cocostuff/blob/master/labels.txt
111
+ # we shift stuffthing one to the right and put unlabeled in zero
112
+ # as long as segmentation is uint8 shifting to right handles the
113
+ # latter too
114
+ assert segmentation.dtype == np.uint8
115
+ segmentation = segmentation + 1
116
+
117
+ processed = self.preprocessor(image=image, segmentation=segmentation)
118
+ image, segmentation = processed["image"], processed["segmentation"]
119
+ image = (image / 127.5 - 1.0).astype(np.float32)
120
+
121
+ if self.onehot:
122
+ assert segmentation.dtype == np.uint8
123
+ # make it one hot
124
+ n_labels = 183
125
+ flatseg = np.ravel(segmentation)
126
+ onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
127
+ onehot[np.arange(flatseg.size), flatseg] = True
128
+ onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
129
+ segmentation = onehot
130
+ else:
131
+ segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
132
+ return image, segmentation
133
+
134
+ def __getitem__(self, i):
135
+ img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
136
+ seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
137
+ image, segmentation = self.preprocess_image(img_path, seg_path)
138
+ captions = self.img_id_to_captions[self.labels["image_ids"][i]]
139
+ # randomly draw one of all available captions per image
140
+ caption = captions[np.random.randint(0, len(captions))]
141
+ example = {"image": image,
142
+ "caption": [str(caption[0])],
143
+ "segmentation": segmentation,
144
+ "img_path": img_path,
145
+ "seg_path": seg_path,
146
+ "filename_": img_path.split(os.sep)[-1]
147
+ }
148
+ return example
149
+
150
+
151
+ class CocoImagesAndCaptionsTrain(CocoBase):
152
+ """returns a pair of (image, caption)"""
153
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
154
+ super().__init__(size=size,
155
+ dataroot="data/coco/train2017",
156
+ datajson="data/coco/annotations/captions_train2017.json",
157
+ onehot_segmentation=onehot_segmentation,
158
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
159
+
160
+ def get_split(self):
161
+ return "train"
162
+
163
+
164
+ class CocoImagesAndCaptionsValidation(CocoBase):
165
+ """returns a pair of (image, caption)"""
166
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
167
+ given_files=None):
168
+ super().__init__(size=size,
169
+ dataroot="data/coco/val2017",
170
+ datajson="data/coco/annotations/captions_val2017.json",
171
+ onehot_segmentation=onehot_segmentation,
172
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
173
+ given_files=given_files)
174
+
175
+ def get_split(self):
176
+ return "validation"
taming/data/conditional_builder/objects_bbox.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import cycle
2
+ from typing import List, Tuple, Callable, Optional
3
+
4
+ from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
5
+ from more_itertools.recipes import grouper
6
+ from taming.data.image_transforms import convert_pil_to_tensor
7
+ from torch import LongTensor, Tensor
8
+
9
+ from taming.data.helper_types import BoundingBox, Annotation
10
+ from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
11
+ from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
12
+ pad_list, get_plot_font_size, absolute_bbox
13
+
14
+
15
+ class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
16
+ @property
17
+ def object_descriptor_length(self) -> int:
18
+ return 3
19
+
20
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
21
+ object_triples = [
22
+ (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
23
+ for ann in annotations
24
+ ]
25
+ empty_triple = (self.none, self.none, self.none)
26
+ object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
27
+ return object_triples
28
+
29
+ def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
30
+ conditional_list = conditional.tolist()
31
+ crop_coordinates = None
32
+ if self.encode_crop:
33
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
34
+ conditional_list = conditional_list[:-2]
35
+ object_triples = grouper(conditional_list, 3)
36
+ assert conditional.shape[0] == self.embedding_dim
37
+ return [
38
+ (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
39
+ for object_triple in object_triples if object_triple[0] != self.none
40
+ ], crop_coordinates
41
+
42
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
43
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
44
+ plot = pil_image.new('RGB', figure_size, WHITE)
45
+ draw = pil_img_draw.Draw(plot)
46
+ font = ImageFont.truetype(
47
+ "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
48
+ size=get_plot_font_size(font_size, figure_size)
49
+ )
50
+ width, height = plot.size
51
+ description, crop_coordinates = self.inverse_build(conditional)
52
+ for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
53
+ annotation = self.representation_to_annotation(representation)
54
+ class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
55
+ bbox = absolute_bbox(bbox, width, height)
56
+ draw.rectangle(bbox, outline=color, width=line_width)
57
+ draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
58
+ if crop_coordinates is not None:
59
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
60
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
taming/data/conditional_builder/objects_center_points.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import warnings
4
+ from itertools import cycle
5
+ from typing import List, Optional, Tuple, Callable
6
+
7
+ from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
8
+ from more_itertools.recipes import grouper
9
+ from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
10
+ additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
11
+ absolute_bbox, rescale_annotations
12
+ from taming.data.helper_types import BoundingBox, Annotation
13
+ from taming.data.image_transforms import convert_pil_to_tensor
14
+ from torch import LongTensor, Tensor
15
+
16
+
17
+ class ObjectsCenterPointsConditionalBuilder:
18
+ def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
19
+ use_group_parameter: bool, use_additional_parameters: bool):
20
+ self.no_object_classes = no_object_classes
21
+ self.no_max_objects = no_max_objects
22
+ self.no_tokens = no_tokens
23
+ self.encode_crop = encode_crop
24
+ self.no_sections = int(math.sqrt(self.no_tokens))
25
+ self.use_group_parameter = use_group_parameter
26
+ self.use_additional_parameters = use_additional_parameters
27
+
28
+ @property
29
+ def none(self) -> int:
30
+ return self.no_tokens - 1
31
+
32
+ @property
33
+ def object_descriptor_length(self) -> int:
34
+ return 2
35
+
36
+ @property
37
+ def embedding_dim(self) -> int:
38
+ extra_length = 2 if self.encode_crop else 0
39
+ return self.no_max_objects * self.object_descriptor_length + extra_length
40
+
41
+ def tokenize_coordinates(self, x: float, y: float) -> int:
42
+ """
43
+ Express 2d coordinates with one number.
44
+ Example: assume self.no_tokens = 16, then no_sections = 4:
45
+ 0 0 0 0
46
+ 0 0 # 0
47
+ 0 0 0 0
48
+ 0 0 0 x
49
+ Then the # position corresponds to token 6, the x position to token 15.
50
+ @param x: float in [0, 1]
51
+ @param y: float in [0, 1]
52
+ @return: discrete tokenized coordinate
53
+ """
54
+ x_discrete = int(round(x * (self.no_sections - 1)))
55
+ y_discrete = int(round(y * (self.no_sections - 1)))
56
+ return y_discrete * self.no_sections + x_discrete
57
+
58
+ def coordinates_from_token(self, token: int) -> (float, float):
59
+ x = token % self.no_sections
60
+ y = token // self.no_sections
61
+ return x / (self.no_sections - 1), y / (self.no_sections - 1)
62
+
63
+ def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
64
+ x0, y0 = self.coordinates_from_token(token1)
65
+ x1, y1 = self.coordinates_from_token(token2)
66
+ return x0, y0, x1 - x0, y1 - y0
67
+
68
+ def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
69
+ return self.tokenize_coordinates(bbox[0], bbox[1]), \
70
+ self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
71
+
72
+ def inverse_build(self, conditional: LongTensor) \
73
+ -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
74
+ conditional_list = conditional.tolist()
75
+ crop_coordinates = None
76
+ if self.encode_crop:
77
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
78
+ conditional_list = conditional_list[:-2]
79
+ table_of_content = grouper(conditional_list, self.object_descriptor_length)
80
+ assert conditional.shape[0] == self.embedding_dim
81
+ return [
82
+ (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
83
+ for object_tuple in table_of_content if object_tuple[0] != self.none
84
+ ], crop_coordinates
85
+
86
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
87
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
88
+ plot = pil_image.new('RGB', figure_size, WHITE)
89
+ draw = pil_img_draw.Draw(plot)
90
+ circle_size = get_circle_size(figure_size)
91
+ font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
92
+ size=get_plot_font_size(font_size, figure_size))
93
+ width, height = plot.size
94
+ description, crop_coordinates = self.inverse_build(conditional)
95
+ for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
96
+ x_abs, y_abs = x * width, y * height
97
+ ann = self.representation_to_annotation(representation)
98
+ label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
99
+ ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
100
+ draw.ellipse(ellipse_bbox, fill=color, width=0)
101
+ draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
102
+ if crop_coordinates is not None:
103
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
104
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
105
+
106
+ def object_representation(self, annotation: Annotation) -> int:
107
+ modifier = 0
108
+ if self.use_group_parameter:
109
+ modifier |= 1 * (annotation.is_group_of is True)
110
+ if self.use_additional_parameters:
111
+ modifier |= 2 * (annotation.is_occluded is True)
112
+ modifier |= 4 * (annotation.is_depiction is True)
113
+ modifier |= 8 * (annotation.is_inside is True)
114
+ return annotation.category_no + self.no_object_classes * modifier
115
+
116
+ def representation_to_annotation(self, representation: int) -> Annotation:
117
+ category_no = representation % self.no_object_classes
118
+ modifier = representation // self.no_object_classes
119
+ # noinspection PyTypeChecker
120
+ return Annotation(
121
+ area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
122
+ category_no=category_no,
123
+ is_group_of=bool((modifier & 1) * self.use_group_parameter),
124
+ is_occluded=bool((modifier & 2) * self.use_additional_parameters),
125
+ is_depiction=bool((modifier & 4) * self.use_additional_parameters),
126
+ is_inside=bool((modifier & 8) * self.use_additional_parameters)
127
+ )
128
+
129
+ def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
130
+ return list(self.token_pair_from_bbox(crop_coordinates))
131
+
132
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
133
+ object_tuples = [
134
+ (self.object_representation(a),
135
+ self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
136
+ for a in annotations
137
+ ]
138
+ empty_tuple = (self.none, self.none)
139
+ object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
140
+ return object_tuples
141
+
142
+ def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
143
+ -> LongTensor:
144
+ if len(annotations) == 0:
145
+ warnings.warn('Did not receive any annotations.')
146
+ if len(annotations) > self.no_max_objects:
147
+ warnings.warn('Received more annotations than allowed.')
148
+ annotations = annotations[:self.no_max_objects]
149
+
150
+ if not crop_coordinates:
151
+ crop_coordinates = FULL_CROP
152
+
153
+ random.shuffle(annotations)
154
+ annotations = filter_annotations(annotations, crop_coordinates)
155
+ if self.encode_crop:
156
+ annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
157
+ if horizontal_flip:
158
+ crop_coordinates = horizontally_flip_bbox(crop_coordinates)
159
+ extra = self._crop_encoder(crop_coordinates)
160
+ else:
161
+ annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
162
+ extra = []
163
+
164
+ object_tuples = self._make_object_descriptors(annotations)
165
+ flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
166
+ assert len(flattened) == self.embedding_dim
167
+ assert all(0 <= value < self.no_tokens for value in flattened)
168
+ return LongTensor(flattened)
taming/data/conditional_builder/utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import List, Any, Tuple, Optional
3
+
4
+ from taming.data.helper_types import BoundingBox, Annotation
5
+
6
+ # source: seaborn, color palette tab10
7
+ COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
8
+ (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
9
+ BLACK = (0, 0, 0)
10
+ GRAY_75 = (63, 63, 63)
11
+ GRAY_50 = (127, 127, 127)
12
+ GRAY_25 = (191, 191, 191)
13
+ WHITE = (255, 255, 255)
14
+ FULL_CROP = (0., 0., 1., 1.)
15
+
16
+
17
+ def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
18
+ """
19
+ Give intersection area of two rectangles.
20
+ @param rectangle1: (x0, y0, w, h) of first rectangle
21
+ @param rectangle2: (x0, y0, w, h) of second rectangle
22
+ """
23
+ rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
24
+ rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
25
+ x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
26
+ y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
27
+ return x_overlap * y_overlap
28
+
29
+
30
+ def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
31
+ return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
32
+
33
+
34
+ def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
35
+ bbox = relative_bbox
36
+ bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
37
+ return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
38
+
39
+
40
+ def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
41
+ return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
42
+
43
+
44
+ def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
45
+ List[Annotation]:
46
+ def clamp(x: float):
47
+ return max(min(x, 1.), 0.)
48
+
49
+ def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
50
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
51
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
52
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
53
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
54
+ if flip:
55
+ x0 = 1 - (x0 + w)
56
+ return x0, y0, w, h
57
+
58
+ return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
59
+
60
+
61
+ def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
62
+ return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
63
+
64
+
65
+ def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
66
+ sl = slice(1) if short else slice(None)
67
+ string = ''
68
+ if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
69
+ return string
70
+ if annotation.is_group_of:
71
+ string += 'group'[sl] + ','
72
+ if annotation.is_occluded:
73
+ string += 'occluded'[sl] + ','
74
+ if annotation.is_depiction:
75
+ string += 'depiction'[sl] + ','
76
+ if annotation.is_inside:
77
+ string += 'inside'[sl]
78
+ return '(' + string.strip(",") + ')'
79
+
80
+
81
+ def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
82
+ if font_size is None:
83
+ font_size = 10
84
+ if max(figure_size) >= 256:
85
+ font_size = 12
86
+ if max(figure_size) >= 512:
87
+ font_size = 15
88
+ return font_size
89
+
90
+
91
+ def get_circle_size(figure_size: Tuple[int, int]) -> int:
92
+ circle_size = 2
93
+ if max(figure_size) >= 256:
94
+ circle_size = 3
95
+ if max(figure_size) >= 512:
96
+ circle_size = 4
97
+ return circle_size
98
+
99
+
100
+ def load_object_from_string(object_string: str) -> Any:
101
+ """
102
+ Source: https://stackoverflow.com/a/10773699
103
+ """
104
+ module_name, class_name = object_string.rsplit(".", 1)
105
+ return getattr(importlib.import_module(module_name), class_name)
taming/data/custom.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import albumentations
4
+ from torch.utils.data import Dataset
5
+
6
+ from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7
+
8
+
9
+ class CustomBase(Dataset):
10
+ def __init__(self, *args, **kwargs):
11
+ super().__init__()
12
+ self.data = None
13
+
14
+ def __len__(self):
15
+ return len(self.data)
16
+
17
+ def __getitem__(self, i):
18
+ example = self.data[i]
19
+ return example
20
+
21
+
22
+
23
+ class CustomTrain(CustomBase):
24
+ def __init__(self, size, training_images_list_file):
25
+ super().__init__()
26
+ with open(training_images_list_file, "r") as f:
27
+ paths = f.read().splitlines()
28
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
29
+
30
+
31
+ class CustomTest(CustomBase):
32
+ def __init__(self, size, test_images_list_file):
33
+ super().__init__()
34
+ with open(test_images_list_file, "r") as f:
35
+ paths = f.read().splitlines()
36
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
37
+
38
+
taming/data/faceshq.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import albumentations
4
+ from torch.utils.data import Dataset
5
+
6
+ from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7
+
8
+
9
+ class FacesBase(Dataset):
10
+ def __init__(self, *args, **kwargs):
11
+ super().__init__()
12
+ self.data = None
13
+ self.keys = None
14
+
15
+ def __len__(self):
16
+ return len(self.data)
17
+
18
+ def __getitem__(self, i):
19
+ example = self.data[i]
20
+ ex = {}
21
+ if self.keys is not None:
22
+ for k in self.keys:
23
+ ex[k] = example[k]
24
+ else:
25
+ ex = example
26
+ return ex
27
+
28
+
29
+ class CelebAHQTrain(FacesBase):
30
+ def __init__(self, size, keys=None):
31
+ super().__init__()
32
+ root = "data/celebahq"
33
+ with open("data/celebahqtrain.txt", "r") as f:
34
+ relpaths = f.read().splitlines()
35
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
36
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
37
+ self.keys = keys
38
+
39
+
40
+ class CelebAHQValidation(FacesBase):
41
+ def __init__(self, size, keys=None):
42
+ super().__init__()
43
+ root = "data/celebahq"
44
+ with open("data/celebahqvalidation.txt", "r") as f:
45
+ relpaths = f.read().splitlines()
46
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
47
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
48
+ self.keys = keys
49
+
50
+
51
+ class FFHQTrain(FacesBase):
52
+ def __init__(self, size, keys=None):
53
+ super().__init__()
54
+ root = "data/ffhq"
55
+ with open("data/ffhqtrain.txt", "r") as f:
56
+ relpaths = f.read().splitlines()
57
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
58
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
59
+ self.keys = keys
60
+
61
+
62
+ class FFHQValidation(FacesBase):
63
+ def __init__(self, size, keys=None):
64
+ super().__init__()
65
+ root = "data/ffhq"
66
+ with open("data/ffhqvalidation.txt", "r") as f:
67
+ relpaths = f.read().splitlines()
68
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
69
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
70
+ self.keys = keys
71
+
72
+
73
+ class FacesHQTrain(Dataset):
74
+ # CelebAHQ [0] + FFHQ [1]
75
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
76
+ d1 = CelebAHQTrain(size=size, keys=keys)
77
+ d2 = FFHQTrain(size=size, keys=keys)
78
+ self.data = ConcatDatasetWithIndex([d1, d2])
79
+ self.coord = coord
80
+ if crop_size is not None:
81
+ self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
82
+ if self.coord:
83
+ self.cropper = albumentations.Compose([self.cropper],
84
+ additional_targets={"coord": "image"})
85
+
86
+ def __len__(self):
87
+ return len(self.data)
88
+
89
+ def __getitem__(self, i):
90
+ ex, y = self.data[i]
91
+ if hasattr(self, "cropper"):
92
+ if not self.coord:
93
+ out = self.cropper(image=ex["image"])
94
+ ex["image"] = out["image"]
95
+ else:
96
+ h,w,_ = ex["image"].shape
97
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
98
+ out = self.cropper(image=ex["image"], coord=coord)
99
+ ex["image"] = out["image"]
100
+ ex["coord"] = out["coord"]
101
+ ex["class"] = y
102
+ return ex
103
+
104
+
105
+ class FacesHQValidation(Dataset):
106
+ # CelebAHQ [0] + FFHQ [1]
107
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
108
+ d1 = CelebAHQValidation(size=size, keys=keys)
109
+ d2 = FFHQValidation(size=size, keys=keys)
110
+ self.data = ConcatDatasetWithIndex([d1, d2])
111
+ self.coord = coord
112
+ if crop_size is not None:
113
+ self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
114
+ if self.coord:
115
+ self.cropper = albumentations.Compose([self.cropper],
116
+ additional_targets={"coord": "image"})
117
+
118
+ def __len__(self):
119
+ return len(self.data)
120
+
121
+ def __getitem__(self, i):
122
+ ex, y = self.data[i]
123
+ if hasattr(self, "cropper"):
124
+ if not self.coord:
125
+ out = self.cropper(image=ex["image"])
126
+ ex["image"] = out["image"]
127
+ else:
128
+ h,w,_ = ex["image"].shape
129
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
130
+ out = self.cropper(image=ex["image"], coord=coord)
131
+ ex["image"] = out["image"]
132
+ ex["coord"] = out["coord"]
133
+ ex["class"] = y
134
+ return ex
taming/data/helper_types.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple, Optional, NamedTuple, Union
2
+ from PIL.Image import Image as pil_image
3
+ from torch import Tensor
4
+
5
+ try:
6
+ from typing import Literal
7
+ except ImportError:
8
+ from typing_extensions import Literal
9
+
10
+ Image = Union[Tensor, pil_image]
11
+ BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
12
+ CropMethodType = Literal['none', 'random', 'center', 'random-2d']
13
+ SplitType = Literal['train', 'validation', 'test']
14
+
15
+
16
+ class ImageDescription(NamedTuple):
17
+ id: int
18
+ file_name: str
19
+ original_size: Tuple[int, int] # w, h
20
+ url: Optional[str] = None
21
+ license: Optional[int] = None
22
+ coco_url: Optional[str] = None
23
+ date_captured: Optional[str] = None
24
+ flickr_url: Optional[str] = None
25
+ flickr_id: Optional[str] = None
26
+ coco_id: Optional[str] = None
27
+
28
+
29
+ class Category(NamedTuple):
30
+ id: str
31
+ super_category: Optional[str]
32
+ name: str
33
+
34
+
35
+ class Annotation(NamedTuple):
36
+ area: float
37
+ image_id: str
38
+ bbox: BoundingBox
39
+ category_no: int
40
+ category_id: str
41
+ id: Optional[int] = None
42
+ source: Optional[str] = None
43
+ confidence: Optional[float] = None
44
+ is_group_of: Optional[bool] = None
45
+ is_truncated: Optional[bool] = None
46
+ is_occluded: Optional[bool] = None
47
+ is_depiction: Optional[bool] = None
48
+ is_inside: Optional[bool] = None
49
+ segmentation: Optional[Dict] = None
taming/data/image_transforms.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import warnings
3
+ from typing import Union
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
8
+ from torchvision.transforms.functional import _get_image_size as get_image_size
9
+
10
+ from taming.data.helper_types import BoundingBox, Image
11
+
12
+ pil_to_tensor = PILToTensor()
13
+
14
+
15
+ def convert_pil_to_tensor(image: Image) -> Tensor:
16
+ with warnings.catch_warnings():
17
+ # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
18
+ warnings.simplefilter("ignore")
19
+ return pil_to_tensor(image)
20
+
21
+
22
+ class RandomCrop1dReturnCoordinates(RandomCrop):
23
+ def forward(self, img: Image) -> (BoundingBox, Image):
24
+ """
25
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
26
+ Args:
27
+ img (PIL Image or Tensor): Image to be cropped.
28
+
29
+ Returns:
30
+ Bounding box: x0, y0, w, h
31
+ PIL Image or Tensor: Cropped image.
32
+
33
+ Based on:
34
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
35
+ """
36
+ if self.padding is not None:
37
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
38
+
39
+ width, height = get_image_size(img)
40
+ # pad the width if needed
41
+ if self.pad_if_needed and width < self.size[1]:
42
+ padding = [self.size[1] - width, 0]
43
+ img = F.pad(img, padding, self.fill, self.padding_mode)
44
+ # pad the height if needed
45
+ if self.pad_if_needed and height < self.size[0]:
46
+ padding = [0, self.size[0] - height]
47
+ img = F.pad(img, padding, self.fill, self.padding_mode)
48
+
49
+ i, j, h, w = self.get_params(img, self.size)
50
+ bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
51
+ return bbox, F.crop(img, i, j, h, w)
52
+
53
+
54
+ class Random2dCropReturnCoordinates(torch.nn.Module):
55
+ """
56
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
57
+ Args:
58
+ img (PIL Image or Tensor): Image to be cropped.
59
+
60
+ Returns:
61
+ Bounding box: x0, y0, w, h
62
+ PIL Image or Tensor: Cropped image.
63
+
64
+ Based on:
65
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
66
+ """
67
+
68
+ def __init__(self, min_size: int):
69
+ super().__init__()
70
+ self.min_size = min_size
71
+
72
+ def forward(self, img: Image) -> (BoundingBox, Image):
73
+ width, height = get_image_size(img)
74
+ max_size = min(width, height)
75
+ if max_size <= self.min_size:
76
+ size = max_size
77
+ else:
78
+ size = random.randint(self.min_size, max_size)
79
+ top = random.randint(0, height - size)
80
+ left = random.randint(0, width - size)
81
+ bbox = left / width, top / height, size / width, size / height
82
+ return bbox, F.crop(img, top, left, size, size)
83
+
84
+
85
+ class CenterCropReturnCoordinates(CenterCrop):
86
+ @staticmethod
87
+ def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
88
+ if width > height:
89
+ w = height / width
90
+ h = 1.0
91
+ x0 = 0.5 - w / 2
92
+ y0 = 0.
93
+ else:
94
+ w = 1.0
95
+ h = width / height
96
+ x0 = 0.
97
+ y0 = 0.5 - h / 2
98
+ return x0, y0, w, h
99
+
100
+ def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
101
+ """
102
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
103
+ Args:
104
+ img (PIL Image or Tensor): Image to be cropped.
105
+
106
+ Returns:
107
+ Bounding box: x0, y0, w, h
108
+ PIL Image or Tensor: Cropped image.
109
+ Based on:
110
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
111
+ """
112
+ width, height = get_image_size(img)
113
+ return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
114
+
115
+
116
+ class RandomHorizontalFlipReturn(RandomHorizontalFlip):
117
+ def forward(self, img: Image) -> (bool, Image):
118
+ """
119
+ Additionally to flipping, returns a boolean whether it was flipped or not.
120
+ Args:
121
+ img (PIL Image or Tensor): Image to be flipped.
122
+
123
+ Returns:
124
+ flipped: whether the image was flipped or not
125
+ PIL Image or Tensor: Randomly flipped image.
126
+
127
+ Based on:
128
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
129
+ """
130
+ if torch.rand(1) < self.p:
131
+ return True, F.hflip(img)
132
+ return False, img
taming/data/imagenet.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, tarfile, glob, shutil
2
+ import yaml
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+ import albumentations
7
+ from omegaconf import OmegaConf
8
+ from torch.utils.data import Dataset
9
+
10
+ from taming.data.base import ImagePaths
11
+ from taming.util import download, retrieve
12
+ import taming.data.utils as bdu
13
+
14
+
15
+ def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
16
+ synsets = []
17
+ with open(path_to_yaml) as f:
18
+ di2s = yaml.load(f)
19
+ for idx in indices:
20
+ synsets.append(str(di2s[idx]))
21
+ print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
22
+ return synsets
23
+
24
+
25
+ def str_to_indices(string):
26
+ """Expects a string in the format '32-123, 256, 280-321'"""
27
+ assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
28
+ subs = string.split(",")
29
+ indices = []
30
+ for sub in subs:
31
+ subsubs = sub.split("-")
32
+ assert len(subsubs) > 0
33
+ if len(subsubs) == 1:
34
+ indices.append(int(subsubs[0]))
35
+ else:
36
+ rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
37
+ indices.extend(rang)
38
+ return sorted(indices)
39
+
40
+
41
+ class ImageNetBase(Dataset):
42
+ def __init__(self, config=None):
43
+ self.config = config or OmegaConf.create()
44
+ if not type(self.config)==dict:
45
+ self.config = OmegaConf.to_container(self.config)
46
+ self._prepare()
47
+ self._prepare_synset_to_human()
48
+ self._prepare_idx_to_synset()
49
+ self._load()
50
+
51
+ def __len__(self):
52
+ return len(self.data)
53
+
54
+ def __getitem__(self, i):
55
+ return self.data[i]
56
+
57
+ def _prepare(self):
58
+ raise NotImplementedError()
59
+
60
+ def _filter_relpaths(self, relpaths):
61
+ ignore = set([
62
+ "n06596364_9591.JPEG",
63
+ ])
64
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
65
+ if "sub_indices" in self.config:
66
+ indices = str_to_indices(self.config["sub_indices"])
67
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
68
+ files = []
69
+ for rpath in relpaths:
70
+ syn = rpath.split("/")[0]
71
+ if syn in synsets:
72
+ files.append(rpath)
73
+ return files
74
+ else:
75
+ return relpaths
76
+
77
+ def _prepare_synset_to_human(self):
78
+ SIZE = 2655750
79
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
80
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
81
+ if (not os.path.exists(self.human_dict) or
82
+ not os.path.getsize(self.human_dict)==SIZE):
83
+ download(URL, self.human_dict)
84
+
85
+ def _prepare_idx_to_synset(self):
86
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
87
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
88
+ if (not os.path.exists(self.idx2syn)):
89
+ download(URL, self.idx2syn)
90
+
91
+ def _load(self):
92
+ with open(self.txt_filelist, "r") as f:
93
+ self.relpaths = f.read().splitlines()
94
+ l1 = len(self.relpaths)
95
+ self.relpaths = self._filter_relpaths(self.relpaths)
96
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
97
+
98
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
99
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
100
+
101
+ unique_synsets = np.unique(self.synsets)
102
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
103
+ self.class_labels = [class_dict[s] for s in self.synsets]
104
+
105
+ with open(self.human_dict, "r") as f:
106
+ human_dict = f.read().splitlines()
107
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
108
+
109
+ self.human_labels = [human_dict[s] for s in self.synsets]
110
+
111
+ labels = {
112
+ "relpath": np.array(self.relpaths),
113
+ "synsets": np.array(self.synsets),
114
+ "class_label": np.array(self.class_labels),
115
+ "human_label": np.array(self.human_labels),
116
+ }
117
+ self.data = ImagePaths(self.abspaths,
118
+ labels=labels,
119
+ size=retrieve(self.config, "size", default=0),
120
+ random_crop=self.random_crop)
121
+
122
+
123
+ class ImageNetTrain(ImageNetBase):
124
+ NAME = "ILSVRC2012_train"
125
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
126
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
127
+ FILES = [
128
+ "ILSVRC2012_img_train.tar",
129
+ ]
130
+ SIZES = [
131
+ 147897477120,
132
+ ]
133
+
134
+ def _prepare(self):
135
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
136
+ default=True)
137
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
138
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
139
+ self.datadir = os.path.join(self.root, "data")
140
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
141
+ self.expected_length = 1281167
142
+ if not bdu.is_prepared(self.root):
143
+ # prep
144
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
145
+
146
+ datadir = self.datadir
147
+ if not os.path.exists(datadir):
148
+ path = os.path.join(self.root, self.FILES[0])
149
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
150
+ import academictorrents as at
151
+ atpath = at.get(self.AT_HASH, datastore=self.root)
152
+ assert atpath == path
153
+
154
+ print("Extracting {} to {}".format(path, datadir))
155
+ os.makedirs(datadir, exist_ok=True)
156
+ with tarfile.open(path, "r:") as tar:
157
+ tar.extractall(path=datadir)
158
+
159
+ print("Extracting sub-tars.")
160
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
161
+ for subpath in tqdm(subpaths):
162
+ subdir = subpath[:-len(".tar")]
163
+ os.makedirs(subdir, exist_ok=True)
164
+ with tarfile.open(subpath, "r:") as tar:
165
+ tar.extractall(path=subdir)
166
+
167
+
168
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
169
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
170
+ filelist = sorted(filelist)
171
+ filelist = "\n".join(filelist)+"\n"
172
+ with open(self.txt_filelist, "w") as f:
173
+ f.write(filelist)
174
+
175
+ bdu.mark_prepared(self.root)
176
+
177
+
178
+ class ImageNetValidation(ImageNetBase):
179
+ NAME = "ILSVRC2012_validation"
180
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
181
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
182
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
183
+ FILES = [
184
+ "ILSVRC2012_img_val.tar",
185
+ "validation_synset.txt",
186
+ ]
187
+ SIZES = [
188
+ 6744924160,
189
+ 1950000,
190
+ ]
191
+
192
+ def _prepare(self):
193
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
194
+ default=False)
195
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
196
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
197
+ self.datadir = os.path.join(self.root, "data")
198
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
199
+ self.expected_length = 50000
200
+ if not bdu.is_prepared(self.root):
201
+ # prep
202
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
203
+
204
+ datadir = self.datadir
205
+ if not os.path.exists(datadir):
206
+ path = os.path.join(self.root, self.FILES[0])
207
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
208
+ import academictorrents as at
209
+ atpath = at.get(self.AT_HASH, datastore=self.root)
210
+ assert atpath == path
211
+
212
+ print("Extracting {} to {}".format(path, datadir))
213
+ os.makedirs(datadir, exist_ok=True)
214
+ with tarfile.open(path, "r:") as tar:
215
+ tar.extractall(path=datadir)
216
+
217
+ vspath = os.path.join(self.root, self.FILES[1])
218
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
219
+ download(self.VS_URL, vspath)
220
+
221
+ with open(vspath, "r") as f:
222
+ synset_dict = f.read().splitlines()
223
+ synset_dict = dict(line.split() for line in synset_dict)
224
+
225
+ print("Reorganizing into synset folders")
226
+ synsets = np.unique(list(synset_dict.values()))
227
+ for s in synsets:
228
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
229
+ for k, v in synset_dict.items():
230
+ src = os.path.join(datadir, k)
231
+ dst = os.path.join(datadir, v)
232
+ shutil.move(src, dst)
233
+
234
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
235
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
236
+ filelist = sorted(filelist)
237
+ filelist = "\n".join(filelist)+"\n"
238
+ with open(self.txt_filelist, "w") as f:
239
+ f.write(filelist)
240
+
241
+ bdu.mark_prepared(self.root)
242
+
243
+
244
+ def get_preprocessor(size=None, random_crop=False, additional_targets=None,
245
+ crop_size=None):
246
+ if size is not None and size > 0:
247
+ transforms = list()
248
+ rescaler = albumentations.SmallestMaxSize(max_size = size)
249
+ transforms.append(rescaler)
250
+ if not random_crop:
251
+ cropper = albumentations.CenterCrop(height=size,width=size)
252
+ transforms.append(cropper)
253
+ else:
254
+ cropper = albumentations.RandomCrop(height=size,width=size)
255
+ transforms.append(cropper)
256
+ flipper = albumentations.HorizontalFlip()
257
+ transforms.append(flipper)
258
+ preprocessor = albumentations.Compose(transforms,
259
+ additional_targets=additional_targets)
260
+ elif crop_size is not None and crop_size > 0:
261
+ if not random_crop:
262
+ cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
263
+ else:
264
+ cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
265
+ transforms = [cropper]
266
+ preprocessor = albumentations.Compose(transforms,
267
+ additional_targets=additional_targets)
268
+ else:
269
+ preprocessor = lambda **kwargs: kwargs
270
+ return preprocessor
271
+
272
+
273
+ def rgba_to_depth(x):
274
+ assert x.dtype == np.uint8
275
+ assert len(x.shape) == 3 and x.shape[2] == 4
276
+ y = x.copy()
277
+ y.dtype = np.float32
278
+ y = y.reshape(x.shape[:2])
279
+ return np.ascontiguousarray(y)
280
+
281
+
282
+ class BaseWithDepth(Dataset):
283
+ DEFAULT_DEPTH_ROOT="data/imagenet_depth"
284
+
285
+ def __init__(self, config=None, size=None, random_crop=False,
286
+ crop_size=None, root=None):
287
+ self.config = config
288
+ self.base_dset = self.get_base_dset()
289
+ self.preprocessor = get_preprocessor(
290
+ size=size,
291
+ crop_size=crop_size,
292
+ random_crop=random_crop,
293
+ additional_targets={"depth": "image"})
294
+ self.crop_size = crop_size
295
+ if self.crop_size is not None:
296
+ self.rescaler = albumentations.Compose(
297
+ [albumentations.SmallestMaxSize(max_size = self.crop_size)],
298
+ additional_targets={"depth": "image"})
299
+ if root is not None:
300
+ self.DEFAULT_DEPTH_ROOT = root
301
+
302
+ def __len__(self):
303
+ return len(self.base_dset)
304
+
305
+ def preprocess_depth(self, path):
306
+ rgba = np.array(Image.open(path))
307
+ depth = rgba_to_depth(rgba)
308
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
309
+ depth = 2.0*depth-1.0
310
+ return depth
311
+
312
+ def __getitem__(self, i):
313
+ e = self.base_dset[i]
314
+ e["depth"] = self.preprocess_depth(self.get_depth_path(e))
315
+ # up if necessary
316
+ h,w,c = e["image"].shape
317
+ if self.crop_size and min(h,w) < self.crop_size:
318
+ # have to upscale to be able to crop - this just uses bilinear
319
+ out = self.rescaler(image=e["image"], depth=e["depth"])
320
+ e["image"] = out["image"]
321
+ e["depth"] = out["depth"]
322
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
323
+ e["image"] = transformed["image"]
324
+ e["depth"] = transformed["depth"]
325
+ return e
326
+
327
+
328
+ class ImageNetTrainWithDepth(BaseWithDepth):
329
+ # default to random_crop=True
330
+ def __init__(self, random_crop=True, sub_indices=None, **kwargs):
331
+ self.sub_indices = sub_indices
332
+ super().__init__(random_crop=random_crop, **kwargs)
333
+
334
+ def get_base_dset(self):
335
+ if self.sub_indices is None:
336
+ return ImageNetTrain()
337
+ else:
338
+ return ImageNetTrain({"sub_indices": self.sub_indices})
339
+
340
+ def get_depth_path(self, e):
341
+ fid = os.path.splitext(e["relpath"])[0]+".png"
342
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
343
+ return fid
344
+
345
+
346
+ class ImageNetValidationWithDepth(BaseWithDepth):
347
+ def __init__(self, sub_indices=None, **kwargs):
348
+ self.sub_indices = sub_indices
349
+ super().__init__(**kwargs)
350
+
351
+ def get_base_dset(self):
352
+ if self.sub_indices is None:
353
+ return ImageNetValidation()
354
+ else:
355
+ return ImageNetValidation({"sub_indices": self.sub_indices})
356
+
357
+ def get_depth_path(self, e):
358
+ fid = os.path.splitext(e["relpath"])[0]+".png"
359
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
360
+ return fid
361
+
362
+
363
+ class RINTrainWithDepth(ImageNetTrainWithDepth):
364
+ def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
365
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
366
+ super().__init__(config=config, size=size, random_crop=random_crop,
367
+ sub_indices=sub_indices, crop_size=crop_size)
368
+
369
+
370
+ class RINValidationWithDepth(ImageNetValidationWithDepth):
371
+ def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
372
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
373
+ super().__init__(config=config, size=size, random_crop=random_crop,
374
+ sub_indices=sub_indices, crop_size=crop_size)
375
+
376
+
377
+ class DRINExamples(Dataset):
378
+ def __init__(self):
379
+ self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
380
+ with open("data/drin_examples.txt", "r") as f:
381
+ relpaths = f.read().splitlines()
382
+ self.image_paths = [os.path.join("data/drin_images",
383
+ relpath) for relpath in relpaths]
384
+ self.depth_paths = [os.path.join("data/drin_depth",
385
+ relpath.replace(".JPEG", ".png")) for relpath in relpaths]
386
+
387
+ def __len__(self):
388
+ return len(self.image_paths)
389
+
390
+ def preprocess_image(self, image_path):
391
+ image = Image.open(image_path)
392
+ if not image.mode == "RGB":
393
+ image = image.convert("RGB")
394
+ image = np.array(image).astype(np.uint8)
395
+ image = self.preprocessor(image=image)["image"]
396
+ image = (image/127.5 - 1.0).astype(np.float32)
397
+ return image
398
+
399
+ def preprocess_depth(self, path):
400
+ rgba = np.array(Image.open(path))
401
+ depth = rgba_to_depth(rgba)
402
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
403
+ depth = 2.0*depth-1.0
404
+ return depth
405
+
406
+ def __getitem__(self, i):
407
+ e = dict()
408
+ e["image"] = self.preprocess_image(self.image_paths[i])
409
+ e["depth"] = self.preprocess_depth(self.depth_paths[i])
410
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
411
+ e["image"] = transformed["image"]
412
+ e["depth"] = transformed["depth"]
413
+ return e
414
+
415
+
416
+ def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
417
+ if factor is None or factor==1:
418
+ return x
419
+
420
+ dtype = x.dtype
421
+ assert dtype in [np.float32, np.float64]
422
+ assert x.min() >= -1
423
+ assert x.max() <= 1
424
+
425
+ keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
426
+ "bicubic": Image.BICUBIC}[keepmode]
427
+
428
+ lr = (x+1.0)*127.5
429
+ lr = lr.clip(0,255).astype(np.uint8)
430
+ lr = Image.fromarray(lr)
431
+
432
+ h, w, _ = x.shape
433
+ nh = h//factor
434
+ nw = w//factor
435
+ assert nh > 0 and nw > 0, (nh, nw)
436
+
437
+ lr = lr.resize((nw,nh), Image.BICUBIC)
438
+ if keepshapes:
439
+ lr = lr.resize((w,h), keepmode)
440
+ lr = np.array(lr)/127.5-1.0
441
+ lr = lr.astype(dtype)
442
+
443
+ return lr
444
+
445
+
446
+ class ImageNetScale(Dataset):
447
+ def __init__(self, size=None, crop_size=None, random_crop=False,
448
+ up_factor=None, hr_factor=None, keep_mode="bicubic"):
449
+ self.base = self.get_base()
450
+
451
+ self.size = size
452
+ self.crop_size = crop_size if crop_size is not None else self.size
453
+ self.random_crop = random_crop
454
+ self.up_factor = up_factor
455
+ self.hr_factor = hr_factor
456
+ self.keep_mode = keep_mode
457
+
458
+ transforms = list()
459
+
460
+ if self.size is not None and self.size > 0:
461
+ rescaler = albumentations.SmallestMaxSize(max_size = self.size)
462
+ self.rescaler = rescaler
463
+ transforms.append(rescaler)
464
+
465
+ if self.crop_size is not None and self.crop_size > 0:
466
+ if len(transforms) == 0:
467
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
468
+
469
+ if not self.random_crop:
470
+ cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
471
+ else:
472
+ cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
473
+ transforms.append(cropper)
474
+
475
+ if len(transforms) > 0:
476
+ if self.up_factor is not None:
477
+ additional_targets = {"lr": "image"}
478
+ else:
479
+ additional_targets = None
480
+ self.preprocessor = albumentations.Compose(transforms,
481
+ additional_targets=additional_targets)
482
+ else:
483
+ self.preprocessor = lambda **kwargs: kwargs
484
+
485
+ def __len__(self):
486
+ return len(self.base)
487
+
488
+ def __getitem__(self, i):
489
+ example = self.base[i]
490
+ image = example["image"]
491
+ # adjust resolution
492
+ image = imscale(image, self.hr_factor, keepshapes=False)
493
+ h,w,c = image.shape
494
+ if self.crop_size and min(h,w) < self.crop_size:
495
+ # have to upscale to be able to crop - this just uses bilinear
496
+ image = self.rescaler(image=image)["image"]
497
+ if self.up_factor is None:
498
+ image = self.preprocessor(image=image)["image"]
499
+ example["image"] = image
500
+ else:
501
+ lr = imscale(image, self.up_factor, keepshapes=True,
502
+ keepmode=self.keep_mode)
503
+
504
+ out = self.preprocessor(image=image, lr=lr)
505
+ example["image"] = out["image"]
506
+ example["lr"] = out["lr"]
507
+
508
+ return example
509
+
510
+ class ImageNetScaleTrain(ImageNetScale):
511
+ def __init__(self, random_crop=True, **kwargs):
512
+ super().__init__(random_crop=random_crop, **kwargs)
513
+
514
+ def get_base(self):
515
+ return ImageNetTrain()
516
+
517
+ class ImageNetScaleValidation(ImageNetScale):
518
+ def get_base(self):
519
+ return ImageNetValidation()
520
+
521
+
522
+ from skimage.feature import canny
523
+ from skimage.color import rgb2gray
524
+
525
+
526
+ class ImageNetEdges(ImageNetScale):
527
+ def __init__(self, up_factor=1, **kwargs):
528
+ super().__init__(up_factor=1, **kwargs)
529
+
530
+ def __getitem__(self, i):
531
+ example = self.base[i]
532
+ image = example["image"]
533
+ h,w,c = image.shape
534
+ if self.crop_size and min(h,w) < self.crop_size:
535
+ # have to upscale to be able to crop - this just uses bilinear
536
+ image = self.rescaler(image=image)["image"]
537
+
538
+ lr = canny(rgb2gray(image), sigma=2)
539
+ lr = lr.astype(np.float32)
540
+ lr = lr[:,:,None][:,:,[0,0,0]]
541
+
542
+ out = self.preprocessor(image=image, lr=lr)
543
+ example["image"] = out["image"]
544
+ example["lr"] = out["lr"]
545
+
546
+ return example
547
+
548
+
549
+ class ImageNetEdgesTrain(ImageNetEdges):
550
+ def __init__(self, random_crop=True, **kwargs):
551
+ super().__init__(random_crop=random_crop, **kwargs)
552
+
553
+ def get_base(self):
554
+ return ImageNetTrain()
555
+
556
+ class ImageNetEdgesValidation(ImageNetEdges):
557
+ def get_base(self):
558
+ return ImageNetValidation()
taming/data/open_images_helper.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ open_images_unify_categories_for_coco = {
2
+ '/m/03bt1vf': '/m/01g317',
3
+ '/m/04yx4': '/m/01g317',
4
+ '/m/05r655': '/m/01g317',
5
+ '/m/01bl7v': '/m/01g317',
6
+ '/m/0cnyhnx': '/m/01xq0k1',
7
+ '/m/01226z': '/m/018xm',
8
+ '/m/05ctyq': '/m/018xm',
9
+ '/m/058qzx': '/m/04ctx',
10
+ '/m/06pcq': '/m/0l515',
11
+ '/m/03m3pdh': '/m/02crq1',
12
+ '/m/046dlr': '/m/01x3z',
13
+ '/m/0h8mzrc': '/m/01x3z',
14
+ }
15
+
16
+
17
+ top_300_classes_plus_coco_compatibility = [
18
+ ('Man', 1060962),
19
+ ('Clothing', 986610),
20
+ ('Tree', 748162),
21
+ ('Woman', 611896),
22
+ ('Person', 610294),
23
+ ('Human face', 442948),
24
+ ('Girl', 175399),
25
+ ('Building', 162147),
26
+ ('Car', 159135),
27
+ ('Plant', 155704),
28
+ ('Human body', 137073),
29
+ ('Flower', 133128),
30
+ ('Window', 127485),
31
+ ('Human arm', 118380),
32
+ ('House', 114365),
33
+ ('Wheel', 111684),
34
+ ('Suit', 99054),
35
+ ('Human hair', 98089),
36
+ ('Human head', 92763),
37
+ ('Chair', 88624),
38
+ ('Boy', 79849),
39
+ ('Table', 73699),
40
+ ('Jeans', 57200),
41
+ ('Tire', 55725),
42
+ ('Skyscraper', 53321),
43
+ ('Food', 52400),
44
+ ('Footwear', 50335),
45
+ ('Dress', 50236),
46
+ ('Human leg', 47124),
47
+ ('Toy', 46636),
48
+ ('Tower', 45605),
49
+ ('Boat', 43486),
50
+ ('Land vehicle', 40541),
51
+ ('Bicycle wheel', 34646),
52
+ ('Palm tree', 33729),
53
+ ('Fashion accessory', 32914),
54
+ ('Glasses', 31940),
55
+ ('Bicycle', 31409),
56
+ ('Furniture', 30656),
57
+ ('Sculpture', 29643),
58
+ ('Bottle', 27558),
59
+ ('Dog', 26980),
60
+ ('Snack', 26796),
61
+ ('Human hand', 26664),
62
+ ('Bird', 25791),
63
+ ('Book', 25415),
64
+ ('Guitar', 24386),
65
+ ('Jacket', 23998),
66
+ ('Poster', 22192),
67
+ ('Dessert', 21284),
68
+ ('Baked goods', 20657),
69
+ ('Drink', 19754),
70
+ ('Flag', 18588),
71
+ ('Houseplant', 18205),
72
+ ('Tableware', 17613),
73
+ ('Airplane', 17218),
74
+ ('Door', 17195),
75
+ ('Sports uniform', 17068),
76
+ ('Shelf', 16865),
77
+ ('Drum', 16612),
78
+ ('Vehicle', 16542),
79
+ ('Microphone', 15269),
80
+ ('Street light', 14957),
81
+ ('Cat', 14879),
82
+ ('Fruit', 13684),
83
+ ('Fast food', 13536),
84
+ ('Animal', 12932),
85
+ ('Vegetable', 12534),
86
+ ('Train', 12358),
87
+ ('Horse', 11948),
88
+ ('Flowerpot', 11728),
89
+ ('Motorcycle', 11621),
90
+ ('Fish', 11517),
91
+ ('Desk', 11405),
92
+ ('Helmet', 10996),
93
+ ('Truck', 10915),
94
+ ('Bus', 10695),
95
+ ('Hat', 10532),
96
+ ('Auto part', 10488),
97
+ ('Musical instrument', 10303),
98
+ ('Sunglasses', 10207),
99
+ ('Picture frame', 10096),
100
+ ('Sports equipment', 10015),
101
+ ('Shorts', 9999),
102
+ ('Wine glass', 9632),
103
+ ('Duck', 9242),
104
+ ('Wine', 9032),
105
+ ('Rose', 8781),
106
+ ('Tie', 8693),
107
+ ('Butterfly', 8436),
108
+ ('Beer', 7978),
109
+ ('Cabinetry', 7956),
110
+ ('Laptop', 7907),
111
+ ('Insect', 7497),
112
+ ('Goggles', 7363),
113
+ ('Shirt', 7098),
114
+ ('Dairy Product', 7021),
115
+ ('Marine invertebrates', 7014),
116
+ ('Cattle', 7006),
117
+ ('Trousers', 6903),
118
+ ('Van', 6843),
119
+ ('Billboard', 6777),
120
+ ('Balloon', 6367),
121
+ ('Human nose', 6103),
122
+ ('Tent', 6073),
123
+ ('Camera', 6014),
124
+ ('Doll', 6002),
125
+ ('Coat', 5951),
126
+ ('Mobile phone', 5758),
127
+ ('Swimwear', 5729),
128
+ ('Strawberry', 5691),
129
+ ('Stairs', 5643),
130
+ ('Goose', 5599),
131
+ ('Umbrella', 5536),
132
+ ('Cake', 5508),
133
+ ('Sun hat', 5475),
134
+ ('Bench', 5310),
135
+ ('Bookcase', 5163),
136
+ ('Bee', 5140),
137
+ ('Computer monitor', 5078),
138
+ ('Hiking equipment', 4983),
139
+ ('Office building', 4981),
140
+ ('Coffee cup', 4748),
141
+ ('Curtain', 4685),
142
+ ('Plate', 4651),
143
+ ('Box', 4621),
144
+ ('Tomato', 4595),
145
+ ('Coffee table', 4529),
146
+ ('Office supplies', 4473),
147
+ ('Maple', 4416),
148
+ ('Muffin', 4365),
149
+ ('Cocktail', 4234),
150
+ ('Castle', 4197),
151
+ ('Couch', 4134),
152
+ ('Pumpkin', 3983),
153
+ ('Computer keyboard', 3960),
154
+ ('Human mouth', 3926),
155
+ ('Christmas tree', 3893),
156
+ ('Mushroom', 3883),
157
+ ('Swimming pool', 3809),
158
+ ('Pastry', 3799),
159
+ ('Lavender (Plant)', 3769),
160
+ ('Football helmet', 3732),
161
+ ('Bread', 3648),
162
+ ('Traffic sign', 3628),
163
+ ('Common sunflower', 3597),
164
+ ('Television', 3550),
165
+ ('Bed', 3525),
166
+ ('Cookie', 3485),
167
+ ('Fountain', 3484),
168
+ ('Paddle', 3447),
169
+ ('Bicycle helmet', 3429),
170
+ ('Porch', 3420),
171
+ ('Deer', 3387),
172
+ ('Fedora', 3339),
173
+ ('Canoe', 3338),
174
+ ('Carnivore', 3266),
175
+ ('Bowl', 3202),
176
+ ('Human eye', 3166),
177
+ ('Ball', 3118),
178
+ ('Pillow', 3077),
179
+ ('Salad', 3061),
180
+ ('Beetle', 3060),
181
+ ('Orange', 3050),
182
+ ('Drawer', 2958),
183
+ ('Platter', 2937),
184
+ ('Elephant', 2921),
185
+ ('Seafood', 2921),
186
+ ('Monkey', 2915),
187
+ ('Countertop', 2879),
188
+ ('Watercraft', 2831),
189
+ ('Helicopter', 2805),
190
+ ('Kitchen appliance', 2797),
191
+ ('Personal flotation device', 2781),
192
+ ('Swan', 2739),
193
+ ('Lamp', 2711),
194
+ ('Boot', 2695),
195
+ ('Bronze sculpture', 2693),
196
+ ('Chicken', 2677),
197
+ ('Taxi', 2643),
198
+ ('Juice', 2615),
199
+ ('Cowboy hat', 2604),
200
+ ('Apple', 2600),
201
+ ('Tin can', 2590),
202
+ ('Necklace', 2564),
203
+ ('Ice cream', 2560),
204
+ ('Human beard', 2539),
205
+ ('Coin', 2536),
206
+ ('Candle', 2515),
207
+ ('Cart', 2512),
208
+ ('High heels', 2441),
209
+ ('Weapon', 2433),
210
+ ('Handbag', 2406),
211
+ ('Penguin', 2396),
212
+ ('Rifle', 2352),
213
+ ('Violin', 2336),
214
+ ('Skull', 2304),
215
+ ('Lantern', 2285),
216
+ ('Scarf', 2269),
217
+ ('Saucer', 2225),
218
+ ('Sheep', 2215),
219
+ ('Vase', 2189),
220
+ ('Lily', 2180),
221
+ ('Mug', 2154),
222
+ ('Parrot', 2140),
223
+ ('Human ear', 2137),
224
+ ('Sandal', 2115),
225
+ ('Lizard', 2100),
226
+ ('Kitchen & dining room table', 2063),
227
+ ('Spider', 1977),
228
+ ('Coffee', 1974),
229
+ ('Goat', 1926),
230
+ ('Squirrel', 1922),
231
+ ('Cello', 1913),
232
+ ('Sushi', 1881),
233
+ ('Tortoise', 1876),
234
+ ('Pizza', 1870),
235
+ ('Studio couch', 1864),
236
+ ('Barrel', 1862),
237
+ ('Cosmetics', 1841),
238
+ ('Moths and butterflies', 1841),
239
+ ('Convenience store', 1817),
240
+ ('Watch', 1792),
241
+ ('Home appliance', 1786),
242
+ ('Harbor seal', 1780),
243
+ ('Luggage and bags', 1756),
244
+ ('Vehicle registration plate', 1754),
245
+ ('Shrimp', 1751),
246
+ ('Jellyfish', 1730),
247
+ ('French fries', 1723),
248
+ ('Egg (Food)', 1698),
249
+ ('Football', 1697),
250
+ ('Musical keyboard', 1683),
251
+ ('Falcon', 1674),
252
+ ('Candy', 1660),
253
+ ('Medical equipment', 1654),
254
+ ('Eagle', 1651),
255
+ ('Dinosaur', 1634),
256
+ ('Surfboard', 1630),
257
+ ('Tank', 1628),
258
+ ('Grape', 1624),
259
+ ('Lion', 1624),
260
+ ('Owl', 1622),
261
+ ('Ski', 1613),
262
+ ('Waste container', 1606),
263
+ ('Frog', 1591),
264
+ ('Sparrow', 1585),
265
+ ('Rabbit', 1581),
266
+ ('Pen', 1546),
267
+ ('Sea lion', 1537),
268
+ ('Spoon', 1521),
269
+ ('Sink', 1512),
270
+ ('Teddy bear', 1507),
271
+ ('Bull', 1495),
272
+ ('Sofa bed', 1490),
273
+ ('Dragonfly', 1479),
274
+ ('Brassiere', 1478),
275
+ ('Chest of drawers', 1472),
276
+ ('Aircraft', 1466),
277
+ ('Human foot', 1463),
278
+ ('Pig', 1455),
279
+ ('Fork', 1454),
280
+ ('Antelope', 1438),
281
+ ('Tripod', 1427),
282
+ ('Tool', 1424),
283
+ ('Cheese', 1422),
284
+ ('Lemon', 1397),
285
+ ('Hamburger', 1393),
286
+ ('Dolphin', 1390),
287
+ ('Mirror', 1390),
288
+ ('Marine mammal', 1387),
289
+ ('Giraffe', 1385),
290
+ ('Snake', 1368),
291
+ ('Gondola', 1364),
292
+ ('Wheelchair', 1360),
293
+ ('Piano', 1358),
294
+ ('Cupboard', 1348),
295
+ ('Banana', 1345),
296
+ ('Trumpet', 1335),
297
+ ('Lighthouse', 1333),
298
+ ('Invertebrate', 1317),
299
+ ('Carrot', 1268),
300
+ ('Sock', 1260),
301
+ ('Tiger', 1241),
302
+ ('Camel', 1224),
303
+ ('Parachute', 1224),
304
+ ('Bathroom accessory', 1223),
305
+ ('Earrings', 1221),
306
+ ('Headphones', 1218),
307
+ ('Skirt', 1198),
308
+ ('Skateboard', 1190),
309
+ ('Sandwich', 1148),
310
+ ('Saxophone', 1141),
311
+ ('Goldfish', 1136),
312
+ ('Stool', 1104),
313
+ ('Traffic light', 1097),
314
+ ('Shellfish', 1081),
315
+ ('Backpack', 1079),
316
+ ('Sea turtle', 1078),
317
+ ('Cucumber', 1075),
318
+ ('Tea', 1051),
319
+ ('Toilet', 1047),
320
+ ('Roller skates', 1040),
321
+ ('Mule', 1039),
322
+ ('Bust', 1031),
323
+ ('Broccoli', 1030),
324
+ ('Crab', 1020),
325
+ ('Oyster', 1019),
326
+ ('Cannon', 1012),
327
+ ('Zebra', 1012),
328
+ ('French horn', 1008),
329
+ ('Grapefruit', 998),
330
+ ('Whiteboard', 997),
331
+ ('Zucchini', 997),
332
+ ('Crocodile', 992),
333
+
334
+ ('Clock', 960),
335
+ ('Wall clock', 958),
336
+
337
+ ('Doughnut', 869),
338
+ ('Snail', 868),
339
+
340
+ ('Baseball glove', 859),
341
+
342
+ ('Panda', 830),
343
+ ('Tennis racket', 830),
344
+
345
+ ('Pear', 652),
346
+
347
+ ('Bagel', 617),
348
+ ('Oven', 616),
349
+ ('Ladybug', 615),
350
+ ('Shark', 615),
351
+ ('Polar bear', 614),
352
+ ('Ostrich', 609),
353
+
354
+ ('Hot dog', 473),
355
+ ('Microwave oven', 467),
356
+ ('Fire hydrant', 20),
357
+ ('Stop sign', 20),
358
+ ('Parking meter', 20),
359
+ ('Bear', 20),
360
+ ('Flying disc', 20),
361
+ ('Snowboard', 20),
362
+ ('Tennis ball', 20),
363
+ ('Kite', 20),
364
+ ('Baseball bat', 20),
365
+ ('Kitchen knife', 20),
366
+ ('Knife', 20),
367
+ ('Submarine sandwich', 20),
368
+ ('Computer mouse', 20),
369
+ ('Remote control', 20),
370
+ ('Toaster', 20),
371
+ ('Sink', 20),
372
+ ('Refrigerator', 20),
373
+ ('Alarm clock', 20),
374
+ ('Wall clock', 20),
375
+ ('Scissors', 20),
376
+ ('Hair dryer', 20),
377
+ ('Toothbrush', 20),
378
+ ('Suitcase', 20)
379
+ ]
taming/data/sflckr.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import albumentations
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+
8
+
9
+ class SegmentationBase(Dataset):
10
+ def __init__(self,
11
+ data_csv, data_root, segmentation_root,
12
+ size=None, random_crop=False, interpolation="bicubic",
13
+ n_labels=182, shift_segmentation=False,
14
+ ):
15
+ self.n_labels = n_labels
16
+ self.shift_segmentation = shift_segmentation
17
+ self.data_csv = data_csv
18
+ self.data_root = data_root
19
+ self.segmentation_root = segmentation_root
20
+ with open(self.data_csv, "r") as f:
21
+ self.image_paths = f.read().splitlines()
22
+ self._length = len(self.image_paths)
23
+ self.labels = {
24
+ "relative_file_path_": [l for l in self.image_paths],
25
+ "file_path_": [os.path.join(self.data_root, l)
26
+ for l in self.image_paths],
27
+ "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
28
+ for l in self.image_paths]
29
+ }
30
+
31
+ size = None if size is not None and size<=0 else size
32
+ self.size = size
33
+ if self.size is not None:
34
+ self.interpolation = interpolation
35
+ self.interpolation = {
36
+ "nearest": cv2.INTER_NEAREST,
37
+ "bilinear": cv2.INTER_LINEAR,
38
+ "bicubic": cv2.INTER_CUBIC,
39
+ "area": cv2.INTER_AREA,
40
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
41
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
42
+ interpolation=self.interpolation)
43
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
44
+ interpolation=cv2.INTER_NEAREST)
45
+ self.center_crop = not random_crop
46
+ if self.center_crop:
47
+ self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
48
+ else:
49
+ self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
50
+ self.preprocessor = self.cropper
51
+
52
+ def __len__(self):
53
+ return self._length
54
+
55
+ def __getitem__(self, i):
56
+ example = dict((k, self.labels[k][i]) for k in self.labels)
57
+ image = Image.open(example["file_path_"])
58
+ if not image.mode == "RGB":
59
+ image = image.convert("RGB")
60
+ image = np.array(image).astype(np.uint8)
61
+ if self.size is not None:
62
+ image = self.image_rescaler(image=image)["image"]
63
+ segmentation = Image.open(example["segmentation_path_"])
64
+ assert segmentation.mode == "L", segmentation.mode
65
+ segmentation = np.array(segmentation).astype(np.uint8)
66
+ if self.shift_segmentation:
67
+ # used to support segmentations containing unlabeled==255 label
68
+ segmentation = segmentation+1
69
+ if self.size is not None:
70
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
71
+ if self.size is not None:
72
+ processed = self.preprocessor(image=image,
73
+ mask=segmentation
74
+ )
75
+ else:
76
+ processed = {"image": image,
77
+ "mask": segmentation
78
+ }
79
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
80
+ segmentation = processed["mask"]
81
+ onehot = np.eye(self.n_labels)[segmentation]
82
+ example["segmentation"] = onehot
83
+ return example
84
+
85
+
86
+ class Examples(SegmentationBase):
87
+ def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
88
+ super().__init__(data_csv="data/sflckr_examples.txt",
89
+ data_root="data/sflckr_images",
90
+ segmentation_root="data/sflckr_segmentations",
91
+ size=size, random_crop=random_crop, interpolation=interpolation)
taming/data/utils.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import tarfile
4
+ import urllib
5
+ import zipfile
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ from taming.data.helper_types import Annotation
11
+ #from torch._six import string_classes
12
+ from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
13
+ from tqdm import tqdm
14
+
15
+ string_classes = (str,bytes)
16
+
17
+
18
+ def unpack(path):
19
+ if path.endswith("tar.gz"):
20
+ with tarfile.open(path, "r:gz") as tar:
21
+ tar.extractall(path=os.path.split(path)[0])
22
+ elif path.endswith("tar"):
23
+ with tarfile.open(path, "r:") as tar:
24
+ tar.extractall(path=os.path.split(path)[0])
25
+ elif path.endswith("zip"):
26
+ with zipfile.ZipFile(path, "r") as f:
27
+ f.extractall(path=os.path.split(path)[0])
28
+ else:
29
+ raise NotImplementedError(
30
+ "Unknown file extension: {}".format(os.path.splitext(path)[1])
31
+ )
32
+
33
+
34
+ def reporthook(bar):
35
+ """tqdm progress bar for downloads."""
36
+
37
+ def hook(b=1, bsize=1, tsize=None):
38
+ if tsize is not None:
39
+ bar.total = tsize
40
+ bar.update(b * bsize - bar.n)
41
+
42
+ return hook
43
+
44
+
45
+ def get_root(name):
46
+ base = "data/"
47
+ root = os.path.join(base, name)
48
+ os.makedirs(root, exist_ok=True)
49
+ return root
50
+
51
+
52
+ def is_prepared(root):
53
+ return Path(root).joinpath(".ready").exists()
54
+
55
+
56
+ def mark_prepared(root):
57
+ Path(root).joinpath(".ready").touch()
58
+
59
+
60
+ def prompt_download(file_, source, target_dir, content_dir=None):
61
+ targetpath = os.path.join(target_dir, file_)
62
+ while not os.path.exists(targetpath):
63
+ if content_dir is not None and os.path.exists(
64
+ os.path.join(target_dir, content_dir)
65
+ ):
66
+ break
67
+ print(
68
+ "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
69
+ )
70
+ if content_dir is not None:
71
+ print(
72
+ "Or place its content into '{}'.".format(
73
+ os.path.join(target_dir, content_dir)
74
+ )
75
+ )
76
+ input("Press Enter when done...")
77
+ return targetpath
78
+
79
+
80
+ def download_url(file_, url, target_dir):
81
+ targetpath = os.path.join(target_dir, file_)
82
+ os.makedirs(target_dir, exist_ok=True)
83
+ with tqdm(
84
+ unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
85
+ ) as bar:
86
+ urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
87
+ return targetpath
88
+
89
+
90
+ def download_urls(urls, target_dir):
91
+ paths = dict()
92
+ for fname, url in urls.items():
93
+ outpath = download_url(fname, url, target_dir)
94
+ paths[fname] = outpath
95
+ return paths
96
+
97
+
98
+ def quadratic_crop(x, bbox, alpha=1.0):
99
+ """bbox is xmin, ymin, xmax, ymax"""
100
+ im_h, im_w = x.shape[:2]
101
+ bbox = np.array(bbox, dtype=np.float32)
102
+ bbox = np.clip(bbox, 0, max(im_h, im_w))
103
+ center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
104
+ w = bbox[2] - bbox[0]
105
+ h = bbox[3] - bbox[1]
106
+ l = int(alpha * max(w, h))
107
+ l = max(l, 2)
108
+
109
+ required_padding = -1 * min(
110
+ center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
111
+ )
112
+ required_padding = int(np.ceil(required_padding))
113
+ if required_padding > 0:
114
+ padding = [
115
+ [required_padding, required_padding],
116
+ [required_padding, required_padding],
117
+ ]
118
+ padding += [[0, 0]] * (len(x.shape) - 2)
119
+ x = np.pad(x, padding, "reflect")
120
+ center = center[0] + required_padding, center[1] + required_padding
121
+ xmin = int(center[0] - l / 2)
122
+ ymin = int(center[1] - l / 2)
123
+ return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
124
+
125
+
126
+ def custom_collate(batch):
127
+ r"""source: pytorch 1.9.0, only one modification to original code """
128
+
129
+ elem = batch[0]
130
+ elem_type = type(elem)
131
+ if isinstance(elem, torch.Tensor):
132
+ out = None
133
+ if torch.utils.data.get_worker_info() is not None:
134
+ # If we're in a background process, concatenate directly into a
135
+ # shared memory tensor to avoid an extra copy
136
+ numel = sum([x.numel() for x in batch])
137
+ storage = elem.storage()._new_shared(numel)
138
+ out = elem.new(storage)
139
+ return torch.stack(batch, 0, out=out)
140
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
141
+ and elem_type.__name__ != 'string_':
142
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
143
+ # array of string classes and object
144
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
145
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
146
+
147
+ return custom_collate([torch.as_tensor(b) for b in batch])
148
+ elif elem.shape == (): # scalars
149
+ return torch.as_tensor(batch)
150
+ elif isinstance(elem, float):
151
+ return torch.tensor(batch, dtype=torch.float64)
152
+ elif isinstance(elem, int):
153
+ return torch.tensor(batch)
154
+ elif isinstance(elem, string_classes):
155
+ return batch
156
+ elif isinstance(elem, collections.abc.Mapping):
157
+ return {key: custom_collate([d[key] for d in batch]) for key in elem}
158
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
159
+ return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
160
+ if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
161
+ return batch # added
162
+ elif isinstance(elem, collections.abc.Sequence):
163
+ # check to make sure that the elements in batch have consistent size
164
+ it = iter(batch)
165
+ elem_size = len(next(it))
166
+ if not all(len(elem) == elem_size for elem in it):
167
+ raise RuntimeError('each element in list of batch should be of equal size')
168
+ transposed = zip(*batch)
169
+ return [custom_collate(samples) for samples in transposed]
170
+
171
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
taming/lr_scheduler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n):
33
+ return self.schedule(n)
34
+
taming/models/__pycache__/vqgan.cpython-312.pyc ADDED
Binary file (21.7 kB). View file
 
taming/models/cond_transformer.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+
6
+ from main import instantiate_from_config
7
+ from taming.modules.util import SOSProvider
8
+
9
+
10
+ def disabled_train(self, mode=True):
11
+ """Overwrite model.train with this function to make sure train/eval mode
12
+ does not change anymore."""
13
+ return self
14
+
15
+
16
+ class Net2NetTransformer(pl.LightningModule):
17
+ def __init__(self,
18
+ transformer_config,
19
+ first_stage_config,
20
+ cond_stage_config,
21
+ permuter_config=None,
22
+ ckpt_path=None,
23
+ ignore_keys=[],
24
+ first_stage_key="image",
25
+ cond_stage_key="depth",
26
+ downsample_cond_size=-1,
27
+ pkeep=1.0,
28
+ sos_token=0,
29
+ unconditional=False,
30
+ ):
31
+ super().__init__()
32
+ self.be_unconditional = unconditional
33
+ self.sos_token = sos_token
34
+ self.first_stage_key = first_stage_key
35
+ self.cond_stage_key = cond_stage_key
36
+ self.init_first_stage_from_ckpt(first_stage_config)
37
+ self.init_cond_stage_from_ckpt(cond_stage_config)
38
+ if permuter_config is None:
39
+ permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
40
+ self.permuter = instantiate_from_config(config=permuter_config)
41
+ self.transformer = instantiate_from_config(config=transformer_config)
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+ self.downsample_cond_size = downsample_cond_size
46
+ self.pkeep = pkeep
47
+
48
+ def init_from_ckpt(self, path, ignore_keys=list()):
49
+ sd = torch.load(path, map_location="cpu")["state_dict"]
50
+ for k in sd.keys():
51
+ for ik in ignore_keys:
52
+ if k.startswith(ik):
53
+ self.print("Deleting key {} from state_dict.".format(k))
54
+ del sd[k]
55
+ self.load_state_dict(sd, strict=False)
56
+ print(f"Restored from {path}")
57
+
58
+ def init_first_stage_from_ckpt(self, config):
59
+ model = instantiate_from_config(config)
60
+ model = model.eval()
61
+ model.train = disabled_train
62
+ self.first_stage_model = model
63
+
64
+ def init_cond_stage_from_ckpt(self, config):
65
+ if config == "__is_first_stage__":
66
+ print("Using first stage also as cond stage.")
67
+ self.cond_stage_model = self.first_stage_model
68
+ elif config == "__is_unconditional__" or self.be_unconditional:
69
+ print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
70
+ f"Prepending {self.sos_token} as a sos token.")
71
+ self.be_unconditional = True
72
+ self.cond_stage_key = self.first_stage_key
73
+ self.cond_stage_model = SOSProvider(self.sos_token)
74
+ else:
75
+ model = instantiate_from_config(config)
76
+ model = model.eval()
77
+ model.train = disabled_train
78
+ self.cond_stage_model = model
79
+
80
+ def forward(self, x, c):
81
+ # one step to produce the logits
82
+ _, z_indices = self.encode_to_z(x)
83
+ _, c_indices = self.encode_to_c(c)
84
+
85
+ if self.training and self.pkeep < 1.0:
86
+ mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
87
+ device=z_indices.device))
88
+ mask = mask.round().to(dtype=torch.int64)
89
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
90
+ a_indices = mask*z_indices+(1-mask)*r_indices
91
+ else:
92
+ a_indices = z_indices
93
+
94
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
95
+
96
+ # target includes all sequence elements (no need to handle first one
97
+ # differently because we are conditioning)
98
+ target = z_indices
99
+ # make the prediction
100
+ logits, _ = self.transformer(cz_indices[:, :-1])
101
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
102
+ logits = logits[:, c_indices.shape[1]-1:]
103
+
104
+ return logits, target
105
+
106
+ def top_k_logits(self, logits, k):
107
+ v, ix = torch.topk(logits, k)
108
+ out = logits.clone()
109
+ out[out < v[..., [-1]]] = -float('Inf')
110
+ return out
111
+
112
+ @torch.no_grad()
113
+ def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None,
114
+ callback=lambda k: None):
115
+ x = torch.cat((c,x),dim=1)
116
+ block_size = self.transformer.get_block_size()
117
+ assert not self.transformer.training
118
+ if self.pkeep <= 0.0:
119
+ # one pass suffices since input is pure noise anyway
120
+ assert len(x.shape)==2
121
+ noise_shape = (x.shape[0], steps-1)
122
+ #noise = torch.randint(self.transformer.config.vocab_size, noise_shape).to(x)
123
+ noise = c.clone()[:,x.shape[1]-c.shape[1]:-1]
124
+ x = torch.cat((x,noise),dim=1)
125
+ logits, _ = self.transformer(x)
126
+ # take all logits for now and scale by temp
127
+ logits = logits / temperature
128
+ # optionally crop probabilities to only the top k options
129
+ if top_k is not None:
130
+ logits = self.top_k_logits(logits, top_k)
131
+ # apply softmax to convert to probabilities
132
+ probs = F.softmax(logits, dim=-1)
133
+ # sample from the distribution or take the most likely
134
+ if sample:
135
+ shape = probs.shape
136
+ probs = probs.reshape(shape[0]*shape[1],shape[2])
137
+ ix = torch.multinomial(probs, num_samples=1)
138
+ probs = probs.reshape(shape[0],shape[1],shape[2])
139
+ ix = ix.reshape(shape[0],shape[1])
140
+ else:
141
+ _, ix = torch.topk(probs, k=1, dim=-1)
142
+ # cut off conditioning
143
+ x = ix[:, c.shape[1]-1:]
144
+ else:
145
+ for k in range(steps):
146
+ callback(k)
147
+ assert x.size(1) <= block_size # make sure model can see conditioning
148
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
149
+ logits, _ = self.transformer(x_cond)
150
+ # pluck the logits at the final step and scale by temperature
151
+ logits = logits[:, -1, :] / temperature
152
+ # optionally crop probabilities to only the top k options
153
+ if top_k is not None:
154
+ logits = self.top_k_logits(logits, top_k)
155
+ # apply softmax to convert to probabilities
156
+ probs = F.softmax(logits, dim=-1)
157
+ # sample from the distribution or take the most likely
158
+ if sample:
159
+ ix = torch.multinomial(probs, num_samples=1)
160
+ else:
161
+ _, ix = torch.topk(probs, k=1, dim=-1)
162
+ # append to the sequence and continue
163
+ x = torch.cat((x, ix), dim=1)
164
+ # cut off conditioning
165
+ x = x[:, c.shape[1]:]
166
+ return x
167
+
168
+ @torch.no_grad()
169
+ def encode_to_z(self, x):
170
+ quant_z, _, info = self.first_stage_model.encode(x)
171
+ indices = info[2].view(quant_z.shape[0], -1)
172
+ indices = self.permuter(indices)
173
+ return quant_z, indices
174
+
175
+ @torch.no_grad()
176
+ def encode_to_c(self, c):
177
+ if self.downsample_cond_size > -1:
178
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
179
+ quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
180
+ if len(indices.shape) > 2:
181
+ indices = indices.view(c.shape[0], -1)
182
+ return quant_c, indices
183
+
184
+ @torch.no_grad()
185
+ def decode_to_img(self, index, zshape):
186
+ index = self.permuter(index, reverse=True)
187
+ bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
188
+ quant_z = self.first_stage_model.quantize.get_codebook_entry(
189
+ index.reshape(-1), shape=bhwc)
190
+ x = self.first_stage_model.decode(quant_z)
191
+ return x
192
+
193
+ @torch.no_grad()
194
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
195
+ log = dict()
196
+
197
+ N = 4
198
+ if lr_interface:
199
+ x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
200
+ else:
201
+ x, c = self.get_xc(batch, N)
202
+ x = x.to(device=self.device)
203
+ c = c.to(device=self.device)
204
+
205
+ quant_z, z_indices = self.encode_to_z(x)
206
+ quant_c, c_indices = self.encode_to_c(c)
207
+
208
+ # create a "half"" sample
209
+ z_start_indices = z_indices[:,:z_indices.shape[1]//2]
210
+ index_sample = self.sample(z_start_indices, c_indices,
211
+ steps=z_indices.shape[1]-z_start_indices.shape[1],
212
+ temperature=temperature if temperature is not None else 1.0,
213
+ sample=True,
214
+ top_k=top_k if top_k is not None else 100,
215
+ callback=callback if callback is not None else lambda k: None)
216
+ x_sample = self.decode_to_img(index_sample, quant_z.shape)
217
+
218
+ # sample
219
+ z_start_indices = z_indices[:, :0]
220
+ index_sample = self.sample(z_start_indices, c_indices,
221
+ steps=z_indices.shape[1],
222
+ temperature=temperature if temperature is not None else 1.0,
223
+ sample=True,
224
+ top_k=top_k if top_k is not None else 100,
225
+ callback=callback if callback is not None else lambda k: None)
226
+ x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
227
+
228
+ # det sample
229
+ z_start_indices = z_indices[:, :0]
230
+ index_sample = self.sample(z_start_indices, c_indices,
231
+ steps=z_indices.shape[1],
232
+ sample=False,
233
+ callback=callback if callback is not None else lambda k: None)
234
+ x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
235
+
236
+ # reconstruction
237
+ x_rec = self.decode_to_img(z_indices, quant_z.shape)
238
+
239
+ log["inputs"] = x
240
+ log["reconstructions"] = x_rec
241
+
242
+ if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
243
+ figure_size = (x_rec.shape[2], x_rec.shape[3])
244
+ dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
245
+ label_for_category_no = dataset.get_textual_label_for_category_no
246
+ plotter = dataset.conditional_builders[self.cond_stage_key].plot
247
+ log["conditioning"] = torch.zeros_like(log["reconstructions"])
248
+ for i in range(quant_c.shape[0]):
249
+ log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
250
+ log["conditioning_rec"] = log["conditioning"]
251
+ elif self.cond_stage_key != "image":
252
+ cond_rec = self.cond_stage_model.decode(quant_c)
253
+ if self.cond_stage_key == "segmentation":
254
+ # get image from segmentation mask
255
+ num_classes = cond_rec.shape[1]
256
+
257
+ c = torch.argmax(c, dim=1, keepdim=True)
258
+ c = F.one_hot(c, num_classes=num_classes)
259
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
260
+ c = self.cond_stage_model.to_rgb(c)
261
+
262
+ cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
263
+ cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
264
+ cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
265
+ cond_rec = self.cond_stage_model.to_rgb(cond_rec)
266
+ log["conditioning_rec"] = cond_rec
267
+ log["conditioning"] = c
268
+
269
+ log["samples_half"] = x_sample
270
+ log["samples_nopix"] = x_sample_nopix
271
+ log["samples_det"] = x_sample_det
272
+ return log
273
+
274
+ def get_input(self, key, batch):
275
+ x = batch[key]
276
+ if len(x.shape) == 3:
277
+ x = x[..., None]
278
+ if len(x.shape) == 4:
279
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
280
+ if x.dtype == torch.double:
281
+ x = x.float()
282
+ return x
283
+
284
+ def get_xc(self, batch, N=None):
285
+ x = self.get_input(self.first_stage_key, batch)
286
+ c = self.get_input(self.cond_stage_key, batch)
287
+ if N is not None:
288
+ x = x[:N]
289
+ c = c[:N]
290
+ return x, c
291
+
292
+ def shared_step(self, batch, batch_idx):
293
+ x, c = self.get_xc(batch)
294
+ logits, target = self(x, c)
295
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
296
+ return loss
297
+
298
+ def training_step(self, batch, batch_idx):
299
+ loss = self.shared_step(batch, batch_idx)
300
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
301
+ return loss
302
+
303
+ def validation_step(self, batch, batch_idx):
304
+ loss = self.shared_step(batch, batch_idx)
305
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
306
+ return loss
307
+
308
+ def configure_optimizers(self):
309
+ """
310
+ Following minGPT:
311
+ This long function is unfortunately doing something very simple and is being very defensive:
312
+ We are separating out all parameters of the model into two buckets: those that will experience
313
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
314
+ We are then returning the PyTorch optimizer object.
315
+ """
316
+ # separate out all parameters to those that will and won't experience regularizing weight decay
317
+ decay = set()
318
+ no_decay = set()
319
+ whitelist_weight_modules = (torch.nn.Linear, )
320
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
321
+ for mn, m in self.transformer.named_modules():
322
+ for pn, p in m.named_parameters():
323
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
324
+
325
+ if pn.endswith('bias'):
326
+ # all biases will not be decayed
327
+ no_decay.add(fpn)
328
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
329
+ # weights of whitelist modules will be weight decayed
330
+ decay.add(fpn)
331
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
332
+ # weights of blacklist modules will NOT be weight decayed
333
+ no_decay.add(fpn)
334
+
335
+ # special case the position embedding parameter in the root GPT module as not decayed
336
+ no_decay.add('pos_emb')
337
+
338
+ # validate that we considered every parameter
339
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
340
+ inter_params = decay & no_decay
341
+ union_params = decay | no_decay
342
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
343
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
344
+ % (str(param_dict.keys() - union_params), )
345
+
346
+ # create the pytorch optimizer object
347
+ optim_groups = [
348
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
349
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
350
+ ]
351
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
352
+ return optimizer
taming/models/dummy_cond_stage.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+
3
+
4
+ class DummyCondStage:
5
+ def __init__(self, conditional_key):
6
+ self.conditional_key = conditional_key
7
+ self.train = None
8
+
9
+ def eval(self):
10
+ return self
11
+
12
+ @staticmethod
13
+ def encode(c: Tensor):
14
+ return c, None, (None, None, c)
15
+
16
+ @staticmethod
17
+ def decode(c: Tensor):
18
+ return c
19
+
20
+ @staticmethod
21
+ def to_rgb(c: Tensor):
22
+ return c
taming/models/vqgan.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import pytorch_lightning as pl
4
+
5
+ from main import instantiate_from_config
6
+
7
+ from taming.modules.diffusionmodules.model import Encoder, Decoder
8
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
9
+ from taming.modules.vqvae.quantize import GumbelQuantize
10
+ from taming.modules.vqvae.quantize import EMAVectorQuantizer
11
+
12
+ class VQModel(pl.LightningModule):
13
+ def __init__(self,
14
+ ddconfig,
15
+ lossconfig,
16
+ n_embed,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ remap=None,
24
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
25
+ ):
26
+ super().__init__()
27
+ self.image_key = image_key
28
+ self.encoder = Encoder(**ddconfig)
29
+ self.decoder = Decoder(**ddconfig)
30
+ self.loss = instantiate_from_config(lossconfig)
31
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
32
+ remap=remap, sane_index_shape=sane_index_shape)
33
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
34
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
+ if ckpt_path is not None:
36
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
37
+ self.image_key = image_key
38
+ if colorize_nlabels is not None:
39
+ assert type(colorize_nlabels)==int
40
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
41
+ if monitor is not None:
42
+ self.monitor = monitor
43
+
44
+ def init_from_ckpt(self, path, ignore_keys=list()):
45
+ sd = torch.load(path, map_location="cpu")["state_dict"]
46
+ keys = list(sd.keys())
47
+ for k in keys:
48
+ for ik in ignore_keys:
49
+ if k.startswith(ik):
50
+ print("Deleting key {} from state_dict.".format(k))
51
+ del sd[k]
52
+ self.load_state_dict(sd, strict=False)
53
+ print(f"Restored from {path}")
54
+
55
+ def encode(self, x):
56
+ h = self.encoder(x)
57
+ h = self.quant_conv(h)
58
+ quant, emb_loss, info = self.quantize(h)
59
+ return quant, emb_loss, info
60
+
61
+ def decode(self, quant):
62
+ quant = self.post_quant_conv(quant)
63
+ dec = self.decoder(quant)
64
+ return dec
65
+
66
+ def decode_code(self, code_b):
67
+ quant_b = self.quantize.embed_code(code_b)
68
+ dec = self.decode(quant_b)
69
+ return dec
70
+
71
+ def forward(self, input):
72
+ quant, diff, _ = self.encode(input)
73
+ dec = self.decode(quant)
74
+ return dec, diff
75
+
76
+ def get_input(self, batch, k):
77
+ x = batch[k]
78
+ if len(x.shape) == 3:
79
+ x = x[..., None]
80
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
81
+ return x.float()
82
+
83
+ def training_step(self, batch, batch_idx, optimizer_idx):
84
+ x = self.get_input(batch, self.image_key)
85
+ xrec, qloss = self(x)
86
+
87
+ if optimizer_idx == 0:
88
+ # autoencode
89
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
90
+ last_layer=self.get_last_layer(), split="train")
91
+
92
+ self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
93
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
94
+ return aeloss
95
+
96
+ if optimizer_idx == 1:
97
+ # discriminator
98
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
99
+ last_layer=self.get_last_layer(), split="train")
100
+ self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
101
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
102
+ return discloss
103
+
104
+ def validation_step(self, batch, batch_idx):
105
+ x = self.get_input(batch, self.image_key)
106
+ xrec, qloss = self(x)
107
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
108
+ last_layer=self.get_last_layer(), split="val")
109
+
110
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
111
+ last_layer=self.get_last_layer(), split="val")
112
+ rec_loss = log_dict_ae["val/rec_loss"]
113
+ self.log("val/rec_loss", rec_loss,
114
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
115
+ self.log("val/aeloss", aeloss,
116
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
117
+ self.log_dict(log_dict_ae)
118
+ self.log_dict(log_dict_disc)
119
+ return self.log_dict
120
+
121
+ def configure_optimizers(self):
122
+ lr = self.learning_rate
123
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
124
+ list(self.decoder.parameters())+
125
+ list(self.quantize.parameters())+
126
+ list(self.quant_conv.parameters())+
127
+ list(self.post_quant_conv.parameters()),
128
+ lr=lr, betas=(0.5, 0.9))
129
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
130
+ lr=lr, betas=(0.5, 0.9))
131
+ return [opt_ae, opt_disc], []
132
+
133
+ def get_last_layer(self):
134
+ return self.decoder.conv_out.weight
135
+
136
+ def log_images(self, batch, **kwargs):
137
+ log = dict()
138
+ x = self.get_input(batch, self.image_key)
139
+ x = x.to(self.device)
140
+ xrec, _ = self(x)
141
+ if x.shape[1] > 3:
142
+ # colorize with random projection
143
+ assert xrec.shape[1] > 3
144
+ x = self.to_rgb(x)
145
+ xrec = self.to_rgb(xrec)
146
+ log["inputs"] = x
147
+ log["reconstructions"] = xrec
148
+ return log
149
+
150
+ def to_rgb(self, x):
151
+ assert self.image_key == "segmentation"
152
+ if not hasattr(self, "colorize"):
153
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
154
+ x = F.conv2d(x, weight=self.colorize)
155
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
156
+ return x
157
+
158
+
159
+ class VQSegmentationModel(VQModel):
160
+ def __init__(self, n_labels, *args, **kwargs):
161
+ super().__init__(*args, **kwargs)
162
+ self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
163
+
164
+ def configure_optimizers(self):
165
+ lr = self.learning_rate
166
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
167
+ list(self.decoder.parameters())+
168
+ list(self.quantize.parameters())+
169
+ list(self.quant_conv.parameters())+
170
+ list(self.post_quant_conv.parameters()),
171
+ lr=lr, betas=(0.5, 0.9))
172
+ return opt_ae
173
+
174
+ def training_step(self, batch, batch_idx):
175
+ x = self.get_input(batch, self.image_key)
176
+ xrec, qloss = self(x)
177
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
178
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
179
+ return aeloss
180
+
181
+ def validation_step(self, batch, batch_idx):
182
+ x = self.get_input(batch, self.image_key)
183
+ xrec, qloss = self(x)
184
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
185
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
186
+ total_loss = log_dict_ae["val/total_loss"]
187
+ self.log("val/total_loss", total_loss,
188
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
189
+ return aeloss
190
+
191
+ @torch.no_grad()
192
+ def log_images(self, batch, **kwargs):
193
+ log = dict()
194
+ x = self.get_input(batch, self.image_key)
195
+ x = x.to(self.device)
196
+ xrec, _ = self(x)
197
+ if x.shape[1] > 3:
198
+ # colorize with random projection
199
+ assert xrec.shape[1] > 3
200
+ # convert logits to indices
201
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
202
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
203
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
204
+ x = self.to_rgb(x)
205
+ xrec = self.to_rgb(xrec)
206
+ log["inputs"] = x
207
+ log["reconstructions"] = xrec
208
+ return log
209
+
210
+
211
+ class VQNoDiscModel(VQModel):
212
+ def __init__(self,
213
+ ddconfig,
214
+ lossconfig,
215
+ n_embed,
216
+ embed_dim,
217
+ ckpt_path=None,
218
+ ignore_keys=[],
219
+ image_key="image",
220
+ colorize_nlabels=None
221
+ ):
222
+ super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
223
+ ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
224
+ colorize_nlabels=colorize_nlabels)
225
+
226
+ def training_step(self, batch, batch_idx):
227
+ x = self.get_input(batch, self.image_key)
228
+ xrec, qloss = self(x)
229
+ # autoencode
230
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
231
+ output = pl.TrainResult(minimize=aeloss)
232
+ output.log("train/aeloss", aeloss,
233
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
234
+ output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
235
+ return output
236
+
237
+ def validation_step(self, batch, batch_idx):
238
+ x = self.get_input(batch, self.image_key)
239
+ xrec, qloss = self(x)
240
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
241
+ rec_loss = log_dict_ae["val/rec_loss"]
242
+ output = pl.EvalResult(checkpoint_on=rec_loss)
243
+ output.log("val/rec_loss", rec_loss,
244
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
245
+ output.log("val/aeloss", aeloss,
246
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
247
+ output.log_dict(log_dict_ae)
248
+
249
+ return output
250
+
251
+ def configure_optimizers(self):
252
+ optimizer = torch.optim.Adam(list(self.encoder.parameters())+
253
+ list(self.decoder.parameters())+
254
+ list(self.quantize.parameters())+
255
+ list(self.quant_conv.parameters())+
256
+ list(self.post_quant_conv.parameters()),
257
+ lr=self.learning_rate, betas=(0.5, 0.9))
258
+ return optimizer
259
+
260
+
261
+ class GumbelVQ(VQModel):
262
+ def __init__(self,
263
+ ddconfig,
264
+ lossconfig,
265
+ n_embed,
266
+ embed_dim,
267
+ temperature_scheduler_config,
268
+ ckpt_path=None,
269
+ ignore_keys=[],
270
+ image_key="image",
271
+ colorize_nlabels=None,
272
+ monitor=None,
273
+ kl_weight=1e-8,
274
+ remap=None,
275
+ ):
276
+
277
+ z_channels = ddconfig["z_channels"]
278
+ super().__init__(ddconfig,
279
+ lossconfig,
280
+ n_embed,
281
+ embed_dim,
282
+ ckpt_path=None,
283
+ ignore_keys=ignore_keys,
284
+ image_key=image_key,
285
+ colorize_nlabels=colorize_nlabels,
286
+ monitor=monitor,
287
+ )
288
+
289
+ self.loss.n_classes = n_embed
290
+ self.vocab_size = n_embed
291
+
292
+ self.quantize = GumbelQuantize(z_channels, embed_dim,
293
+ n_embed=n_embed,
294
+ kl_weight=kl_weight, temp_init=1.0,
295
+ remap=remap)
296
+
297
+ self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
298
+
299
+ if ckpt_path is not None:
300
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
301
+
302
+ def temperature_scheduling(self):
303
+ self.quantize.temperature = self.temperature_scheduler(self.global_step)
304
+
305
+ def encode_to_prequant(self, x):
306
+ h = self.encoder(x)
307
+ h = self.quant_conv(h)
308
+ return h
309
+
310
+ def decode_code(self, code_b):
311
+ raise NotImplementedError
312
+
313
+ def training_step(self, batch, batch_idx, optimizer_idx):
314
+ self.temperature_scheduling()
315
+ x = self.get_input(batch, self.image_key)
316
+ xrec, qloss = self(x)
317
+
318
+ if optimizer_idx == 0:
319
+ # autoencode
320
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
321
+ last_layer=self.get_last_layer(), split="train")
322
+
323
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
324
+ self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
325
+ return aeloss
326
+
327
+ if optimizer_idx == 1:
328
+ # discriminator
329
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
330
+ last_layer=self.get_last_layer(), split="train")
331
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
332
+ return discloss
333
+
334
+ def validation_step(self, batch, batch_idx):
335
+ x = self.get_input(batch, self.image_key)
336
+ xrec, qloss = self(x, return_pred_indices=True)
337
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
338
+ last_layer=self.get_last_layer(), split="val")
339
+
340
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
341
+ last_layer=self.get_last_layer(), split="val")
342
+ rec_loss = log_dict_ae["val/rec_loss"]
343
+ self.log("val/rec_loss", rec_loss,
344
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
345
+ self.log("val/aeloss", aeloss,
346
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
347
+ self.log_dict(log_dict_ae)
348
+ self.log_dict(log_dict_disc)
349
+ return self.log_dict
350
+
351
+ def log_images(self, batch, **kwargs):
352
+ log = dict()
353
+ x = self.get_input(batch, self.image_key)
354
+ x = x.to(self.device)
355
+ # encode
356
+ h = self.encoder(x)
357
+ h = self.quant_conv(h)
358
+ quant, _, _ = self.quantize(h)
359
+ # decode
360
+ x_rec = self.decode(quant)
361
+ log["inputs"] = x
362
+ log["reconstructions"] = x_rec
363
+ return log
364
+
365
+
366
+ class EMAVQ(VQModel):
367
+ def __init__(self,
368
+ ddconfig,
369
+ lossconfig,
370
+ n_embed,
371
+ embed_dim,
372
+ ckpt_path=None,
373
+ ignore_keys=[],
374
+ image_key="image",
375
+ colorize_nlabels=None,
376
+ monitor=None,
377
+ remap=None,
378
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
379
+ ):
380
+ super().__init__(ddconfig,
381
+ lossconfig,
382
+ n_embed,
383
+ embed_dim,
384
+ ckpt_path=None,
385
+ ignore_keys=ignore_keys,
386
+ image_key=image_key,
387
+ colorize_nlabels=colorize_nlabels,
388
+ monitor=monitor,
389
+ )
390
+ self.quantize = EMAVectorQuantizer(n_embed=n_embed,
391
+ embedding_dim=embed_dim,
392
+ beta=0.25,
393
+ remap=remap)
394
+ def configure_optimizers(self):
395
+ lr = self.learning_rate
396
+ #Remove self.quantize from parameter list since it is updated via EMA
397
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
398
+ list(self.decoder.parameters())+
399
+ list(self.quant_conv.parameters())+
400
+ list(self.post_quant_conv.parameters()),
401
+ lr=lr, betas=(0.5, 0.9))
402
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
403
+ lr=lr, betas=(0.5, 0.9))
404
+ return [opt_ae, opt_disc], []
taming/modules/__pycache__/util.cpython-312.pyc ADDED
Binary file (7.4 kB). View file
 
taming/modules/diffusionmodules/__pycache__/model.cpython-312.pyc ADDED
Binary file (34.6 kB). View file
 
taming/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+
8
+ def get_timestep_embedding(timesteps, embedding_dim):
9
+ """
10
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
11
+ From Fairseq.
12
+ Build sinusoidal embeddings.
13
+ This matches the implementation in tensor2tensor, but differs slightly
14
+ from the description in Section 3.5 of "Attention Is All You Need".
15
+ """
16
+ assert len(timesteps.shape) == 1
17
+
18
+ half_dim = embedding_dim // 2
19
+ emb = math.log(10000) / (half_dim - 1)
20
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
21
+ emb = emb.to(device=timesteps.device)
22
+ emb = timesteps.float()[:, None] * emb[None, :]
23
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
24
+ if embedding_dim % 2 == 1: # zero pad
25
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
26
+ return emb
27
+
28
+
29
+ def nonlinearity(x):
30
+ # swish
31
+ return x*torch.sigmoid(x)
32
+
33
+
34
+ def Normalize(in_channels):
35
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
36
+
37
+
38
+ class Upsample(nn.Module):
39
+ def __init__(self, in_channels, with_conv):
40
+ super().__init__()
41
+ self.with_conv = with_conv
42
+ if self.with_conv:
43
+ self.conv = torch.nn.Conv2d(in_channels,
44
+ in_channels,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1)
48
+
49
+ def forward(self, x):
50
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
51
+ if self.with_conv:
52
+ x = self.conv(x)
53
+ return x
54
+
55
+
56
+ class Downsample(nn.Module):
57
+ def __init__(self, in_channels, with_conv):
58
+ super().__init__()
59
+ self.with_conv = with_conv
60
+ if self.with_conv:
61
+ # no asymmetric padding in torch conv, must do it ourselves
62
+ self.conv = torch.nn.Conv2d(in_channels,
63
+ in_channels,
64
+ kernel_size=3,
65
+ stride=2,
66
+ padding=0)
67
+
68
+ def forward(self, x):
69
+ if self.with_conv:
70
+ pad = (0,1,0,1)
71
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
72
+ x = self.conv(x)
73
+ else:
74
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
75
+ return x
76
+
77
+
78
+ class ResnetBlock(nn.Module):
79
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
80
+ dropout, temb_channels=512):
81
+ super().__init__()
82
+ self.in_channels = in_channels
83
+ out_channels = in_channels if out_channels is None else out_channels
84
+ self.out_channels = out_channels
85
+ self.use_conv_shortcut = conv_shortcut
86
+
87
+ self.norm1 = Normalize(in_channels)
88
+ self.conv1 = torch.nn.Conv2d(in_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1)
93
+ if temb_channels > 0:
94
+ self.temb_proj = torch.nn.Linear(temb_channels,
95
+ out_channels)
96
+ self.norm2 = Normalize(out_channels)
97
+ self.dropout = torch.nn.Dropout(dropout)
98
+ self.conv2 = torch.nn.Conv2d(out_channels,
99
+ out_channels,
100
+ kernel_size=3,
101
+ stride=1,
102
+ padding=1)
103
+ if self.in_channels != self.out_channels:
104
+ if self.use_conv_shortcut:
105
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
106
+ out_channels,
107
+ kernel_size=3,
108
+ stride=1,
109
+ padding=1)
110
+ else:
111
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
112
+ out_channels,
113
+ kernel_size=1,
114
+ stride=1,
115
+ padding=0)
116
+
117
+ def forward(self, x, temb):
118
+ h = x
119
+ h = self.norm1(h)
120
+ h = nonlinearity(h)
121
+ h = self.conv1(h)
122
+
123
+ if temb is not None:
124
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
125
+
126
+ h = self.norm2(h)
127
+ h = nonlinearity(h)
128
+ h = self.dropout(h)
129
+ h = self.conv2(h)
130
+
131
+ if self.in_channels != self.out_channels:
132
+ if self.use_conv_shortcut:
133
+ x = self.conv_shortcut(x)
134
+ else:
135
+ x = self.nin_shortcut(x)
136
+
137
+ return x+h
138
+
139
+
140
+ class AttnBlock(nn.Module):
141
+ def __init__(self, in_channels):
142
+ super().__init__()
143
+ self.in_channels = in_channels
144
+
145
+ self.norm = Normalize(in_channels)
146
+ self.q = torch.nn.Conv2d(in_channels,
147
+ in_channels,
148
+ kernel_size=1,
149
+ stride=1,
150
+ padding=0)
151
+ self.k = torch.nn.Conv2d(in_channels,
152
+ in_channels,
153
+ kernel_size=1,
154
+ stride=1,
155
+ padding=0)
156
+ self.v = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.proj_out = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b,c,h,w = q.shape
177
+ q = q.reshape(b,c,h*w)
178
+ q = q.permute(0,2,1) # b,hw,c
179
+ k = k.reshape(b,c,h*w) # b,c,hw
180
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
181
+ w_ = w_ * (int(c)**(-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = v.reshape(b,c,h*w)
186
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
187
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
188
+ h_ = h_.reshape(b,c,h,w)
189
+
190
+ h_ = self.proj_out(h_)
191
+
192
+ return x+h_
193
+
194
+
195
+ class Model(nn.Module):
196
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
197
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
198
+ resolution, use_timestep=True):
199
+ super().__init__()
200
+ self.ch = ch
201
+ self.temb_ch = self.ch*4
202
+ self.num_resolutions = len(ch_mult)
203
+ self.num_res_blocks = num_res_blocks
204
+ self.resolution = resolution
205
+ self.in_channels = in_channels
206
+
207
+ self.use_timestep = use_timestep
208
+ if self.use_timestep:
209
+ # timestep embedding
210
+ self.temb = nn.Module()
211
+ self.temb.dense = nn.ModuleList([
212
+ torch.nn.Linear(self.ch,
213
+ self.temb_ch),
214
+ torch.nn.Linear(self.temb_ch,
215
+ self.temb_ch),
216
+ ])
217
+
218
+ # downsampling
219
+ self.conv_in = torch.nn.Conv2d(in_channels,
220
+ self.ch,
221
+ kernel_size=3,
222
+ stride=1,
223
+ padding=1)
224
+
225
+ curr_res = resolution
226
+ in_ch_mult = (1,)+tuple(ch_mult)
227
+ self.down = nn.ModuleList()
228
+ for i_level in range(self.num_resolutions):
229
+ block = nn.ModuleList()
230
+ attn = nn.ModuleList()
231
+ block_in = ch*in_ch_mult[i_level]
232
+ block_out = ch*ch_mult[i_level]
233
+ for i_block in range(self.num_res_blocks):
234
+ block.append(ResnetBlock(in_channels=block_in,
235
+ out_channels=block_out,
236
+ temb_channels=self.temb_ch,
237
+ dropout=dropout))
238
+ block_in = block_out
239
+ if curr_res in attn_resolutions:
240
+ attn.append(AttnBlock(block_in))
241
+ down = nn.Module()
242
+ down.block = block
243
+ down.attn = attn
244
+ if i_level != self.num_resolutions-1:
245
+ down.downsample = Downsample(block_in, resamp_with_conv)
246
+ curr_res = curr_res // 2
247
+ self.down.append(down)
248
+
249
+ # middle
250
+ self.mid = nn.Module()
251
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
252
+ out_channels=block_in,
253
+ temb_channels=self.temb_ch,
254
+ dropout=dropout)
255
+ self.mid.attn_1 = AttnBlock(block_in)
256
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
257
+ out_channels=block_in,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout)
260
+
261
+ # upsampling
262
+ self.up = nn.ModuleList()
263
+ for i_level in reversed(range(self.num_resolutions)):
264
+ block = nn.ModuleList()
265
+ attn = nn.ModuleList()
266
+ block_out = ch*ch_mult[i_level]
267
+ skip_in = ch*ch_mult[i_level]
268
+ for i_block in range(self.num_res_blocks+1):
269
+ if i_block == self.num_res_blocks:
270
+ skip_in = ch*in_ch_mult[i_level]
271
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
272
+ out_channels=block_out,
273
+ temb_channels=self.temb_ch,
274
+ dropout=dropout))
275
+ block_in = block_out
276
+ if curr_res in attn_resolutions:
277
+ attn.append(AttnBlock(block_in))
278
+ up = nn.Module()
279
+ up.block = block
280
+ up.attn = attn
281
+ if i_level != 0:
282
+ up.upsample = Upsample(block_in, resamp_with_conv)
283
+ curr_res = curr_res * 2
284
+ self.up.insert(0, up) # prepend to get consistent order
285
+
286
+ # end
287
+ self.norm_out = Normalize(block_in)
288
+ self.conv_out = torch.nn.Conv2d(block_in,
289
+ out_ch,
290
+ kernel_size=3,
291
+ stride=1,
292
+ padding=1)
293
+
294
+
295
+ def forward(self, x, t=None):
296
+ #assert x.shape[2] == x.shape[3] == self.resolution
297
+
298
+ if self.use_timestep:
299
+ # timestep embedding
300
+ assert t is not None
301
+ temb = get_timestep_embedding(t, self.ch)
302
+ temb = self.temb.dense[0](temb)
303
+ temb = nonlinearity(temb)
304
+ temb = self.temb.dense[1](temb)
305
+ else:
306
+ temb = None
307
+
308
+ # downsampling
309
+ hs = [self.conv_in(x)]
310
+ for i_level in range(self.num_resolutions):
311
+ for i_block in range(self.num_res_blocks):
312
+ h = self.down[i_level].block[i_block](hs[-1], temb)
313
+ if len(self.down[i_level].attn) > 0:
314
+ h = self.down[i_level].attn[i_block](h)
315
+ hs.append(h)
316
+ if i_level != self.num_resolutions-1:
317
+ hs.append(self.down[i_level].downsample(hs[-1]))
318
+
319
+ # middle
320
+ h = hs[-1]
321
+ h = self.mid.block_1(h, temb)
322
+ h = self.mid.attn_1(h)
323
+ h = self.mid.block_2(h, temb)
324
+
325
+ # upsampling
326
+ for i_level in reversed(range(self.num_resolutions)):
327
+ for i_block in range(self.num_res_blocks+1):
328
+ h = self.up[i_level].block[i_block](
329
+ torch.cat([h, hs.pop()], dim=1), temb)
330
+ if len(self.up[i_level].attn) > 0:
331
+ h = self.up[i_level].attn[i_block](h)
332
+ if i_level != 0:
333
+ h = self.up[i_level].upsample(h)
334
+
335
+ # end
336
+ h = self.norm_out(h)
337
+ h = nonlinearity(h)
338
+ h = self.conv_out(h)
339
+ return h
340
+
341
+
342
+ class Encoder(nn.Module):
343
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
344
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
345
+ resolution, z_channels, double_z=True, **ignore_kwargs):
346
+ super().__init__()
347
+ self.ch = ch
348
+ self.temb_ch = 0
349
+ self.num_resolutions = len(ch_mult)
350
+ self.num_res_blocks = num_res_blocks
351
+ self.resolution = resolution
352
+ self.in_channels = in_channels
353
+
354
+ # downsampling
355
+ self.conv_in = torch.nn.Conv2d(in_channels,
356
+ self.ch,
357
+ kernel_size=3,
358
+ stride=1,
359
+ padding=1)
360
+
361
+ curr_res = resolution
362
+ in_ch_mult = (1,)+tuple(ch_mult)
363
+ self.down = nn.ModuleList()
364
+ for i_level in range(self.num_resolutions):
365
+ block = nn.ModuleList()
366
+ attn = nn.ModuleList()
367
+ block_in = ch*in_ch_mult[i_level]
368
+ block_out = ch*ch_mult[i_level]
369
+ for i_block in range(self.num_res_blocks):
370
+ block.append(ResnetBlock(in_channels=block_in,
371
+ out_channels=block_out,
372
+ temb_channels=self.temb_ch,
373
+ dropout=dropout))
374
+ block_in = block_out
375
+ if curr_res in attn_resolutions:
376
+ attn.append(AttnBlock(block_in))
377
+ down = nn.Module()
378
+ down.block = block
379
+ down.attn = attn
380
+ if i_level != self.num_resolutions-1:
381
+ down.downsample = Downsample(block_in, resamp_with_conv)
382
+ curr_res = curr_res // 2
383
+ self.down.append(down)
384
+
385
+ # middle
386
+ self.mid = nn.Module()
387
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
388
+ out_channels=block_in,
389
+ temb_channels=self.temb_ch,
390
+ dropout=dropout)
391
+ self.mid.attn_1 = AttnBlock(block_in)
392
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
393
+ out_channels=block_in,
394
+ temb_channels=self.temb_ch,
395
+ dropout=dropout)
396
+
397
+ # end
398
+ self.norm_out = Normalize(block_in)
399
+ self.conv_out = torch.nn.Conv2d(block_in,
400
+ 2*z_channels if double_z else z_channels,
401
+ kernel_size=3,
402
+ stride=1,
403
+ padding=1)
404
+
405
+
406
+ def forward(self, x):
407
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
408
+
409
+ # timestep embedding
410
+ temb = None
411
+
412
+ # downsampling
413
+ hs = [self.conv_in(x)]
414
+ for i_level in range(self.num_resolutions):
415
+ for i_block in range(self.num_res_blocks):
416
+ h = self.down[i_level].block[i_block](hs[-1], temb)
417
+ if len(self.down[i_level].attn) > 0:
418
+ h = self.down[i_level].attn[i_block](h)
419
+ hs.append(h)
420
+ if i_level != self.num_resolutions-1:
421
+ hs.append(self.down[i_level].downsample(hs[-1]))
422
+
423
+ # middle
424
+ h = hs[-1]
425
+ h = self.mid.block_1(h, temb)
426
+ h = self.mid.attn_1(h)
427
+ h = self.mid.block_2(h, temb)
428
+
429
+ # end
430
+ h = self.norm_out(h)
431
+ h = nonlinearity(h)
432
+ h = self.conv_out(h)
433
+ return h
434
+
435
+
436
+ class Decoder(nn.Module):
437
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
438
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
439
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
440
+ super().__init__()
441
+ self.ch = ch
442
+ self.temb_ch = 0
443
+ self.num_resolutions = len(ch_mult)
444
+ self.num_res_blocks = num_res_blocks
445
+ self.resolution = resolution
446
+ self.in_channels = in_channels
447
+ self.give_pre_end = give_pre_end
448
+
449
+ # compute in_ch_mult, block_in and curr_res at lowest res
450
+ in_ch_mult = (1,)+tuple(ch_mult)
451
+ block_in = ch*ch_mult[self.num_resolutions-1]
452
+ curr_res = resolution // 2**(self.num_resolutions-1)
453
+ self.z_shape = (1,z_channels,curr_res,curr_res)
454
+ print("Working with z of shape {} = {} dimensions.".format(
455
+ self.z_shape, np.prod(self.z_shape)))
456
+
457
+ # z to block_in
458
+ self.conv_in = torch.nn.Conv2d(z_channels,
459
+ block_in,
460
+ kernel_size=3,
461
+ stride=1,
462
+ padding=1)
463
+
464
+ # middle
465
+ self.mid = nn.Module()
466
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
467
+ out_channels=block_in,
468
+ temb_channels=self.temb_ch,
469
+ dropout=dropout)
470
+ self.mid.attn_1 = AttnBlock(block_in)
471
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
472
+ out_channels=block_in,
473
+ temb_channels=self.temb_ch,
474
+ dropout=dropout)
475
+
476
+ # upsampling
477
+ self.up = nn.ModuleList()
478
+ for i_level in reversed(range(self.num_resolutions)):
479
+ block = nn.ModuleList()
480
+ attn = nn.ModuleList()
481
+ block_out = ch*ch_mult[i_level]
482
+ for i_block in range(self.num_res_blocks+1):
483
+ block.append(ResnetBlock(in_channels=block_in,
484
+ out_channels=block_out,
485
+ temb_channels=self.temb_ch,
486
+ dropout=dropout))
487
+ block_in = block_out
488
+ if curr_res in attn_resolutions:
489
+ attn.append(AttnBlock(block_in))
490
+ up = nn.Module()
491
+ up.block = block
492
+ up.attn = attn
493
+ if i_level != 0:
494
+ up.upsample = Upsample(block_in, resamp_with_conv)
495
+ curr_res = curr_res * 2
496
+ self.up.insert(0, up) # prepend to get consistent order
497
+
498
+ # end
499
+ self.norm_out = Normalize(block_in)
500
+ self.conv_out = torch.nn.Conv2d(block_in,
501
+ out_ch,
502
+ kernel_size=3,
503
+ stride=1,
504
+ padding=1)
505
+
506
+ def forward(self, z):
507
+ #assert z.shape[1:] == self.z_shape[1:]
508
+ self.last_z_shape = z.shape
509
+
510
+ # timestep embedding
511
+ temb = None
512
+
513
+ # z to block_in
514
+ h = self.conv_in(z)
515
+
516
+ # middle
517
+ h = self.mid.block_1(h, temb)
518
+ h = self.mid.attn_1(h)
519
+ h = self.mid.block_2(h, temb)
520
+
521
+ # upsampling
522
+ for i_level in reversed(range(self.num_resolutions)):
523
+ for i_block in range(self.num_res_blocks+1):
524
+ h = self.up[i_level].block[i_block](h, temb)
525
+ if len(self.up[i_level].attn) > 0:
526
+ h = self.up[i_level].attn[i_block](h)
527
+ if i_level != 0:
528
+ h = self.up[i_level].upsample(h)
529
+
530
+ # end
531
+ if self.give_pre_end:
532
+ return h
533
+
534
+ h = self.norm_out(h)
535
+ h = nonlinearity(h)
536
+ h = self.conv_out(h)
537
+ return h
538
+
539
+
540
+ class VUNet(nn.Module):
541
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
542
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
543
+ in_channels, c_channels,
544
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
545
+ super().__init__()
546
+ self.ch = ch
547
+ self.temb_ch = self.ch*4
548
+ self.num_resolutions = len(ch_mult)
549
+ self.num_res_blocks = num_res_blocks
550
+ self.resolution = resolution
551
+
552
+ self.use_timestep = use_timestep
553
+ if self.use_timestep:
554
+ # timestep embedding
555
+ self.temb = nn.Module()
556
+ self.temb.dense = nn.ModuleList([
557
+ torch.nn.Linear(self.ch,
558
+ self.temb_ch),
559
+ torch.nn.Linear(self.temb_ch,
560
+ self.temb_ch),
561
+ ])
562
+
563
+ # downsampling
564
+ self.conv_in = torch.nn.Conv2d(c_channels,
565
+ self.ch,
566
+ kernel_size=3,
567
+ stride=1,
568
+ padding=1)
569
+
570
+ curr_res = resolution
571
+ in_ch_mult = (1,)+tuple(ch_mult)
572
+ self.down = nn.ModuleList()
573
+ for i_level in range(self.num_resolutions):
574
+ block = nn.ModuleList()
575
+ attn = nn.ModuleList()
576
+ block_in = ch*in_ch_mult[i_level]
577
+ block_out = ch*ch_mult[i_level]
578
+ for i_block in range(self.num_res_blocks):
579
+ block.append(ResnetBlock(in_channels=block_in,
580
+ out_channels=block_out,
581
+ temb_channels=self.temb_ch,
582
+ dropout=dropout))
583
+ block_in = block_out
584
+ if curr_res in attn_resolutions:
585
+ attn.append(AttnBlock(block_in))
586
+ down = nn.Module()
587
+ down.block = block
588
+ down.attn = attn
589
+ if i_level != self.num_resolutions-1:
590
+ down.downsample = Downsample(block_in, resamp_with_conv)
591
+ curr_res = curr_res // 2
592
+ self.down.append(down)
593
+
594
+ self.z_in = torch.nn.Conv2d(z_channels,
595
+ block_in,
596
+ kernel_size=1,
597
+ stride=1,
598
+ padding=0)
599
+ # middle
600
+ self.mid = nn.Module()
601
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
602
+ out_channels=block_in,
603
+ temb_channels=self.temb_ch,
604
+ dropout=dropout)
605
+ self.mid.attn_1 = AttnBlock(block_in)
606
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
607
+ out_channels=block_in,
608
+ temb_channels=self.temb_ch,
609
+ dropout=dropout)
610
+
611
+ # upsampling
612
+ self.up = nn.ModuleList()
613
+ for i_level in reversed(range(self.num_resolutions)):
614
+ block = nn.ModuleList()
615
+ attn = nn.ModuleList()
616
+ block_out = ch*ch_mult[i_level]
617
+ skip_in = ch*ch_mult[i_level]
618
+ for i_block in range(self.num_res_blocks+1):
619
+ if i_block == self.num_res_blocks:
620
+ skip_in = ch*in_ch_mult[i_level]
621
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
622
+ out_channels=block_out,
623
+ temb_channels=self.temb_ch,
624
+ dropout=dropout))
625
+ block_in = block_out
626
+ if curr_res in attn_resolutions:
627
+ attn.append(AttnBlock(block_in))
628
+ up = nn.Module()
629
+ up.block = block
630
+ up.attn = attn
631
+ if i_level != 0:
632
+ up.upsample = Upsample(block_in, resamp_with_conv)
633
+ curr_res = curr_res * 2
634
+ self.up.insert(0, up) # prepend to get consistent order
635
+
636
+ # end
637
+ self.norm_out = Normalize(block_in)
638
+ self.conv_out = torch.nn.Conv2d(block_in,
639
+ out_ch,
640
+ kernel_size=3,
641
+ stride=1,
642
+ padding=1)
643
+
644
+
645
+ def forward(self, x, z):
646
+ #assert x.shape[2] == x.shape[3] == self.resolution
647
+
648
+ if self.use_timestep:
649
+ # timestep embedding
650
+ assert t is not None
651
+ temb = get_timestep_embedding(t, self.ch)
652
+ temb = self.temb.dense[0](temb)
653
+ temb = nonlinearity(temb)
654
+ temb = self.temb.dense[1](temb)
655
+ else:
656
+ temb = None
657
+
658
+ # downsampling
659
+ hs = [self.conv_in(x)]
660
+ for i_level in range(self.num_resolutions):
661
+ for i_block in range(self.num_res_blocks):
662
+ h = self.down[i_level].block[i_block](hs[-1], temb)
663
+ if len(self.down[i_level].attn) > 0:
664
+ h = self.down[i_level].attn[i_block](h)
665
+ hs.append(h)
666
+ if i_level != self.num_resolutions-1:
667
+ hs.append(self.down[i_level].downsample(hs[-1]))
668
+
669
+ # middle
670
+ h = hs[-1]
671
+ z = self.z_in(z)
672
+ h = torch.cat((h,z),dim=1)
673
+ h = self.mid.block_1(h, temb)
674
+ h = self.mid.attn_1(h)
675
+ h = self.mid.block_2(h, temb)
676
+
677
+ # upsampling
678
+ for i_level in reversed(range(self.num_resolutions)):
679
+ for i_block in range(self.num_res_blocks+1):
680
+ h = self.up[i_level].block[i_block](
681
+ torch.cat([h, hs.pop()], dim=1), temb)
682
+ if len(self.up[i_level].attn) > 0:
683
+ h = self.up[i_level].attn[i_block](h)
684
+ if i_level != 0:
685
+ h = self.up[i_level].upsample(h)
686
+
687
+ # end
688
+ h = self.norm_out(h)
689
+ h = nonlinearity(h)
690
+ h = self.conv_out(h)
691
+ return h
692
+
693
+
694
+ class SimpleDecoder(nn.Module):
695
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
696
+ super().__init__()
697
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
698
+ ResnetBlock(in_channels=in_channels,
699
+ out_channels=2 * in_channels,
700
+ temb_channels=0, dropout=0.0),
701
+ ResnetBlock(in_channels=2 * in_channels,
702
+ out_channels=4 * in_channels,
703
+ temb_channels=0, dropout=0.0),
704
+ ResnetBlock(in_channels=4 * in_channels,
705
+ out_channels=2 * in_channels,
706
+ temb_channels=0, dropout=0.0),
707
+ nn.Conv2d(2*in_channels, in_channels, 1),
708
+ Upsample(in_channels, with_conv=True)])
709
+ # end
710
+ self.norm_out = Normalize(in_channels)
711
+ self.conv_out = torch.nn.Conv2d(in_channels,
712
+ out_channels,
713
+ kernel_size=3,
714
+ stride=1,
715
+ padding=1)
716
+
717
+ def forward(self, x):
718
+ for i, layer in enumerate(self.model):
719
+ if i in [1,2,3]:
720
+ x = layer(x, None)
721
+ else:
722
+ x = layer(x)
723
+
724
+ h = self.norm_out(x)
725
+ h = nonlinearity(h)
726
+ x = self.conv_out(h)
727
+ return x
728
+
729
+
730
+ class UpsampleDecoder(nn.Module):
731
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
732
+ ch_mult=(2,2), dropout=0.0):
733
+ super().__init__()
734
+ # upsampling
735
+ self.temb_ch = 0
736
+ self.num_resolutions = len(ch_mult)
737
+ self.num_res_blocks = num_res_blocks
738
+ block_in = in_channels
739
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
740
+ self.res_blocks = nn.ModuleList()
741
+ self.upsample_blocks = nn.ModuleList()
742
+ for i_level in range(self.num_resolutions):
743
+ res_block = []
744
+ block_out = ch * ch_mult[i_level]
745
+ for i_block in range(self.num_res_blocks + 1):
746
+ res_block.append(ResnetBlock(in_channels=block_in,
747
+ out_channels=block_out,
748
+ temb_channels=self.temb_ch,
749
+ dropout=dropout))
750
+ block_in = block_out
751
+ self.res_blocks.append(nn.ModuleList(res_block))
752
+ if i_level != self.num_resolutions - 1:
753
+ self.upsample_blocks.append(Upsample(block_in, True))
754
+ curr_res = curr_res * 2
755
+
756
+ # end
757
+ self.norm_out = Normalize(block_in)
758
+ self.conv_out = torch.nn.Conv2d(block_in,
759
+ out_channels,
760
+ kernel_size=3,
761
+ stride=1,
762
+ padding=1)
763
+
764
+ def forward(self, x):
765
+ # upsampling
766
+ h = x
767
+ for k, i_level in enumerate(range(self.num_resolutions)):
768
+ for i_block in range(self.num_res_blocks + 1):
769
+ h = self.res_blocks[i_level][i_block](h, None)
770
+ if i_level != self.num_resolutions - 1:
771
+ h = self.upsample_blocks[k](h)
772
+ h = self.norm_out(h)
773
+ h = nonlinearity(h)
774
+ h = self.conv_out(h)
775
+ return h
776
+
taming/modules/discriminator/__pycache__/model.cpython-312.pyc ADDED
Binary file (3.81 kB). View file
 
taming/modules/discriminator/model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.nn as nn
3
+
4
+
5
+ from taming.modules.util import ActNorm
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find('Conv') != -1:
11
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
12
+ elif classname.find('BatchNorm') != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22
+ """Construct a PatchGAN discriminator
23
+ Parameters:
24
+ input_nc (int) -- the number of channels in input images
25
+ ndf (int) -- the number of filters in the last conv layer
26
+ n_layers (int) -- the number of conv layers in the discriminator
27
+ norm_layer -- normalization layer
28
+ """
29
+ super(NLayerDiscriminator, self).__init__()
30
+ if not use_actnorm:
31
+ norm_layer = nn.BatchNorm2d
32
+ else:
33
+ norm_layer = ActNorm
34
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35
+ use_bias = norm_layer.func != nn.BatchNorm2d
36
+ else:
37
+ use_bias = norm_layer != nn.BatchNorm2d
38
+
39
+ kw = 4
40
+ padw = 1
41
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42
+ nf_mult = 1
43
+ nf_mult_prev = 1
44
+ for n in range(1, n_layers): # gradually increase the number of filters
45
+ nf_mult_prev = nf_mult
46
+ nf_mult = min(2 ** n, 8)
47
+ sequence += [
48
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49
+ norm_layer(ndf * nf_mult),
50
+ nn.LeakyReLU(0.2, True)
51
+ ]
52
+
53
+ nf_mult_prev = nf_mult
54
+ nf_mult = min(2 ** n_layers, 8)
55
+ sequence += [
56
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57
+ norm_layer(ndf * nf_mult),
58
+ nn.LeakyReLU(0.2, True)
59
+ ]
60
+
61
+ sequence += [
62
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63
+ self.main = nn.Sequential(*sequence)
64
+
65
+ def forward(self, input):
66
+ """Standard forward."""
67
+ return self.main(input)
taming/modules/losses/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from taming.modules.losses.vqperceptual import DummyLoss
2
+
taming/modules/losses/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (253 Bytes). View file