sejamenath2023 commited on
Commit
239ee43
1 Parent(s): 19c0ae1

Upload 12 files

Browse files
Files changed (12) hide show
  1. __init__.py +21 -0
  2. cli.py +180 -0
  3. configs.py +181 -0
  4. data.py +137 -0
  5. default_config.json +50 -0
  6. elucidated_imagen.py +940 -0
  7. imagen_pytorch.py +2731 -0
  8. imagen_video.py +1935 -0
  9. t5.py +119 -0
  10. trainer.py +992 -0
  11. utils.py +61 -0
  12. version.py +1 -0
__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from imagen_pytorch.imagen_pytorch import Imagen, Unet
2
+ from imagen_pytorch.imagen_pytorch import NullUnet
3
+ from imagen_pytorch.imagen_pytorch import BaseUnet64, SRUnet256, SRUnet1024
4
+ from imagen_pytorch.trainer import ImagenTrainer
5
+ from imagen_pytorch.version import __version__
6
+
7
+ # imagen using the elucidated ddpm from Tero Karras' new paper
8
+
9
+ from imagen_pytorch.elucidated_imagen import ElucidatedImagen
10
+
11
+ # config driven creation of imagen instances
12
+
13
+ from imagen_pytorch.configs import UnetConfig, ImagenConfig, ElucidatedImagenConfig, ImagenTrainerConfig
14
+
15
+ # utils
16
+
17
+ from imagen_pytorch.utils import load_imagen_from_checkpoint
18
+
19
+ # video
20
+
21
+ from imagen_pytorch.imagen_video import Unet3D
cli.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import torch
3
+ from pathlib import Path
4
+ import pkgutil
5
+
6
+ from imagen_pytorch import load_imagen_from_checkpoint
7
+ from imagen_pytorch.version import __version__
8
+ from imagen_pytorch.data import Collator
9
+ from imagen_pytorch.utils import safeget
10
+ from imagen_pytorch import ImagenTrainer, ElucidatedImagenConfig, ImagenConfig
11
+ from datasets import load_dataset
12
+
13
+ import json
14
+
15
+ def exists(val):
16
+ return val is not None
17
+
18
+ def simple_slugify(text, max_length = 255):
19
+ return text.replace('-', '_').replace(',', '').replace(' ', '_').replace('|', '--').strip('-_')[:max_length]
20
+
21
+ def main():
22
+ pass
23
+
24
+ @click.group()
25
+ def imagen():
26
+ pass
27
+
28
+ @imagen.command(help = 'Sample from the Imagen model checkpoint')
29
+ @click.option('--model', default = './imagen.pt', help = 'path to trained Imagen model')
30
+ @click.option('--cond_scale', default = 5, help = 'conditioning scale (classifier free guidance) in decoder')
31
+ @click.option('--load_ema', default = True, help = 'load EMA version of unets if available')
32
+ @click.argument('text')
33
+ def sample(
34
+ model,
35
+ cond_scale,
36
+ load_ema,
37
+ text
38
+ ):
39
+ model_path = Path(model)
40
+ full_model_path = str(model_path.resolve())
41
+ assert model_path.exists(), f'model not found at {full_model_path}'
42
+ loaded = torch.load(str(model_path))
43
+
44
+ # get version
45
+
46
+ version = safeget(loaded, 'version')
47
+ print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
48
+
49
+ # get imagen parameters and type
50
+
51
+ imagen = load_imagen_from_checkpoint(str(model_path), load_ema_if_available = load_ema)
52
+ imagen.cuda()
53
+
54
+ # generate image
55
+
56
+ pil_image = imagen.sample(text, cond_scale = cond_scale, return_pil_images = True)
57
+
58
+ image_path = f'./{simple_slugify(text)}.png'
59
+ pil_image[0].save(image_path)
60
+
61
+ print(f'image saved to {str(image_path)}')
62
+ return
63
+
64
+ @imagen.command(help = 'Generate a config for the Imagen model')
65
+ @click.option('--path', default = './imagen_config.json', help = 'Path to the Imagen model config')
66
+ def config(
67
+ path
68
+ ):
69
+ data = pkgutil.get_data(__name__, 'default_config.json').decode("utf-8")
70
+ with open(path, 'w') as f:
71
+ f.write(data)
72
+
73
+ @imagen.command(help = 'Train the Imagen model')
74
+ @click.option('--config', default = './imagen_config.json', help = 'Path to the Imagen model config')
75
+ @click.option('--unet', default = 1, help = 'Unet to train', type = click.IntRange(1, 3, False, True, True))
76
+ @click.option('--epoches', default = 1000, help = 'Amount of epoches to train for')
77
+ @click.option('--text', required = False, help = 'Text to sample with between epoches', type=str)
78
+ @click.option('--valid', is_flag = False, flag_value=50, default = 0, help = 'Do validation between epoches', show_default = True)
79
+ def train(
80
+ config,
81
+ unet,
82
+ epoches,
83
+ text,
84
+ valid
85
+ ):
86
+ # check config path
87
+
88
+ config_path = Path(config)
89
+ full_config_path = str(config_path.resolve())
90
+ assert config_path.exists(), f'config not found at {full_config_path}'
91
+
92
+ with open(config_path, 'r') as f:
93
+ config_data = json.loads(f.read())
94
+
95
+ assert 'checkpoint_path' in config_data, 'checkpoint path not found in config'
96
+
97
+ model_path = Path(config_data['checkpoint_path'])
98
+ full_model_path = str(model_path.resolve())
99
+
100
+ # setup imagen config
101
+
102
+ imagen_config_klass = ElucidatedImagenConfig if config_data['type'] == 'elucidated' else ImagenConfig
103
+ imagen = imagen_config_klass(**config_data['imagen']).create()
104
+
105
+ trainer = ImagenTrainer(
106
+ imagen = imagen,
107
+ **config_data['trainer']
108
+ )
109
+
110
+ # load pt
111
+ if model_path.exists():
112
+ loaded = torch.load(str(model_path))
113
+ version = safeget(loaded, 'version')
114
+ print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
115
+ trainer.load(model_path)
116
+
117
+ if torch.cuda.is_available():
118
+ trainer = trainer.cuda()
119
+
120
+ size = config_data['imagen']['image_sizes'][unet-1]
121
+
122
+ max_batch_size = config_data['max_batch_size'] if 'max_batch_size' in config_data else 1
123
+
124
+ channels = 'RGB'
125
+ if 'channels' in config_data['imagen']:
126
+ assert config_data['imagen']['channels'] > 0 and config_data['imagen']['channels'] < 5, 'Imagen only support 1 to 4 channels L, LA, RGB, RGBA'
127
+ if config_data['imagen']['channels'] == 4:
128
+ channels = 'RGBA' # Color with alpha
129
+ elif config_data['imagen']['channels'] == 2:
130
+ channels == 'LA' # Luminance (Greyscale) with alpha
131
+ elif config_data['imagen']['channels'] == 1:
132
+ channels = 'L' # Luminance (Greyscale)
133
+
134
+
135
+ assert 'batch_size' in config_data['dataset'], 'A batch_size is required in the config file'
136
+
137
+ # load and add train dataset and valid dataset
138
+ ds = load_dataset(config_data['dataset_name'])
139
+ trainer.add_train_dataset(
140
+ ds = ds['train'],
141
+ collate_fn = Collator(
142
+ image_size = size,
143
+ image_label = config_data['image_label'],
144
+ text_label = config_data['text_label'],
145
+ url_label = config_data['url_label'],
146
+ name = imagen.text_encoder_name,
147
+ channels = channels
148
+ ),
149
+ **config_data['dataset']
150
+ )
151
+
152
+
153
+ if not trainer.split_valid_from_train and valid != 0:
154
+ assert 'valid' in ds, 'There is no validation split in the dataset'
155
+ trainer.add_valid_dataset(
156
+ ds = ds['valid'],
157
+ collate_fn = Collator(
158
+ image_size = size,
159
+ image_label = config_data['image_label'],
160
+ text_label= config_data['text_label'],
161
+ url_label = config_data['url_label'],
162
+ name = imagen.text_encoder_name,
163
+ channels = channels
164
+ ),
165
+ **config_data['dataset']
166
+ )
167
+
168
+ for i in range(epoches):
169
+ loss = trainer.train_step(unet_number = unet, max_batch_size = max_batch_size)
170
+ print(f'loss: {loss}')
171
+
172
+ if valid != 0 and not (i % valid) and i > 0:
173
+ valid_loss = trainer.valid_step(unet_number = unet, max_batch_size = max_batch_size)
174
+ print(f'valid loss: {valid_loss}')
175
+
176
+ if not (i % 100) and i > 0 and trainer.is_main and text is not None:
177
+ images = trainer.sample(texts = [text], batch_size = 1, return_pil_images = True, stop_at_unet_number = unet)
178
+ images[0].save(f'./sample-{i // 100}.png')
179
+
180
+ trainer.save(model_path)
configs.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pydantic import BaseModel, validator
3
+ from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
4
+ from enum import Enum
5
+
6
+ from imagen_pytorch.imagen_pytorch import Imagen, Unet, Unet3D, NullUnet
7
+ from imagen_pytorch.trainer import ImagenTrainer
8
+ from imagen_pytorch.elucidated_imagen import ElucidatedImagen
9
+ from imagen_pytorch.t5 import DEFAULT_T5_NAME, get_encoded_dim
10
+
11
+ # helper functions
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+ def default(val, d):
17
+ return val if exists(val) else d
18
+
19
+ def ListOrTuple(inner_type):
20
+ return Union[List[inner_type], Tuple[inner_type]]
21
+
22
+ def SingleOrList(inner_type):
23
+ return Union[inner_type, ListOrTuple(inner_type)]
24
+
25
+ # noise schedule
26
+
27
+ class NoiseSchedule(Enum):
28
+ cosine = 'cosine'
29
+ linear = 'linear'
30
+
31
+ class AllowExtraBaseModel(BaseModel):
32
+ class Config:
33
+ extra = "allow"
34
+ use_enum_values = True
35
+
36
+ # imagen pydantic classes
37
+
38
+ class NullUnetConfig(BaseModel):
39
+ is_null: bool
40
+
41
+ def create(self):
42
+ return NullUnet()
43
+
44
+ class UnetConfig(AllowExtraBaseModel):
45
+ dim: int
46
+ dim_mults: ListOrTuple(int)
47
+ text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
48
+ cond_dim: int = None
49
+ channels: int = 3
50
+ attn_dim_head: int = 32
51
+ attn_heads: int = 16
52
+
53
+ def create(self):
54
+ return Unet(**self.dict())
55
+
56
+ class Unet3DConfig(AllowExtraBaseModel):
57
+ dim: int
58
+ dim_mults: ListOrTuple(int)
59
+ text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
60
+ cond_dim: int = None
61
+ channels: int = 3
62
+ attn_dim_head: int = 32
63
+ attn_heads: int = 16
64
+
65
+ def create(self):
66
+ return Unet3D(**self.dict())
67
+
68
+ class ImagenConfig(AllowExtraBaseModel):
69
+ unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
70
+ image_sizes: ListOrTuple(int)
71
+ video: bool = False
72
+ timesteps: SingleOrList(int) = 1000
73
+ noise_schedules: SingleOrList(NoiseSchedule) = 'cosine'
74
+ text_encoder_name: str = DEFAULT_T5_NAME
75
+ channels: int = 3
76
+ loss_type: str = 'l2'
77
+ cond_drop_prob: float = 0.5
78
+
79
+ @validator('image_sizes')
80
+ def check_image_sizes(cls, image_sizes, values):
81
+ unets = values.get('unets')
82
+ if len(image_sizes) != len(unets):
83
+ raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
84
+ return image_sizes
85
+
86
+ def create(self):
87
+ decoder_kwargs = self.dict()
88
+ unets_kwargs = decoder_kwargs.pop('unets')
89
+ is_video = decoder_kwargs.pop('video', False)
90
+
91
+ unets = []
92
+
93
+ for unet, unet_kwargs in zip(self.unets, unets_kwargs):
94
+ if isinstance(unet, NullUnetConfig):
95
+ unet_klass = NullUnet
96
+ elif is_video:
97
+ unet_klass = Unet3D
98
+ else:
99
+ unet_klass = Unet
100
+
101
+ unets.append(unet_klass(**unet_kwargs))
102
+
103
+ imagen = Imagen(unets, **decoder_kwargs)
104
+
105
+ imagen._config = self.dict().copy()
106
+ return imagen
107
+
108
+ class ElucidatedImagenConfig(AllowExtraBaseModel):
109
+ unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
110
+ image_sizes: ListOrTuple(int)
111
+ video: bool = False
112
+ text_encoder_name: str = DEFAULT_T5_NAME
113
+ channels: int = 3
114
+ cond_drop_prob: float = 0.5
115
+ num_sample_steps: SingleOrList(int) = 32
116
+ sigma_min: SingleOrList(float) = 0.002
117
+ sigma_max: SingleOrList(int) = 80
118
+ sigma_data: SingleOrList(float) = 0.5
119
+ rho: SingleOrList(int) = 7
120
+ P_mean: SingleOrList(float) = -1.2
121
+ P_std: SingleOrList(float) = 1.2
122
+ S_churn: SingleOrList(int) = 80
123
+ S_tmin: SingleOrList(float) = 0.05
124
+ S_tmax: SingleOrList(int) = 50
125
+ S_noise: SingleOrList(float) = 1.003
126
+
127
+ @validator('image_sizes')
128
+ def check_image_sizes(cls, image_sizes, values):
129
+ unets = values.get('unets')
130
+ if len(image_sizes) != len(unets):
131
+ raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
132
+ return image_sizes
133
+
134
+ def create(self):
135
+ decoder_kwargs = self.dict()
136
+ unets_kwargs = decoder_kwargs.pop('unets')
137
+ is_video = decoder_kwargs.pop('video', False)
138
+
139
+ unet_klass = Unet3D if is_video else Unet
140
+
141
+ unets = []
142
+
143
+ for unet, unet_kwargs in zip(self.unets, unets_kwargs):
144
+ if isinstance(unet, NullUnetConfig):
145
+ unet_klass = NullUnet
146
+ elif is_video:
147
+ unet_klass = Unet3D
148
+ else:
149
+ unet_klass = Unet
150
+
151
+ unets.append(unet_klass(**unet_kwargs))
152
+
153
+ imagen = ElucidatedImagen(unets, **decoder_kwargs)
154
+
155
+ imagen._config = self.dict().copy()
156
+ return imagen
157
+
158
+ class ImagenTrainerConfig(AllowExtraBaseModel):
159
+ imagen: dict
160
+ elucidated: bool = False
161
+ video: bool = False
162
+ use_ema: bool = True
163
+ lr: SingleOrList(float) = 1e-4
164
+ eps: SingleOrList(float) = 1e-8
165
+ beta1: float = 0.9
166
+ beta2: float = 0.99
167
+ max_grad_norm: Optional[float] = None
168
+ group_wd_params: bool = True
169
+ warmup_steps: SingleOrList(Optional[int]) = None
170
+ cosine_decay_max_steps: SingleOrList(Optional[int]) = None
171
+
172
+ def create(self):
173
+ trainer_kwargs = self.dict()
174
+
175
+ imagen_config = trainer_kwargs.pop('imagen')
176
+ elucidated = trainer_kwargs.pop('elucidated')
177
+
178
+ imagen_config_klass = ElucidatedImagenConfig if elucidated else ImagenConfig
179
+ imagen = imagen_config_klass(**{**imagen_config, 'video': video}).create()
180
+
181
+ return ImagenTrainer(imagen, **trainer_kwargs)
data.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torchvision import transforms as T, utils
8
+ import torch.nn.functional as F
9
+ from imagen_pytorch import t5
10
+ from torch.nn.utils.rnn import pad_sequence
11
+
12
+ from PIL import Image
13
+
14
+ from datasets.utils.file_utils import get_datasets_user_agent
15
+ import io
16
+ import urllib
17
+
18
+ USER_AGENT = get_datasets_user_agent()
19
+
20
+ # helpers functions
21
+
22
+ def exists(val):
23
+ return val is not None
24
+
25
+ def cycle(dl):
26
+ while True:
27
+ for data in dl:
28
+ yield data
29
+
30
+ def convert_image_to(img_type, image):
31
+ if image.mode != img_type:
32
+ return image.convert(img_type)
33
+ return image
34
+
35
+ # dataset, dataloader, collator
36
+
37
+ class Collator:
38
+ def __init__(self, image_size, url_label, text_label, image_label, name, channels):
39
+ self.url_label = url_label
40
+ self.text_label = text_label
41
+ self.image_label = image_label
42
+ self.download = url_label is not None
43
+ self.name = name
44
+ self.channels = channels
45
+ self.transform = T.Compose([
46
+ T.Resize(image_size),
47
+ T.CenterCrop(image_size),
48
+ T.ToTensor(),
49
+ ])
50
+ def __call__(self, batch):
51
+
52
+ texts = []
53
+ images = []
54
+ for item in batch:
55
+ try:
56
+ if self.download:
57
+ image = self.fetch_single_image(item[self.url_label])
58
+ else:
59
+ image = item[self.image_label]
60
+ image = self.transform(image.convert(self.channels))
61
+ except:
62
+ continue
63
+
64
+ text = t5.t5_encode_text([item[self.text_label]], name=self.name)
65
+ texts.append(torch.squeeze(text))
66
+ images.append(image)
67
+
68
+ if len(texts) == 0:
69
+ return None
70
+
71
+ texts = pad_sequence(texts, True)
72
+
73
+ newbatch = []
74
+ for i in range(len(texts)):
75
+ newbatch.append((images[i], texts[i]))
76
+
77
+ return torch.utils.data.dataloader.default_collate(newbatch)
78
+
79
+ def fetch_single_image(self, image_url, timeout=1):
80
+ try:
81
+ request = urllib.request.Request(
82
+ image_url,
83
+ data=None,
84
+ headers={"user-agent": USER_AGENT},
85
+ )
86
+ with urllib.request.urlopen(request, timeout=timeout) as req:
87
+ image = Image.open(io.BytesIO(req.read())).convert('RGB')
88
+ except Exception:
89
+ image = None
90
+ return image
91
+
92
+ class Dataset(Dataset):
93
+ def __init__(
94
+ self,
95
+ folder,
96
+ image_size,
97
+ exts = ['jpg', 'jpeg', 'png', 'tiff'],
98
+ convert_image_to_type = None
99
+ ):
100
+ super().__init__()
101
+ self.folder = folder
102
+ self.image_size = image_size
103
+ self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
104
+
105
+ convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity()
106
+
107
+ self.transform = T.Compose([
108
+ T.Lambda(convert_fn),
109
+ T.Resize(image_size),
110
+ T.RandomHorizontalFlip(),
111
+ T.CenterCrop(image_size),
112
+ T.ToTensor()
113
+ ])
114
+
115
+ def __len__(self):
116
+ return len(self.paths)
117
+
118
+ def __getitem__(self, index):
119
+ path = self.paths[index]
120
+ img = Image.open(path)
121
+ return self.transform(img)
122
+
123
+ def get_images_dataloader(
124
+ folder,
125
+ *,
126
+ batch_size,
127
+ image_size,
128
+ shuffle = True,
129
+ cycle_dl = False,
130
+ pin_memory = True
131
+ ):
132
+ ds = Dataset(folder, image_size)
133
+ dl = DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
134
+
135
+ if cycle_dl:
136
+ dl = cycle(dl)
137
+ return dl
default_config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "original",
3
+ "imagen": {
4
+ "video": false,
5
+ "timesteps": [1024, 512, 512],
6
+ "image_sizes": [64, 256, 1024],
7
+ "random_crop_sizes": [null, 64, 256],
8
+ "condition_on_text": true,
9
+ "cond_drop_prob": 0.1,
10
+ "text_encoder_name": "google/t5-v1_1-large",
11
+ "unets": [
12
+ {
13
+ "dim": 512,
14
+ "dim_mults": [1, 2, 3, 4],
15
+ "num_resnet_blocks": 3,
16
+ "layer_attns": [false, true, true, true],
17
+ "layer_cross_attns": [false, true, true, true],
18
+ "attn_heads": 8
19
+ },
20
+ {
21
+ "dim": 128,
22
+ "dim_mults": [1, 2, 4, 8],
23
+ "num_resnet_blocks": [2, 4, 8, 8],
24
+ "layer_attns": [false, false, false, true],
25
+ "layer_cross_attns": [false, false, false, true],
26
+ "attn_heads": 8
27
+ },
28
+ {
29
+ "dim": 128,
30
+ "dim_mults": [1, 2, 4, 8],
31
+ "num_resnet_blocks": [2, 4, 8, 8],
32
+ "layer_attns": false,
33
+ "layer_cross_attns": [false, false, false, true],
34
+ "attn_heads": 8
35
+ }
36
+ ]
37
+ },
38
+ "trainer": {
39
+ "lr": 1e-4
40
+ },
41
+ "dataset_name": "laion/laion2B-en",
42
+ "dataset": {
43
+ "batch_size": 2048,
44
+ "shuffle": true
45
+ },
46
+ "image_label": null,
47
+ "url_label": "URL",
48
+ "text_label": "TEXT",
49
+ "checkpoint_path": "./imagen.pt"
50
+ }
elucidated_imagen.py ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt
2
+ from random import random
3
+ from functools import partial
4
+ from contextlib import contextmanager, nullcontext
5
+ from typing import List, Union
6
+ from collections import namedtuple
7
+ from tqdm.auto import tqdm
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn, einsum
12
+ from torch.cuda.amp import autocast
13
+ from torch.nn.parallel import DistributedDataParallel
14
+ import torchvision.transforms as T
15
+
16
+ import kornia.augmentation as K
17
+
18
+ from einops import rearrange, repeat, reduce
19
+
20
+ from imagen_pytorch.imagen_pytorch import (
21
+ GaussianDiffusionContinuousTimes,
22
+ Unet,
23
+ NullUnet,
24
+ first,
25
+ exists,
26
+ identity,
27
+ maybe,
28
+ default,
29
+ cast_tuple,
30
+ cast_uint8_images_to_float,
31
+ eval_decorator,
32
+ pad_tuple_to_length,
33
+ resize_image_to,
34
+ calc_all_frame_dims,
35
+ safe_get_tuple_index,
36
+ right_pad_dims_to,
37
+ module_device,
38
+ normalize_neg_one_to_one,
39
+ unnormalize_zero_to_one,
40
+ compact,
41
+ maybe_transform_dict_key
42
+ )
43
+
44
+ from imagen_pytorch.imagen_video import (
45
+ Unet3D,
46
+ resize_video_to,
47
+ scale_video_time
48
+ )
49
+
50
+ from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
51
+
52
+ # constants
53
+
54
+ Hparams_fields = [
55
+ 'num_sample_steps',
56
+ 'sigma_min',
57
+ 'sigma_max',
58
+ 'sigma_data',
59
+ 'rho',
60
+ 'P_mean',
61
+ 'P_std',
62
+ 'S_churn',
63
+ 'S_tmin',
64
+ 'S_tmax',
65
+ 'S_noise'
66
+ ]
67
+
68
+ Hparams = namedtuple('Hparams', Hparams_fields)
69
+
70
+ # helper functions
71
+
72
+ def log(t, eps = 1e-20):
73
+ return torch.log(t.clamp(min = eps))
74
+
75
+ # main class
76
+
77
+ class ElucidatedImagen(nn.Module):
78
+ def __init__(
79
+ self,
80
+ unets,
81
+ *,
82
+ image_sizes, # for cascading ddpm, image size at each stage
83
+ text_encoder_name = DEFAULT_T5_NAME,
84
+ text_embed_dim = None,
85
+ channels = 3,
86
+ cond_drop_prob = 0.1,
87
+ random_crop_sizes = None,
88
+ resize_mode = 'nearest',
89
+ temporal_downsample_factor = 1,
90
+ resize_cond_video_frames = True,
91
+ lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
92
+ per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
93
+ condition_on_text = True,
94
+ auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
95
+ dynamic_thresholding = True,
96
+ dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
97
+ only_train_unet_number = None,
98
+ lowres_noise_schedule = 'linear',
99
+ num_sample_steps = 32, # number of sampling steps
100
+ sigma_min = 0.002, # min noise level
101
+ sigma_max = 80, # max noise level
102
+ sigma_data = 0.5, # standard deviation of data distribution
103
+ rho = 7, # controls the sampling schedule
104
+ P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
105
+ P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
106
+ S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
107
+ S_tmin = 0.05,
108
+ S_tmax = 50,
109
+ S_noise = 1.003,
110
+ ):
111
+ super().__init__()
112
+
113
+ self.only_train_unet_number = only_train_unet_number
114
+
115
+ # conditioning hparams
116
+
117
+ self.condition_on_text = condition_on_text
118
+ self.unconditional = not condition_on_text
119
+
120
+ # channels
121
+
122
+ self.channels = channels
123
+
124
+ # automatically take care of ensuring that first unet is unconditional
125
+ # while the rest of the unets are conditioned on the low resolution image produced by previous unet
126
+
127
+ unets = cast_tuple(unets)
128
+ num_unets = len(unets)
129
+
130
+ # randomly cropping for upsampler training
131
+
132
+ self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
133
+ assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
134
+
135
+ # lowres augmentation noise schedule
136
+
137
+ self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule)
138
+
139
+ # get text encoder
140
+
141
+ self.text_encoder_name = text_encoder_name
142
+ self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))
143
+
144
+ self.encode_text = partial(t5_encode_text, name = text_encoder_name)
145
+
146
+ # construct unets
147
+
148
+ self.unets = nn.ModuleList([])
149
+ self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
150
+
151
+ for ind, one_unet in enumerate(unets):
152
+ assert isinstance(one_unet, (Unet, Unet3D, NullUnet))
153
+ is_first = ind == 0
154
+
155
+ one_unet = one_unet.cast_model_parameters(
156
+ lowres_cond = not is_first,
157
+ cond_on_text = self.condition_on_text,
158
+ text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
159
+ channels = self.channels,
160
+ channels_out = self.channels
161
+ )
162
+
163
+ self.unets.append(one_unet)
164
+
165
+ # determine whether we are training on images or video
166
+
167
+ is_video = any([isinstance(unet, Unet3D) for unet in self.unets])
168
+ self.is_video = is_video
169
+
170
+ self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
171
+
172
+ self.resize_to = resize_video_to if is_video else resize_image_to
173
+ self.resize_to = partial(self.resize_to, mode = resize_mode)
174
+
175
+ # unet image sizes
176
+
177
+ self.image_sizes = cast_tuple(image_sizes)
178
+ assert num_unets == len(self.image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}'
179
+
180
+ self.sample_channels = cast_tuple(self.channels, num_unets)
181
+
182
+ # cascading ddpm related stuff
183
+
184
+ lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
185
+ assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
186
+
187
+ self.lowres_sample_noise_level = lowres_sample_noise_level
188
+ self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
189
+
190
+ # classifier free guidance
191
+
192
+ self.cond_drop_prob = cond_drop_prob
193
+ self.can_classifier_guidance = cond_drop_prob > 0.
194
+
195
+ # normalize and unnormalize image functions
196
+
197
+ self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
198
+ self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
199
+ self.input_image_range = (0. if auto_normalize_img else -1., 1.)
200
+
201
+ # dynamic thresholding
202
+
203
+ self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
204
+ self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
205
+
206
+ # temporal interpolations
207
+
208
+ temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets)
209
+ self.temporal_downsample_factor = temporal_downsample_factor
210
+
211
+ self.resize_cond_video_frames = resize_cond_video_frames
212
+ self.temporal_downsample_divisor = temporal_downsample_factor[0]
213
+
214
+ assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1'
215
+ assert tuple(sorted(temporal_downsample_factor, reverse = True)) == temporal_downsample_factor, 'temporal downsample factor must be in order of descending'
216
+
217
+ # elucidating parameters
218
+
219
+ hparams = [
220
+ num_sample_steps,
221
+ sigma_min,
222
+ sigma_max,
223
+ sigma_data,
224
+ rho,
225
+ P_mean,
226
+ P_std,
227
+ S_churn,
228
+ S_tmin,
229
+ S_tmax,
230
+ S_noise,
231
+ ]
232
+
233
+ hparams = [cast_tuple(hp, num_unets) for hp in hparams]
234
+ self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)]
235
+
236
+ # one temp parameter for keeping track of device
237
+
238
+ self.register_buffer('_temp', torch.tensor([0.]), persistent = False)
239
+
240
+ # default to device of unets passed in
241
+
242
+ self.to(next(self.unets.parameters()).device)
243
+
244
+ def force_unconditional_(self):
245
+ self.condition_on_text = False
246
+ self.unconditional = True
247
+
248
+ for unet in self.unets:
249
+ unet.cond_on_text = False
250
+
251
+ @property
252
+ def device(self):
253
+ return self._temp.device
254
+
255
+ def get_unet(self, unet_number):
256
+ assert 0 < unet_number <= len(self.unets)
257
+ index = unet_number - 1
258
+
259
+ if isinstance(self.unets, nn.ModuleList):
260
+ unets_list = [unet for unet in self.unets]
261
+ delattr(self, 'unets')
262
+ self.unets = unets_list
263
+
264
+ if index != self.unet_being_trained_index:
265
+ for unet_index, unet in enumerate(self.unets):
266
+ unet.to(self.device if unet_index == index else 'cpu')
267
+
268
+ self.unet_being_trained_index = index
269
+ return self.unets[index]
270
+
271
+ def reset_unets_all_one_device(self, device = None):
272
+ device = default(device, self.device)
273
+ self.unets = nn.ModuleList([*self.unets])
274
+ self.unets.to(device)
275
+
276
+ self.unet_being_trained_index = -1
277
+
278
+ @contextmanager
279
+ def one_unet_in_gpu(self, unet_number = None, unet = None):
280
+ assert exists(unet_number) ^ exists(unet)
281
+
282
+ if exists(unet_number):
283
+ unet = self.unets[unet_number - 1]
284
+
285
+ cpu = torch.device('cpu')
286
+
287
+ devices = [module_device(unet) for unet in self.unets]
288
+
289
+ self.unets.to(cpu)
290
+ unet.to(self.device)
291
+
292
+ yield
293
+
294
+ for unet, device in zip(self.unets, devices):
295
+ unet.to(device)
296
+
297
+ # overriding state dict functions
298
+
299
+ def state_dict(self, *args, **kwargs):
300
+ self.reset_unets_all_one_device()
301
+ return super().state_dict(*args, **kwargs)
302
+
303
+ def load_state_dict(self, *args, **kwargs):
304
+ self.reset_unets_all_one_device()
305
+ return super().load_state_dict(*args, **kwargs)
306
+
307
+ # dynamic thresholding
308
+
309
+ def threshold_x_start(self, x_start, dynamic_threshold = True):
310
+ if not dynamic_threshold:
311
+ return x_start.clamp(-1., 1.)
312
+
313
+ s = torch.quantile(
314
+ rearrange(x_start, 'b ... -> b (...)').abs(),
315
+ self.dynamic_thresholding_percentile,
316
+ dim = -1
317
+ )
318
+
319
+ s.clamp_(min = 1.)
320
+ s = right_pad_dims_to(x_start, s)
321
+ return x_start.clamp(-s, s) / s
322
+
323
+ # derived preconditioning params - Table 1
324
+
325
+ def c_skip(self, sigma_data, sigma):
326
+ return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2)
327
+
328
+ def c_out(self, sigma_data, sigma):
329
+ return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5
330
+
331
+ def c_in(self, sigma_data, sigma):
332
+ return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5
333
+
334
+ def c_noise(self, sigma):
335
+ return log(sigma) * 0.25
336
+
337
+ # preconditioned network output
338
+ # equation (7) in the paper
339
+
340
+ def preconditioned_network_forward(
341
+ self,
342
+ unet_forward,
343
+ noised_images,
344
+ sigma,
345
+ *,
346
+ sigma_data,
347
+ clamp = False,
348
+ dynamic_threshold = True,
349
+ **kwargs
350
+ ):
351
+ batch, device = noised_images.shape[0], noised_images.device
352
+
353
+ if isinstance(sigma, float):
354
+ sigma = torch.full((batch,), sigma, device = device)
355
+
356
+ padded_sigma = self.right_pad_dims_to_datatype(sigma)
357
+
358
+ net_out = unet_forward(
359
+ self.c_in(sigma_data, padded_sigma) * noised_images,
360
+ self.c_noise(sigma),
361
+ **kwargs
362
+ )
363
+
364
+ out = self.c_skip(sigma_data, padded_sigma) * noised_images + self.c_out(sigma_data, padded_sigma) * net_out
365
+
366
+ if not clamp:
367
+ return out
368
+
369
+ return self.threshold_x_start(out, dynamic_threshold)
370
+
371
+ # sampling
372
+
373
+ # sample schedule
374
+ # equation (5) in the paper
375
+
376
+ def sample_schedule(
377
+ self,
378
+ num_sample_steps,
379
+ rho,
380
+ sigma_min,
381
+ sigma_max
382
+ ):
383
+ N = num_sample_steps
384
+ inv_rho = 1 / rho
385
+
386
+ steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
387
+ sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho
388
+
389
+ sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
390
+ return sigmas
391
+
392
+ @torch.no_grad()
393
+ def one_unet_sample(
394
+ self,
395
+ unet,
396
+ shape,
397
+ *,
398
+ unet_number,
399
+ clamp = True,
400
+ dynamic_threshold = True,
401
+ cond_scale = 1.,
402
+ use_tqdm = True,
403
+ inpaint_videos = None,
404
+ inpaint_images = None,
405
+ inpaint_masks = None,
406
+ inpaint_resample_times = 5,
407
+ init_images = None,
408
+ skip_steps = None,
409
+ sigma_min = None,
410
+ sigma_max = None,
411
+ **kwargs
412
+ ):
413
+ # video
414
+
415
+ is_video = len(shape) == 5
416
+ frames = shape[-3] if is_video else None
417
+ resize_kwargs = dict(target_frames = frames) if exists(frames) else dict()
418
+
419
+ # get specific sampling hyperparameters for unet
420
+
421
+ hp = self.hparams[unet_number - 1]
422
+
423
+ sigma_min = default(sigma_min, hp.sigma_min)
424
+ sigma_max = default(sigma_max, hp.sigma_max)
425
+
426
+ # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
427
+
428
+ sigmas = self.sample_schedule(hp.num_sample_steps, hp.rho, sigma_min, sigma_max)
429
+
430
+ gammas = torch.where(
431
+ (sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax),
432
+ min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1),
433
+ 0.
434
+ )
435
+
436
+ sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
437
+
438
+ # images is noise at the beginning
439
+
440
+ init_sigma = sigmas[0]
441
+
442
+ images = init_sigma * torch.randn(shape, device = self.device)
443
+
444
+ # initializing with an image
445
+
446
+ if exists(init_images):
447
+ images += init_images
448
+
449
+ # keeping track of x0, for self conditioning if needed
450
+
451
+ x_start = None
452
+
453
+ # prepare inpainting images and mask
454
+
455
+ inpaint_images = default(inpaint_videos, inpaint_images)
456
+ has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
457
+ resample_times = inpaint_resample_times if has_inpainting else 1
458
+
459
+ if has_inpainting:
460
+ inpaint_images = self.normalize_img(inpaint_images)
461
+ inpaint_images = self.resize_to(inpaint_images, shape[-1], **resize_kwargs)
462
+ inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1], **resize_kwargs).bool()
463
+
464
+ # unet kwargs
465
+
466
+ unet_kwargs = dict(
467
+ sigma_data = hp.sigma_data,
468
+ clamp = clamp,
469
+ dynamic_threshold = dynamic_threshold,
470
+ cond_scale = cond_scale,
471
+ **kwargs
472
+ )
473
+
474
+ # gradually denoise
475
+
476
+ initial_step = default(skip_steps, 0)
477
+ sigmas_and_gammas = sigmas_and_gammas[initial_step:]
478
+
479
+ total_steps = len(sigmas_and_gammas)
480
+
481
+ for ind, (sigma, sigma_next, gamma) in tqdm(enumerate(sigmas_and_gammas), total = total_steps, desc = 'sampling time step', disable = not use_tqdm):
482
+ is_last_timestep = ind == (total_steps - 1)
483
+
484
+ sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
485
+
486
+ for r in reversed(range(resample_times)):
487
+ is_last_resample_step = r == 0
488
+
489
+ eps = hp.S_noise * torch.randn(shape, device = self.device) # stochastic sampling
490
+
491
+ sigma_hat = sigma + gamma * sigma
492
+ added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps
493
+
494
+ images_hat = images + added_noise
495
+
496
+ self_cond = x_start if unet.self_cond else None
497
+
498
+ if has_inpainting:
499
+ images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks
500
+
501
+ model_output = self.preconditioned_network_forward(
502
+ unet.forward_with_cond_scale,
503
+ images_hat,
504
+ sigma_hat,
505
+ self_cond = self_cond,
506
+ **unet_kwargs
507
+ )
508
+
509
+ denoised_over_sigma = (images_hat - model_output) / sigma_hat
510
+
511
+ images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma
512
+
513
+ # second order correction, if not the last timestep
514
+
515
+ has_second_order_correction = sigma_next != 0
516
+
517
+ if has_second_order_correction:
518
+ self_cond = model_output if unet.self_cond else None
519
+
520
+ model_output_next = self.preconditioned_network_forward(
521
+ unet.forward_with_cond_scale,
522
+ images_next,
523
+ sigma_next,
524
+ self_cond = self_cond,
525
+ **unet_kwargs
526
+ )
527
+
528
+ denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
529
+ images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
530
+
531
+ images = images_next
532
+
533
+ if has_inpainting and not (is_last_resample_step or is_last_timestep):
534
+ # renoise in repaint and then resample
535
+ repaint_noise = torch.randn(shape, device = self.device)
536
+ images = images + (sigma - sigma_next) * repaint_noise
537
+
538
+ x_start = model_output if not has_second_order_correction else model_output_next # save model output for self conditioning
539
+
540
+ images = images.clamp(-1., 1.)
541
+
542
+ if has_inpainting:
543
+ images = images * ~inpaint_masks + inpaint_images * inpaint_masks
544
+
545
+ return self.unnormalize_img(images)
546
+
547
+ @torch.no_grad()
548
+ @eval_decorator
549
+ def sample(
550
+ self,
551
+ texts: List[str] = None,
552
+ text_masks = None,
553
+ text_embeds = None,
554
+ cond_images = None,
555
+ cond_video_frames = None,
556
+ post_cond_video_frames = None,
557
+ inpaint_videos = None,
558
+ inpaint_images = None,
559
+ inpaint_masks = None,
560
+ inpaint_resample_times = 5,
561
+ init_images = None,
562
+ skip_steps = None,
563
+ sigma_min = None,
564
+ sigma_max = None,
565
+ video_frames = None,
566
+ batch_size = 1,
567
+ cond_scale = 1.,
568
+ lowres_sample_noise_level = None,
569
+ start_at_unet_number = 1,
570
+ start_image_or_video = None,
571
+ stop_at_unet_number = None,
572
+ return_all_unet_outputs = False,
573
+ return_pil_images = False,
574
+ use_tqdm = True,
575
+ use_one_unet_in_gpu = True,
576
+ device = None,
577
+ ):
578
+ device = default(device, self.device)
579
+ self.reset_unets_all_one_device(device = device)
580
+
581
+ cond_images = maybe(cast_uint8_images_to_float)(cond_images)
582
+
583
+ if exists(texts) and not exists(text_embeds) and not self.unconditional:
584
+ assert all([*map(len, texts)]), 'text cannot be empty'
585
+
586
+ with autocast(enabled = False):
587
+ text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
588
+
589
+ text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
590
+
591
+ if not self.unconditional:
592
+ assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
593
+
594
+ text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
595
+ batch_size = text_embeds.shape[0]
596
+
597
+ # inpainting
598
+
599
+ inpaint_images = default(inpaint_videos, inpaint_images)
600
+
601
+ if exists(inpaint_images):
602
+ if self.unconditional:
603
+ if batch_size == 1: # assume researcher wants to broadcast along inpainted images
604
+ batch_size = inpaint_images.shape[0]
605
+
606
+ assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
607
+ assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'
608
+
609
+ assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
610
+ assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
611
+ assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
612
+
613
+ assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting'
614
+
615
+ outputs = []
616
+
617
+ is_cuda = next(self.parameters()).is_cuda
618
+ device = next(self.parameters()).device
619
+
620
+ lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
621
+
622
+ num_unets = len(self.unets)
623
+ cond_scale = cast_tuple(cond_scale, num_unets)
624
+
625
+ # handle video and frame dimension
626
+
627
+ if self.is_video and exists(inpaint_images):
628
+ video_frames = inpaint_images.shape[2]
629
+
630
+ if inpaint_masks.ndim == 3:
631
+ inpaint_masks = repeat(inpaint_masks, 'b h w -> b f h w', f = video_frames)
632
+
633
+ assert inpaint_masks.shape[1] == video_frames
634
+
635
+ assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
636
+
637
+ # determine the frame dimensions, if needed
638
+
639
+ all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)
640
+
641
+ # initializing with an image or video
642
+
643
+ init_images = cast_tuple(init_images, num_unets)
644
+ init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
645
+
646
+ skip_steps = cast_tuple(skip_steps, num_unets)
647
+
648
+ sigma_min = cast_tuple(sigma_min, num_unets)
649
+ sigma_max = cast_tuple(sigma_max, num_unets)
650
+
651
+ # handle starting at a unet greater than 1, for training only-upscaler training
652
+
653
+ if start_at_unet_number > 1:
654
+ assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
655
+ assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
656
+ assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
657
+
658
+ prev_image_size = self.image_sizes[start_at_unet_number - 2]
659
+ img = self.resize_to(start_image_or_video, prev_image_size)
660
+
661
+ # go through each unet in cascade
662
+
663
+ for unet_number, unet, channel, image_size, frame_dims, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, all_frame_dims, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm):
664
+ if unet_number < start_at_unet_number:
665
+ continue
666
+
667
+ assert not isinstance(unet, NullUnet), 'cannot sample from null unet'
668
+
669
+ context = self.one_unet_in_gpu(unet = unet) if is_cuda and use_one_unet_in_gpu else nullcontext()
670
+
671
+ with context:
672
+ lowres_cond_img = lowres_noise_times = None
673
+
674
+ shape = (batch_size, channel, *frame_dims, image_size, image_size)
675
+
676
+ resize_kwargs = dict()
677
+ video_kwargs = dict()
678
+
679
+ if self.is_video:
680
+ resize_kwargs = dict(target_frames = frame_dims[0])
681
+
682
+ video_kwargs = dict(
683
+ cond_video_frames = cond_video_frames,
684
+ post_cond_video_frames = post_cond_video_frames
685
+ )
686
+
687
+ video_kwargs = compact(video_kwargs)
688
+
689
+ # handle video conditioning frames
690
+
691
+ if self.is_video and self.resize_cond_video_frames:
692
+ downsample_scale = self.temporal_downsample_factor[unet_number - 1]
693
+ temporal_downsample_fn = partial(scale_video_time, downsample_scale = downsample_scale)
694
+ video_kwargs = maybe_transform_dict_key(video_kwargs, 'cond_video_frames', temporal_downsample_fn)
695
+ video_kwargs = maybe_transform_dict_key(video_kwargs, 'post_cond_video_frames', temporal_downsample_fn)
696
+
697
+ # low resolution conditioning
698
+
699
+ if unet.lowres_cond:
700
+ lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)
701
+
702
+ lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs)
703
+ lowres_cond_img = self.normalize_img(lowres_cond_img)
704
+
705
+ lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
706
+
707
+ if exists(unet_init_images):
708
+ unet_init_images = self.resize_to(unet_init_images, image_size, **resize_kwargs)
709
+
710
+ shape = (batch_size, self.channels, *frame_dims, image_size, image_size)
711
+
712
+ img = self.one_unet_sample(
713
+ unet,
714
+ shape,
715
+ unet_number = unet_number,
716
+ text_embeds = text_embeds,
717
+ text_mask = text_masks,
718
+ cond_images = cond_images,
719
+ inpaint_images = inpaint_images,
720
+ inpaint_masks = inpaint_masks,
721
+ inpaint_resample_times = inpaint_resample_times,
722
+ init_images = unet_init_images,
723
+ skip_steps = unet_skip_steps,
724
+ sigma_min = unet_sigma_min,
725
+ sigma_max = unet_sigma_max,
726
+ cond_scale = unet_cond_scale,
727
+ lowres_cond_img = lowres_cond_img,
728
+ lowres_noise_times = lowres_noise_times,
729
+ dynamic_threshold = dynamic_threshold,
730
+ use_tqdm = use_tqdm,
731
+ **video_kwargs
732
+ )
733
+
734
+ outputs.append(img)
735
+
736
+ if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
737
+ break
738
+
739
+ output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs
740
+
741
+ if not return_pil_images:
742
+ return outputs[output_index]
743
+
744
+ if not return_all_unet_outputs:
745
+ outputs = outputs[-1:]
746
+
747
+ assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet'
748
+
749
+ pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs))
750
+
751
+ return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
752
+
753
+ # training
754
+
755
+ def loss_weight(self, sigma_data, sigma):
756
+ return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2
757
+
758
+ def noise_distribution(self, P_mean, P_std, batch_size):
759
+ return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp()
760
+
761
+ def forward(
762
+ self,
763
+ images, # rename to images or video
764
+ unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
765
+ texts: List[str] = None,
766
+ text_embeds = None,
767
+ text_masks = None,
768
+ unet_number = None,
769
+ cond_images = None,
770
+ **kwargs
771
+ ):
772
+ if self.is_video and images.ndim == 4:
773
+ images = rearrange(images, 'b c h w -> b c 1 h w')
774
+ kwargs.update(ignore_time = True)
775
+
776
+ assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
777
+ assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
778
+ unet_number = default(unet_number, 1)
779
+ assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'
780
+
781
+ images = cast_uint8_images_to_float(images)
782
+ cond_images = maybe(cast_uint8_images_to_float)(cond_images)
783
+
784
+ assert images.dtype == torch.float, f'images tensor needs to be floats but {images.dtype} dtype found instead'
785
+
786
+ unet_index = unet_number - 1
787
+
788
+ unet = default(unet, lambda: self.get_unet(unet_number))
789
+
790
+ assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
791
+
792
+ target_image_size = self.image_sizes[unet_index]
793
+ random_crop_size = self.random_crop_sizes[unet_index]
794
+ prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
795
+ hp = self.hparams[unet_index]
796
+
797
+ batch_size, c, *_, h, w, device, is_video = *images.shape, images.device, (images.ndim == 5)
798
+
799
+ frames = images.shape[2] if is_video else None
800
+ all_frame_dims = tuple(safe_get_tuple_index(el, 0) for el in calc_all_frame_dims(self.temporal_downsample_factor, frames))
801
+ ignore_time = kwargs.get('ignore_time', False)
802
+
803
+ target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None
804
+ prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None
805
+ frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()
806
+
807
+ assert images.shape[1] == self.channels
808
+ assert h >= target_image_size and w >= target_image_size
809
+
810
+ if exists(texts) and not exists(text_embeds) and not self.unconditional:
811
+ assert all([*map(len, texts)]), 'text cannot be empty'
812
+ assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
813
+
814
+ with autocast(enabled = False):
815
+ text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
816
+
817
+ text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
818
+
819
+ if not self.unconditional:
820
+ text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
821
+
822
+ assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified'
823
+ assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented'
824
+
825
+ assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
826
+
827
+ # handle video conditioning frames
828
+
829
+ if self.is_video and self.resize_cond_video_frames:
830
+ downsample_scale = self.temporal_downsample_factor[unet_index]
831
+ temporal_downsample_fn = partial(scale_video_time, downsample_scale = downsample_scale)
832
+ kwargs = maybe_transform_dict_key(kwargs, 'cond_video_frames', temporal_downsample_fn)
833
+ kwargs = maybe_transform_dict_key(kwargs, 'post_cond_video_frames', temporal_downsample_fn)
834
+
835
+ # low resolution conditioning
836
+
837
+ lowres_cond_img = lowres_aug_times = None
838
+ if exists(prev_image_size):
839
+ lowres_cond_img = self.resize_to(images, prev_image_size, **frames_to_resize_kwargs(prev_frame_size), clamp_range = self.input_image_range)
840
+ lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, **frames_to_resize_kwargs(target_frame_size), clamp_range = self.input_image_range)
841
+
842
+ if self.per_sample_random_aug_noise_level:
843
+ lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device)
844
+ else:
845
+ lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
846
+ lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size)
847
+
848
+ images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size))
849
+
850
+ # normalize to [-1, 1]
851
+
852
+ images = self.normalize_img(images)
853
+ lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
854
+
855
+ # random cropping during training
856
+ # for upsamplers
857
+
858
+ if exists(random_crop_size):
859
+ aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
860
+
861
+ if is_video:
862
+ images, lowres_cond_img = map(lambda t: rearrange(t, 'b c f h w -> (b f) c h w'), (images, lowres_cond_img))
863
+
864
+ # make sure low res conditioner and image both get augmented the same way
865
+ # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
866
+ images = aug(images)
867
+ lowres_cond_img = aug(lowres_cond_img, params = aug._params)
868
+
869
+ if is_video:
870
+ images, lowres_cond_img = map(lambda t: rearrange(t, '(b f) c h w -> b c f h w', f = frames), (images, lowres_cond_img))
871
+
872
+ # noise the lowres conditioning image
873
+ # at sample time, they then fix the noise level of 0.1 - 0.3
874
+
875
+ lowres_cond_img_noisy = None
876
+ if exists(lowres_cond_img):
877
+ lowres_cond_img_noisy, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
878
+
879
+ # get the sigmas
880
+
881
+ sigmas = self.noise_distribution(hp.P_mean, hp.P_std, batch_size)
882
+ padded_sigmas = self.right_pad_dims_to_datatype(sigmas)
883
+
884
+ # noise
885
+
886
+ noise = torch.randn_like(images)
887
+ noised_images = images + padded_sigmas * noise # alphas are 1. in the paper
888
+
889
+ # unet kwargs
890
+
891
+ unet_kwargs = dict(
892
+ sigma_data = hp.sigma_data,
893
+ text_embeds = text_embeds,
894
+ text_mask = text_masks,
895
+ cond_images = cond_images,
896
+ lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
897
+ lowres_cond_img = lowres_cond_img_noisy,
898
+ cond_drop_prob = self.cond_drop_prob,
899
+ **kwargs
900
+ )
901
+
902
+ # self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower
903
+
904
+ # Because 'unet' can be an instance of DistributedDataParallel coming from the
905
+ # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
906
+ # access the member 'module' of the wrapped unet instance.
907
+ self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet.self_cond
908
+
909
+ if self_cond and random() < 0.5:
910
+ with torch.no_grad():
911
+ pred_x0 = self.preconditioned_network_forward(
912
+ unet.forward,
913
+ noised_images,
914
+ sigmas,
915
+ **unet_kwargs
916
+ ).detach()
917
+
918
+ unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0}
919
+
920
+ # get prediction
921
+
922
+ denoised_images = self.preconditioned_network_forward(
923
+ unet.forward,
924
+ noised_images,
925
+ sigmas,
926
+ **unet_kwargs
927
+ )
928
+
929
+ # losses
930
+
931
+ losses = F.mse_loss(denoised_images, images, reduction = 'none')
932
+ losses = reduce(losses, 'b ... -> b', 'mean')
933
+
934
+ # loss weighting
935
+
936
+ losses = losses * self.loss_weight(hp.sigma_data, sigmas)
937
+
938
+ # return average loss
939
+
940
+ return losses.mean()
imagen_pytorch.py ADDED
@@ -0,0 +1,2731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import copy
3
+ from random import random
4
+ from beartype.typing import List, Union
5
+ from beartype import beartype
6
+ from tqdm.auto import tqdm
7
+ from functools import partial, wraps
8
+ from contextlib import contextmanager, nullcontext
9
+ from collections import namedtuple
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch.nn.parallel import DistributedDataParallel
15
+ from torch import nn, einsum
16
+ from torch.cuda.amp import autocast
17
+ from torch.special import expm1
18
+ import torchvision.transforms as T
19
+
20
+ import kornia.augmentation as K
21
+
22
+ from einops import rearrange, repeat, reduce, pack, unpack
23
+ from einops.layers.torch import Rearrange, Reduce
24
+
25
+ from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
26
+
27
+ from imagen_pytorch.imagen_video import Unet3D, resize_video_to, scale_video_time
28
+
29
+ # helper functions
30
+
31
+ def exists(val):
32
+ return val is not None
33
+
34
+ def identity(t, *args, **kwargs):
35
+ return t
36
+
37
+ def divisible_by(numer, denom):
38
+ return (numer % denom) == 0
39
+
40
+ def first(arr, d = None):
41
+ if len(arr) == 0:
42
+ return d
43
+ return arr[0]
44
+
45
+ def maybe(fn):
46
+ @wraps(fn)
47
+ def inner(x):
48
+ if not exists(x):
49
+ return x
50
+ return fn(x)
51
+ return inner
52
+
53
+ def once(fn):
54
+ called = False
55
+ @wraps(fn)
56
+ def inner(x):
57
+ nonlocal called
58
+ if called:
59
+ return
60
+ called = True
61
+ return fn(x)
62
+ return inner
63
+
64
+ print_once = once(print)
65
+
66
+ def default(val, d):
67
+ if exists(val):
68
+ return val
69
+ return d() if callable(d) else d
70
+
71
+ def cast_tuple(val, length = None):
72
+ if isinstance(val, list):
73
+ val = tuple(val)
74
+
75
+ output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
76
+
77
+ if exists(length):
78
+ assert len(output) == length
79
+
80
+ return output
81
+
82
+ def compact(input_dict):
83
+ return {key: value for key, value in input_dict.items() if exists(value)}
84
+
85
+ def maybe_transform_dict_key(input_dict, key, fn):
86
+ if key not in input_dict:
87
+ return input_dict
88
+
89
+ copied_dict = input_dict.copy()
90
+ copied_dict[key] = fn(copied_dict[key])
91
+ return copied_dict
92
+
93
+ def cast_uint8_images_to_float(images):
94
+ if not images.dtype == torch.uint8:
95
+ return images
96
+ return images / 255
97
+
98
+ def module_device(module):
99
+ return next(module.parameters()).device
100
+
101
+ def zero_init_(m):
102
+ nn.init.zeros_(m.weight)
103
+ if exists(m.bias):
104
+ nn.init.zeros_(m.bias)
105
+
106
+ def eval_decorator(fn):
107
+ def inner(model, *args, **kwargs):
108
+ was_training = model.training
109
+ model.eval()
110
+ out = fn(model, *args, **kwargs)
111
+ model.train(was_training)
112
+ return out
113
+ return inner
114
+
115
+ def pad_tuple_to_length(t, length, fillvalue = None):
116
+ remain_length = length - len(t)
117
+ if remain_length <= 0:
118
+ return t
119
+ return (*t, *((fillvalue,) * remain_length))
120
+
121
+ # helper classes
122
+
123
+ class Identity(nn.Module):
124
+ def __init__(self, *args, **kwargs):
125
+ super().__init__()
126
+
127
+ def forward(self, x, *args, **kwargs):
128
+ return x
129
+
130
+ # tensor helpers
131
+
132
+ def log(t, eps: float = 1e-12):
133
+ return torch.log(t.clamp(min = eps))
134
+
135
+ def l2norm(t):
136
+ return F.normalize(t, dim = -1)
137
+
138
+ def right_pad_dims_to(x, t):
139
+ padding_dims = x.ndim - t.ndim
140
+ if padding_dims <= 0:
141
+ return t
142
+ return t.view(*t.shape, *((1,) * padding_dims))
143
+
144
+ def masked_mean(t, *, dim, mask = None):
145
+ if not exists(mask):
146
+ return t.mean(dim = dim)
147
+
148
+ denom = mask.sum(dim = dim, keepdim = True)
149
+ mask = rearrange(mask, 'b n -> b n 1')
150
+ masked_t = t.masked_fill(~mask, 0.)
151
+
152
+ return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
153
+
154
+ def resize_image_to(
155
+ image,
156
+ target_image_size,
157
+ clamp_range = None,
158
+ mode = 'nearest'
159
+ ):
160
+ orig_image_size = image.shape[-1]
161
+
162
+ if orig_image_size == target_image_size:
163
+ return image
164
+
165
+ out = F.interpolate(image, target_image_size, mode = mode)
166
+
167
+ if exists(clamp_range):
168
+ out = out.clamp(*clamp_range)
169
+
170
+ return out
171
+
172
+ def calc_all_frame_dims(
173
+ downsample_factors: List[int],
174
+ frames
175
+ ):
176
+ if not exists(frames):
177
+ return (tuple(),) * len(downsample_factors)
178
+
179
+ all_frame_dims = []
180
+
181
+ for divisor in downsample_factors:
182
+ assert divisible_by(frames, divisor)
183
+ all_frame_dims.append((frames // divisor,))
184
+
185
+ return all_frame_dims
186
+
187
+ def safe_get_tuple_index(tup, index, default = None):
188
+ if len(tup) <= index:
189
+ return default
190
+ return tup[index]
191
+
192
+ # image normalization functions
193
+ # ddpms expect images to be in the range of -1 to 1
194
+
195
+ def normalize_neg_one_to_one(img):
196
+ return img * 2 - 1
197
+
198
+ def unnormalize_zero_to_one(normed_img):
199
+ return (normed_img + 1) * 0.5
200
+
201
+ # classifier free guidance functions
202
+
203
+ def prob_mask_like(shape, prob, device):
204
+ if prob == 1:
205
+ return torch.ones(shape, device = device, dtype = torch.bool)
206
+ elif prob == 0:
207
+ return torch.zeros(shape, device = device, dtype = torch.bool)
208
+ else:
209
+ return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
210
+
211
+ # gaussian diffusion with continuous time helper functions and classes
212
+ # large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
213
+
214
+ @torch.jit.script
215
+ def beta_linear_log_snr(t):
216
+ return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
217
+
218
+ @torch.jit.script
219
+ def alpha_cosine_log_snr(t, s: float = 0.008):
220
+ return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version
221
+
222
+ def log_snr_to_alpha_sigma(log_snr):
223
+ return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
224
+
225
+ class GaussianDiffusionContinuousTimes(nn.Module):
226
+ def __init__(self, *, noise_schedule, timesteps = 1000):
227
+ super().__init__()
228
+
229
+ if noise_schedule == "linear":
230
+ self.log_snr = beta_linear_log_snr
231
+ elif noise_schedule == "cosine":
232
+ self.log_snr = alpha_cosine_log_snr
233
+ else:
234
+ raise ValueError(f'invalid noise schedule {noise_schedule}')
235
+
236
+ self.num_timesteps = timesteps
237
+
238
+ def get_times(self, batch_size, noise_level, *, device):
239
+ return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32)
240
+
241
+ def sample_random_times(self, batch_size, *, device):
242
+ return torch.zeros((batch_size,), device = device).float().uniform_(0, 1)
243
+
244
+ def get_condition(self, times):
245
+ return maybe(self.log_snr)(times)
246
+
247
+ def get_sampling_timesteps(self, batch, *, device):
248
+ times = torch.linspace(1., 0., self.num_timesteps + 1, device = device)
249
+ times = repeat(times, 't -> b t', b = batch)
250
+ times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
251
+ times = times.unbind(dim = -1)
252
+ return times
253
+
254
+ def q_posterior(self, x_start, x_t, t, *, t_next = None):
255
+ t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))
256
+
257
+ """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
258
+ log_snr = self.log_snr(t)
259
+ log_snr_next = self.log_snr(t_next)
260
+ log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))
261
+
262
+ alpha, sigma = log_snr_to_alpha_sigma(log_snr)
263
+ alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
264
+
265
+ # c - as defined near eq 33
266
+ c = -expm1(log_snr - log_snr_next)
267
+ posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)
268
+
269
+ # following (eq. 33)
270
+ posterior_variance = (sigma_next ** 2) * c
271
+ posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
272
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
273
+
274
+ def q_sample(self, x_start, t, noise = None):
275
+ dtype = x_start.dtype
276
+
277
+ if isinstance(t, float):
278
+ batch = x_start.shape[0]
279
+ t = torch.full((batch,), t, device = x_start.device, dtype = dtype)
280
+
281
+ noise = default(noise, lambda: torch.randn_like(x_start))
282
+ log_snr = self.log_snr(t).type(dtype)
283
+ log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
284
+ alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
285
+
286
+ return alpha * x_start + sigma * noise, log_snr, alpha, sigma
287
+
288
+ def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
289
+ shape, device, dtype = x_from.shape, x_from.device, x_from.dtype
290
+ batch = shape[0]
291
+
292
+ if isinstance(from_t, float):
293
+ from_t = torch.full((batch,), from_t, device = device, dtype = dtype)
294
+
295
+ if isinstance(to_t, float):
296
+ to_t = torch.full((batch,), to_t, device = device, dtype = dtype)
297
+
298
+ noise = default(noise, lambda: torch.randn_like(x_from))
299
+
300
+ log_snr = self.log_snr(from_t)
301
+ log_snr_padded_dim = right_pad_dims_to(x_from, log_snr)
302
+ alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
303
+
304
+ log_snr_to = self.log_snr(to_t)
305
+ log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to)
306
+ alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to)
307
+
308
+ return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha
309
+
310
+ def predict_start_from_v(self, x_t, t, v):
311
+ log_snr = self.log_snr(t)
312
+ log_snr = right_pad_dims_to(x_t, log_snr)
313
+ alpha, sigma = log_snr_to_alpha_sigma(log_snr)
314
+ return alpha * x_t - sigma * v
315
+
316
+ def predict_start_from_noise(self, x_t, t, noise):
317
+ log_snr = self.log_snr(t)
318
+ log_snr = right_pad_dims_to(x_t, log_snr)
319
+ alpha, sigma = log_snr_to_alpha_sigma(log_snr)
320
+ return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)
321
+
322
+ # norms and residuals
323
+
324
+ class LayerNorm(nn.Module):
325
+ def __init__(self, feats, stable = False, dim = -1):
326
+ super().__init__()
327
+ self.stable = stable
328
+ self.dim = dim
329
+
330
+ self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1))))
331
+
332
+ def forward(self, x):
333
+ dtype, dim = x.dtype, self.dim
334
+
335
+ if self.stable:
336
+ x = x / x.amax(dim = dim, keepdim = True).detach()
337
+
338
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
339
+ var = torch.var(x, dim = dim, unbiased = False, keepdim = True)
340
+ mean = torch.mean(x, dim = dim, keepdim = True)
341
+
342
+ return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype)
343
+
344
+ ChanLayerNorm = partial(LayerNorm, dim = -3)
345
+
346
+ class Always():
347
+ def __init__(self, val):
348
+ self.val = val
349
+
350
+ def __call__(self, *args, **kwargs):
351
+ return self.val
352
+
353
+ class Residual(nn.Module):
354
+ def __init__(self, fn):
355
+ super().__init__()
356
+ self.fn = fn
357
+
358
+ def forward(self, x, **kwargs):
359
+ return self.fn(x, **kwargs) + x
360
+
361
+ class Parallel(nn.Module):
362
+ def __init__(self, *fns):
363
+ super().__init__()
364
+ self.fns = nn.ModuleList(fns)
365
+
366
+ def forward(self, x):
367
+ outputs = [fn(x) for fn in self.fns]
368
+ return sum(outputs)
369
+
370
+ # attention pooling
371
+
372
+ class PerceiverAttention(nn.Module):
373
+ def __init__(
374
+ self,
375
+ *,
376
+ dim,
377
+ dim_head = 64,
378
+ heads = 8,
379
+ scale = 8
380
+ ):
381
+ super().__init__()
382
+ self.scale = scale
383
+
384
+ self.heads = heads
385
+ inner_dim = dim_head * heads
386
+
387
+ self.norm = nn.LayerNorm(dim)
388
+ self.norm_latents = nn.LayerNorm(dim)
389
+
390
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
391
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
392
+
393
+ self.q_scale = nn.Parameter(torch.ones(dim_head))
394
+ self.k_scale = nn.Parameter(torch.ones(dim_head))
395
+
396
+ self.to_out = nn.Sequential(
397
+ nn.Linear(inner_dim, dim, bias = False),
398
+ nn.LayerNorm(dim)
399
+ )
400
+
401
+ def forward(self, x, latents, mask = None):
402
+ x = self.norm(x)
403
+ latents = self.norm_latents(latents)
404
+
405
+ b, h = x.shape[0], self.heads
406
+
407
+ q = self.to_q(latents)
408
+
409
+ # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
410
+ kv_input = torch.cat((x, latents), dim = -2)
411
+ k, v = self.to_kv(kv_input).chunk(2, dim = -1)
412
+
413
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
414
+
415
+ # qk rmsnorm
416
+
417
+ q, k = map(l2norm, (q, k))
418
+ q = q * self.q_scale
419
+ k = k * self.k_scale
420
+
421
+ # similarities and masking
422
+
423
+ sim = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
424
+
425
+ if exists(mask):
426
+ max_neg_value = -torch.finfo(sim.dtype).max
427
+ mask = F.pad(mask, (0, latents.shape[-2]), value = True)
428
+ mask = rearrange(mask, 'b j -> b 1 1 j')
429
+ sim = sim.masked_fill(~mask, max_neg_value)
430
+
431
+ # attention
432
+
433
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
434
+ attn = attn.to(sim.dtype)
435
+
436
+ out = einsum('... i j, ... j d -> ... i d', attn, v)
437
+ out = rearrange(out, 'b h n d -> b n (h d)', h = h)
438
+ return self.to_out(out)
439
+
440
+ class PerceiverResampler(nn.Module):
441
+ def __init__(
442
+ self,
443
+ *,
444
+ dim,
445
+ depth,
446
+ dim_head = 64,
447
+ heads = 8,
448
+ num_latents = 64,
449
+ num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
450
+ max_seq_len = 512,
451
+ ff_mult = 4
452
+ ):
453
+ super().__init__()
454
+ self.pos_emb = nn.Embedding(max_seq_len, dim)
455
+
456
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
457
+
458
+ self.to_latents_from_mean_pooled_seq = None
459
+
460
+ if num_latents_mean_pooled > 0:
461
+ self.to_latents_from_mean_pooled_seq = nn.Sequential(
462
+ LayerNorm(dim),
463
+ nn.Linear(dim, dim * num_latents_mean_pooled),
464
+ Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
465
+ )
466
+
467
+ self.layers = nn.ModuleList([])
468
+ for _ in range(depth):
469
+ self.layers.append(nn.ModuleList([
470
+ PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
471
+ FeedForward(dim = dim, mult = ff_mult)
472
+ ]))
473
+
474
+ def forward(self, x, mask = None):
475
+ n, device = x.shape[1], x.device
476
+ pos_emb = self.pos_emb(torch.arange(n, device = device))
477
+
478
+ x_with_pos = x + pos_emb
479
+
480
+ latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
481
+
482
+ if exists(self.to_latents_from_mean_pooled_seq):
483
+ meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
484
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
485
+ latents = torch.cat((meanpooled_latents, latents), dim = -2)
486
+
487
+ for attn, ff in self.layers:
488
+ latents = attn(x_with_pos, latents, mask = mask) + latents
489
+ latents = ff(latents) + latents
490
+
491
+ return latents
492
+
493
+ # attention
494
+
495
+ class Attention(nn.Module):
496
+ def __init__(
497
+ self,
498
+ dim,
499
+ *,
500
+ dim_head = 64,
501
+ heads = 8,
502
+ context_dim = None,
503
+ scale = 8
504
+ ):
505
+ super().__init__()
506
+ self.scale = scale
507
+
508
+ self.heads = heads
509
+ inner_dim = dim_head * heads
510
+
511
+ self.norm = LayerNorm(dim)
512
+
513
+ self.null_kv = nn.Parameter(torch.randn(2, dim_head))
514
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
515
+ self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
516
+
517
+ self.q_scale = nn.Parameter(torch.ones(dim_head))
518
+ self.k_scale = nn.Parameter(torch.ones(dim_head))
519
+
520
+ self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
521
+
522
+ self.to_out = nn.Sequential(
523
+ nn.Linear(inner_dim, dim, bias = False),
524
+ LayerNorm(dim)
525
+ )
526
+
527
+ def forward(self, x, context = None, mask = None, attn_bias = None):
528
+ b, n, device = *x.shape[:2], x.device
529
+
530
+ x = self.norm(x)
531
+
532
+ q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
533
+
534
+ q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
535
+
536
+ # add null key / value for classifier free guidance in prior net
537
+
538
+ nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
539
+ k = torch.cat((nk, k), dim = -2)
540
+ v = torch.cat((nv, v), dim = -2)
541
+
542
+ # add text conditioning, if present
543
+
544
+ if exists(context):
545
+ assert exists(self.to_context)
546
+ ck, cv = self.to_context(context).chunk(2, dim = -1)
547
+ k = torch.cat((ck, k), dim = -2)
548
+ v = torch.cat((cv, v), dim = -2)
549
+
550
+ # qk rmsnorm
551
+
552
+ q, k = map(l2norm, (q, k))
553
+ q = q * self.q_scale
554
+ k = k * self.k_scale
555
+
556
+ # calculate query / key similarities
557
+
558
+ sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale
559
+
560
+ # relative positional encoding (T5 style)
561
+
562
+ if exists(attn_bias):
563
+ sim = sim + attn_bias
564
+
565
+ # masking
566
+
567
+ max_neg_value = -torch.finfo(sim.dtype).max
568
+
569
+ if exists(mask):
570
+ mask = F.pad(mask, (1, 0), value = True)
571
+ mask = rearrange(mask, 'b j -> b 1 1 j')
572
+ sim = sim.masked_fill(~mask, max_neg_value)
573
+
574
+ # attention
575
+
576
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
577
+ attn = attn.to(sim.dtype)
578
+
579
+ # aggregate values
580
+
581
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
582
+
583
+ out = rearrange(out, 'b h n d -> b n (h d)')
584
+ return self.to_out(out)
585
+
586
+ # decoder
587
+
588
+ def Upsample(dim, dim_out = None):
589
+ dim_out = default(dim_out, dim)
590
+
591
+ return nn.Sequential(
592
+ nn.Upsample(scale_factor = 2, mode = 'nearest'),
593
+ nn.Conv2d(dim, dim_out, 3, padding = 1)
594
+ )
595
+
596
+ class PixelShuffleUpsample(nn.Module):
597
+ """
598
+ code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
599
+ https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
600
+ """
601
+ def __init__(self, dim, dim_out = None):
602
+ super().__init__()
603
+ dim_out = default(dim_out, dim)
604
+ conv = nn.Conv2d(dim, dim_out * 4, 1)
605
+
606
+ self.net = nn.Sequential(
607
+ conv,
608
+ nn.SiLU(),
609
+ nn.PixelShuffle(2)
610
+ )
611
+
612
+ self.init_conv_(conv)
613
+
614
+ def init_conv_(self, conv):
615
+ o, i, h, w = conv.weight.shape
616
+ conv_weight = torch.empty(o // 4, i, h, w)
617
+ nn.init.kaiming_uniform_(conv_weight)
618
+ conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
619
+
620
+ conv.weight.data.copy_(conv_weight)
621
+ nn.init.zeros_(conv.bias.data)
622
+
623
+ def forward(self, x):
624
+ return self.net(x)
625
+
626
+ def Downsample(dim, dim_out = None):
627
+ # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
628
+ # named SP-conv in the paper, but basically a pixel unshuffle
629
+ dim_out = default(dim_out, dim)
630
+ return nn.Sequential(
631
+ Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
632
+ nn.Conv2d(dim * 4, dim_out, 1)
633
+ )
634
+
635
+ class SinusoidalPosEmb(nn.Module):
636
+ def __init__(self, dim):
637
+ super().__init__()
638
+ self.dim = dim
639
+
640
+ def forward(self, x):
641
+ half_dim = self.dim // 2
642
+ emb = math.log(10000) / (half_dim - 1)
643
+ emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
644
+ emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
645
+ return torch.cat((emb.sin(), emb.cos()), dim = -1)
646
+
647
+ class LearnedSinusoidalPosEmb(nn.Module):
648
+ """ following @crowsonkb 's lead with learned sinusoidal pos emb """
649
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
650
+
651
+ def __init__(self, dim):
652
+ super().__init__()
653
+ assert (dim % 2) == 0
654
+ half_dim = dim // 2
655
+ self.weights = nn.Parameter(torch.randn(half_dim))
656
+
657
+ def forward(self, x):
658
+ x = rearrange(x, 'b -> b 1')
659
+ freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
660
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
661
+ fouriered = torch.cat((x, fouriered), dim = -1)
662
+ return fouriered
663
+
664
+ class Block(nn.Module):
665
+ def __init__(
666
+ self,
667
+ dim,
668
+ dim_out,
669
+ groups = 8,
670
+ norm = True
671
+ ):
672
+ super().__init__()
673
+ self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
674
+ self.activation = nn.SiLU()
675
+ self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
676
+
677
+ def forward(self, x, scale_shift = None):
678
+ x = self.groupnorm(x)
679
+
680
+ if exists(scale_shift):
681
+ scale, shift = scale_shift
682
+ x = x * (scale + 1) + shift
683
+
684
+ x = self.activation(x)
685
+ return self.project(x)
686
+
687
+ class ResnetBlock(nn.Module):
688
+ def __init__(
689
+ self,
690
+ dim,
691
+ dim_out,
692
+ *,
693
+ cond_dim = None,
694
+ time_cond_dim = None,
695
+ groups = 8,
696
+ linear_attn = False,
697
+ use_gca = False,
698
+ squeeze_excite = False,
699
+ **attn_kwargs
700
+ ):
701
+ super().__init__()
702
+
703
+ self.time_mlp = None
704
+
705
+ if exists(time_cond_dim):
706
+ self.time_mlp = nn.Sequential(
707
+ nn.SiLU(),
708
+ nn.Linear(time_cond_dim, dim_out * 2)
709
+ )
710
+
711
+ self.cross_attn = None
712
+
713
+ if exists(cond_dim):
714
+ attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
715
+
716
+ self.cross_attn = attn_klass(
717
+ dim = dim_out,
718
+ context_dim = cond_dim,
719
+ **attn_kwargs
720
+ )
721
+
722
+ self.block1 = Block(dim, dim_out, groups = groups)
723
+ self.block2 = Block(dim_out, dim_out, groups = groups)
724
+
725
+ self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
726
+
727
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()
728
+
729
+
730
+ def forward(self, x, time_emb = None, cond = None):
731
+
732
+ scale_shift = None
733
+ if exists(self.time_mlp) and exists(time_emb):
734
+ time_emb = self.time_mlp(time_emb)
735
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1')
736
+ scale_shift = time_emb.chunk(2, dim = 1)
737
+
738
+ h = self.block1(x)
739
+
740
+ if exists(self.cross_attn):
741
+ assert exists(cond)
742
+ h = rearrange(h, 'b c h w -> b h w c')
743
+ h, ps = pack([h], 'b * c')
744
+ h = self.cross_attn(h, context = cond) + h
745
+ h, = unpack(h, ps, 'b * c')
746
+ h = rearrange(h, 'b h w c -> b c h w')
747
+
748
+ h = self.block2(h, scale_shift = scale_shift)
749
+
750
+ h = h * self.gca(h)
751
+
752
+ return h + self.res_conv(x)
753
+
754
+ class CrossAttention(nn.Module):
755
+ def __init__(
756
+ self,
757
+ dim,
758
+ *,
759
+ context_dim = None,
760
+ dim_head = 64,
761
+ heads = 8,
762
+ norm_context = False,
763
+ scale = 8
764
+ ):
765
+ super().__init__()
766
+ self.scale = scale
767
+
768
+ self.heads = heads
769
+ inner_dim = dim_head * heads
770
+
771
+ context_dim = default(context_dim, dim)
772
+
773
+ self.norm = LayerNorm(dim)
774
+ self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
775
+
776
+ self.null_kv = nn.Parameter(torch.randn(2, dim_head))
777
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
778
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
779
+
780
+ self.q_scale = nn.Parameter(torch.ones(dim_head))
781
+ self.k_scale = nn.Parameter(torch.ones(dim_head))
782
+
783
+ self.to_out = nn.Sequential(
784
+ nn.Linear(inner_dim, dim, bias = False),
785
+ LayerNorm(dim)
786
+ )
787
+
788
+ def forward(self, x, context, mask = None):
789
+ b, n, device = *x.shape[:2], x.device
790
+
791
+ x = self.norm(x)
792
+ context = self.norm_context(context)
793
+
794
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
795
+
796
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
797
+
798
+ # add null key / value for classifier free guidance in prior net
799
+
800
+ nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
801
+
802
+ k = torch.cat((nk, k), dim = -2)
803
+ v = torch.cat((nv, v), dim = -2)
804
+
805
+ # cosine sim attention
806
+
807
+ q, k = map(l2norm, (q, k))
808
+ q = q * self.q_scale
809
+ k = k * self.k_scale
810
+
811
+ # similarities
812
+
813
+ sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
814
+
815
+ # masking
816
+
817
+ max_neg_value = -torch.finfo(sim.dtype).max
818
+
819
+ if exists(mask):
820
+ mask = F.pad(mask, (1, 0), value = True)
821
+ mask = rearrange(mask, 'b j -> b 1 1 j')
822
+ sim = sim.masked_fill(~mask, max_neg_value)
823
+
824
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
825
+ attn = attn.to(sim.dtype)
826
+
827
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
828
+ out = rearrange(out, 'b h n d -> b n (h d)')
829
+ return self.to_out(out)
830
+
831
+ class LinearCrossAttention(CrossAttention):
832
+ def forward(self, x, context, mask = None):
833
+ b, n, device = *x.shape[:2], x.device
834
+
835
+ x = self.norm(x)
836
+ context = self.norm_context(context)
837
+
838
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
839
+
840
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
841
+
842
+ # add null key / value for classifier free guidance in prior net
843
+
844
+ nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
845
+
846
+ k = torch.cat((nk, k), dim = -2)
847
+ v = torch.cat((nv, v), dim = -2)
848
+
849
+ # masking
850
+
851
+ max_neg_value = -torch.finfo(x.dtype).max
852
+
853
+ if exists(mask):
854
+ mask = F.pad(mask, (1, 0), value = True)
855
+ mask = rearrange(mask, 'b n -> b n 1')
856
+ k = k.masked_fill(~mask, max_neg_value)
857
+ v = v.masked_fill(~mask, 0.)
858
+
859
+ # linear attention
860
+
861
+ q = q.softmax(dim = -1)
862
+ k = k.softmax(dim = -2)
863
+
864
+ q = q * self.scale
865
+
866
+ context = einsum('b n d, b n e -> b d e', k, v)
867
+ out = einsum('b n d, b d e -> b n e', q, context)
868
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
869
+ return self.to_out(out)
870
+
871
+ class LinearAttention(nn.Module):
872
+ def __init__(
873
+ self,
874
+ dim,
875
+ dim_head = 32,
876
+ heads = 8,
877
+ dropout = 0.05,
878
+ context_dim = None,
879
+ **kwargs
880
+ ):
881
+ super().__init__()
882
+ self.scale = dim_head ** -0.5
883
+ self.heads = heads
884
+ inner_dim = dim_head * heads
885
+ self.norm = ChanLayerNorm(dim)
886
+
887
+ self.nonlin = nn.SiLU()
888
+
889
+ self.to_q = nn.Sequential(
890
+ nn.Dropout(dropout),
891
+ nn.Conv2d(dim, inner_dim, 1, bias = False),
892
+ nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
893
+ )
894
+
895
+ self.to_k = nn.Sequential(
896
+ nn.Dropout(dropout),
897
+ nn.Conv2d(dim, inner_dim, 1, bias = False),
898
+ nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
899
+ )
900
+
901
+ self.to_v = nn.Sequential(
902
+ nn.Dropout(dropout),
903
+ nn.Conv2d(dim, inner_dim, 1, bias = False),
904
+ nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
905
+ )
906
+
907
+ self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
908
+
909
+ self.to_out = nn.Sequential(
910
+ nn.Conv2d(inner_dim, dim, 1, bias = False),
911
+ ChanLayerNorm(dim)
912
+ )
913
+
914
+ def forward(self, fmap, context = None):
915
+ h, x, y = self.heads, *fmap.shape[-2:]
916
+
917
+ fmap = self.norm(fmap)
918
+ q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
919
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
920
+
921
+ if exists(context):
922
+ assert exists(self.to_context)
923
+ ck, cv = self.to_context(context).chunk(2, dim = -1)
924
+ ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv))
925
+ k = torch.cat((k, ck), dim = -2)
926
+ v = torch.cat((v, cv), dim = -2)
927
+
928
+ q = q.softmax(dim = -1)
929
+ k = k.softmax(dim = -2)
930
+
931
+ q = q * self.scale
932
+
933
+ context = einsum('b n d, b n e -> b d e', k, v)
934
+ out = einsum('b n d, b d e -> b n e', q, context)
935
+ out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
936
+
937
+ out = self.nonlin(out)
938
+ return self.to_out(out)
939
+
940
+ class GlobalContext(nn.Module):
941
+ """ basically a superior form of squeeze-excitation that is attention-esque """
942
+
943
+ def __init__(
944
+ self,
945
+ *,
946
+ dim_in,
947
+ dim_out
948
+ ):
949
+ super().__init__()
950
+ self.to_k = nn.Conv2d(dim_in, 1, 1)
951
+ hidden_dim = max(3, dim_out // 2)
952
+
953
+ self.net = nn.Sequential(
954
+ nn.Conv2d(dim_in, hidden_dim, 1),
955
+ nn.SiLU(),
956
+ nn.Conv2d(hidden_dim, dim_out, 1),
957
+ nn.Sigmoid()
958
+ )
959
+
960
+ def forward(self, x):
961
+ context = self.to_k(x)
962
+ x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context))
963
+ out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
964
+ out = rearrange(out, '... -> ... 1')
965
+ return self.net(out)
966
+
967
+ def FeedForward(dim, mult = 2):
968
+ hidden_dim = int(dim * mult)
969
+ return nn.Sequential(
970
+ LayerNorm(dim),
971
+ nn.Linear(dim, hidden_dim, bias = False),
972
+ nn.GELU(),
973
+ LayerNorm(hidden_dim),
974
+ nn.Linear(hidden_dim, dim, bias = False)
975
+ )
976
+
977
+ def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width
978
+ hidden_dim = int(dim * mult)
979
+ return nn.Sequential(
980
+ ChanLayerNorm(dim),
981
+ nn.Conv2d(dim, hidden_dim, 1, bias = False),
982
+ nn.GELU(),
983
+ ChanLayerNorm(hidden_dim),
984
+ nn.Conv2d(hidden_dim, dim, 1, bias = False)
985
+ )
986
+
987
+ class TransformerBlock(nn.Module):
988
+ def __init__(
989
+ self,
990
+ dim,
991
+ *,
992
+ depth = 1,
993
+ heads = 8,
994
+ dim_head = 32,
995
+ ff_mult = 2,
996
+ context_dim = None
997
+ ):
998
+ super().__init__()
999
+ self.layers = nn.ModuleList([])
1000
+
1001
+ for _ in range(depth):
1002
+ self.layers.append(nn.ModuleList([
1003
+ Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
1004
+ FeedForward(dim = dim, mult = ff_mult)
1005
+ ]))
1006
+
1007
+ def forward(self, x, context = None):
1008
+ x = rearrange(x, 'b c h w -> b h w c')
1009
+ x, ps = pack([x], 'b * c')
1010
+
1011
+ for attn, ff in self.layers:
1012
+ x = attn(x, context = context) + x
1013
+ x = ff(x) + x
1014
+
1015
+ x, = unpack(x, ps, 'b * c')
1016
+ x = rearrange(x, 'b h w c -> b c h w')
1017
+ return x
1018
+
1019
+ class LinearAttentionTransformerBlock(nn.Module):
1020
+ def __init__(
1021
+ self,
1022
+ dim,
1023
+ *,
1024
+ depth = 1,
1025
+ heads = 8,
1026
+ dim_head = 32,
1027
+ ff_mult = 2,
1028
+ context_dim = None,
1029
+ **kwargs
1030
+ ):
1031
+ super().__init__()
1032
+ self.layers = nn.ModuleList([])
1033
+
1034
+ for _ in range(depth):
1035
+ self.layers.append(nn.ModuleList([
1036
+ LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
1037
+ ChanFeedForward(dim = dim, mult = ff_mult)
1038
+ ]))
1039
+
1040
+ def forward(self, x, context = None):
1041
+ for attn, ff in self.layers:
1042
+ x = attn(x, context = context) + x
1043
+ x = ff(x) + x
1044
+ return x
1045
+
1046
+ class CrossEmbedLayer(nn.Module):
1047
+ def __init__(
1048
+ self,
1049
+ dim_in,
1050
+ kernel_sizes,
1051
+ dim_out = None,
1052
+ stride = 2
1053
+ ):
1054
+ super().__init__()
1055
+ assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
1056
+ dim_out = default(dim_out, dim_in)
1057
+
1058
+ kernel_sizes = sorted(kernel_sizes)
1059
+ num_scales = len(kernel_sizes)
1060
+
1061
+ # calculate the dimension at each scale
1062
+ dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
1063
+ dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
1064
+
1065
+ self.convs = nn.ModuleList([])
1066
+ for kernel, dim_scale in zip(kernel_sizes, dim_scales):
1067
+ self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
1068
+
1069
+ def forward(self, x):
1070
+ fmaps = tuple(map(lambda conv: conv(x), self.convs))
1071
+ return torch.cat(fmaps, dim = 1)
1072
+
1073
+ class UpsampleCombiner(nn.Module):
1074
+ def __init__(
1075
+ self,
1076
+ dim,
1077
+ *,
1078
+ enabled = False,
1079
+ dim_ins = tuple(),
1080
+ dim_outs = tuple()
1081
+ ):
1082
+ super().__init__()
1083
+ dim_outs = cast_tuple(dim_outs, len(dim_ins))
1084
+ assert len(dim_ins) == len(dim_outs)
1085
+
1086
+ self.enabled = enabled
1087
+
1088
+ if not self.enabled:
1089
+ self.dim_out = dim
1090
+ return
1091
+
1092
+ self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
1093
+ self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
1094
+
1095
+ def forward(self, x, fmaps = None):
1096
+ target_size = x.shape[-1]
1097
+
1098
+ fmaps = default(fmaps, tuple())
1099
+
1100
+ if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
1101
+ return x
1102
+
1103
+ fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
1104
+ outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
1105
+ return torch.cat((x, *outs), dim = 1)
1106
+
1107
+ class Unet(nn.Module):
1108
+ def __init__(
1109
+ self,
1110
+ *,
1111
+ dim,
1112
+ text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
1113
+ num_resnet_blocks = 1,
1114
+ cond_dim = None,
1115
+ num_image_tokens = 4,
1116
+ num_time_tokens = 2,
1117
+ learned_sinu_pos_emb_dim = 16,
1118
+ out_dim = None,
1119
+ dim_mults=(1, 2, 4, 8),
1120
+ cond_images_channels = 0,
1121
+ channels = 3,
1122
+ channels_out = None,
1123
+ attn_dim_head = 64,
1124
+ attn_heads = 8,
1125
+ ff_mult = 2.,
1126
+ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
1127
+ layer_attns = True,
1128
+ layer_attns_depth = 1,
1129
+ layer_mid_attns_depth = 1,
1130
+ layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
1131
+ attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
1132
+ layer_cross_attns = True,
1133
+ use_linear_attn = False,
1134
+ use_linear_cross_attn = False,
1135
+ cond_on_text = True,
1136
+ max_text_len = 256,
1137
+ init_dim = None,
1138
+ resnet_groups = 8,
1139
+ init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
1140
+ init_cross_embed = True,
1141
+ init_cross_embed_kernel_sizes = (3, 7, 15),
1142
+ cross_embed_downsample = False,
1143
+ cross_embed_downsample_kernel_sizes = (2, 4),
1144
+ attn_pool_text = True,
1145
+ attn_pool_num_latents = 32,
1146
+ dropout = 0.,
1147
+ memory_efficient = False,
1148
+ init_conv_to_final_conv_residual = False,
1149
+ use_global_context_attn = True,
1150
+ scale_skip_connection = True,
1151
+ final_resnet_block = True,
1152
+ final_conv_kernel_size = 3,
1153
+ self_cond = False,
1154
+ resize_mode = 'nearest',
1155
+ combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
1156
+ pixel_shuffle_upsample = True, # may address checkboard artifacts
1157
+ ):
1158
+ super().__init__()
1159
+
1160
+ # guide researchers
1161
+
1162
+ assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
1163
+
1164
+ if dim < 128:
1165
+ print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
1166
+
1167
+ # save locals to take care of some hyperparameters for cascading DDPM
1168
+
1169
+ self._locals = locals()
1170
+ self._locals.pop('self', None)
1171
+ self._locals.pop('__class__', None)
1172
+
1173
+ # determine dimensions
1174
+
1175
+ self.channels = channels
1176
+ self.channels_out = default(channels_out, channels)
1177
+
1178
+ # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
1179
+ # (2) in self conditioning, one appends the predict x0 (x_start)
1180
+ init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
1181
+ init_dim = default(init_dim, dim)
1182
+
1183
+ self.self_cond = self_cond
1184
+
1185
+ # optional image conditioning
1186
+
1187
+ self.has_cond_image = cond_images_channels > 0
1188
+ self.cond_images_channels = cond_images_channels
1189
+
1190
+ init_channels += cond_images_channels
1191
+
1192
+ # initial convolution
1193
+
1194
+ self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
1195
+
1196
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
1197
+ in_out = list(zip(dims[:-1], dims[1:]))
1198
+
1199
+ # time conditioning
1200
+
1201
+ cond_dim = default(cond_dim, dim)
1202
+ time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
1203
+
1204
+ # embedding time for log(snr) noise from continuous version
1205
+
1206
+ sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
1207
+ sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
1208
+
1209
+ self.to_time_hiddens = nn.Sequential(
1210
+ sinu_pos_emb,
1211
+ nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
1212
+ nn.SiLU()
1213
+ )
1214
+
1215
+ self.to_time_cond = nn.Sequential(
1216
+ nn.Linear(time_cond_dim, time_cond_dim)
1217
+ )
1218
+
1219
+ # project to time tokens as well as time hiddens
1220
+
1221
+ self.to_time_tokens = nn.Sequential(
1222
+ nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
1223
+ Rearrange('b (r d) -> b r d', r = num_time_tokens)
1224
+ )
1225
+
1226
+ # low res aug noise conditioning
1227
+
1228
+ self.lowres_cond = lowres_cond
1229
+
1230
+ if lowres_cond:
1231
+ self.to_lowres_time_hiddens = nn.Sequential(
1232
+ LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
1233
+ nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
1234
+ nn.SiLU()
1235
+ )
1236
+
1237
+ self.to_lowres_time_cond = nn.Sequential(
1238
+ nn.Linear(time_cond_dim, time_cond_dim)
1239
+ )
1240
+
1241
+ self.to_lowres_time_tokens = nn.Sequential(
1242
+ nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
1243
+ Rearrange('b (r d) -> b r d', r = num_time_tokens)
1244
+ )
1245
+
1246
+ # normalizations
1247
+
1248
+ self.norm_cond = nn.LayerNorm(cond_dim)
1249
+
1250
+ # text encoding conditioning (optional)
1251
+
1252
+ self.text_to_cond = None
1253
+
1254
+ if cond_on_text:
1255
+ assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
1256
+ self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
1257
+
1258
+ # finer control over whether to condition on text encodings
1259
+
1260
+ self.cond_on_text = cond_on_text
1261
+
1262
+ # attention pooling
1263
+
1264
+ self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents) if attn_pool_text else None
1265
+
1266
+ # for classifier free guidance
1267
+
1268
+ self.max_text_len = max_text_len
1269
+
1270
+ self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
1271
+ self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
1272
+
1273
+ # for non-attention based text conditioning at all points in the network where time is also conditioned
1274
+
1275
+ self.to_text_non_attn_cond = None
1276
+
1277
+ if cond_on_text:
1278
+ self.to_text_non_attn_cond = nn.Sequential(
1279
+ nn.LayerNorm(cond_dim),
1280
+ nn.Linear(cond_dim, time_cond_dim),
1281
+ nn.SiLU(),
1282
+ nn.Linear(time_cond_dim, time_cond_dim)
1283
+ )
1284
+
1285
+ # attention related params
1286
+
1287
+ attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
1288
+
1289
+ num_layers = len(in_out)
1290
+
1291
+ # resnet block klass
1292
+
1293
+ num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
1294
+ resnet_groups = cast_tuple(resnet_groups, num_layers)
1295
+
1296
+ resnet_klass = partial(ResnetBlock, **attn_kwargs)
1297
+
1298
+ layer_attns = cast_tuple(layer_attns, num_layers)
1299
+ layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
1300
+ layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
1301
+
1302
+ use_linear_attn = cast_tuple(use_linear_attn, num_layers)
1303
+ use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers)
1304
+
1305
+ assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
1306
+
1307
+ # downsample klass
1308
+
1309
+ downsample_klass = Downsample
1310
+
1311
+ if cross_embed_downsample:
1312
+ downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
1313
+
1314
+ # initial resnet block (for memory efficient unet)
1315
+
1316
+ self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
1317
+
1318
+ # scale for resnet skip connections
1319
+
1320
+ self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
1321
+
1322
+ # layers
1323
+
1324
+ self.downs = nn.ModuleList([])
1325
+ self.ups = nn.ModuleList([])
1326
+ num_resolutions = len(in_out)
1327
+
1328
+ layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn]
1329
+ reversed_layer_params = list(map(reversed, layer_params))
1330
+
1331
+ # downsampling layers
1332
+
1333
+ skip_connect_dims = [] # keep track of skip connection dimensions
1334
+
1335
+ for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)):
1336
+ is_last = ind >= (num_resolutions - 1)
1337
+
1338
+ layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
1339
+
1340
+ if layer_attn:
1341
+ transformer_block_klass = TransformerBlock
1342
+ elif layer_use_linear_attn:
1343
+ transformer_block_klass = LinearAttentionTransformerBlock
1344
+ else:
1345
+ transformer_block_klass = Identity
1346
+
1347
+ current_dim = dim_in
1348
+
1349
+ # whether to pre-downsample, from memory efficient unet
1350
+
1351
+ pre_downsample = None
1352
+
1353
+ if memory_efficient:
1354
+ pre_downsample = downsample_klass(dim_in, dim_out)
1355
+ current_dim = dim_out
1356
+
1357
+ skip_connect_dims.append(current_dim)
1358
+
1359
+ # whether to do post-downsample, for non-memory efficient unet
1360
+
1361
+ post_downsample = None
1362
+ if not memory_efficient:
1363
+ post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv2d(dim_in, dim_out, 3, padding = 1), nn.Conv2d(dim_in, dim_out, 1))
1364
+
1365
+ self.downs.append(nn.ModuleList([
1366
+ pre_downsample,
1367
+ resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
1368
+ nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
1369
+ transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
1370
+ post_downsample
1371
+ ]))
1372
+
1373
+ # middle layers
1374
+
1375
+ mid_dim = dims[-1]
1376
+
1377
+ self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
1378
+ self.mid_attn = TransformerBlock(mid_dim, depth = layer_mid_attns_depth, **attn_kwargs) if attend_at_middle else None
1379
+ self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
1380
+
1381
+ # upsample klass
1382
+
1383
+ upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
1384
+
1385
+ # upsampling layers
1386
+
1387
+ upsample_fmap_dims = []
1388
+
1389
+ for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
1390
+ is_last = ind == (len(in_out) - 1)
1391
+
1392
+ layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
1393
+
1394
+ if layer_attn:
1395
+ transformer_block_klass = TransformerBlock
1396
+ elif layer_use_linear_attn:
1397
+ transformer_block_klass = LinearAttentionTransformerBlock
1398
+ else:
1399
+ transformer_block_klass = Identity
1400
+
1401
+ skip_connect_dim = skip_connect_dims.pop()
1402
+
1403
+ upsample_fmap_dims.append(dim_out)
1404
+
1405
+ self.ups.append(nn.ModuleList([
1406
+ resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
1407
+ nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
1408
+ transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
1409
+ upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
1410
+ ]))
1411
+
1412
+ # whether to combine feature maps from all upsample blocks before final resnet block out
1413
+
1414
+ self.upsample_combiner = UpsampleCombiner(
1415
+ dim = dim,
1416
+ enabled = combine_upsample_fmaps,
1417
+ dim_ins = upsample_fmap_dims,
1418
+ dim_outs = dim
1419
+ )
1420
+
1421
+ # whether to do a final residual from initial conv to the final resnet block out
1422
+
1423
+ self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
1424
+ final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
1425
+
1426
+ # final optional resnet block and convolution out
1427
+
1428
+ self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
1429
+
1430
+ final_conv_dim_in = dim if final_resnet_block else final_conv_dim
1431
+ final_conv_dim_in += (channels if lowres_cond else 0)
1432
+
1433
+ self.final_conv = nn.Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)
1434
+
1435
+ zero_init_(self.final_conv)
1436
+
1437
+ # resize mode
1438
+
1439
+ self.resize_mode = resize_mode
1440
+
1441
+ # if the current settings for the unet are not correct
1442
+ # for cascading DDPM, then reinit the unet with the right settings
1443
+ def cast_model_parameters(
1444
+ self,
1445
+ *,
1446
+ lowres_cond,
1447
+ text_embed_dim,
1448
+ channels,
1449
+ channels_out,
1450
+ cond_on_text
1451
+ ):
1452
+ if lowres_cond == self.lowres_cond and \
1453
+ channels == self.channels and \
1454
+ cond_on_text == self.cond_on_text and \
1455
+ text_embed_dim == self._locals['text_embed_dim'] and \
1456
+ channels_out == self.channels_out:
1457
+ return self
1458
+
1459
+ updated_kwargs = dict(
1460
+ lowres_cond = lowres_cond,
1461
+ text_embed_dim = text_embed_dim,
1462
+ channels = channels,
1463
+ channels_out = channels_out,
1464
+ cond_on_text = cond_on_text
1465
+ )
1466
+
1467
+ return self.__class__(**{**self._locals, **updated_kwargs})
1468
+
1469
+ # methods for returning the full unet config as well as its parameter state
1470
+
1471
+ def to_config_and_state_dict(self):
1472
+ return self._locals, self.state_dict()
1473
+
1474
+ # class method for rehydrating the unet from its config and state dict
1475
+
1476
+ @classmethod
1477
+ def from_config_and_state_dict(klass, config, state_dict):
1478
+ unet = klass(**config)
1479
+ unet.load_state_dict(state_dict)
1480
+ return unet
1481
+
1482
+ # methods for persisting unet to disk
1483
+
1484
+ def persist_to_file(self, path):
1485
+ path = Path(path)
1486
+ path.parents[0].mkdir(exist_ok = True, parents = True)
1487
+
1488
+ config, state_dict = self.to_config_and_state_dict()
1489
+ pkg = dict(config = config, state_dict = state_dict)
1490
+ torch.save(pkg, str(path))
1491
+
1492
+ # class method for rehydrating the unet from file saved with `persist_to_file`
1493
+
1494
+ @classmethod
1495
+ def hydrate_from_file(klass, path):
1496
+ path = Path(path)
1497
+ assert path.exists()
1498
+ pkg = torch.load(str(path))
1499
+
1500
+ assert 'config' in pkg and 'state_dict' in pkg
1501
+ config, state_dict = pkg['config'], pkg['state_dict']
1502
+
1503
+ return Unet.from_config_and_state_dict(config, state_dict)
1504
+
1505
+ # forward with classifier free guidance
1506
+
1507
+ def forward_with_cond_scale(
1508
+ self,
1509
+ *args,
1510
+ cond_scale = 1.,
1511
+ **kwargs
1512
+ ):
1513
+ logits = self.forward(*args, **kwargs)
1514
+
1515
+ if cond_scale == 1:
1516
+ return logits
1517
+
1518
+ null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
1519
+ return null_logits + (logits - null_logits) * cond_scale
1520
+
1521
+ def forward(
1522
+ self,
1523
+ x,
1524
+ time,
1525
+ *,
1526
+ lowres_cond_img = None,
1527
+ lowres_noise_times = None,
1528
+ text_embeds = None,
1529
+ text_mask = None,
1530
+ cond_images = None,
1531
+ self_cond = None,
1532
+ cond_drop_prob = 0.
1533
+ ):
1534
+ batch_size, device = x.shape[0], x.device
1535
+
1536
+ # condition on self
1537
+
1538
+ if self.self_cond:
1539
+ self_cond = default(self_cond, lambda: torch.zeros_like(x))
1540
+ x = torch.cat((x, self_cond), dim = 1)
1541
+
1542
+ # add low resolution conditioning, if present
1543
+
1544
+ assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
1545
+ assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'
1546
+
1547
+ if exists(lowres_cond_img):
1548
+ x = torch.cat((x, lowres_cond_img), dim = 1)
1549
+
1550
+ # condition on input image
1551
+
1552
+ assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
1553
+
1554
+ if exists(cond_images):
1555
+ assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
1556
+ cond_images = resize_image_to(cond_images, x.shape[-1], mode = self.resize_mode)
1557
+ x = torch.cat((cond_images, x), dim = 1)
1558
+
1559
+ # initial convolution
1560
+
1561
+ x = self.init_conv(x)
1562
+
1563
+ # init conv residual
1564
+
1565
+ if self.init_conv_to_final_conv_residual:
1566
+ init_conv_residual = x.clone()
1567
+
1568
+ # time conditioning
1569
+
1570
+ time_hiddens = self.to_time_hiddens(time)
1571
+
1572
+ # derive time tokens
1573
+
1574
+ time_tokens = self.to_time_tokens(time_hiddens)
1575
+ t = self.to_time_cond(time_hiddens)
1576
+
1577
+ # add lowres time conditioning to time hiddens
1578
+ # and add lowres time tokens along sequence dimension for attention
1579
+
1580
+ if self.lowres_cond:
1581
+ lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
1582
+ lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
1583
+ lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
1584
+
1585
+ t = t + lowres_t
1586
+ time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)
1587
+
1588
+ # text conditioning
1589
+
1590
+ text_tokens = None
1591
+
1592
+ if exists(text_embeds) and self.cond_on_text:
1593
+
1594
+ # conditional dropout
1595
+
1596
+ text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
1597
+
1598
+ text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
1599
+ text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
1600
+
1601
+ # calculate text embeds
1602
+
1603
+ text_tokens = self.text_to_cond(text_embeds)
1604
+
1605
+ text_tokens = text_tokens[:, :self.max_text_len]
1606
+
1607
+ if exists(text_mask):
1608
+ text_mask = text_mask[:, :self.max_text_len]
1609
+
1610
+ text_tokens_len = text_tokens.shape[1]
1611
+ remainder = self.max_text_len - text_tokens_len
1612
+
1613
+ if remainder > 0:
1614
+ text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
1615
+
1616
+ if exists(text_mask):
1617
+ if remainder > 0:
1618
+ text_mask = F.pad(text_mask, (0, remainder), value = False)
1619
+
1620
+ text_mask = rearrange(text_mask, 'b n -> b n 1')
1621
+ text_keep_mask_embed = text_mask & text_keep_mask_embed
1622
+
1623
+ null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
1624
+
1625
+ text_tokens = torch.where(
1626
+ text_keep_mask_embed,
1627
+ text_tokens,
1628
+ null_text_embed
1629
+ )
1630
+
1631
+ if exists(self.attn_pool):
1632
+ text_tokens = self.attn_pool(text_tokens)
1633
+
1634
+ # extra non-attention conditioning by projecting and then summing text embeddings to time
1635
+ # termed as text hiddens
1636
+
1637
+ mean_pooled_text_tokens = text_tokens.mean(dim = -2)
1638
+
1639
+ text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
1640
+
1641
+ null_text_hidden = self.null_text_hidden.to(t.dtype)
1642
+
1643
+ text_hiddens = torch.where(
1644
+ text_keep_mask_hidden,
1645
+ text_hiddens,
1646
+ null_text_hidden
1647
+ )
1648
+
1649
+ t = t + text_hiddens
1650
+
1651
+ # main conditioning tokens (c)
1652
+
1653
+ c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)
1654
+
1655
+ # normalize conditioning tokens
1656
+
1657
+ c = self.norm_cond(c)
1658
+
1659
+ # initial resnet block (for memory efficient unet)
1660
+
1661
+ if exists(self.init_resnet_block):
1662
+ x = self.init_resnet_block(x, t)
1663
+
1664
+ # go through the layers of the unet, down and up
1665
+
1666
+ hiddens = []
1667
+
1668
+ for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs:
1669
+ if exists(pre_downsample):
1670
+ x = pre_downsample(x)
1671
+
1672
+ x = init_block(x, t, c)
1673
+
1674
+ for resnet_block in resnet_blocks:
1675
+ x = resnet_block(x, t)
1676
+ hiddens.append(x)
1677
+
1678
+ x = attn_block(x, c)
1679
+ hiddens.append(x)
1680
+
1681
+ if exists(post_downsample):
1682
+ x = post_downsample(x)
1683
+
1684
+ x = self.mid_block1(x, t, c)
1685
+
1686
+ if exists(self.mid_attn):
1687
+ x = self.mid_attn(x)
1688
+
1689
+ x = self.mid_block2(x, t, c)
1690
+
1691
+ add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
1692
+
1693
+ up_hiddens = []
1694
+
1695
+ for init_block, resnet_blocks, attn_block, upsample in self.ups:
1696
+ x = add_skip_connection(x)
1697
+ x = init_block(x, t, c)
1698
+
1699
+ for resnet_block in resnet_blocks:
1700
+ x = add_skip_connection(x)
1701
+ x = resnet_block(x, t)
1702
+
1703
+ x = attn_block(x, c)
1704
+ up_hiddens.append(x.contiguous())
1705
+ x = upsample(x)
1706
+
1707
+ # whether to combine all feature maps from upsample blocks
1708
+
1709
+ x = self.upsample_combiner(x, up_hiddens)
1710
+
1711
+ # final top-most residual if needed
1712
+
1713
+ if self.init_conv_to_final_conv_residual:
1714
+ x = torch.cat((x, init_conv_residual), dim = 1)
1715
+
1716
+ if exists(self.final_res_block):
1717
+ x = self.final_res_block(x, t)
1718
+
1719
+ if exists(lowres_cond_img):
1720
+ x = torch.cat((x, lowres_cond_img), dim = 1)
1721
+
1722
+ return self.final_conv(x)
1723
+
1724
+ # null unet
1725
+
1726
+ class NullUnet(nn.Module):
1727
+ def __init__(self, *args, **kwargs):
1728
+ super().__init__()
1729
+ self.lowres_cond = False
1730
+ self.dummy_parameter = nn.Parameter(torch.tensor([0.]))
1731
+
1732
+ def cast_model_parameters(self, *args, **kwargs):
1733
+ return self
1734
+
1735
+ def forward(self, x, *args, **kwargs):
1736
+ return x
1737
+
1738
+ # predefined unets, with configs lining up with hyperparameters in appendix of paper
1739
+
1740
+ class BaseUnet64(Unet):
1741
+ def __init__(self, *args, **kwargs):
1742
+ default_kwargs = dict(
1743
+ dim = 512,
1744
+ dim_mults = (1, 2, 3, 4),
1745
+ num_resnet_blocks = 3,
1746
+ layer_attns = (False, True, True, True),
1747
+ layer_cross_attns = (False, True, True, True),
1748
+ attn_heads = 8,
1749
+ ff_mult = 2.,
1750
+ memory_efficient = False
1751
+ )
1752
+ super().__init__(*args, **{**default_kwargs, **kwargs})
1753
+
1754
+ class SRUnet256(Unet):
1755
+ def __init__(self, *args, **kwargs):
1756
+ default_kwargs = dict(
1757
+ dim = 128,
1758
+ dim_mults = (1, 2, 4, 8),
1759
+ num_resnet_blocks = (2, 4, 8, 8),
1760
+ layer_attns = (False, False, False, True),
1761
+ layer_cross_attns = (False, False, False, True),
1762
+ attn_heads = 8,
1763
+ ff_mult = 2.,
1764
+ memory_efficient = True
1765
+ )
1766
+ super().__init__(*args, **{**default_kwargs, **kwargs})
1767
+
1768
+ class SRUnet1024(Unet):
1769
+ def __init__(self, *args, **kwargs):
1770
+ default_kwargs = dict(
1771
+ dim = 128,
1772
+ dim_mults = (1, 2, 4, 8),
1773
+ num_resnet_blocks = (2, 4, 8, 8),
1774
+ layer_attns = False,
1775
+ layer_cross_attns = (False, False, False, True),
1776
+ attn_heads = 8,
1777
+ ff_mult = 2.,
1778
+ memory_efficient = True
1779
+ )
1780
+ super().__init__(*args, **{**default_kwargs, **kwargs})
1781
+
1782
+ # main imagen ddpm class, which is a cascading DDPM from Ho et al.
1783
+
1784
+ class Imagen(nn.Module):
1785
+ def __init__(
1786
+ self,
1787
+ unets,
1788
+ *,
1789
+ image_sizes, # for cascading ddpm, image size at each stage
1790
+ text_encoder_name = DEFAULT_T5_NAME,
1791
+ text_embed_dim = None,
1792
+ channels = 3,
1793
+ timesteps = 1000,
1794
+ cond_drop_prob = 0.1,
1795
+ loss_type = 'l2',
1796
+ noise_schedules = 'cosine',
1797
+ pred_objectives = 'noise',
1798
+ random_crop_sizes = None,
1799
+ lowres_noise_schedule = 'linear',
1800
+ lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
1801
+ per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
1802
+ condition_on_text = True,
1803
+ auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
1804
+ dynamic_thresholding = True,
1805
+ dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
1806
+ only_train_unet_number = None,
1807
+ temporal_downsample_factor = 1,
1808
+ resize_cond_video_frames = True,
1809
+ resize_mode = 'nearest',
1810
+ min_snr_loss_weight = True, # https://arxiv.org/abs/2303.09556
1811
+ min_snr_gamma = 5
1812
+ ):
1813
+ super().__init__()
1814
+
1815
+ # loss
1816
+
1817
+ if loss_type == 'l1':
1818
+ loss_fn = F.l1_loss
1819
+ elif loss_type == 'l2':
1820
+ loss_fn = F.mse_loss
1821
+ elif loss_type == 'huber':
1822
+ loss_fn = F.smooth_l1_loss
1823
+ else:
1824
+ raise NotImplementedError()
1825
+
1826
+ self.loss_type = loss_type
1827
+ self.loss_fn = loss_fn
1828
+
1829
+ # conditioning hparams
1830
+
1831
+ self.condition_on_text = condition_on_text
1832
+ self.unconditional = not condition_on_text
1833
+
1834
+ # channels
1835
+
1836
+ self.channels = channels
1837
+
1838
+ # automatically take care of ensuring that first unet is unconditional
1839
+ # while the rest of the unets are conditioned on the low resolution image produced by previous unet
1840
+
1841
+ unets = cast_tuple(unets)
1842
+ num_unets = len(unets)
1843
+
1844
+ # determine noise schedules per unet
1845
+
1846
+ timesteps = cast_tuple(timesteps, num_unets)
1847
+
1848
+ # make sure noise schedule defaults to 'cosine', 'cosine', and then 'linear' for rest of super-resoluting unets
1849
+
1850
+ noise_schedules = cast_tuple(noise_schedules)
1851
+ noise_schedules = pad_tuple_to_length(noise_schedules, 2, 'cosine')
1852
+ noise_schedules = pad_tuple_to_length(noise_schedules, num_unets, 'linear')
1853
+
1854
+ # construct noise schedulers
1855
+
1856
+ noise_scheduler_klass = GaussianDiffusionContinuousTimes
1857
+ self.noise_schedulers = nn.ModuleList([])
1858
+
1859
+ for timestep, noise_schedule in zip(timesteps, noise_schedules):
1860
+ noise_scheduler = noise_scheduler_klass(noise_schedule = noise_schedule, timesteps = timestep)
1861
+ self.noise_schedulers.append(noise_scheduler)
1862
+
1863
+ # randomly cropping for upsampler training
1864
+
1865
+ self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
1866
+ assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
1867
+
1868
+ # lowres augmentation noise schedule
1869
+
1870
+ self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule)
1871
+
1872
+ # ddpm objectives - predicting noise by default
1873
+
1874
+ self.pred_objectives = cast_tuple(pred_objectives, num_unets)
1875
+
1876
+ # get text encoder
1877
+
1878
+ self.text_encoder_name = text_encoder_name
1879
+ self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))
1880
+
1881
+ self.encode_text = partial(t5_encode_text, name = text_encoder_name)
1882
+
1883
+ # construct unets
1884
+
1885
+ self.unets = nn.ModuleList([])
1886
+
1887
+ self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
1888
+ self.only_train_unet_number = only_train_unet_number
1889
+
1890
+ for ind, one_unet in enumerate(unets):
1891
+ assert isinstance(one_unet, (Unet, Unet3D, NullUnet))
1892
+ is_first = ind == 0
1893
+
1894
+ one_unet = one_unet.cast_model_parameters(
1895
+ lowres_cond = not is_first,
1896
+ cond_on_text = self.condition_on_text,
1897
+ text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
1898
+ channels = self.channels,
1899
+ channels_out = self.channels
1900
+ )
1901
+
1902
+ self.unets.append(one_unet)
1903
+
1904
+ # unet image sizes
1905
+
1906
+ image_sizes = cast_tuple(image_sizes)
1907
+ self.image_sizes = image_sizes
1908
+
1909
+ assert num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({len(unets)}) for resolutions {image_sizes}'
1910
+
1911
+ self.sample_channels = cast_tuple(self.channels, num_unets)
1912
+
1913
+ # determine whether we are training on images or video
1914
+
1915
+ is_video = any([isinstance(unet, Unet3D) for unet in self.unets])
1916
+ self.is_video = is_video
1917
+
1918
+ self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
1919
+
1920
+ self.resize_to = resize_video_to if is_video else resize_image_to
1921
+ self.resize_to = partial(self.resize_to, mode = resize_mode)
1922
+
1923
+ # temporal interpolation
1924
+
1925
+ temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets)
1926
+ self.temporal_downsample_factor = temporal_downsample_factor
1927
+
1928
+ self.resize_cond_video_frames = resize_cond_video_frames
1929
+ self.temporal_downsample_divisor = temporal_downsample_factor[0]
1930
+
1931
+ assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1'
1932
+ assert tuple(sorted(temporal_downsample_factor, reverse = True)) == temporal_downsample_factor, 'temporal downsample factor must be in order of descending'
1933
+
1934
+ # cascading ddpm related stuff
1935
+
1936
+ lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
1937
+ assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
1938
+
1939
+ self.lowres_sample_noise_level = lowres_sample_noise_level
1940
+ self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
1941
+
1942
+ # classifier free guidance
1943
+
1944
+ self.cond_drop_prob = cond_drop_prob
1945
+ self.can_classifier_guidance = cond_drop_prob > 0.
1946
+
1947
+ # normalize and unnormalize image functions
1948
+
1949
+ self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
1950
+ self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
1951
+ self.input_image_range = (0. if auto_normalize_img else -1., 1.)
1952
+
1953
+ # dynamic thresholding
1954
+
1955
+ self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
1956
+ self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
1957
+
1958
+ # min snr loss weight
1959
+
1960
+ min_snr_loss_weight = cast_tuple(min_snr_loss_weight, num_unets)
1961
+ min_snr_gamma = cast_tuple(min_snr_gamma, num_unets)
1962
+
1963
+ assert len(min_snr_loss_weight) == len(min_snr_gamma) == num_unets
1964
+ self.min_snr_gamma = tuple((gamma if use_min_snr else None) for use_min_snr, gamma in zip(min_snr_loss_weight, min_snr_gamma))
1965
+
1966
+ # one temp parameter for keeping track of device
1967
+
1968
+ self.register_buffer('_temp', torch.tensor([0.]), persistent = False)
1969
+
1970
+ # default to device of unets passed in
1971
+
1972
+ self.to(next(self.unets.parameters()).device)
1973
+
1974
+ def force_unconditional_(self):
1975
+ self.condition_on_text = False
1976
+ self.unconditional = True
1977
+
1978
+ for unet in self.unets:
1979
+ unet.cond_on_text = False
1980
+
1981
+ @property
1982
+ def device(self):
1983
+ return self._temp.device
1984
+
1985
+ def get_unet(self, unet_number):
1986
+ assert 0 < unet_number <= len(self.unets)
1987
+ index = unet_number - 1
1988
+
1989
+ if isinstance(self.unets, nn.ModuleList):
1990
+ unets_list = [unet for unet in self.unets]
1991
+ delattr(self, 'unets')
1992
+ self.unets = unets_list
1993
+
1994
+ if index != self.unet_being_trained_index:
1995
+ for unet_index, unet in enumerate(self.unets):
1996
+ unet.to(self.device if unet_index == index else 'cpu')
1997
+
1998
+ self.unet_being_trained_index = index
1999
+ return self.unets[index]
2000
+
2001
+ def reset_unets_all_one_device(self, device = None):
2002
+ device = default(device, self.device)
2003
+ self.unets = nn.ModuleList([*self.unets])
2004
+ self.unets.to(device)
2005
+
2006
+ self.unet_being_trained_index = -1
2007
+
2008
+ @contextmanager
2009
+ def one_unet_in_gpu(self, unet_number = None, unet = None):
2010
+ assert exists(unet_number) ^ exists(unet)
2011
+
2012
+ if exists(unet_number):
2013
+ unet = self.unets[unet_number - 1]
2014
+
2015
+ cpu = torch.device('cpu')
2016
+
2017
+ devices = [module_device(unet) for unet in self.unets]
2018
+
2019
+ self.unets.to(cpu)
2020
+ unet.to(self.device)
2021
+
2022
+ yield
2023
+
2024
+ for unet, device in zip(self.unets, devices):
2025
+ unet.to(device)
2026
+
2027
+ # overriding state dict functions
2028
+
2029
+ def state_dict(self, *args, **kwargs):
2030
+ self.reset_unets_all_one_device()
2031
+ return super().state_dict(*args, **kwargs)
2032
+
2033
+ def load_state_dict(self, *args, **kwargs):
2034
+ self.reset_unets_all_one_device()
2035
+ return super().load_state_dict(*args, **kwargs)
2036
+
2037
+ # gaussian diffusion methods
2038
+
2039
+ def p_mean_variance(
2040
+ self,
2041
+ unet,
2042
+ x,
2043
+ t,
2044
+ *,
2045
+ noise_scheduler,
2046
+ text_embeds = None,
2047
+ text_mask = None,
2048
+ cond_images = None,
2049
+ cond_video_frames = None,
2050
+ post_cond_video_frames = None,
2051
+ lowres_cond_img = None,
2052
+ self_cond = None,
2053
+ lowres_noise_times = None,
2054
+ cond_scale = 1.,
2055
+ model_output = None,
2056
+ t_next = None,
2057
+ pred_objective = 'noise',
2058
+ dynamic_threshold = True
2059
+ ):
2060
+ assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
2061
+
2062
+ video_kwargs = dict()
2063
+ if self.is_video:
2064
+ video_kwargs = dict(
2065
+ cond_video_frames = cond_video_frames,
2066
+ post_cond_video_frames = post_cond_video_frames,
2067
+ )
2068
+
2069
+ pred = default(model_output, lambda: unet.forward_with_cond_scale(
2070
+ x,
2071
+ noise_scheduler.get_condition(t),
2072
+ text_embeds = text_embeds,
2073
+ text_mask = text_mask,
2074
+ cond_images = cond_images,
2075
+ cond_scale = cond_scale,
2076
+ lowres_cond_img = lowres_cond_img,
2077
+ self_cond = self_cond,
2078
+ lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times),
2079
+ **video_kwargs
2080
+ ))
2081
+
2082
+ if pred_objective == 'noise':
2083
+ x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
2084
+ elif pred_objective == 'x_start':
2085
+ x_start = pred
2086
+ elif pred_objective == 'v':
2087
+ x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
2088
+ else:
2089
+ raise ValueError(f'unknown objective {pred_objective}')
2090
+
2091
+ if dynamic_threshold:
2092
+ # following pseudocode in appendix
2093
+ # s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element
2094
+ s = torch.quantile(
2095
+ rearrange(x_start, 'b ... -> b (...)').abs(),
2096
+ self.dynamic_thresholding_percentile,
2097
+ dim = -1
2098
+ )
2099
+
2100
+ s.clamp_(min = 1.)
2101
+ s = right_pad_dims_to(x_start, s)
2102
+ x_start = x_start.clamp(-s, s) / s
2103
+ else:
2104
+ x_start.clamp_(-1., 1.)
2105
+
2106
+ mean_and_variance = noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)
2107
+ return mean_and_variance, x_start
2108
+
2109
+ @torch.no_grad()
2110
+ def p_sample(
2111
+ self,
2112
+ unet,
2113
+ x,
2114
+ t,
2115
+ *,
2116
+ noise_scheduler,
2117
+ t_next = None,
2118
+ text_embeds = None,
2119
+ text_mask = None,
2120
+ cond_images = None,
2121
+ cond_video_frames = None,
2122
+ post_cond_video_frames = None,
2123
+ cond_scale = 1.,
2124
+ self_cond = None,
2125
+ lowres_cond_img = None,
2126
+ lowres_noise_times = None,
2127
+ pred_objective = 'noise',
2128
+ dynamic_threshold = True
2129
+ ):
2130
+ b, *_, device = *x.shape, x.device
2131
+
2132
+ video_kwargs = dict()
2133
+ if self.is_video:
2134
+ video_kwargs = dict(
2135
+ cond_video_frames = cond_video_frames,
2136
+ post_cond_video_frames = post_cond_video_frames,
2137
+ )
2138
+
2139
+ (model_mean, _, model_log_variance), x_start = self.p_mean_variance(
2140
+ unet,
2141
+ x = x,
2142
+ t = t,
2143
+ t_next = t_next,
2144
+ noise_scheduler = noise_scheduler,
2145
+ text_embeds = text_embeds,
2146
+ text_mask = text_mask,
2147
+ cond_images = cond_images,
2148
+ cond_scale = cond_scale,
2149
+ lowres_cond_img = lowres_cond_img,
2150
+ self_cond = self_cond,
2151
+ lowres_noise_times = lowres_noise_times,
2152
+ pred_objective = pred_objective,
2153
+ dynamic_threshold = dynamic_threshold,
2154
+ **video_kwargs
2155
+ )
2156
+
2157
+ noise = torch.randn_like(x)
2158
+ # no noise when t == 0
2159
+ is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
2160
+ nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
2161
+ pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
2162
+ return pred, x_start
2163
+
2164
+ @torch.no_grad()
2165
+ def p_sample_loop(
2166
+ self,
2167
+ unet,
2168
+ shape,
2169
+ *,
2170
+ noise_scheduler,
2171
+ lowres_cond_img = None,
2172
+ lowres_noise_times = None,
2173
+ text_embeds = None,
2174
+ text_mask = None,
2175
+ cond_images = None,
2176
+ cond_video_frames = None,
2177
+ post_cond_video_frames = None,
2178
+ inpaint_images = None,
2179
+ inpaint_videos = None,
2180
+ inpaint_masks = None,
2181
+ inpaint_resample_times = 5,
2182
+ init_images = None,
2183
+ skip_steps = None,
2184
+ cond_scale = 1,
2185
+ pred_objective = 'noise',
2186
+ dynamic_threshold = True,
2187
+ use_tqdm = True
2188
+ ):
2189
+ device = self.device
2190
+
2191
+ batch = shape[0]
2192
+ img = torch.randn(shape, device = device)
2193
+
2194
+ # video
2195
+
2196
+ is_video = len(shape) == 5
2197
+ frames = shape[-3] if is_video else None
2198
+ resize_kwargs = dict(target_frames = frames) if exists(frames) else dict()
2199
+
2200
+ # for initialization with an image or video
2201
+
2202
+ if exists(init_images):
2203
+ img += init_images
2204
+
2205
+ # keep track of x0, for self conditioning
2206
+
2207
+ x_start = None
2208
+
2209
+ # prepare inpainting
2210
+
2211
+ inpaint_images = default(inpaint_videos, inpaint_images)
2212
+
2213
+ has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
2214
+ resample_times = inpaint_resample_times if has_inpainting else 1
2215
+
2216
+ if has_inpainting:
2217
+ inpaint_images = self.normalize_img(inpaint_images)
2218
+ inpaint_images = self.resize_to(inpaint_images, shape[-1], **resize_kwargs)
2219
+ inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1], **resize_kwargs).bool()
2220
+
2221
+ # time
2222
+
2223
+ timesteps = noise_scheduler.get_sampling_timesteps(batch, device = device)
2224
+
2225
+ # whether to skip any steps
2226
+
2227
+ skip_steps = default(skip_steps, 0)
2228
+ timesteps = timesteps[skip_steps:]
2229
+
2230
+ # video conditioning kwargs
2231
+
2232
+ video_kwargs = dict()
2233
+ if self.is_video:
2234
+ video_kwargs = dict(
2235
+ cond_video_frames = cond_video_frames,
2236
+ post_cond_video_frames = post_cond_video_frames,
2237
+ )
2238
+
2239
+ for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps), disable = not use_tqdm):
2240
+ is_last_timestep = times_next == 0
2241
+
2242
+ for r in reversed(range(resample_times)):
2243
+ is_last_resample_step = r == 0
2244
+
2245
+ if has_inpainting:
2246
+ noised_inpaint_images, *_ = noise_scheduler.q_sample(inpaint_images, t = times)
2247
+ img = img * ~inpaint_masks + noised_inpaint_images * inpaint_masks
2248
+
2249
+ self_cond = x_start if unet.self_cond else None
2250
+
2251
+ img, x_start = self.p_sample(
2252
+ unet,
2253
+ img,
2254
+ times,
2255
+ t_next = times_next,
2256
+ text_embeds = text_embeds,
2257
+ text_mask = text_mask,
2258
+ cond_images = cond_images,
2259
+ cond_scale = cond_scale,
2260
+ self_cond = self_cond,
2261
+ lowres_cond_img = lowres_cond_img,
2262
+ lowres_noise_times = lowres_noise_times,
2263
+ noise_scheduler = noise_scheduler,
2264
+ pred_objective = pred_objective,
2265
+ dynamic_threshold = dynamic_threshold,
2266
+ **video_kwargs
2267
+ )
2268
+
2269
+ if has_inpainting and not (is_last_resample_step or torch.all(is_last_timestep)):
2270
+ renoised_img = noise_scheduler.q_sample_from_to(img, times_next, times)
2271
+
2272
+ img = torch.where(
2273
+ self.right_pad_dims_to_datatype(is_last_timestep),
2274
+ img,
2275
+ renoised_img
2276
+ )
2277
+
2278
+ img.clamp_(-1., 1.)
2279
+
2280
+ # final inpainting
2281
+
2282
+ if has_inpainting:
2283
+ img = img * ~inpaint_masks + inpaint_images * inpaint_masks
2284
+
2285
+ unnormalize_img = self.unnormalize_img(img)
2286
+ return unnormalize_img
2287
+
2288
+ @torch.no_grad()
2289
+ @eval_decorator
2290
+ @beartype
2291
+ def sample(
2292
+ self,
2293
+ texts: List[str] = None,
2294
+ text_masks = None,
2295
+ text_embeds = None,
2296
+ video_frames = None,
2297
+ cond_images = None,
2298
+ cond_video_frames = None,
2299
+ post_cond_video_frames = None,
2300
+ inpaint_videos = None,
2301
+ inpaint_images = None,
2302
+ inpaint_masks = None,
2303
+ inpaint_resample_times = 5,
2304
+ init_images = None,
2305
+ skip_steps = None,
2306
+ batch_size = 1,
2307
+ cond_scale = 1.,
2308
+ lowres_sample_noise_level = None,
2309
+ start_at_unet_number = 1,
2310
+ start_image_or_video = None,
2311
+ stop_at_unet_number = None,
2312
+ return_all_unet_outputs = False,
2313
+ return_pil_images = False,
2314
+ device = None,
2315
+ use_tqdm = True,
2316
+ use_one_unet_in_gpu = True
2317
+ ):
2318
+ device = default(device, self.device)
2319
+ self.reset_unets_all_one_device(device = device)
2320
+
2321
+ cond_images = maybe(cast_uint8_images_to_float)(cond_images)
2322
+
2323
+ if exists(texts) and not exists(text_embeds) and not self.unconditional:
2324
+ assert all([*map(len, texts)]), 'text cannot be empty'
2325
+
2326
+ with autocast(enabled = False):
2327
+ text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
2328
+
2329
+ text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
2330
+
2331
+ if not self.unconditional:
2332
+ assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
2333
+
2334
+ text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
2335
+ batch_size = text_embeds.shape[0]
2336
+
2337
+ # inpainting
2338
+
2339
+ inpaint_images = default(inpaint_videos, inpaint_images)
2340
+
2341
+ if exists(inpaint_images):
2342
+ if self.unconditional:
2343
+ if batch_size == 1: # assume researcher wants to broadcast along inpainted images
2344
+ batch_size = inpaint_images.shape[0]
2345
+
2346
+ assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``'
2347
+ assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'
2348
+
2349
+ assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
2350
+ assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
2351
+ assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
2352
+
2353
+ assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting'
2354
+
2355
+ outputs = []
2356
+
2357
+ is_cuda = next(self.parameters()).is_cuda
2358
+ device = next(self.parameters()).device
2359
+
2360
+ lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
2361
+
2362
+ num_unets = len(self.unets)
2363
+
2364
+ # condition scaling
2365
+
2366
+ cond_scale = cast_tuple(cond_scale, num_unets)
2367
+
2368
+ # add frame dimension for video
2369
+
2370
+ if self.is_video and exists(inpaint_images):
2371
+ video_frames = inpaint_images.shape[2]
2372
+
2373
+ if inpaint_masks.ndim == 3:
2374
+ inpaint_masks = repeat(inpaint_masks, 'b h w -> b f h w', f = video_frames)
2375
+
2376
+ assert inpaint_masks.shape[1] == video_frames
2377
+
2378
+ assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
2379
+
2380
+ all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)
2381
+
2382
+ frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()
2383
+
2384
+ # for initial image and skipping steps
2385
+
2386
+ init_images = cast_tuple(init_images, num_unets)
2387
+ init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
2388
+
2389
+ skip_steps = cast_tuple(skip_steps, num_unets)
2390
+
2391
+ # handle starting at a unet greater than 1, for training only-upscaler training
2392
+
2393
+ if start_at_unet_number > 1:
2394
+ assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
2395
+ assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
2396
+ assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
2397
+
2398
+ prev_image_size = self.image_sizes[start_at_unet_number - 2]
2399
+ prev_frame_size = all_frame_dims[start_at_unet_number - 2][0] if self.is_video else None
2400
+ img = self.resize_to(start_image_or_video, prev_image_size, **frames_to_resize_kwargs(prev_frame_size))
2401
+
2402
+
2403
+ # go through each unet in cascade
2404
+
2405
+ for unet_number, unet, channel, image_size, frame_dims, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, all_frame_dims, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm):
2406
+
2407
+ if unet_number < start_at_unet_number:
2408
+ continue
2409
+
2410
+ assert not isinstance(unet, NullUnet), 'one cannot sample from null / placeholder unets'
2411
+
2412
+ context = self.one_unet_in_gpu(unet = unet) if is_cuda and use_one_unet_in_gpu else nullcontext()
2413
+
2414
+ with context:
2415
+ # video kwargs
2416
+
2417
+ video_kwargs = dict()
2418
+ if self.is_video:
2419
+ video_kwargs = dict(
2420
+ cond_video_frames = cond_video_frames,
2421
+ post_cond_video_frames = post_cond_video_frames,
2422
+ )
2423
+
2424
+ video_kwargs = compact(video_kwargs)
2425
+
2426
+ if self.is_video and self.resize_cond_video_frames:
2427
+ downsample_scale = self.temporal_downsample_factor[unet_number - 1]
2428
+ temporal_downsample_fn = partial(scale_video_time, downsample_scale = downsample_scale)
2429
+
2430
+ video_kwargs = maybe_transform_dict_key(video_kwargs, 'cond_video_frames', temporal_downsample_fn)
2431
+ video_kwargs = maybe_transform_dict_key(video_kwargs, 'post_cond_video_frames', temporal_downsample_fn)
2432
+
2433
+ # low resolution conditioning
2434
+
2435
+ lowres_cond_img = lowres_noise_times = None
2436
+ shape = (batch_size, channel, *frame_dims, image_size, image_size)
2437
+
2438
+ resize_kwargs = dict(target_frames = frame_dims[0]) if self.is_video else dict()
2439
+
2440
+ if unet.lowres_cond:
2441
+ lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)
2442
+
2443
+ lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs)
2444
+
2445
+ lowres_cond_img = self.normalize_img(lowres_cond_img)
2446
+ lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
2447
+
2448
+ # init images or video
2449
+
2450
+ if exists(unet_init_images):
2451
+ unet_init_images = self.resize_to(unet_init_images, image_size, **resize_kwargs)
2452
+
2453
+ # shape of stage
2454
+
2455
+ shape = (batch_size, self.channels, *frame_dims, image_size, image_size)
2456
+
2457
+ img = self.p_sample_loop(
2458
+ unet,
2459
+ shape,
2460
+ text_embeds = text_embeds,
2461
+ text_mask = text_masks,
2462
+ cond_images = cond_images,
2463
+ inpaint_images = inpaint_images,
2464
+ inpaint_masks = inpaint_masks,
2465
+ inpaint_resample_times = inpaint_resample_times,
2466
+ init_images = unet_init_images,
2467
+ skip_steps = unet_skip_steps,
2468
+ cond_scale = unet_cond_scale,
2469
+ lowres_cond_img = lowres_cond_img,
2470
+ lowres_noise_times = lowres_noise_times,
2471
+ noise_scheduler = noise_scheduler,
2472
+ pred_objective = pred_objective,
2473
+ dynamic_threshold = dynamic_threshold,
2474
+ use_tqdm = use_tqdm,
2475
+ **video_kwargs
2476
+ )
2477
+
2478
+ outputs.append(img)
2479
+
2480
+ if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
2481
+ break
2482
+
2483
+ output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs
2484
+
2485
+ if not return_pil_images:
2486
+ return outputs[output_index]
2487
+
2488
+ if not return_all_unet_outputs:
2489
+ outputs = outputs[-1:]
2490
+
2491
+ assert not self.is_video, 'converting sampled video tensor to video file is not supported yet'
2492
+
2493
+ pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs))
2494
+
2495
+ return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
2496
+
2497
+ @beartype
2498
+ def p_losses(
2499
+ self,
2500
+ unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel],
2501
+ x_start,
2502
+ times,
2503
+ *,
2504
+ noise_scheduler,
2505
+ lowres_cond_img = None,
2506
+ lowres_aug_times = None,
2507
+ text_embeds = None,
2508
+ text_mask = None,
2509
+ cond_images = None,
2510
+ noise = None,
2511
+ times_next = None,
2512
+ pred_objective = 'noise',
2513
+ min_snr_gamma = None,
2514
+ random_crop_size = None,
2515
+ **kwargs
2516
+ ):
2517
+ is_video = x_start.ndim == 5
2518
+
2519
+ noise = default(noise, lambda: torch.randn_like(x_start))
2520
+
2521
+ # normalize to [-1, 1]
2522
+
2523
+ x_start = self.normalize_img(x_start)
2524
+ lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
2525
+
2526
+ # random cropping during training
2527
+ # for upsamplers
2528
+
2529
+ if exists(random_crop_size):
2530
+ if is_video:
2531
+ frames = x_start.shape[2]
2532
+ x_start, lowres_cond_img, noise = map(lambda t: rearrange(t, 'b c f h w -> (b f) c h w'), (x_start, lowres_cond_img, noise))
2533
+
2534
+ aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
2535
+
2536
+ # make sure low res conditioner and image both get augmented the same way
2537
+ # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
2538
+ x_start = aug(x_start)
2539
+ lowres_cond_img = aug(lowres_cond_img, params = aug._params)
2540
+ noise = aug(noise, params = aug._params)
2541
+
2542
+ if is_video:
2543
+ x_start, lowres_cond_img, noise = map(lambda t: rearrange(t, '(b f) c h w -> b c f h w', f = frames), (x_start, lowres_cond_img, noise))
2544
+
2545
+ # get x_t
2546
+
2547
+ x_noisy, log_snr, alpha, sigma = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
2548
+
2549
+ # also noise the lowres conditioning image
2550
+ # at sample time, they then fix the noise level of 0.1 - 0.3
2551
+
2552
+ lowres_cond_img_noisy = None
2553
+ if exists(lowres_cond_img):
2554
+ lowres_aug_times = default(lowres_aug_times, times)
2555
+ lowres_cond_img_noisy, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
2556
+
2557
+ # time condition
2558
+
2559
+ noise_cond = noise_scheduler.get_condition(times)
2560
+
2561
+ # unet kwargs
2562
+
2563
+ unet_kwargs = dict(
2564
+ text_embeds = text_embeds,
2565
+ text_mask = text_mask,
2566
+ cond_images = cond_images,
2567
+ lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
2568
+ lowres_cond_img = lowres_cond_img_noisy,
2569
+ cond_drop_prob = self.cond_drop_prob,
2570
+ **kwargs
2571
+ )
2572
+
2573
+ # self condition if needed
2574
+
2575
+ # Because 'unet' can be an instance of DistributedDataParallel coming from the
2576
+ # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
2577
+ # access the member 'module' of the wrapped unet instance.
2578
+ self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet.self_cond
2579
+
2580
+ if self_cond and random() < 0.5:
2581
+ with torch.no_grad():
2582
+ pred = unet.forward(
2583
+ x_noisy,
2584
+ noise_cond,
2585
+ **unet_kwargs
2586
+ ).detach()
2587
+
2588
+ x_start = noise_scheduler.predict_start_from_noise(x_noisy, t = times, noise = pred) if pred_objective == 'noise' else pred
2589
+
2590
+ unet_kwargs = {**unet_kwargs, 'self_cond': x_start}
2591
+
2592
+ # get prediction
2593
+
2594
+ pred = unet.forward(
2595
+ x_noisy,
2596
+ noise_cond,
2597
+ **unet_kwargs
2598
+ )
2599
+
2600
+ # prediction objective
2601
+
2602
+ if pred_objective == 'noise':
2603
+ target = noise
2604
+ elif pred_objective == 'x_start':
2605
+ target = x_start
2606
+ elif pred_objective == 'v':
2607
+ # derivation detailed in Appendix D of Progressive Distillation paper
2608
+ # https://arxiv.org/abs/2202.00512
2609
+ # this makes distillation viable as well as solve an issue with color shifting in upresoluting unets, noted in imagen-video
2610
+ target = alpha * noise - sigma * x_start
2611
+ else:
2612
+ raise ValueError(f'unknown objective {pred_objective}')
2613
+
2614
+ # losses
2615
+
2616
+ losses = self.loss_fn(pred, target, reduction = 'none')
2617
+ losses = reduce(losses, 'b ... -> b', 'mean')
2618
+
2619
+ # min snr loss reweighting
2620
+
2621
+ snr = log_snr.exp()
2622
+ maybe_clipped_snr = snr.clone()
2623
+
2624
+ if exists(min_snr_gamma):
2625
+ maybe_clipped_snr.clamp_(max = min_snr_gamma)
2626
+
2627
+ if pred_objective == 'noise':
2628
+ loss_weight = maybe_clipped_snr / snr
2629
+ elif pred_objective == 'x_start':
2630
+ loss_weight = maybe_clipped_snr
2631
+ elif pred_objective == 'v':
2632
+ loss_weight = maybe_clipped_snr / (snr + 1)
2633
+
2634
+ losses = losses * loss_weight
2635
+ return losses.mean()
2636
+
2637
+ @beartype
2638
+ def forward(
2639
+ self,
2640
+ images, # rename to images or video
2641
+ unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
2642
+ texts: List[str] = None,
2643
+ text_embeds = None,
2644
+ text_masks = None,
2645
+ unet_number = None,
2646
+ cond_images = None,
2647
+ **kwargs
2648
+ ):
2649
+ if self.is_video and images.ndim == 4:
2650
+ images = rearrange(images, 'b c h w -> b c 1 h w')
2651
+ kwargs.update(ignore_time = True)
2652
+
2653
+ assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
2654
+ assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
2655
+ unet_number = default(unet_number, 1)
2656
+ assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'
2657
+
2658
+ images = cast_uint8_images_to_float(images)
2659
+ cond_images = maybe(cast_uint8_images_to_float)(cond_images)
2660
+
2661
+ assert images.dtype == torch.float or images.dtype == torch.half, f'images tensor needs to be floats but {images.dtype} dtype found instead'
2662
+
2663
+ unet_index = unet_number - 1
2664
+
2665
+ unet = default(unet, lambda: self.get_unet(unet_number))
2666
+
2667
+ assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
2668
+
2669
+ noise_scheduler = self.noise_schedulers[unet_index]
2670
+ min_snr_gamma = self.min_snr_gamma[unet_index]
2671
+ pred_objective = self.pred_objectives[unet_index]
2672
+ target_image_size = self.image_sizes[unet_index]
2673
+ random_crop_size = self.random_crop_sizes[unet_index]
2674
+ prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
2675
+
2676
+ b, c, *_, h, w, device, is_video = *images.shape, images.device, images.ndim == 5
2677
+
2678
+ assert images.shape[1] == self.channels
2679
+ assert h >= target_image_size and w >= target_image_size
2680
+
2681
+ frames = images.shape[2] if is_video else None
2682
+ all_frame_dims = tuple(safe_get_tuple_index(el, 0) for el in calc_all_frame_dims(self.temporal_downsample_factor, frames))
2683
+ ignore_time = kwargs.get('ignore_time', False)
2684
+
2685
+ target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None
2686
+ prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None
2687
+ frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()
2688
+
2689
+ times = noise_scheduler.sample_random_times(b, device = device)
2690
+
2691
+ if exists(texts) and not exists(text_embeds) and not self.unconditional:
2692
+ assert all([*map(len, texts)]), 'text cannot be empty'
2693
+ assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
2694
+
2695
+ with autocast(enabled = False):
2696
+ text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
2697
+
2698
+ text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
2699
+
2700
+ if not self.unconditional:
2701
+ text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
2702
+
2703
+ assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified'
2704
+ assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented'
2705
+
2706
+ assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
2707
+
2708
+ # handle video frame conditioning
2709
+
2710
+ if self.is_video and self.resize_cond_video_frames:
2711
+ downsample_scale = self.temporal_downsample_factor[unet_index]
2712
+ temporal_downsample_fn = partial(scale_video_time, downsample_scale = downsample_scale)
2713
+ kwargs = maybe_transform_dict_key(kwargs, 'cond_video_frames', temporal_downsample_fn)
2714
+ kwargs = maybe_transform_dict_key(kwargs, 'post_cond_video_frames', temporal_downsample_fn)
2715
+
2716
+ # handle low resolution conditioning
2717
+
2718
+ lowres_cond_img = lowres_aug_times = None
2719
+ if exists(prev_image_size):
2720
+ lowres_cond_img = self.resize_to(images, prev_image_size, **frames_to_resize_kwargs(prev_frame_size), clamp_range = self.input_image_range)
2721
+ lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, **frames_to_resize_kwargs(target_frame_size), clamp_range = self.input_image_range)
2722
+
2723
+ if self.per_sample_random_aug_noise_level:
2724
+ lowres_aug_times = self.lowres_noise_schedule.sample_random_times(b, device = device)
2725
+ else:
2726
+ lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
2727
+ lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = b)
2728
+
2729
+ images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size))
2730
+
2731
+ return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, min_snr_gamma = min_snr_gamma, random_crop_size = random_crop_size, **kwargs)
imagen_video.py ADDED
@@ -0,0 +1,1935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import copy
3
+ import operator
4
+ import functools
5
+ from typing import List
6
+ from tqdm.auto import tqdm
7
+ from functools import partial, wraps
8
+ from contextlib import contextmanager, nullcontext
9
+ from collections import namedtuple
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn, einsum
15
+
16
+ from einops import rearrange, repeat, reduce, pack, unpack
17
+ from einops.layers.torch import Rearrange, Reduce
18
+ from einops_exts.torch import EinopsToAndFrom
19
+
20
+ from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
21
+
22
+ # helper functions
23
+
24
+ def exists(val):
25
+ return val is not None
26
+
27
+ def identity(t, *args, **kwargs):
28
+ return t
29
+
30
+ def first(arr, d = None):
31
+ if len(arr) == 0:
32
+ return d
33
+ return arr[0]
34
+
35
+ def divisible_by(numer, denom):
36
+ return (numer % denom) == 0
37
+
38
+ def maybe(fn):
39
+ @wraps(fn)
40
+ def inner(x):
41
+ if not exists(x):
42
+ return x
43
+ return fn(x)
44
+ return inner
45
+
46
+ def once(fn):
47
+ called = False
48
+ @wraps(fn)
49
+ def inner(x):
50
+ nonlocal called
51
+ if called:
52
+ return
53
+ called = True
54
+ return fn(x)
55
+ return inner
56
+
57
+ print_once = once(print)
58
+
59
+ def default(val, d):
60
+ if exists(val):
61
+ return val
62
+ return d() if callable(d) else d
63
+
64
+ def cast_tuple(val, length = None):
65
+ if isinstance(val, list):
66
+ val = tuple(val)
67
+
68
+ output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
69
+
70
+ if exists(length):
71
+ assert len(output) == length
72
+
73
+ return output
74
+
75
+ def cast_uint8_images_to_float(images):
76
+ if not images.dtype == torch.uint8:
77
+ return images
78
+ return images / 255
79
+
80
+ def module_device(module):
81
+ return next(module.parameters()).device
82
+
83
+ def zero_init_(m):
84
+ nn.init.zeros_(m.weight)
85
+ if exists(m.bias):
86
+ nn.init.zeros_(m.bias)
87
+
88
+ def eval_decorator(fn):
89
+ def inner(model, *args, **kwargs):
90
+ was_training = model.training
91
+ model.eval()
92
+ out = fn(model, *args, **kwargs)
93
+ model.train(was_training)
94
+ return out
95
+ return inner
96
+
97
+ def pad_tuple_to_length(t, length, fillvalue = None):
98
+ remain_length = length - len(t)
99
+ if remain_length <= 0:
100
+ return t
101
+ return (*t, *((fillvalue,) * remain_length))
102
+
103
+ # helper classes
104
+
105
+ class Identity(nn.Module):
106
+ def __init__(self, *args, **kwargs):
107
+ super().__init__()
108
+
109
+ def forward(self, x, *args, **kwargs):
110
+ return x
111
+
112
+ def Sequential(*modules):
113
+ return nn.Sequential(*filter(exists, modules))
114
+
115
+ # tensor helpers
116
+
117
+ def log(t, eps: float = 1e-12):
118
+ return torch.log(t.clamp(min = eps))
119
+
120
+ def l2norm(t):
121
+ return F.normalize(t, dim = -1)
122
+
123
+ def right_pad_dims_to(x, t):
124
+ padding_dims = x.ndim - t.ndim
125
+ if padding_dims <= 0:
126
+ return t
127
+ return t.view(*t.shape, *((1,) * padding_dims))
128
+
129
+ def masked_mean(t, *, dim, mask = None):
130
+ if not exists(mask):
131
+ return t.mean(dim = dim)
132
+
133
+ denom = mask.sum(dim = dim, keepdim = True)
134
+ mask = rearrange(mask, 'b n -> b n 1')
135
+ masked_t = t.masked_fill(~mask, 0.)
136
+
137
+ return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
138
+
139
+ def resize_video_to(
140
+ video,
141
+ target_image_size,
142
+ target_frames = None,
143
+ clamp_range = None,
144
+ mode = 'nearest'
145
+ ):
146
+ orig_video_size = video.shape[-1]
147
+
148
+ frames = video.shape[2]
149
+ target_frames = default(target_frames, frames)
150
+
151
+ target_shape = (target_frames, target_image_size, target_image_size)
152
+
153
+ if tuple(video.shape[-3:]) == target_shape:
154
+ return video
155
+
156
+ out = F.interpolate(video, target_shape, mode = mode)
157
+
158
+ if exists(clamp_range):
159
+ out = out.clamp(*clamp_range)
160
+
161
+ return out
162
+
163
+ def scale_video_time(
164
+ video,
165
+ downsample_scale = 1,
166
+ mode = 'nearest'
167
+ ):
168
+ if downsample_scale == 1:
169
+ return video
170
+
171
+ image_size, frames = video.shape[-1], video.shape[-3]
172
+ assert divisible_by(frames, downsample_scale), f'trying to temporally downsample a conditioning video frames of length {frames} by {downsample_scale}, however it is not neatly divisible'
173
+
174
+ target_frames = frames // downsample_scale
175
+
176
+ resized_video = resize_video_to(
177
+ video,
178
+ image_size,
179
+ target_frames = target_frames,
180
+ mode = mode
181
+ )
182
+
183
+ return resized_video
184
+
185
+ # classifier free guidance functions
186
+
187
+ def prob_mask_like(shape, prob, device):
188
+ if prob == 1:
189
+ return torch.ones(shape, device = device, dtype = torch.bool)
190
+ elif prob == 0:
191
+ return torch.zeros(shape, device = device, dtype = torch.bool)
192
+ else:
193
+ return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
194
+
195
+ # norms and residuals
196
+
197
+ class LayerNorm(nn.Module):
198
+ def __init__(self, dim, stable = False):
199
+ super().__init__()
200
+ self.stable = stable
201
+ self.g = nn.Parameter(torch.ones(dim))
202
+
203
+ def forward(self, x):
204
+ if self.stable:
205
+ x = x / x.amax(dim = -1, keepdim = True).detach()
206
+
207
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
208
+ var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
209
+ mean = torch.mean(x, dim = -1, keepdim = True)
210
+ return (x - mean) * (var + eps).rsqrt() * self.g
211
+
212
+ class ChanLayerNorm(nn.Module):
213
+ def __init__(self, dim, stable = False):
214
+ super().__init__()
215
+ self.stable = stable
216
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
217
+
218
+ def forward(self, x):
219
+ if self.stable:
220
+ x = x / x.amax(dim = 1, keepdim = True).detach()
221
+
222
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
223
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
224
+ mean = torch.mean(x, dim = 1, keepdim = True)
225
+ return (x - mean) * (var + eps).rsqrt() * self.g
226
+
227
+ class Always():
228
+ def __init__(self, val):
229
+ self.val = val
230
+
231
+ def __call__(self, *args, **kwargs):
232
+ return self.val
233
+
234
+ class Residual(nn.Module):
235
+ def __init__(self, fn):
236
+ super().__init__()
237
+ self.fn = fn
238
+
239
+ def forward(self, x, **kwargs):
240
+ return self.fn(x, **kwargs) + x
241
+
242
+ class Parallel(nn.Module):
243
+ def __init__(self, *fns):
244
+ super().__init__()
245
+ self.fns = nn.ModuleList(fns)
246
+
247
+ def forward(self, x):
248
+ outputs = [fn(x) for fn in self.fns]
249
+ return sum(outputs)
250
+
251
+ # rearranging
252
+
253
+ class RearrangeTimeCentric(nn.Module):
254
+ def __init__(self, fn):
255
+ super().__init__()
256
+ self.fn = fn
257
+
258
+ def forward(self, x):
259
+ x = rearrange(x, 'b c f ... -> b ... f c')
260
+ x, ps = pack([x], '* f c')
261
+
262
+ x = self.fn(x)
263
+
264
+ x, = unpack(x, ps, '* f c')
265
+ x = rearrange(x, 'b ... f c -> b c f ...')
266
+ return x
267
+
268
+ # attention pooling
269
+
270
+ class PerceiverAttention(nn.Module):
271
+ def __init__(
272
+ self,
273
+ *,
274
+ dim,
275
+ dim_head = 64,
276
+ heads = 8,
277
+ scale = 8
278
+ ):
279
+ super().__init__()
280
+ self.scale = scale
281
+
282
+ self.heads = heads
283
+ inner_dim = dim_head * heads
284
+
285
+ self.norm = nn.LayerNorm(dim)
286
+ self.norm_latents = nn.LayerNorm(dim)
287
+
288
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
289
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
290
+
291
+ self.q_scale = nn.Parameter(torch.ones(dim_head))
292
+ self.k_scale = nn.Parameter(torch.ones(dim_head))
293
+
294
+ self.to_out = nn.Sequential(
295
+ nn.Linear(inner_dim, dim, bias = False),
296
+ nn.LayerNorm(dim)
297
+ )
298
+
299
+ def forward(self, x, latents, mask = None):
300
+ x = self.norm(x)
301
+ latents = self.norm_latents(latents)
302
+
303
+ b, h = x.shape[0], self.heads
304
+
305
+ q = self.to_q(latents)
306
+
307
+ # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
308
+ kv_input = torch.cat((x, latents), dim = -2)
309
+ k, v = self.to_kv(kv_input).chunk(2, dim = -1)
310
+
311
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
312
+
313
+ # qk rmsnorm
314
+
315
+ q, k = map(l2norm, (q, k))
316
+ q = q * self.q_scale
317
+ k = k * self.k_scale
318
+
319
+ # similarities and masking
320
+
321
+ sim = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
322
+
323
+ if exists(mask):
324
+ max_neg_value = -torch.finfo(sim.dtype).max
325
+ mask = F.pad(mask, (0, latents.shape[-2]), value = True)
326
+ mask = rearrange(mask, 'b j -> b 1 1 j')
327
+ sim = sim.masked_fill(~mask, max_neg_value)
328
+
329
+ # attention
330
+
331
+ attn = sim.softmax(dim = -1)
332
+
333
+ out = einsum('... i j, ... j d -> ... i d', attn, v)
334
+ out = rearrange(out, 'b h n d -> b n (h d)', h = h)
335
+ return self.to_out(out)
336
+
337
+ class PerceiverResampler(nn.Module):
338
+ def __init__(
339
+ self,
340
+ *,
341
+ dim,
342
+ depth,
343
+ dim_head = 64,
344
+ heads = 8,
345
+ num_latents = 64,
346
+ num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
347
+ max_seq_len = 512,
348
+ ff_mult = 4
349
+ ):
350
+ super().__init__()
351
+ self.pos_emb = nn.Embedding(max_seq_len, dim)
352
+
353
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
354
+
355
+ self.to_latents_from_mean_pooled_seq = None
356
+
357
+ if num_latents_mean_pooled > 0:
358
+ self.to_latents_from_mean_pooled_seq = nn.Sequential(
359
+ LayerNorm(dim),
360
+ nn.Linear(dim, dim * num_latents_mean_pooled),
361
+ Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
362
+ )
363
+
364
+ self.layers = nn.ModuleList([])
365
+ for _ in range(depth):
366
+ self.layers.append(nn.ModuleList([
367
+ PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
368
+ FeedForward(dim = dim, mult = ff_mult)
369
+ ]))
370
+
371
+ def forward(self, x, mask = None):
372
+ n, device = x.shape[1], x.device
373
+ pos_emb = self.pos_emb(torch.arange(n, device = device))
374
+
375
+ x_with_pos = x + pos_emb
376
+
377
+ latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
378
+
379
+ if exists(self.to_latents_from_mean_pooled_seq):
380
+ meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
381
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
382
+ latents = torch.cat((meanpooled_latents, latents), dim = -2)
383
+
384
+ for attn, ff in self.layers:
385
+ latents = attn(x_with_pos, latents, mask = mask) + latents
386
+ latents = ff(latents) + latents
387
+
388
+ return latents
389
+
390
+ # main contribution from make-a-video - pseudo conv3d
391
+ # axial space-time convolutions, but made causal to keep in line with the design decisions of imagen-video paper
392
+
393
+ class Conv3d(nn.Module):
394
+ def __init__(
395
+ self,
396
+ dim,
397
+ dim_out = None,
398
+ kernel_size = 3,
399
+ *,
400
+ temporal_kernel_size = None,
401
+ **kwargs
402
+ ):
403
+ super().__init__()
404
+ dim_out = default(dim_out, dim)
405
+ temporal_kernel_size = default(temporal_kernel_size, kernel_size)
406
+
407
+ self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
408
+ self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size) if kernel_size > 1 else None
409
+ self.kernel_size = kernel_size
410
+
411
+ if exists(self.temporal_conv):
412
+ nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
413
+ nn.init.zeros_(self.temporal_conv.bias.data)
414
+
415
+ def forward(
416
+ self,
417
+ x,
418
+ ignore_time = False
419
+ ):
420
+ b, c, *_, h, w = x.shape
421
+
422
+ is_video = x.ndim == 5
423
+ ignore_time &= is_video
424
+
425
+ if is_video:
426
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
427
+
428
+ x = self.spatial_conv(x)
429
+
430
+ if is_video:
431
+ x = rearrange(x, '(b f) c h w -> b c f h w', b = b)
432
+
433
+ if ignore_time or not exists(self.temporal_conv):
434
+ return x
435
+
436
+ x = rearrange(x, 'b c f h w -> (b h w) c f')
437
+
438
+ # causal temporal convolution - time is causal in imagen-video
439
+
440
+ if self.kernel_size > 1:
441
+ x = F.pad(x, (self.kernel_size - 1, 0))
442
+
443
+ x = self.temporal_conv(x)
444
+
445
+ x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)
446
+
447
+ return x
448
+
449
+ # attention
450
+
451
+ class Attention(nn.Module):
452
+ def __init__(
453
+ self,
454
+ dim,
455
+ *,
456
+ dim_head = 64,
457
+ heads = 8,
458
+ causal = False,
459
+ context_dim = None,
460
+ rel_pos_bias = False,
461
+ rel_pos_bias_mlp_depth = 2,
462
+ init_zero = False,
463
+ scale = 8
464
+ ):
465
+ super().__init__()
466
+ self.scale = scale
467
+ self.causal = causal
468
+
469
+ self.rel_pos_bias = DynamicPositionBias(dim = dim, heads = heads, depth = rel_pos_bias_mlp_depth) if rel_pos_bias else None
470
+
471
+ self.heads = heads
472
+ inner_dim = dim_head * heads
473
+
474
+ self.norm = LayerNorm(dim)
475
+
476
+ self.null_attn_bias = nn.Parameter(torch.randn(heads))
477
+
478
+ self.null_kv = nn.Parameter(torch.randn(2, dim_head))
479
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
480
+ self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
481
+
482
+ self.q_scale = nn.Parameter(torch.ones(dim_head))
483
+ self.k_scale = nn.Parameter(torch.ones(dim_head))
484
+
485
+ self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
486
+
487
+ self.to_out = nn.Sequential(
488
+ nn.Linear(inner_dim, dim, bias = False),
489
+ LayerNorm(dim)
490
+ )
491
+
492
+ if init_zero:
493
+ nn.init.zeros_(self.to_out[-1].g)
494
+
495
+ def forward(
496
+ self,
497
+ x,
498
+ context = None,
499
+ mask = None,
500
+ attn_bias = None
501
+ ):
502
+ b, n, device = *x.shape[:2], x.device
503
+
504
+ x = self.norm(x)
505
+ q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
506
+
507
+ q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
508
+
509
+ # add null key / value for classifier free guidance in prior net
510
+
511
+ nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
512
+ k = torch.cat((nk, k), dim = -2)
513
+ v = torch.cat((nv, v), dim = -2)
514
+
515
+ # add text conditioning, if present
516
+
517
+ if exists(context):
518
+ assert exists(self.to_context)
519
+ ck, cv = self.to_context(context).chunk(2, dim = -1)
520
+ k = torch.cat((ck, k), dim = -2)
521
+ v = torch.cat((cv, v), dim = -2)
522
+
523
+ # qk rmsnorm
524
+
525
+ q, k = map(l2norm, (q, k))
526
+ q = q * self.q_scale
527
+ k = k * self.k_scale
528
+
529
+ # calculate query / key similarities
530
+
531
+ sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale
532
+
533
+ # relative positional encoding (T5 style)
534
+
535
+ if not exists(attn_bias) and exists(self.rel_pos_bias):
536
+ attn_bias = self.rel_pos_bias(n, device = device, dtype = q.dtype)
537
+
538
+ if exists(attn_bias):
539
+ null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n)
540
+ attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1)
541
+ sim = sim + attn_bias
542
+
543
+ # masking
544
+
545
+ max_neg_value = -torch.finfo(sim.dtype).max
546
+
547
+ if self.causal:
548
+ i, j = sim.shape[-2:]
549
+ causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
550
+ sim = sim.masked_fill(causal_mask, max_neg_value)
551
+
552
+ if exists(mask):
553
+ mask = F.pad(mask, (1, 0), value = True)
554
+ mask = rearrange(mask, 'b j -> b 1 1 j')
555
+ sim = sim.masked_fill(~mask, max_neg_value)
556
+
557
+ # attention
558
+
559
+ attn = sim.softmax(dim = -1)
560
+
561
+ # aggregate values
562
+
563
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
564
+
565
+ out = rearrange(out, 'b h n d -> b n (h d)')
566
+ return self.to_out(out)
567
+
568
+ # pseudo conv2d that uses conv3d but with kernel size of 1 across frames dimension
569
+
570
+ def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs):
571
+ kernel = cast_tuple(kernel, 2)
572
+ stride = cast_tuple(stride, 2)
573
+ padding = cast_tuple(padding, 2)
574
+
575
+ if len(kernel) == 2:
576
+ kernel = (1, *kernel)
577
+
578
+ if len(stride) == 2:
579
+ stride = (1, *stride)
580
+
581
+ if len(padding) == 2:
582
+ padding = (0, *padding)
583
+
584
+ return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs)
585
+
586
+ class Pad(nn.Module):
587
+ def __init__(self, padding, value = 0.):
588
+ super().__init__()
589
+ self.padding = padding
590
+ self.value = value
591
+
592
+ def forward(self, x):
593
+ return F.pad(x, self.padding, value = self.value)
594
+
595
+ # decoder
596
+
597
+ def Upsample(dim, dim_out = None):
598
+ dim_out = default(dim_out, dim)
599
+
600
+ return nn.Sequential(
601
+ nn.Upsample(scale_factor = 2, mode = 'nearest'),
602
+ Conv2d(dim, dim_out, 3, padding = 1)
603
+ )
604
+
605
+ class PixelShuffleUpsample(nn.Module):
606
+ def __init__(self, dim, dim_out = None):
607
+ super().__init__()
608
+ dim_out = default(dim_out, dim)
609
+ conv = Conv2d(dim, dim_out * 4, 1)
610
+
611
+ self.net = nn.Sequential(
612
+ conv,
613
+ nn.SiLU()
614
+ )
615
+
616
+ self.pixel_shuffle = nn.PixelShuffle(2)
617
+
618
+ self.init_conv_(conv)
619
+
620
+ def init_conv_(self, conv):
621
+ o, i, f, h, w = conv.weight.shape
622
+ conv_weight = torch.empty(o // 4, i, f, h, w)
623
+ nn.init.kaiming_uniform_(conv_weight)
624
+ conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
625
+
626
+ conv.weight.data.copy_(conv_weight)
627
+ nn.init.zeros_(conv.bias.data)
628
+
629
+ def forward(self, x):
630
+ out = self.net(x)
631
+ frames = x.shape[2]
632
+ out = rearrange(out, 'b c f h w -> (b f) c h w')
633
+ out = self.pixel_shuffle(out)
634
+ return rearrange(out, '(b f) c h w -> b c f h w', f = frames)
635
+
636
+ def Downsample(dim, dim_out = None):
637
+ dim_out = default(dim_out, dim)
638
+ return nn.Sequential(
639
+ Rearrange('b c f (h p1) (w p2) -> b (c p1 p2) f h w', p1 = 2, p2 = 2),
640
+ Conv2d(dim * 4, dim_out, 1)
641
+ )
642
+
643
+ # temporal up and downsamples
644
+
645
+ class TemporalPixelShuffleUpsample(nn.Module):
646
+ def __init__(self, dim, dim_out = None, stride = 2):
647
+ super().__init__()
648
+ self.stride = stride
649
+ dim_out = default(dim_out, dim)
650
+ conv = nn.Conv1d(dim, dim_out * stride, 1)
651
+
652
+ self.net = nn.Sequential(
653
+ conv,
654
+ nn.SiLU()
655
+ )
656
+
657
+ self.pixel_shuffle = Rearrange('b (c r) n -> b c (n r)', r = stride)
658
+
659
+ self.init_conv_(conv)
660
+
661
+ def init_conv_(self, conv):
662
+ o, i, f = conv.weight.shape
663
+ conv_weight = torch.empty(o // self.stride, i, f)
664
+ nn.init.kaiming_uniform_(conv_weight)
665
+ conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.stride)
666
+
667
+ conv.weight.data.copy_(conv_weight)
668
+ nn.init.zeros_(conv.bias.data)
669
+
670
+ def forward(self, x):
671
+ b, c, f, h, w = x.shape
672
+ x = rearrange(x, 'b c f h w -> (b h w) c f')
673
+ out = self.net(x)
674
+ out = self.pixel_shuffle(out)
675
+ return rearrange(out, '(b h w) c f -> b c f h w', h = h, w = w)
676
+
677
+ def TemporalDownsample(dim, dim_out = None, stride = 2):
678
+ dim_out = default(dim_out, dim)
679
+ return nn.Sequential(
680
+ Rearrange('b c (f p) h w -> b (c p) f h w', p = stride),
681
+ Conv2d(dim * stride, dim_out, 1)
682
+ )
683
+
684
+ # positional embedding
685
+
686
+ class SinusoidalPosEmb(nn.Module):
687
+ def __init__(self, dim):
688
+ super().__init__()
689
+ self.dim = dim
690
+
691
+ def forward(self, x):
692
+ half_dim = self.dim // 2
693
+ emb = math.log(10000) / (half_dim - 1)
694
+ emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
695
+ emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
696
+ return torch.cat((emb.sin(), emb.cos()), dim = -1)
697
+
698
+ class LearnedSinusoidalPosEmb(nn.Module):
699
+ def __init__(self, dim):
700
+ super().__init__()
701
+ assert (dim % 2) == 0
702
+ half_dim = dim // 2
703
+ self.weights = nn.Parameter(torch.randn(half_dim))
704
+
705
+ def forward(self, x):
706
+ x = rearrange(x, 'b -> b 1')
707
+ freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
708
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
709
+ fouriered = torch.cat((x, fouriered), dim = -1)
710
+ return fouriered
711
+
712
+ class Block(nn.Module):
713
+ def __init__(
714
+ self,
715
+ dim,
716
+ dim_out,
717
+ groups = 8,
718
+ norm = True
719
+ ):
720
+ super().__init__()
721
+ self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
722
+ self.activation = nn.SiLU()
723
+ self.project = Conv3d(dim, dim_out, 3, padding = 1)
724
+
725
+ def forward(
726
+ self,
727
+ x,
728
+ scale_shift = None,
729
+ ignore_time = False
730
+ ):
731
+ x = self.groupnorm(x)
732
+
733
+ if exists(scale_shift):
734
+ scale, shift = scale_shift
735
+ x = x * (scale + 1) + shift
736
+
737
+ x = self.activation(x)
738
+ return self.project(x, ignore_time = ignore_time)
739
+
740
+ class ResnetBlock(nn.Module):
741
+ def __init__(
742
+ self,
743
+ dim,
744
+ dim_out,
745
+ *,
746
+ cond_dim = None,
747
+ time_cond_dim = None,
748
+ groups = 8,
749
+ linear_attn = False,
750
+ use_gca = False,
751
+ squeeze_excite = False,
752
+ **attn_kwargs
753
+ ):
754
+ super().__init__()
755
+
756
+ self.time_mlp = None
757
+
758
+ if exists(time_cond_dim):
759
+ self.time_mlp = nn.Sequential(
760
+ nn.SiLU(),
761
+ nn.Linear(time_cond_dim, dim_out * 2)
762
+ )
763
+
764
+ self.cross_attn = None
765
+
766
+ if exists(cond_dim):
767
+ attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
768
+
769
+ self.cross_attn = attn_klass(
770
+ dim = dim_out,
771
+ context_dim = cond_dim,
772
+ **attn_kwargs
773
+ )
774
+
775
+ self.block1 = Block(dim, dim_out, groups = groups)
776
+ self.block2 = Block(dim_out, dim_out, groups = groups)
777
+
778
+ self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
779
+
780
+ self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()
781
+
782
+
783
+ def forward(
784
+ self,
785
+ x,
786
+ time_emb = None,
787
+ cond = None,
788
+ ignore_time = False
789
+ ):
790
+
791
+ scale_shift = None
792
+ if exists(self.time_mlp) and exists(time_emb):
793
+ time_emb = self.time_mlp(time_emb)
794
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
795
+ scale_shift = time_emb.chunk(2, dim = 1)
796
+
797
+ h = self.block1(x, ignore_time = ignore_time)
798
+
799
+ if exists(self.cross_attn):
800
+ assert exists(cond)
801
+ h = rearrange(h, 'b c ... -> b ... c')
802
+ h, ps = pack([h], 'b * c')
803
+
804
+ h = self.cross_attn(h, context = cond) + h
805
+
806
+ h, = unpack(h, ps, 'b * c')
807
+ h = rearrange(h, 'b ... c -> b c ...')
808
+
809
+ h = self.block2(h, scale_shift = scale_shift, ignore_time = ignore_time)
810
+
811
+ h = h * self.gca(h)
812
+
813
+ return h + self.res_conv(x)
814
+
815
+ class CrossAttention(nn.Module):
816
+ def __init__(
817
+ self,
818
+ dim,
819
+ *,
820
+ context_dim = None,
821
+ dim_head = 64,
822
+ heads = 8,
823
+ norm_context = False,
824
+ scale = 8
825
+ ):
826
+ super().__init__()
827
+ self.scale = scale
828
+
829
+ self.heads = heads
830
+ inner_dim = dim_head * heads
831
+
832
+ context_dim = default(context_dim, dim)
833
+
834
+ self.norm = LayerNorm(dim)
835
+ self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
836
+
837
+ self.null_kv = nn.Parameter(torch.randn(2, dim_head))
838
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
839
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
840
+
841
+ self.q_scale = nn.Parameter(torch.ones(dim_head))
842
+ self.k_scale = nn.Parameter(torch.ones(dim_head))
843
+
844
+ self.to_out = nn.Sequential(
845
+ nn.Linear(inner_dim, dim, bias = False),
846
+ LayerNorm(dim)
847
+ )
848
+
849
+ def forward(self, x, context, mask = None):
850
+ b, n, device = *x.shape[:2], x.device
851
+
852
+ x = self.norm(x)
853
+ context = self.norm_context(context)
854
+
855
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
856
+
857
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
858
+
859
+ # add null key / value for classifier free guidance in prior net
860
+
861
+ nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
862
+
863
+ k = torch.cat((nk, k), dim = -2)
864
+ v = torch.cat((nv, v), dim = -2)
865
+
866
+ # qk rmsnorm
867
+
868
+ q, k = map(l2norm, (q, k))
869
+ q = q * self.q_scale
870
+ k = k * self.k_scale
871
+
872
+ # similarities
873
+
874
+ sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
875
+
876
+ # masking
877
+
878
+ max_neg_value = -torch.finfo(sim.dtype).max
879
+
880
+ if exists(mask):
881
+ mask = F.pad(mask, (1, 0), value = True)
882
+ mask = rearrange(mask, 'b j -> b 1 1 j')
883
+ sim = sim.masked_fill(~mask, max_neg_value)
884
+
885
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
886
+
887
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
888
+ out = rearrange(out, 'b h n d -> b n (h d)')
889
+ return self.to_out(out)
890
+
891
+ class LinearCrossAttention(CrossAttention):
892
+ def forward(self, x, context, mask = None):
893
+ b, n, device = *x.shape[:2], x.device
894
+
895
+ x = self.norm(x)
896
+ context = self.norm_context(context)
897
+
898
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
899
+
900
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
901
+
902
+ # add null key / value for classifier free guidance in prior net
903
+
904
+ nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
905
+
906
+ k = torch.cat((nk, k), dim = -2)
907
+ v = torch.cat((nv, v), dim = -2)
908
+
909
+ # masking
910
+
911
+ max_neg_value = -torch.finfo(x.dtype).max
912
+
913
+ if exists(mask):
914
+ mask = F.pad(mask, (1, 0), value = True)
915
+ mask = rearrange(mask, 'b n -> b n 1')
916
+ k = k.masked_fill(~mask, max_neg_value)
917
+ v = v.masked_fill(~mask, 0.)
918
+
919
+ # linear attention
920
+
921
+ q = q.softmax(dim = -1)
922
+ k = k.softmax(dim = -2)
923
+
924
+ q = q * self.scale
925
+
926
+ context = einsum('b n d, b n e -> b d e', k, v)
927
+ out = einsum('b n d, b d e -> b n e', q, context)
928
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
929
+ return self.to_out(out)
930
+
931
+ class LinearAttention(nn.Module):
932
+ def __init__(
933
+ self,
934
+ dim,
935
+ dim_head = 32,
936
+ heads = 8,
937
+ dropout = 0.05,
938
+ context_dim = None,
939
+ **kwargs
940
+ ):
941
+ super().__init__()
942
+ self.scale = dim_head ** -0.5
943
+ self.heads = heads
944
+ inner_dim = dim_head * heads
945
+ self.norm = ChanLayerNorm(dim)
946
+
947
+ self.nonlin = nn.SiLU()
948
+
949
+ self.to_q = nn.Sequential(
950
+ nn.Dropout(dropout),
951
+ Conv2d(dim, inner_dim, 1, bias = False),
952
+ Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
953
+ )
954
+
955
+ self.to_k = nn.Sequential(
956
+ nn.Dropout(dropout),
957
+ Conv2d(dim, inner_dim, 1, bias = False),
958
+ Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
959
+ )
960
+
961
+ self.to_v = nn.Sequential(
962
+ nn.Dropout(dropout),
963
+ Conv2d(dim, inner_dim, 1, bias = False),
964
+ Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
965
+ )
966
+
967
+ self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
968
+
969
+ self.to_out = nn.Sequential(
970
+ Conv2d(inner_dim, dim, 1, bias = False),
971
+ ChanLayerNorm(dim)
972
+ )
973
+
974
+ def forward(self, fmap, context = None):
975
+ h, x, y = self.heads, *fmap.shape[-2:]
976
+
977
+ fmap = self.norm(fmap)
978
+ q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
979
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
980
+
981
+ if exists(context):
982
+ assert exists(self.to_context)
983
+ ck, cv = self.to_context(context).chunk(2, dim = -1)
984
+ ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv))
985
+ k = torch.cat((k, ck), dim = -2)
986
+ v = torch.cat((v, cv), dim = -2)
987
+
988
+ q = q.softmax(dim = -1)
989
+ k = k.softmax(dim = -2)
990
+
991
+ q = q * self.scale
992
+
993
+ context = einsum('b n d, b n e -> b d e', k, v)
994
+ out = einsum('b n d, b d e -> b n e', q, context)
995
+ out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
996
+
997
+ out = self.nonlin(out)
998
+ return self.to_out(out)
999
+
1000
+ class GlobalContext(nn.Module):
1001
+ """ basically a superior form of squeeze-excitation that is attention-esque """
1002
+
1003
+ def __init__(
1004
+ self,
1005
+ *,
1006
+ dim_in,
1007
+ dim_out
1008
+ ):
1009
+ super().__init__()
1010
+ self.to_k = Conv2d(dim_in, 1, 1)
1011
+ hidden_dim = max(3, dim_out // 2)
1012
+
1013
+ self.net = nn.Sequential(
1014
+ Conv2d(dim_in, hidden_dim, 1),
1015
+ nn.SiLU(),
1016
+ Conv2d(hidden_dim, dim_out, 1),
1017
+ nn.Sigmoid()
1018
+ )
1019
+
1020
+ def forward(self, x):
1021
+ context = self.to_k(x)
1022
+ x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context))
1023
+ out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
1024
+ out = rearrange(out, '... -> ... 1 1')
1025
+ return self.net(out)
1026
+
1027
+ def FeedForward(dim, mult = 2):
1028
+ hidden_dim = int(dim * mult)
1029
+ return nn.Sequential(
1030
+ LayerNorm(dim),
1031
+ nn.Linear(dim, hidden_dim, bias = False),
1032
+ nn.GELU(),
1033
+ LayerNorm(hidden_dim),
1034
+ nn.Linear(hidden_dim, dim, bias = False)
1035
+ )
1036
+
1037
+ class TimeTokenShift(nn.Module):
1038
+ def forward(self, x):
1039
+ if x.ndim != 5:
1040
+ return x
1041
+
1042
+ x, x_shift = x.chunk(2, dim = 1)
1043
+ x_shift = F.pad(x_shift, (0, 0, 0, 0, 1, -1), value = 0.)
1044
+ return torch.cat((x, x_shift), dim = 1)
1045
+
1046
+ def ChanFeedForward(dim, mult = 2, time_token_shift = True): # in paper, it seems for self attention layers they did feedforwards with twice channel width
1047
+ hidden_dim = int(dim * mult)
1048
+ return Sequential(
1049
+ ChanLayerNorm(dim),
1050
+ Conv2d(dim, hidden_dim, 1, bias = False),
1051
+ nn.GELU(),
1052
+ TimeTokenShift() if time_token_shift else None,
1053
+ ChanLayerNorm(hidden_dim),
1054
+ Conv2d(hidden_dim, dim, 1, bias = False)
1055
+ )
1056
+
1057
+ class TransformerBlock(nn.Module):
1058
+ def __init__(
1059
+ self,
1060
+ dim,
1061
+ *,
1062
+ depth = 1,
1063
+ heads = 8,
1064
+ dim_head = 32,
1065
+ ff_mult = 2,
1066
+ ff_time_token_shift = True,
1067
+ context_dim = None
1068
+ ):
1069
+ super().__init__()
1070
+ self.layers = nn.ModuleList([])
1071
+
1072
+ for _ in range(depth):
1073
+ self.layers.append(nn.ModuleList([
1074
+ Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
1075
+ ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)
1076
+ ]))
1077
+
1078
+ def forward(self, x, context = None):
1079
+ for attn, ff in self.layers:
1080
+ x = rearrange(x, 'b c ... -> b ... c')
1081
+ x, ps = pack([x], 'b * c')
1082
+
1083
+ x = attn(x, context = context) + x
1084
+
1085
+ x, = unpack(x, ps, 'b * c')
1086
+ x = rearrange(x, 'b ... c -> b c ...')
1087
+
1088
+ x = ff(x) + x
1089
+ return x
1090
+
1091
+ class LinearAttentionTransformerBlock(nn.Module):
1092
+ def __init__(
1093
+ self,
1094
+ dim,
1095
+ *,
1096
+ depth = 1,
1097
+ heads = 8,
1098
+ dim_head = 32,
1099
+ ff_mult = 2,
1100
+ ff_time_token_shift = True,
1101
+ context_dim = None,
1102
+ **kwargs
1103
+ ):
1104
+ super().__init__()
1105
+ self.layers = nn.ModuleList([])
1106
+
1107
+ for _ in range(depth):
1108
+ self.layers.append(nn.ModuleList([
1109
+ LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
1110
+ ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)
1111
+ ]))
1112
+
1113
+ def forward(self, x, context = None):
1114
+ for attn, ff in self.layers:
1115
+ x = attn(x, context = context) + x
1116
+ x = ff(x) + x
1117
+ return x
1118
+
1119
+ class CrossEmbedLayer(nn.Module):
1120
+ def __init__(
1121
+ self,
1122
+ dim_in,
1123
+ kernel_sizes,
1124
+ dim_out = None,
1125
+ stride = 2
1126
+ ):
1127
+ super().__init__()
1128
+ assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
1129
+ dim_out = default(dim_out, dim_in)
1130
+
1131
+ kernel_sizes = sorted(kernel_sizes)
1132
+ num_scales = len(kernel_sizes)
1133
+
1134
+ # calculate the dimension at each scale
1135
+ dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
1136
+ dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
1137
+
1138
+ self.convs = nn.ModuleList([])
1139
+ for kernel, dim_scale in zip(kernel_sizes, dim_scales):
1140
+ self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
1141
+
1142
+ def forward(self, x):
1143
+ fmaps = tuple(map(lambda conv: conv(x), self.convs))
1144
+ return torch.cat(fmaps, dim = 1)
1145
+
1146
+ class UpsampleCombiner(nn.Module):
1147
+ def __init__(
1148
+ self,
1149
+ dim,
1150
+ *,
1151
+ enabled = False,
1152
+ dim_ins = tuple(),
1153
+ dim_outs = tuple()
1154
+ ):
1155
+ super().__init__()
1156
+ dim_outs = cast_tuple(dim_outs, len(dim_ins))
1157
+ assert len(dim_ins) == len(dim_outs)
1158
+
1159
+ self.enabled = enabled
1160
+
1161
+ if not self.enabled:
1162
+ self.dim_out = dim
1163
+ return
1164
+
1165
+ self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
1166
+ self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
1167
+
1168
+ def forward(self, x, fmaps = None):
1169
+ target_size = x.shape[-1]
1170
+
1171
+ fmaps = default(fmaps, tuple())
1172
+
1173
+ if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
1174
+ return x
1175
+
1176
+ fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps]
1177
+ outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
1178
+ return torch.cat((x, *outs), dim = 1)
1179
+
1180
+ class DynamicPositionBias(nn.Module):
1181
+ def __init__(
1182
+ self,
1183
+ dim,
1184
+ *,
1185
+ heads,
1186
+ depth
1187
+ ):
1188
+ super().__init__()
1189
+ self.mlp = nn.ModuleList([])
1190
+
1191
+ self.mlp.append(nn.Sequential(
1192
+ nn.Linear(1, dim),
1193
+ LayerNorm(dim),
1194
+ nn.SiLU()
1195
+ ))
1196
+
1197
+ for _ in range(max(depth - 1, 0)):
1198
+ self.mlp.append(nn.Sequential(
1199
+ nn.Linear(dim, dim),
1200
+ LayerNorm(dim),
1201
+ nn.SiLU()
1202
+ ))
1203
+
1204
+ self.mlp.append(nn.Linear(dim, heads))
1205
+
1206
+ def forward(self, n, device, dtype):
1207
+ i = torch.arange(n, device = device)
1208
+ j = torch.arange(n, device = device)
1209
+
1210
+ indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j')
1211
+ indices += (n - 1)
1212
+
1213
+ pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
1214
+ pos = rearrange(pos, '... -> ... 1')
1215
+
1216
+ for layer in self.mlp:
1217
+ pos = layer(pos)
1218
+
1219
+ bias = pos[indices]
1220
+ bias = rearrange(bias, 'i j h -> h i j')
1221
+ return bias
1222
+
1223
+ class Unet3D(nn.Module):
1224
+ def __init__(
1225
+ self,
1226
+ *,
1227
+ dim,
1228
+ text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
1229
+ num_resnet_blocks = 1,
1230
+ cond_dim = None,
1231
+ num_image_tokens = 4,
1232
+ num_time_tokens = 2,
1233
+ learned_sinu_pos_emb_dim = 16,
1234
+ out_dim = None,
1235
+ dim_mults = (1, 2, 4, 8),
1236
+ temporal_strides = 1,
1237
+ cond_images_channels = 0,
1238
+ channels = 3,
1239
+ channels_out = None,
1240
+ attn_dim_head = 64,
1241
+ attn_heads = 8,
1242
+ ff_mult = 2.,
1243
+ ff_time_token_shift = True, # this would do a token shift along time axis, at the hidden layer within feedforwards - from successful use in RWKV (Peng et al), and other token shift video transformer works
1244
+ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
1245
+ layer_attns = False,
1246
+ layer_attns_depth = 1,
1247
+ layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
1248
+ attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
1249
+ time_rel_pos_bias_depth = 2,
1250
+ time_causal_attn = True,
1251
+ layer_cross_attns = True,
1252
+ use_linear_attn = False,
1253
+ use_linear_cross_attn = False,
1254
+ cond_on_text = True,
1255
+ max_text_len = 256,
1256
+ init_dim = None,
1257
+ resnet_groups = 8,
1258
+ init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
1259
+ init_cross_embed = True,
1260
+ init_cross_embed_kernel_sizes = (3, 7, 15),
1261
+ cross_embed_downsample = False,
1262
+ cross_embed_downsample_kernel_sizes = (2, 4),
1263
+ attn_pool_text = True,
1264
+ attn_pool_num_latents = 32,
1265
+ dropout = 0.,
1266
+ memory_efficient = False,
1267
+ init_conv_to_final_conv_residual = False,
1268
+ use_global_context_attn = True,
1269
+ scale_skip_connection = True,
1270
+ final_resnet_block = True,
1271
+ final_conv_kernel_size = 3,
1272
+ self_cond = False,
1273
+ combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
1274
+ pixel_shuffle_upsample = True, # may address checkboard artifacts
1275
+ resize_mode = 'nearest'
1276
+ ):
1277
+ super().__init__()
1278
+
1279
+ # guide researchers
1280
+
1281
+ assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
1282
+
1283
+ if dim < 128:
1284
+ print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
1285
+
1286
+ # save locals to take care of some hyperparameters for cascading DDPM
1287
+
1288
+ self._locals = locals()
1289
+ self._locals.pop('self', None)
1290
+ self._locals.pop('__class__', None)
1291
+
1292
+ self.self_cond = self_cond
1293
+
1294
+ # determine dimensions
1295
+
1296
+ self.channels = channels
1297
+ self.channels_out = default(channels_out, channels)
1298
+
1299
+ # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
1300
+ # (2) in self conditioning, one appends the predict x0 (x_start)
1301
+ init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
1302
+ init_dim = default(init_dim, dim)
1303
+
1304
+ # optional image conditioning
1305
+
1306
+ self.has_cond_image = cond_images_channels > 0
1307
+ self.cond_images_channels = cond_images_channels
1308
+
1309
+ init_channels += cond_images_channels
1310
+
1311
+ # initial convolution
1312
+
1313
+ self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
1314
+
1315
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
1316
+ in_out = list(zip(dims[:-1], dims[1:]))
1317
+
1318
+ # time conditioning
1319
+
1320
+ cond_dim = default(cond_dim, dim)
1321
+ time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
1322
+
1323
+ # embedding time for log(snr) noise from continuous version
1324
+
1325
+ sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
1326
+ sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
1327
+
1328
+ self.to_time_hiddens = nn.Sequential(
1329
+ sinu_pos_emb,
1330
+ nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
1331
+ nn.SiLU()
1332
+ )
1333
+
1334
+ self.to_time_cond = nn.Sequential(
1335
+ nn.Linear(time_cond_dim, time_cond_dim)
1336
+ )
1337
+
1338
+ # project to time tokens as well as time hiddens
1339
+
1340
+ self.to_time_tokens = nn.Sequential(
1341
+ nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
1342
+ Rearrange('b (r d) -> b r d', r = num_time_tokens)
1343
+ )
1344
+
1345
+ # low res aug noise conditioning
1346
+
1347
+ self.lowres_cond = lowres_cond
1348
+
1349
+ if lowres_cond:
1350
+ self.to_lowres_time_hiddens = nn.Sequential(
1351
+ LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
1352
+ nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
1353
+ nn.SiLU()
1354
+ )
1355
+
1356
+ self.to_lowres_time_cond = nn.Sequential(
1357
+ nn.Linear(time_cond_dim, time_cond_dim)
1358
+ )
1359
+
1360
+ self.to_lowres_time_tokens = nn.Sequential(
1361
+ nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
1362
+ Rearrange('b (r d) -> b r d', r = num_time_tokens)
1363
+ )
1364
+
1365
+ # normalizations
1366
+
1367
+ self.norm_cond = nn.LayerNorm(cond_dim)
1368
+
1369
+ # text encoding conditioning (optional)
1370
+
1371
+ self.text_to_cond = None
1372
+
1373
+ if cond_on_text:
1374
+ assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
1375
+ self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
1376
+
1377
+ # finer control over whether to condition on text encodings
1378
+
1379
+ self.cond_on_text = cond_on_text
1380
+
1381
+ # attention pooling
1382
+
1383
+ self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents) if attn_pool_text else None
1384
+
1385
+ # for classifier free guidance
1386
+
1387
+ self.max_text_len = max_text_len
1388
+
1389
+ self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
1390
+ self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
1391
+
1392
+ # for non-attention based text conditioning at all points in the network where time is also conditioned
1393
+
1394
+ self.to_text_non_attn_cond = None
1395
+
1396
+ if cond_on_text:
1397
+ self.to_text_non_attn_cond = nn.Sequential(
1398
+ nn.LayerNorm(cond_dim),
1399
+ nn.Linear(cond_dim, time_cond_dim),
1400
+ nn.SiLU(),
1401
+ nn.Linear(time_cond_dim, time_cond_dim)
1402
+ )
1403
+
1404
+ # attention related params
1405
+
1406
+ attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
1407
+
1408
+ num_layers = len(in_out)
1409
+
1410
+ # temporal attention - attention across video frames
1411
+
1412
+ temporal_peg_padding = (0, 0, 0, 0, 2, 0) if time_causal_attn else (0, 0, 0, 0, 1, 1)
1413
+ temporal_peg = lambda dim: Residual(nn.Sequential(Pad(temporal_peg_padding), nn.Conv3d(dim, dim, (3, 1, 1), groups = dim)))
1414
+
1415
+ temporal_attn = lambda dim: RearrangeTimeCentric(Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn, 'init_zero': True, 'rel_pos_bias': True})))
1416
+
1417
+ # resnet block klass
1418
+
1419
+ num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
1420
+ resnet_groups = cast_tuple(resnet_groups, num_layers)
1421
+
1422
+ resnet_klass = partial(ResnetBlock, **attn_kwargs)
1423
+
1424
+ layer_attns = cast_tuple(layer_attns, num_layers)
1425
+ layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
1426
+ layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
1427
+
1428
+ assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
1429
+
1430
+ # temporal downsample config
1431
+
1432
+ temporal_strides = cast_tuple(temporal_strides, num_layers)
1433
+ self.total_temporal_divisor = functools.reduce(operator.mul, temporal_strides, 1)
1434
+
1435
+ # downsample klass
1436
+
1437
+ downsample_klass = Downsample
1438
+
1439
+ if cross_embed_downsample:
1440
+ downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
1441
+
1442
+ # initial resnet block (for memory efficient unet)
1443
+
1444
+ self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
1445
+
1446
+ self.init_temporal_peg = temporal_peg(init_dim)
1447
+ self.init_temporal_attn = temporal_attn(init_dim)
1448
+
1449
+ # scale for resnet skip connections
1450
+
1451
+ self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
1452
+
1453
+ # layers
1454
+
1455
+ self.downs = nn.ModuleList([])
1456
+ self.ups = nn.ModuleList([])
1457
+ num_resolutions = len(in_out)
1458
+
1459
+ layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, temporal_strides]
1460
+ reversed_layer_params = list(map(reversed, layer_params))
1461
+
1462
+ # downsampling layers
1463
+
1464
+ skip_connect_dims = [] # keep track of skip connection dimensions
1465
+
1466
+ for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(in_out, *layer_params)):
1467
+ is_last = ind >= (num_resolutions - 1)
1468
+
1469
+ layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
1470
+ layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
1471
+
1472
+ transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
1473
+
1474
+ current_dim = dim_in
1475
+
1476
+ # whether to pre-downsample, from memory efficient unet
1477
+
1478
+ pre_downsample = None
1479
+
1480
+ if memory_efficient:
1481
+ pre_downsample = downsample_klass(dim_in, dim_out)
1482
+ current_dim = dim_out
1483
+
1484
+ skip_connect_dims.append(current_dim)
1485
+
1486
+ # whether to do post-downsample, for non-memory efficient unet
1487
+
1488
+ post_downsample = None
1489
+ if not memory_efficient:
1490
+ post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(Conv2d(dim_in, dim_out, 3, padding = 1), Conv2d(dim_in, dim_out, 1))
1491
+
1492
+ self.downs.append(nn.ModuleList([
1493
+ pre_downsample,
1494
+ resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
1495
+ nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
1496
+ transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs),
1497
+ temporal_peg(current_dim),
1498
+ temporal_attn(current_dim),
1499
+ TemporalDownsample(current_dim, stride = temporal_stride) if temporal_stride > 1 else None,
1500
+ post_downsample
1501
+ ]))
1502
+
1503
+ # middle layers
1504
+
1505
+ mid_dim = dims[-1]
1506
+
1507
+ self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
1508
+ self.mid_attn = EinopsToAndFrom('b c f h w', 'b (f h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
1509
+ self.mid_temporal_peg = temporal_peg(mid_dim)
1510
+ self.mid_temporal_attn = temporal_attn(mid_dim)
1511
+ self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
1512
+
1513
+ # upsample klass
1514
+
1515
+ upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
1516
+
1517
+ # upsampling layers
1518
+
1519
+ upsample_fmap_dims = []
1520
+
1521
+ for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
1522
+ is_last = ind == (len(in_out) - 1)
1523
+ layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
1524
+ layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
1525
+ transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
1526
+
1527
+ skip_connect_dim = skip_connect_dims.pop()
1528
+
1529
+ upsample_fmap_dims.append(dim_out)
1530
+
1531
+ self.ups.append(nn.ModuleList([
1532
+ resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
1533
+ nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
1534
+ transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs),
1535
+ temporal_peg(dim_out),
1536
+ temporal_attn(dim_out),
1537
+ TemporalPixelShuffleUpsample(dim_out, stride = temporal_stride) if temporal_stride > 1 else None,
1538
+ upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
1539
+ ]))
1540
+
1541
+ # whether to combine feature maps from all upsample blocks before final resnet block out
1542
+
1543
+ self.upsample_combiner = UpsampleCombiner(
1544
+ dim = dim,
1545
+ enabled = combine_upsample_fmaps,
1546
+ dim_ins = upsample_fmap_dims,
1547
+ dim_outs = dim
1548
+ )
1549
+
1550
+ # whether to do a final residual from initial conv to the final resnet block out
1551
+
1552
+ self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
1553
+ final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
1554
+
1555
+ # final optional resnet block and convolution out
1556
+
1557
+ self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
1558
+
1559
+ final_conv_dim_in = dim if final_resnet_block else final_conv_dim
1560
+ final_conv_dim_in += (channels if lowres_cond else 0)
1561
+
1562
+ self.final_conv = Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)
1563
+
1564
+ zero_init_(self.final_conv)
1565
+
1566
+ # resize mode
1567
+
1568
+ self.resize_mode = resize_mode
1569
+
1570
+ # if the current settings for the unet are not correct
1571
+ # for cascading DDPM, then reinit the unet with the right settings
1572
+ def cast_model_parameters(
1573
+ self,
1574
+ *,
1575
+ lowres_cond,
1576
+ text_embed_dim,
1577
+ channels,
1578
+ channels_out,
1579
+ cond_on_text
1580
+ ):
1581
+ if lowres_cond == self.lowres_cond and \
1582
+ channels == self.channels and \
1583
+ cond_on_text == self.cond_on_text and \
1584
+ text_embed_dim == self._locals['text_embed_dim'] and \
1585
+ channels_out == self.channels_out:
1586
+ return self
1587
+
1588
+ updated_kwargs = dict(
1589
+ lowres_cond = lowres_cond,
1590
+ text_embed_dim = text_embed_dim,
1591
+ channels = channels,
1592
+ channels_out = channels_out,
1593
+ cond_on_text = cond_on_text
1594
+ )
1595
+
1596
+ return self.__class__(**{**self._locals, **updated_kwargs})
1597
+
1598
+ # methods for returning the full unet config as well as its parameter state
1599
+
1600
+ def to_config_and_state_dict(self):
1601
+ return self._locals, self.state_dict()
1602
+
1603
+ # class method for rehydrating the unet from its config and state dict
1604
+
1605
+ @classmethod
1606
+ def from_config_and_state_dict(klass, config, state_dict):
1607
+ unet = klass(**config)
1608
+ unet.load_state_dict(state_dict)
1609
+ return unet
1610
+
1611
+ # methods for persisting unet to disk
1612
+
1613
+ def persist_to_file(self, path):
1614
+ path = Path(path)
1615
+ path.parents[0].mkdir(exist_ok = True, parents = True)
1616
+
1617
+ config, state_dict = self.to_config_and_state_dict()
1618
+ pkg = dict(config = config, state_dict = state_dict)
1619
+ torch.save(pkg, str(path))
1620
+
1621
+ # class method for rehydrating the unet from file saved with `persist_to_file`
1622
+
1623
+ @classmethod
1624
+ def hydrate_from_file(klass, path):
1625
+ path = Path(path)
1626
+ assert path.exists()
1627
+ pkg = torch.load(str(path))
1628
+
1629
+ assert 'config' in pkg and 'state_dict' in pkg
1630
+ config, state_dict = pkg['config'], pkg['state_dict']
1631
+
1632
+ return Unet.from_config_and_state_dict(config, state_dict)
1633
+
1634
+ # forward with classifier free guidance
1635
+
1636
+ def forward_with_cond_scale(
1637
+ self,
1638
+ *args,
1639
+ cond_scale = 1.,
1640
+ **kwargs
1641
+ ):
1642
+ logits = self.forward(*args, **kwargs)
1643
+
1644
+ if cond_scale == 1:
1645
+ return logits
1646
+
1647
+ null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
1648
+ return null_logits + (logits - null_logits) * cond_scale
1649
+
1650
+ def forward(
1651
+ self,
1652
+ x,
1653
+ time,
1654
+ *,
1655
+ lowres_cond_img = None,
1656
+ lowres_noise_times = None,
1657
+ text_embeds = None,
1658
+ text_mask = None,
1659
+ cond_images = None,
1660
+ cond_video_frames = None,
1661
+ post_cond_video_frames = None,
1662
+ self_cond = None,
1663
+ cond_drop_prob = 0.,
1664
+ ignore_time = False
1665
+ ):
1666
+ assert x.ndim == 5, 'input to 3d unet must have 5 dimensions (batch, channels, time, height, width)'
1667
+
1668
+ batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype
1669
+
1670
+ assert ignore_time or divisible_by(frames, self.total_temporal_divisor), f'number of input frames {frames} must be divisible by {self.total_temporal_divisor}'
1671
+
1672
+ # add self conditioning if needed
1673
+
1674
+ if self.self_cond:
1675
+ self_cond = default(self_cond, lambda: torch.zeros_like(x))
1676
+ x = torch.cat((x, self_cond), dim = 1)
1677
+
1678
+ # add low resolution conditioning, if present
1679
+
1680
+ assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
1681
+ assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'
1682
+
1683
+ if exists(lowres_cond_img):
1684
+ x = torch.cat((x, lowres_cond_img), dim = 1)
1685
+
1686
+ if exists(cond_video_frames):
1687
+ lowres_cond_img = torch.cat((cond_video_frames, lowres_cond_img), dim = 2)
1688
+ cond_video_frames = torch.cat((cond_video_frames, cond_video_frames), dim = 1)
1689
+
1690
+ if exists(post_cond_video_frames):
1691
+ lowres_cond_img = torch.cat((lowres_cond_img, post_cond_video_frames), dim = 2)
1692
+ post_cond_video_frames = torch.cat((post_cond_video_frames, post_cond_video_frames), dim = 1)
1693
+
1694
+ # conditioning on video frames as a prompt
1695
+
1696
+ num_preceding_frames = 0
1697
+ if exists(cond_video_frames):
1698
+ cond_video_frames_len = cond_video_frames.shape[2]
1699
+
1700
+ assert divisible_by(cond_video_frames_len, self.total_temporal_divisor)
1701
+
1702
+ cond_video_frames = resize_video_to(cond_video_frames, x.shape[-1])
1703
+ x = torch.cat((cond_video_frames, x), dim = 2)
1704
+
1705
+ num_preceding_frames = cond_video_frames_len
1706
+
1707
+ # conditioning on video frames as a prompt
1708
+
1709
+ num_succeeding_frames = 0
1710
+ if exists(post_cond_video_frames):
1711
+ cond_video_frames_len = post_cond_video_frames.shape[2]
1712
+
1713
+ assert divisible_by(cond_video_frames_len, self.total_temporal_divisor)
1714
+
1715
+ post_cond_video_frames = resize_video_to(post_cond_video_frames, x.shape[-1])
1716
+ x = torch.cat((post_cond_video_frames, x), dim = 2)
1717
+
1718
+ num_succeeding_frames = cond_video_frames_len
1719
+
1720
+ # condition on input image
1721
+
1722
+ assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
1723
+
1724
+ if exists(cond_images):
1725
+ assert cond_images.ndim == 4, 'conditioning images must have 4 dimensions only, if you want to condition on frames of video, use `cond_video_frames` instead'
1726
+ assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
1727
+
1728
+ cond_images = repeat(cond_images, 'b c h w -> b c f h w', f = x.shape[2])
1729
+ cond_images = resize_video_to(cond_images, x.shape[-1], mode = self.resize_mode)
1730
+
1731
+ x = torch.cat((cond_images, x), dim = 1)
1732
+
1733
+ # ignoring time in pseudo 3d resnet blocks
1734
+
1735
+ conv_kwargs = dict(
1736
+ ignore_time = ignore_time
1737
+ )
1738
+
1739
+ # initial convolution
1740
+
1741
+ x = self.init_conv(x)
1742
+
1743
+ if not ignore_time:
1744
+ x = self.init_temporal_peg(x)
1745
+ x = self.init_temporal_attn(x)
1746
+
1747
+ # init conv residual
1748
+
1749
+ if self.init_conv_to_final_conv_residual:
1750
+ init_conv_residual = x.clone()
1751
+
1752
+ # time conditioning
1753
+
1754
+ time_hiddens = self.to_time_hiddens(time)
1755
+
1756
+ # derive time tokens
1757
+
1758
+ time_tokens = self.to_time_tokens(time_hiddens)
1759
+ t = self.to_time_cond(time_hiddens)
1760
+
1761
+ # add lowres time conditioning to time hiddens
1762
+ # and add lowres time tokens along sequence dimension for attention
1763
+
1764
+ if self.lowres_cond:
1765
+ lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
1766
+ lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
1767
+ lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
1768
+
1769
+ t = t + lowres_t
1770
+ time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)
1771
+
1772
+ # text conditioning
1773
+
1774
+ text_tokens = None
1775
+
1776
+ if exists(text_embeds) and self.cond_on_text:
1777
+
1778
+ # conditional dropout
1779
+
1780
+ text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
1781
+
1782
+ text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
1783
+ text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
1784
+
1785
+ # calculate text embeds
1786
+
1787
+ text_tokens = self.text_to_cond(text_embeds)
1788
+
1789
+ text_tokens = text_tokens[:, :self.max_text_len]
1790
+
1791
+ if exists(text_mask):
1792
+ text_mask = text_mask[:, :self.max_text_len]
1793
+
1794
+ text_tokens_len = text_tokens.shape[1]
1795
+ remainder = self.max_text_len - text_tokens_len
1796
+
1797
+ if remainder > 0:
1798
+ text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
1799
+
1800
+ if exists(text_mask):
1801
+ if remainder > 0:
1802
+ text_mask = F.pad(text_mask, (0, remainder), value = False)
1803
+
1804
+ text_mask = rearrange(text_mask, 'b n -> b n 1')
1805
+ text_keep_mask_embed = text_mask & text_keep_mask_embed
1806
+
1807
+ null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
1808
+
1809
+ text_tokens = torch.where(
1810
+ text_keep_mask_embed,
1811
+ text_tokens,
1812
+ null_text_embed
1813
+ )
1814
+
1815
+ if exists(self.attn_pool):
1816
+ text_tokens = self.attn_pool(text_tokens)
1817
+
1818
+ # extra non-attention conditioning by projecting and then summing text embeddings to time
1819
+ # termed as text hiddens
1820
+
1821
+ mean_pooled_text_tokens = text_tokens.mean(dim = -2)
1822
+
1823
+ text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
1824
+
1825
+ null_text_hidden = self.null_text_hidden.to(t.dtype)
1826
+
1827
+ text_hiddens = torch.where(
1828
+ text_keep_mask_hidden,
1829
+ text_hiddens,
1830
+ null_text_hidden
1831
+ )
1832
+
1833
+ t = t + text_hiddens
1834
+
1835
+ # main conditioning tokens (c)
1836
+
1837
+ c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)
1838
+
1839
+ # normalize conditioning tokens
1840
+
1841
+ c = self.norm_cond(c)
1842
+
1843
+ # initial resnet block (for memory efficient unet)
1844
+
1845
+ if exists(self.init_resnet_block):
1846
+ x = self.init_resnet_block(x, t, **conv_kwargs)
1847
+
1848
+ # go through the layers of the unet, down and up
1849
+
1850
+ hiddens = []
1851
+
1852
+ for pre_downsample, init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, temporal_downsample, post_downsample in self.downs:
1853
+ if exists(pre_downsample):
1854
+ x = pre_downsample(x)
1855
+
1856
+ x = init_block(x, t, c, **conv_kwargs)
1857
+
1858
+ for resnet_block in resnet_blocks:
1859
+ x = resnet_block(x, t, **conv_kwargs)
1860
+ hiddens.append(x)
1861
+
1862
+ x = attn_block(x, c)
1863
+
1864
+ if not ignore_time:
1865
+ x = temporal_peg(x)
1866
+ x = temporal_attn(x)
1867
+
1868
+ hiddens.append(x)
1869
+
1870
+ if exists(temporal_downsample) and not ignore_time:
1871
+ x = temporal_downsample(x)
1872
+
1873
+ if exists(post_downsample):
1874
+ x = post_downsample(x)
1875
+
1876
+ x = self.mid_block1(x, t, c, **conv_kwargs)
1877
+
1878
+ if exists(self.mid_attn):
1879
+ x = self.mid_attn(x)
1880
+
1881
+ if not ignore_time:
1882
+ x = self.mid_temporal_peg(x)
1883
+ x = self.mid_temporal_attn(x)
1884
+
1885
+ x = self.mid_block2(x, t, c, **conv_kwargs)
1886
+
1887
+ add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
1888
+
1889
+ up_hiddens = []
1890
+
1891
+ for init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, temporal_upsample, upsample in self.ups:
1892
+ if exists(temporal_upsample) and not ignore_time:
1893
+ x = temporal_upsample(x)
1894
+
1895
+ x = add_skip_connection(x)
1896
+ x = init_block(x, t, c, **conv_kwargs)
1897
+
1898
+ for resnet_block in resnet_blocks:
1899
+ x = add_skip_connection(x)
1900
+ x = resnet_block(x, t, **conv_kwargs)
1901
+
1902
+ x = attn_block(x, c)
1903
+
1904
+ if not ignore_time:
1905
+ x = temporal_peg(x)
1906
+ x = temporal_attn(x)
1907
+
1908
+ up_hiddens.append(x.contiguous())
1909
+
1910
+ x = upsample(x)
1911
+
1912
+ # whether to combine all feature maps from upsample blocks
1913
+
1914
+ x = self.upsample_combiner(x, up_hiddens)
1915
+
1916
+ # final top-most residual if needed
1917
+
1918
+ if self.init_conv_to_final_conv_residual:
1919
+ x = torch.cat((x, init_conv_residual), dim = 1)
1920
+
1921
+ if exists(self.final_res_block):
1922
+ x = self.final_res_block(x, t, **conv_kwargs)
1923
+
1924
+ if exists(lowres_cond_img):
1925
+ x = torch.cat((x, lowres_cond_img), dim = 1)
1926
+
1927
+ out = self.final_conv(x)
1928
+
1929
+ if num_preceding_frames > 0:
1930
+ out = out[:, :, num_preceding_frames:]
1931
+
1932
+ if num_succeeding_frames > 0:
1933
+ out = out[:, :, :-num_succeeding_frames]
1934
+
1935
+ return out
t5.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from typing import List
4
+ from transformers import T5Tokenizer, T5EncoderModel, T5Config
5
+ from einops import rearrange
6
+
7
+ transformers.logging.set_verbosity_error()
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+ def default(val, d):
13
+ if exists(val):
14
+ return val
15
+ return d() if callable(d) else d
16
+
17
+ # config
18
+
19
+ MAX_LENGTH = 256
20
+
21
+ DEFAULT_T5_NAME = 'google/t5-v1_1-base'
22
+
23
+ T5_CONFIGS = {}
24
+
25
+ # singleton globals
26
+
27
+ def get_tokenizer(name):
28
+ tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH)
29
+ return tokenizer
30
+
31
+ def get_model(name):
32
+ model = T5EncoderModel.from_pretrained(name)
33
+ return model
34
+
35
+ def get_model_and_tokenizer(name):
36
+ global T5_CONFIGS
37
+
38
+ if name not in T5_CONFIGS:
39
+ T5_CONFIGS[name] = dict()
40
+ if "model" not in T5_CONFIGS[name]:
41
+ T5_CONFIGS[name]["model"] = get_model(name)
42
+ if "tokenizer" not in T5_CONFIGS[name]:
43
+ T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
44
+
45
+ return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
46
+
47
+ def get_encoded_dim(name):
48
+ if name not in T5_CONFIGS:
49
+ # avoids loading the model if we only want to get the dim
50
+ config = T5Config.from_pretrained(name)
51
+ T5_CONFIGS[name] = dict(config=config)
52
+ elif "config" in T5_CONFIGS[name]:
53
+ config = T5_CONFIGS[name]["config"]
54
+ elif "model" in T5_CONFIGS[name]:
55
+ config = T5_CONFIGS[name]["model"].config
56
+ else:
57
+ assert False
58
+ return config.d_model
59
+
60
+ # encoding text
61
+
62
+ def t5_tokenize(
63
+ texts: List[str],
64
+ name = DEFAULT_T5_NAME
65
+ ):
66
+ t5, tokenizer = get_model_and_tokenizer(name)
67
+
68
+ if torch.cuda.is_available():
69
+ t5 = t5.cuda()
70
+
71
+ device = next(t5.parameters()).device
72
+
73
+ encoded = tokenizer.batch_encode_plus(
74
+ texts,
75
+ return_tensors = "pt",
76
+ padding = 'longest',
77
+ max_length = MAX_LENGTH,
78
+ truncation = True
79
+ )
80
+
81
+ input_ids = encoded.input_ids.to(device)
82
+ attn_mask = encoded.attention_mask.to(device)
83
+ return input_ids, attn_mask
84
+
85
+ def t5_encode_tokenized_text(
86
+ token_ids,
87
+ attn_mask = None,
88
+ pad_id = None,
89
+ name = DEFAULT_T5_NAME
90
+ ):
91
+ assert exists(attn_mask) or exists(pad_id)
92
+ t5, _ = get_model_and_tokenizer(name)
93
+
94
+ attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long())
95
+
96
+ t5.eval()
97
+
98
+ with torch.no_grad():
99
+ output = t5(input_ids = token_ids, attention_mask = attn_mask)
100
+ encoded_text = output.last_hidden_state.detach()
101
+
102
+ attn_mask = attn_mask.bool()
103
+
104
+ encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.) # just force all embeddings that is padding to be equal to 0.
105
+ return encoded_text
106
+
107
+ def t5_encode_text(
108
+ texts: List[str],
109
+ name = DEFAULT_T5_NAME,
110
+ return_attn_mask = False
111
+ ):
112
+ token_ids, attn_mask = t5_tokenize(texts, name = name)
113
+ encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name)
114
+
115
+ if return_attn_mask:
116
+ attn_mask = attn_mask.bool()
117
+ return encoded_text, attn_mask
118
+
119
+ return encoded_text
trainer.py ADDED
@@ -0,0 +1,992 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import copy
4
+ from pathlib import Path
5
+ from math import ceil
6
+ from contextlib import contextmanager, nullcontext
7
+ from functools import partial, wraps
8
+ from collections.abc import Iterable
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import random_split, DataLoader
14
+ from torch.optim import Adam
15
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
16
+ from torch.cuda.amp import autocast, GradScaler
17
+
18
+ import pytorch_warmup as warmup
19
+
20
+ from imagen_pytorch.imagen_pytorch import Imagen, NullUnet
21
+ from imagen_pytorch.elucidated_imagen import ElucidatedImagen
22
+ from imagen_pytorch.data import cycle
23
+
24
+ from imagen_pytorch.version import __version__
25
+ from packaging import version
26
+
27
+ import numpy as np
28
+
29
+ from ema_pytorch import EMA
30
+
31
+ from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs
32
+
33
+ from fsspec.core import url_to_fs
34
+ from fsspec.implementations.local import LocalFileSystem
35
+
36
+ # helper functions
37
+
38
+ def exists(val):
39
+ return val is not None
40
+
41
+ def default(val, d):
42
+ if exists(val):
43
+ return val
44
+ return d() if callable(d) else d
45
+
46
+ def cast_tuple(val, length = 1):
47
+ if isinstance(val, list):
48
+ val = tuple(val)
49
+
50
+ return val if isinstance(val, tuple) else ((val,) * length)
51
+
52
+ def find_first(fn, arr):
53
+ for ind, el in enumerate(arr):
54
+ if fn(el):
55
+ return ind
56
+ return -1
57
+
58
+ def pick_and_pop(keys, d):
59
+ values = list(map(lambda key: d.pop(key), keys))
60
+ return dict(zip(keys, values))
61
+
62
+ def group_dict_by_key(cond, d):
63
+ return_val = [dict(),dict()]
64
+ for key in d.keys():
65
+ match = bool(cond(key))
66
+ ind = int(not match)
67
+ return_val[ind][key] = d[key]
68
+ return (*return_val,)
69
+
70
+ def string_begins_with(prefix, str):
71
+ return str.startswith(prefix)
72
+
73
+ def group_by_key_prefix(prefix, d):
74
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
75
+
76
+ def groupby_prefix_and_trim(prefix, d):
77
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
78
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
79
+ return kwargs_without_prefix, kwargs
80
+
81
+ def num_to_groups(num, divisor):
82
+ groups = num // divisor
83
+ remainder = num % divisor
84
+ arr = [divisor] * groups
85
+ if remainder > 0:
86
+ arr.append(remainder)
87
+ return arr
88
+
89
+ # url to fs, bucket, path - for checkpointing to cloud
90
+
91
+ def url_to_bucket(url):
92
+ if '://' not in url:
93
+ return url
94
+
95
+ _, suffix = url.split('://')
96
+
97
+ if prefix in {'gs', 's3'}:
98
+ return suffix.split('/')[0]
99
+ else:
100
+ raise ValueError(f'storage type prefix "{prefix}" is not supported yet')
101
+
102
+ # decorators
103
+
104
+ def eval_decorator(fn):
105
+ def inner(model, *args, **kwargs):
106
+ was_training = model.training
107
+ model.eval()
108
+ out = fn(model, *args, **kwargs)
109
+ model.train(was_training)
110
+ return out
111
+ return inner
112
+
113
+ def cast_torch_tensor(fn, cast_fp16 = False):
114
+ @wraps(fn)
115
+ def inner(model, *args, **kwargs):
116
+ device = kwargs.pop('_device', model.device)
117
+ cast_device = kwargs.pop('_cast_device', True)
118
+
119
+ should_cast_fp16 = cast_fp16 and model.cast_half_at_training
120
+
121
+ kwargs_keys = kwargs.keys()
122
+ all_args = (*args, *kwargs.values())
123
+ split_kwargs_index = len(all_args) - len(kwargs_keys)
124
+ all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
125
+
126
+ if cast_device:
127
+ all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
128
+
129
+ if should_cast_fp16:
130
+ all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args))
131
+
132
+ args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
133
+ kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
134
+
135
+ out = fn(model, *args, **kwargs)
136
+ return out
137
+ return inner
138
+
139
+ # gradient accumulation functions
140
+
141
+ def split_iterable(it, split_size):
142
+ accum = []
143
+ for ind in range(ceil(len(it) / split_size)):
144
+ start_index = ind * split_size
145
+ accum.append(it[start_index: (start_index + split_size)])
146
+ return accum
147
+
148
+ def split(t, split_size = None):
149
+ if not exists(split_size):
150
+ return t
151
+
152
+ if isinstance(t, torch.Tensor):
153
+ return t.split(split_size, dim = 0)
154
+
155
+ if isinstance(t, Iterable):
156
+ return split_iterable(t, split_size)
157
+
158
+ return TypeError
159
+
160
+ def find_first(cond, arr):
161
+ for el in arr:
162
+ if cond(el):
163
+ return el
164
+ return None
165
+
166
+ def split_args_and_kwargs(*args, split_size = None, **kwargs):
167
+ all_args = (*args, *kwargs.values())
168
+ len_all_args = len(all_args)
169
+ first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
170
+ assert exists(first_tensor)
171
+
172
+ batch_size = len(first_tensor)
173
+ split_size = default(split_size, batch_size)
174
+ num_chunks = ceil(batch_size / split_size)
175
+
176
+ dict_len = len(kwargs)
177
+ dict_keys = kwargs.keys()
178
+ split_kwargs_index = len_all_args - dict_len
179
+
180
+ split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
181
+ chunk_sizes = num_to_groups(batch_size, split_size)
182
+
183
+ for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
184
+ chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
185
+ chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
186
+ chunk_size_frac = chunk_size / batch_size
187
+ yield chunk_size_frac, (chunked_args, chunked_kwargs)
188
+
189
+ # imagen trainer
190
+
191
+ def imagen_sample_in_chunks(fn):
192
+ @wraps(fn)
193
+ def inner(self, *args, max_batch_size = None, **kwargs):
194
+ if not exists(max_batch_size):
195
+ return fn(self, *args, **kwargs)
196
+
197
+ if self.imagen.unconditional:
198
+ batch_size = kwargs.get('batch_size')
199
+ batch_sizes = num_to_groups(batch_size, max_batch_size)
200
+ outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
201
+ else:
202
+ outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
203
+
204
+ if isinstance(outputs[0], torch.Tensor):
205
+ return torch.cat(outputs, dim = 0)
206
+
207
+ return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs))))
208
+
209
+ return inner
210
+
211
+
212
+ def restore_parts(state_dict_target, state_dict_from):
213
+ for name, param in state_dict_from.items():
214
+
215
+ if name not in state_dict_target:
216
+ continue
217
+
218
+ if param.size() == state_dict_target[name].size():
219
+ state_dict_target[name].copy_(param)
220
+ else:
221
+ print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}")
222
+
223
+ return state_dict_target
224
+
225
+
226
+ class ImagenTrainer(nn.Module):
227
+ locked = False
228
+
229
+ def __init__(
230
+ self,
231
+ imagen = None,
232
+ imagen_checkpoint_path = None,
233
+ use_ema = True,
234
+ lr = 1e-4,
235
+ eps = 1e-8,
236
+ beta1 = 0.9,
237
+ beta2 = 0.99,
238
+ max_grad_norm = None,
239
+ group_wd_params = True,
240
+ warmup_steps = None,
241
+ cosine_decay_max_steps = None,
242
+ only_train_unet_number = None,
243
+ fp16 = False,
244
+ precision = None,
245
+ split_batches = True,
246
+ dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'),
247
+ verbose = True,
248
+ split_valid_fraction = 0.025,
249
+ split_valid_from_train = False,
250
+ split_random_seed = 42,
251
+ checkpoint_path = None,
252
+ checkpoint_every = None,
253
+ checkpoint_fs = None,
254
+ fs_kwargs: dict = None,
255
+ max_checkpoints_keep = 20,
256
+ **kwargs
257
+ ):
258
+ super().__init__()
259
+ assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)'
260
+ assert exists(imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config'
261
+
262
+ # determine filesystem, using fsspec, for saving to local filesystem or cloud
263
+
264
+ self.fs = checkpoint_fs
265
+
266
+ if not exists(self.fs):
267
+ fs_kwargs = default(fs_kwargs, {})
268
+ self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs)
269
+
270
+ assert isinstance(imagen, (Imagen, ElucidatedImagen))
271
+ ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
272
+
273
+ # elucidated or not
274
+
275
+ self.is_elucidated = isinstance(imagen, ElucidatedImagen)
276
+
277
+ # create accelerator instance
278
+
279
+ accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs)
280
+
281
+ assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator'
282
+ accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no')
283
+
284
+ self.accelerator = Accelerator(**{
285
+ 'split_batches': split_batches,
286
+ 'mixed_precision': accelerator_mixed_precision,
287
+ 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)]
288
+ , **accelerate_kwargs})
289
+
290
+ ImagenTrainer.locked = self.is_distributed
291
+
292
+ # cast data to fp16 at training time if needed
293
+
294
+ self.cast_half_at_training = accelerator_mixed_precision == 'fp16'
295
+
296
+ # grad scaler must be managed outside of accelerator
297
+
298
+ grad_scaler_enabled = fp16
299
+
300
+ # imagen, unets and ema unets
301
+
302
+ self.imagen = imagen
303
+ self.num_unets = len(self.imagen.unets)
304
+
305
+ self.use_ema = use_ema and self.is_main
306
+ self.ema_unets = nn.ModuleList([])
307
+
308
+ # keep track of what unet is being trained on
309
+ # only going to allow 1 unet training at a time
310
+
311
+ self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on
312
+
313
+ # data related functions
314
+
315
+ self.train_dl_iter = None
316
+ self.train_dl = None
317
+
318
+ self.valid_dl_iter = None
319
+ self.valid_dl = None
320
+
321
+ self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names
322
+
323
+ # auto splitting validation from training, if dataset is passed in
324
+
325
+ self.split_valid_from_train = split_valid_from_train
326
+
327
+ assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1'
328
+ self.split_valid_fraction = split_valid_fraction
329
+ self.split_random_seed = split_random_seed
330
+
331
+ # be able to finely customize learning rate, weight decay
332
+ # per unet
333
+
334
+ lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps))
335
+
336
+ for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)):
337
+
338
+ optimizer = Adam(
339
+ unet.parameters(),
340
+ lr = unet_lr,
341
+ eps = unet_eps,
342
+ betas = (beta1, beta2),
343
+ **kwargs
344
+ )
345
+
346
+ if self.use_ema:
347
+ self.ema_unets.append(EMA(unet, **ema_kwargs))
348
+
349
+ scaler = GradScaler(enabled = grad_scaler_enabled)
350
+
351
+ scheduler = warmup_scheduler = None
352
+
353
+ if exists(unet_cosine_decay_max_steps):
354
+ scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
355
+
356
+ if exists(unet_warmup_steps):
357
+ warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps)
358
+
359
+ if not exists(scheduler):
360
+ scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
361
+
362
+ # set on object
363
+
364
+ setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
365
+ setattr(self, f'scaler{ind}', scaler)
366
+ setattr(self, f'scheduler{ind}', scheduler)
367
+ setattr(self, f'warmup{ind}', warmup_scheduler)
368
+
369
+ # gradient clipping if needed
370
+
371
+ self.max_grad_norm = max_grad_norm
372
+
373
+ # step tracker and misc
374
+
375
+ self.register_buffer('steps', torch.tensor([0] * self.num_unets))
376
+
377
+ self.verbose = verbose
378
+
379
+ # automatic set devices based on what accelerator decided
380
+
381
+ self.imagen.to(self.device)
382
+ self.to(self.device)
383
+
384
+ # checkpointing
385
+
386
+ assert not (exists(checkpoint_path) ^ exists(checkpoint_every))
387
+ self.checkpoint_path = checkpoint_path
388
+ self.checkpoint_every = checkpoint_every
389
+ self.max_checkpoints_keep = max_checkpoints_keep
390
+
391
+ self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main
392
+
393
+ if exists(checkpoint_path) and self.can_checkpoint:
394
+ bucket = url_to_bucket(checkpoint_path)
395
+
396
+ if not self.fs.exists(bucket):
397
+ self.fs.mkdir(bucket)
398
+
399
+ self.load_from_checkpoint_folder()
400
+
401
+ # only allowing training for unet
402
+
403
+ self.only_train_unet_number = only_train_unet_number
404
+ self.prepared = False
405
+
406
+
407
+ def prepare(self):
408
+ assert not self.prepared, f'The trainer is allready prepared'
409
+ self.validate_and_set_unet_being_trained(self.only_train_unet_number)
410
+ self.prepared = True
411
+ # computed values
412
+
413
+ @property
414
+ def device(self):
415
+ return self.accelerator.device
416
+
417
+ @property
418
+ def is_distributed(self):
419
+ return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
420
+
421
+ @property
422
+ def is_main(self):
423
+ return self.accelerator.is_main_process
424
+
425
+ @property
426
+ def is_local_main(self):
427
+ return self.accelerator.is_local_main_process
428
+
429
+ @property
430
+ def unwrapped_unet(self):
431
+ return self.accelerator.unwrap_model(self.unet_being_trained)
432
+
433
+ # optimizer helper functions
434
+
435
+ def get_lr(self, unet_number):
436
+ self.validate_unet_number(unet_number)
437
+ unet_index = unet_number - 1
438
+
439
+ optim = getattr(self, f'optim{unet_index}')
440
+
441
+ return optim.param_groups[0]['lr']
442
+
443
+ # function for allowing only one unet from being trained at a time
444
+
445
+ def validate_and_set_unet_being_trained(self, unet_number = None):
446
+ if exists(unet_number):
447
+ self.validate_unet_number(unet_number)
448
+
449
+ assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'
450
+
451
+ self.only_train_unet_number = unet_number
452
+ self.imagen.only_train_unet_number = unet_number
453
+
454
+ if not exists(unet_number):
455
+ return
456
+
457
+ self.wrap_unet(unet_number)
458
+
459
+ def wrap_unet(self, unet_number):
460
+ if hasattr(self, 'one_unet_wrapped'):
461
+ return
462
+
463
+ unet = self.imagen.get_unet(unet_number)
464
+ unet_index = unet_number - 1
465
+
466
+ optimizer = getattr(self, f'optim{unet_index}')
467
+ scheduler = getattr(self, f'scheduler{unet_index}')
468
+
469
+ if self.train_dl:
470
+ self.unet_being_trained, self.train_dl, optimizer = self.accelerator.prepare(unet, self.train_dl, optimizer)
471
+ else:
472
+ self.unet_being_trained, optimizer = self.accelerator.prepare(unet, optimizer)
473
+
474
+ if exists(scheduler):
475
+ scheduler = self.accelerator.prepare(scheduler)
476
+
477
+ setattr(self, f'optim{unet_index}', optimizer)
478
+ setattr(self, f'scheduler{unet_index}', scheduler)
479
+
480
+ self.one_unet_wrapped = True
481
+
482
+ # hacking accelerator due to not having separate gradscaler per optimizer
483
+
484
+ def set_accelerator_scaler(self, unet_number):
485
+ def patch_optimizer_step(accelerated_optimizer, method):
486
+ def patched_step(*args, **kwargs):
487
+ accelerated_optimizer._accelerate_step_called = True
488
+ return method(*args, **kwargs)
489
+ return patched_step
490
+
491
+ unet_number = self.validate_unet_number(unet_number)
492
+ scaler = getattr(self, f'scaler{unet_number - 1}')
493
+
494
+ self.accelerator.scaler = scaler
495
+ for optimizer in self.accelerator._optimizers:
496
+ optimizer.scaler = scaler
497
+ optimizer._accelerate_step_called = False
498
+ optimizer._optimizer_original_step_method = optimizer.optimizer.step
499
+ optimizer._optimizer_patched_step_method = patch_optimizer_step(optimizer, optimizer.optimizer.step)
500
+
501
+ # helper print
502
+
503
+ def print(self, msg):
504
+ if not self.is_main:
505
+ return
506
+
507
+ if not self.verbose:
508
+ return
509
+
510
+ return self.accelerator.print(msg)
511
+
512
+ # validating the unet number
513
+
514
+ def validate_unet_number(self, unet_number = None):
515
+ if self.num_unets == 1:
516
+ unet_number = default(unet_number, 1)
517
+
518
+ assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
519
+ return unet_number
520
+
521
+ # number of training steps taken
522
+
523
+ def num_steps_taken(self, unet_number = None):
524
+ if self.num_unets == 1:
525
+ unet_number = default(unet_number, 1)
526
+
527
+ return self.steps[unet_number - 1].item()
528
+
529
+ def print_untrained_unets(self):
530
+ print_final_error = False
531
+
532
+ for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
533
+ if steps > 0 or isinstance(unet, NullUnet):
534
+ continue
535
+
536
+ self.print(f'unet {ind + 1} has not been trained')
537
+ print_final_error = True
538
+
539
+ if print_final_error:
540
+ self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')
541
+
542
+ # data related functions
543
+
544
+ def add_train_dataloader(self, dl = None):
545
+ if not exists(dl):
546
+ return
547
+
548
+ assert not exists(self.train_dl), 'training dataloader was already added'
549
+ assert not self.prepared, f'You need to add the dataset before preperation'
550
+ self.train_dl = dl
551
+
552
+ def add_valid_dataloader(self, dl):
553
+ if not exists(dl):
554
+ return
555
+
556
+ assert not exists(self.valid_dl), 'validation dataloader was already added'
557
+ assert not self.prepared, f'You need to add the dataset before preperation'
558
+ self.valid_dl = dl
559
+
560
+ def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs):
561
+ if not exists(ds):
562
+ return
563
+
564
+ assert not exists(self.train_dl), 'training dataloader was already added'
565
+
566
+ valid_ds = None
567
+ if self.split_valid_from_train:
568
+ train_size = int((1 - self.split_valid_fraction) * len(ds))
569
+ valid_size = len(ds) - train_size
570
+
571
+ ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed))
572
+ self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples')
573
+
574
+ dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
575
+ self.add_train_dataloader(dl)
576
+
577
+ if not self.split_valid_from_train:
578
+ return
579
+
580
+ self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs)
581
+
582
+ def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
583
+ if not exists(ds):
584
+ return
585
+
586
+ assert not exists(self.valid_dl), 'validation dataloader was already added'
587
+
588
+ dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
589
+ self.add_valid_dataloader(dl)
590
+
591
+ def create_train_iter(self):
592
+ assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'
593
+
594
+ if exists(self.train_dl_iter):
595
+ return
596
+
597
+ self.train_dl_iter = cycle(self.train_dl)
598
+
599
+ def create_valid_iter(self):
600
+ assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'
601
+
602
+ if exists(self.valid_dl_iter):
603
+ return
604
+
605
+ self.valid_dl_iter = cycle(self.valid_dl)
606
+
607
+ def train_step(self, *, unet_number = None, **kwargs):
608
+ if not self.prepared:
609
+ self.prepare()
610
+ self.create_train_iter()
611
+
612
+ kwargs = {'unet_number': unet_number, **kwargs}
613
+ loss = self.step_with_dl_iter(self.train_dl_iter, **kwargs)
614
+ self.update(unet_number = unet_number)
615
+ return loss
616
+
617
+ @torch.no_grad()
618
+ @eval_decorator
619
+ def valid_step(self, **kwargs):
620
+ if not self.prepared:
621
+ self.prepare()
622
+ self.create_valid_iter()
623
+ context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext
624
+ with context():
625
+ loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
626
+ return loss
627
+
628
+ def step_with_dl_iter(self, dl_iter, **kwargs):
629
+ dl_tuple_output = cast_tuple(next(dl_iter))
630
+ model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
631
+ loss = self.forward(**{**kwargs, **model_input})
632
+ return loss
633
+
634
+ # checkpointing functions
635
+
636
+ @property
637
+ def all_checkpoints_sorted(self):
638
+ glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
639
+ checkpoints = self.fs.glob(glob_pattern)
640
+ sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True)
641
+ return sorted_checkpoints
642
+
643
+ def load_from_checkpoint_folder(self, last_total_steps = -1):
644
+ if last_total_steps != -1:
645
+ filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
646
+ self.load(filepath)
647
+ return
648
+
649
+ sorted_checkpoints = self.all_checkpoints_sorted
650
+
651
+ if len(sorted_checkpoints) == 0:
652
+ self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
653
+ return
654
+
655
+ last_checkpoint = sorted_checkpoints[0]
656
+ self.load(last_checkpoint)
657
+
658
+ def save_to_checkpoint_folder(self):
659
+ self.accelerator.wait_for_everyone()
660
+
661
+ if not self.can_checkpoint:
662
+ return
663
+
664
+ total_steps = int(self.steps.sum().item())
665
+ filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')
666
+
667
+ self.save(filepath)
668
+
669
+ if self.max_checkpoints_keep <= 0:
670
+ return
671
+
672
+ sorted_checkpoints = self.all_checkpoints_sorted
673
+ checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]
674
+
675
+ for checkpoint in checkpoints_to_discard:
676
+ self.fs.rm(checkpoint)
677
+
678
+ # saving and loading functions
679
+
680
+ def save(
681
+ self,
682
+ path,
683
+ overwrite = True,
684
+ without_optim_and_sched = False,
685
+ **kwargs
686
+ ):
687
+ self.accelerator.wait_for_everyone()
688
+
689
+ if not self.can_checkpoint:
690
+ return
691
+
692
+ fs = self.fs
693
+
694
+ assert not (fs.exists(path) and not overwrite)
695
+
696
+ self.reset_ema_unets_all_one_device()
697
+
698
+ save_obj = dict(
699
+ model = self.imagen.state_dict(),
700
+ version = __version__,
701
+ steps = self.steps.cpu(),
702
+ **kwargs
703
+ )
704
+
705
+ save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()
706
+
707
+ for ind in save_optim_and_sched_iter:
708
+ scaler_key = f'scaler{ind}'
709
+ optimizer_key = f'optim{ind}'
710
+ scheduler_key = f'scheduler{ind}'
711
+ warmup_scheduler_key = f'warmup{ind}'
712
+
713
+ scaler = getattr(self, scaler_key)
714
+ optimizer = getattr(self, optimizer_key)
715
+ scheduler = getattr(self, scheduler_key)
716
+ warmup_scheduler = getattr(self, warmup_scheduler_key)
717
+
718
+ if exists(scheduler):
719
+ save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}
720
+
721
+ if exists(warmup_scheduler):
722
+ save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}
723
+
724
+ save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
725
+
726
+ if self.use_ema:
727
+ save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
728
+
729
+ # determine if imagen config is available
730
+
731
+ if hasattr(self.imagen, '_config'):
732
+ self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"<prompt>\""')
733
+
734
+ save_obj = {
735
+ **save_obj,
736
+ 'imagen_type': 'elucidated' if self.is_elucidated else 'original',
737
+ 'imagen_params': self.imagen._config
738
+ }
739
+
740
+ #save to path
741
+
742
+ with fs.open(path, 'wb') as f:
743
+ torch.save(save_obj, f)
744
+
745
+ self.print(f'checkpoint saved to {path}')
746
+
747
+ def load(self, path, only_model = False, strict = True, noop_if_not_exist = False):
748
+ fs = self.fs
749
+
750
+ if noop_if_not_exist and not fs.exists(path):
751
+ self.print(f'trainer checkpoint not found at {str(path)}')
752
+ return
753
+
754
+ assert fs.exists(path), f'{path} does not exist'
755
+
756
+ self.reset_ema_unets_all_one_device()
757
+
758
+ # to avoid extra GPU memory usage in main process when using Accelerate
759
+
760
+ with fs.open(path) as f:
761
+ loaded_obj = torch.load(f, map_location='cpu')
762
+
763
+ if version.parse(__version__) != version.parse(loaded_obj['version']):
764
+ self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')
765
+
766
+ try:
767
+ self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
768
+ except RuntimeError:
769
+ print("Failed loading state dict. Trying partial load")
770
+ self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
771
+ loaded_obj['model']))
772
+
773
+ if only_model:
774
+ return loaded_obj
775
+
776
+ self.steps.copy_(loaded_obj['steps'])
777
+
778
+ for ind in range(0, self.num_unets):
779
+ scaler_key = f'scaler{ind}'
780
+ optimizer_key = f'optim{ind}'
781
+ scheduler_key = f'scheduler{ind}'
782
+ warmup_scheduler_key = f'warmup{ind}'
783
+
784
+ scaler = getattr(self, scaler_key)
785
+ optimizer = getattr(self, optimizer_key)
786
+ scheduler = getattr(self, scheduler_key)
787
+ warmup_scheduler = getattr(self, warmup_scheduler_key)
788
+
789
+ if exists(scheduler) and scheduler_key in loaded_obj:
790
+ scheduler.load_state_dict(loaded_obj[scheduler_key])
791
+
792
+ if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
793
+ warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])
794
+
795
+ if exists(optimizer):
796
+ try:
797
+ optimizer.load_state_dict(loaded_obj[optimizer_key])
798
+ scaler.load_state_dict(loaded_obj[scaler_key])
799
+ except:
800
+ self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers')
801
+
802
+ if self.use_ema:
803
+ assert 'ema' in loaded_obj
804
+ try:
805
+ self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
806
+ except RuntimeError:
807
+ print("Failed loading state dict. Trying partial load")
808
+ self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
809
+ loaded_obj['ema']))
810
+
811
+ self.print(f'checkpoint loaded from {path}')
812
+ return loaded_obj
813
+
814
+ # managing ema unets and their devices
815
+
816
+ @property
817
+ def unets(self):
818
+ return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
819
+
820
+ def get_ema_unet(self, unet_number = None):
821
+ if not self.use_ema:
822
+ return
823
+
824
+ unet_number = self.validate_unet_number(unet_number)
825
+ index = unet_number - 1
826
+
827
+ if isinstance(self.unets, nn.ModuleList):
828
+ unets_list = [unet for unet in self.ema_unets]
829
+ delattr(self, 'ema_unets')
830
+ self.ema_unets = unets_list
831
+
832
+ if index != self.ema_unet_being_trained_index:
833
+ for unet_index, unet in enumerate(self.ema_unets):
834
+ unet.to(self.device if unet_index == index else 'cpu')
835
+
836
+ self.ema_unet_being_trained_index = index
837
+ return self.ema_unets[index]
838
+
839
+ def reset_ema_unets_all_one_device(self, device = None):
840
+ if not self.use_ema:
841
+ return
842
+
843
+ device = default(device, self.device)
844
+ self.ema_unets = nn.ModuleList([*self.ema_unets])
845
+ self.ema_unets.to(device)
846
+
847
+ self.ema_unet_being_trained_index = -1
848
+
849
+ @torch.no_grad()
850
+ @contextmanager
851
+ def use_ema_unets(self):
852
+ if not self.use_ema:
853
+ output = yield
854
+ return output
855
+
856
+ self.reset_ema_unets_all_one_device()
857
+ self.imagen.reset_unets_all_one_device()
858
+
859
+ self.unets.eval()
860
+
861
+ trainable_unets = self.imagen.unets
862
+ self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling
863
+
864
+ output = yield
865
+
866
+ self.imagen.unets = trainable_unets # restore original training unets
867
+
868
+ # cast the ema_model unets back to original device
869
+ for ema in self.ema_unets:
870
+ ema.restore_ema_model_device()
871
+
872
+ return output
873
+
874
+ def print_unet_devices(self):
875
+ self.print('unet devices:')
876
+ for i, unet in enumerate(self.imagen.unets):
877
+ device = next(unet.parameters()).device
878
+ self.print(f'\tunet {i}: {device}')
879
+
880
+ if not self.use_ema:
881
+ return
882
+
883
+ self.print('\nema unet devices:')
884
+ for i, ema_unet in enumerate(self.ema_unets):
885
+ device = next(ema_unet.parameters()).device
886
+ self.print(f'\tema unet {i}: {device}')
887
+
888
+ # overriding state dict functions
889
+
890
+ def state_dict(self, *args, **kwargs):
891
+ self.reset_ema_unets_all_one_device()
892
+ return super().state_dict(*args, **kwargs)
893
+
894
+ def load_state_dict(self, *args, **kwargs):
895
+ self.reset_ema_unets_all_one_device()
896
+ return super().load_state_dict(*args, **kwargs)
897
+
898
+ # encoding text functions
899
+
900
+ def encode_text(self, text, **kwargs):
901
+ return self.imagen.encode_text(text, **kwargs)
902
+
903
+ # forwarding functions and gradient step updates
904
+
905
+ def update(self, unet_number = None):
906
+ unet_number = self.validate_unet_number(unet_number)
907
+ self.validate_and_set_unet_being_trained(unet_number)
908
+ self.set_accelerator_scaler(unet_number)
909
+
910
+ index = unet_number - 1
911
+ unet = self.unet_being_trained
912
+
913
+ optimizer = getattr(self, f'optim{index}')
914
+ scaler = getattr(self, f'scaler{index}')
915
+ scheduler = getattr(self, f'scheduler{index}')
916
+ warmup_scheduler = getattr(self, f'warmup{index}')
917
+
918
+ # set the grad scaler on the accelerator, since we are managing one per u-net
919
+
920
+ if exists(self.max_grad_norm):
921
+ self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
922
+
923
+ optimizer.step()
924
+ optimizer.zero_grad()
925
+
926
+ if self.use_ema:
927
+ ema_unet = self.get_ema_unet(unet_number)
928
+ ema_unet.update()
929
+
930
+ # scheduler, if needed
931
+
932
+ maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()
933
+
934
+ with maybe_warmup_context:
935
+ if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs
936
+ scheduler.step()
937
+
938
+ self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps))
939
+
940
+ if not exists(self.checkpoint_path):
941
+ return
942
+
943
+ total_steps = int(self.steps.sum().item())
944
+
945
+ if total_steps % self.checkpoint_every:
946
+ return
947
+
948
+ self.save_to_checkpoint_folder()
949
+
950
+ @torch.no_grad()
951
+ @cast_torch_tensor
952
+ @imagen_sample_in_chunks
953
+ def sample(self, *args, **kwargs):
954
+ context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets
955
+
956
+ self.print_untrained_unets()
957
+
958
+ if not self.is_main:
959
+ kwargs['use_tqdm'] = False
960
+
961
+ with context():
962
+ output = self.imagen.sample(*args, device = self.device, **kwargs)
963
+
964
+ return output
965
+
966
+ @partial(cast_torch_tensor, cast_fp16 = True)
967
+ def forward(
968
+ self,
969
+ *args,
970
+ unet_number = None,
971
+ max_batch_size = None,
972
+ **kwargs
973
+ ):
974
+ unet_number = self.validate_unet_number(unet_number)
975
+ self.validate_and_set_unet_being_trained(unet_number)
976
+ self.set_accelerator_scaler(unet_number)
977
+
978
+ assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'
979
+
980
+ total_loss = 0.
981
+
982
+ for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
983
+ with self.accelerator.autocast():
984
+ loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs)
985
+ loss = loss * chunk_size_frac
986
+
987
+ total_loss += loss.item()
988
+
989
+ if self.training:
990
+ self.accelerator.backward(loss)
991
+
992
+ return total_loss
utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from functools import reduce
4
+ from pathlib import Path
5
+
6
+ from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig
7
+ from ema_pytorch import EMA
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+ def safeget(dictionary, keys, default = None):
13
+ return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
14
+
15
+ def load_imagen_from_checkpoint(
16
+ checkpoint_path,
17
+ load_weights = True,
18
+ load_ema_if_available = False
19
+ ):
20
+ model_path = Path(checkpoint_path)
21
+ full_model_path = str(model_path.resolve())
22
+ assert model_path.exists(), f'checkpoint not found at {full_model_path}'
23
+ loaded = torch.load(str(model_path), map_location='cpu')
24
+
25
+ imagen_params = safeget(loaded, 'imagen_params')
26
+ imagen_type = safeget(loaded, 'imagen_type')
27
+
28
+ if imagen_type == 'original':
29
+ imagen_klass = ImagenConfig
30
+ elif imagen_type == 'elucidated':
31
+ imagen_klass = ElucidatedImagenConfig
32
+ else:
33
+ raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig')
34
+
35
+ assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint'
36
+
37
+ imagen = imagen_klass(**imagen_params).create()
38
+
39
+ if not load_weights:
40
+ return imagen
41
+
42
+ has_ema = 'ema' in loaded
43
+ should_load_ema = has_ema and load_ema_if_available
44
+
45
+ imagen.load_state_dict(loaded['model'])
46
+
47
+ if not should_load_ema:
48
+ print('loading non-EMA version of unets')
49
+ return imagen
50
+
51
+ ema_unets = nn.ModuleList([])
52
+ for unet in imagen.unets:
53
+ ema_unets.append(EMA(unet))
54
+
55
+ ema_unets.load_state_dict(loaded['ema'])
56
+
57
+ for unet, ema_unet in zip(imagen.unets, ema_unets):
58
+ unet.load_state_dict(ema_unet.ema_model.state_dict())
59
+
60
+ print('loaded EMA version of unets')
61
+ return imagen
version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '1.25.12'