import math import copy from random import random from beartype.typing import List, Union from beartype import beartype from tqdm.auto import tqdm from functools import partial, wraps from contextlib import contextmanager, nullcontext from collections import namedtuple from pathlib import Path import torch import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel from torch import nn, einsum from torch.cuda.amp import autocast from torch.special import expm1 import torchvision.transforms as T import kornia.augmentation as K from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME from imagen_pytorch.imagen_video import Unet3D, resize_video_to, scale_video_time # helper functions def exists(val): return val is not None def identity(t, *args, **kwargs): return t def divisible_by(numer, denom): return (numer % denom) == 0 def first(arr, d = None): if len(arr) == 0: return d return arr[0] def maybe(fn): @wraps(fn) def inner(x): if not exists(x): return x return fn(x) return inner def once(fn): called = False @wraps(fn) def inner(x): nonlocal called if called: return called = True return fn(x) return inner print_once = once(print) def default(val, d): if exists(val): return val return d() if callable(d) else d def cast_tuple(val, length = None): if isinstance(val, list): val = tuple(val) output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) if exists(length): assert len(output) == length return output def compact(input_dict): return {key: value for key, value in input_dict.items() if exists(value)} def maybe_transform_dict_key(input_dict, key, fn): if key not in input_dict: return input_dict copied_dict = input_dict.copy() copied_dict[key] = fn(copied_dict[key]) return copied_dict def cast_uint8_images_to_float(images): if not images.dtype == torch.uint8: return images return images / 255 def module_device(module): return next(module.parameters()).device def zero_init_(m): nn.init.zeros_(m.weight) if exists(m.bias): nn.init.zeros_(m.bias) def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training model.eval() out = fn(model, *args, **kwargs) model.train(was_training) return out return inner def pad_tuple_to_length(t, length, fillvalue = None): remain_length = length - len(t) if remain_length <= 0: return t return (*t, *((fillvalue,) * remain_length)) # helper classes class Identity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, x, *args, **kwargs): return x # tensor helpers def log(t, eps: float = 1e-12): return torch.log(t.clamp(min = eps)) def l2norm(t): return F.normalize(t, dim = -1) def right_pad_dims_to(x, t): padding_dims = x.ndim - t.ndim if padding_dims <= 0: return t return t.view(*t.shape, *((1,) * padding_dims)) def masked_mean(t, *, dim, mask = None): if not exists(mask): return t.mean(dim = dim) denom = mask.sum(dim = dim, keepdim = True) mask = rearrange(mask, 'b n -> b n 1') masked_t = t.masked_fill(~mask, 0.) return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) def resize_image_to( image, target_image_size, clamp_range = None, mode = 'nearest' ): orig_image_size = image.shape[-1] if orig_image_size == target_image_size: return image out = F.interpolate(image, target_image_size, mode = mode) if exists(clamp_range): out = out.clamp(*clamp_range) return out def calc_all_frame_dims( downsample_factors: List[int], frames ): if not exists(frames): return (tuple(),) * len(downsample_factors) all_frame_dims = [] for divisor in downsample_factors: assert divisible_by(frames, divisor) all_frame_dims.append((frames // divisor,)) return all_frame_dims def safe_get_tuple_index(tup, index, default = None): if len(tup) <= index: return default return tup[index] # image normalization functions # ddpms expect images to be in the range of -1 to 1 def normalize_neg_one_to_one(img): return img * 2 - 1 def unnormalize_zero_to_one(normed_img): return (normed_img + 1) * 0.5 # classifier free guidance functions def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device = device, dtype = torch.bool) elif prob == 0: return torch.zeros(shape, device = device, dtype = torch.bool) else: return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob # gaussian diffusion with continuous time helper functions and classes # large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py @torch.jit.script def beta_linear_log_snr(t): return -torch.log(expm1(1e-4 + 10 * (t ** 2))) @torch.jit.script def alpha_cosine_log_snr(t, s: float = 0.008): return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version def log_snr_to_alpha_sigma(log_snr): return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr)) class GaussianDiffusionContinuousTimes(nn.Module): def __init__(self, *, noise_schedule, timesteps = 1000): super().__init__() if noise_schedule == "linear": self.log_snr = beta_linear_log_snr elif noise_schedule == "cosine": self.log_snr = alpha_cosine_log_snr else: raise ValueError(f'invalid noise schedule {noise_schedule}') self.num_timesteps = timesteps def get_times(self, batch_size, noise_level, *, device): return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32) def sample_random_times(self, batch_size, *, device): return torch.zeros((batch_size,), device = device).float().uniform_(0, 1) def get_condition(self, times): return maybe(self.log_snr)(times) def get_sampling_timesteps(self, batch, *, device): times = torch.linspace(1., 0., self.num_timesteps + 1, device = device) times = repeat(times, 't -> b t', b = batch) times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0) times = times.unbind(dim = -1) return times def q_posterior(self, x_start, x_t, t, *, t_next = None): t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.)) """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """ log_snr = self.log_snr(t) log_snr_next = self.log_snr(t_next) log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next)) alpha, sigma = log_snr_to_alpha_sigma(log_snr) alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next) # c - as defined near eq 33 c = -expm1(log_snr - log_snr_next) posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start) # following (eq. 33) posterior_variance = (sigma_next ** 2) * c posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20) return posterior_mean, posterior_variance, posterior_log_variance_clipped def q_sample(self, x_start, t, noise = None): dtype = x_start.dtype if isinstance(t, float): batch = x_start.shape[0] t = torch.full((batch,), t, device = x_start.device, dtype = dtype) noise = default(noise, lambda: torch.randn_like(x_start)) log_snr = self.log_snr(t).type(dtype) log_snr_padded_dim = right_pad_dims_to(x_start, log_snr) alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) return alpha * x_start + sigma * noise, log_snr, alpha, sigma def q_sample_from_to(self, x_from, from_t, to_t, noise = None): shape, device, dtype = x_from.shape, x_from.device, x_from.dtype batch = shape[0] if isinstance(from_t, float): from_t = torch.full((batch,), from_t, device = device, dtype = dtype) if isinstance(to_t, float): to_t = torch.full((batch,), to_t, device = device, dtype = dtype) noise = default(noise, lambda: torch.randn_like(x_from)) log_snr = self.log_snr(from_t) log_snr_padded_dim = right_pad_dims_to(x_from, log_snr) alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) log_snr_to = self.log_snr(to_t) log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to) alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to) return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha def predict_start_from_v(self, x_t, t, v): log_snr = self.log_snr(t) log_snr = right_pad_dims_to(x_t, log_snr) alpha, sigma = log_snr_to_alpha_sigma(log_snr) return alpha * x_t - sigma * v def predict_start_from_noise(self, x_t, t, noise): log_snr = self.log_snr(t) log_snr = right_pad_dims_to(x_t, log_snr) alpha, sigma = log_snr_to_alpha_sigma(log_snr) return (x_t - sigma * noise) / alpha.clamp(min = 1e-8) # norms and residuals class LayerNorm(nn.Module): def __init__(self, feats, stable = False, dim = -1): super().__init__() self.stable = stable self.dim = dim self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1)))) def forward(self, x): dtype, dim = x.dtype, self.dim if self.stable: x = x / x.amax(dim = dim, keepdim = True).detach() eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim = dim, unbiased = False, keepdim = True) mean = torch.mean(x, dim = dim, keepdim = True) return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype) ChanLayerNorm = partial(LayerNorm, dim = -3) class Always(): def __init__(self, val): self.val = val def __call__(self, *args, **kwargs): return self.val class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) + x class Parallel(nn.Module): def __init__(self, *fns): super().__init__() self.fns = nn.ModuleList(fns) def forward(self, x): outputs = [fn(x) for fn in self.fns] return sum(outputs) # attention pooling class PerceiverAttention(nn.Module): def __init__( self, *, dim, dim_head = 64, heads = 8, scale = 8 ): super().__init__() self.scale = scale self.heads = heads inner_dim = dim_head * heads self.norm = nn.LayerNorm(dim) self.norm_latents = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), nn.LayerNorm(dim) ) def forward(self, x, latents, mask = None): x = self.norm(x) latents = self.norm_latents(latents) b, h = x.shape[0], self.heads q = self.to_q(latents) # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to kv_input = torch.cat((x, latents), dim = -2) k, v = self.to_kv(kv_input).chunk(2, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # qk rmsnorm q, k = map(l2norm, (q, k)) q = q * self.q_scale k = k * self.k_scale # similarities and masking sim = einsum('... i d, ... j d -> ... i j', q, k) * self.scale if exists(mask): max_neg_value = -torch.finfo(sim.dtype).max mask = F.pad(mask, (0, latents.shape[-2]), value = True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) # attention attn = sim.softmax(dim = -1, dtype = torch.float32) attn = attn.to(sim.dtype) out = einsum('... i j, ... j d -> ... i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)', h = h) return self.to_out(out) class PerceiverResampler(nn.Module): def __init__( self, *, dim, depth, dim_head = 64, heads = 8, num_latents = 64, num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence max_seq_len = 512, ff_mult = 4 ): super().__init__() self.pos_emb = nn.Embedding(max_seq_len, dim) self.latents = nn.Parameter(torch.randn(num_latents, dim)) self.to_latents_from_mean_pooled_seq = None if num_latents_mean_pooled > 0: self.to_latents_from_mean_pooled_seq = nn.Sequential( LayerNorm(dim), nn.Linear(dim, dim * num_latents_mean_pooled), Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled) ) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads), FeedForward(dim = dim, mult = ff_mult) ])) def forward(self, x, mask = None): n, device = x.shape[1], x.device pos_emb = self.pos_emb(torch.arange(n, device = device)) x_with_pos = x + pos_emb latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0]) if exists(self.to_latents_from_mean_pooled_seq): meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)) meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) latents = torch.cat((meanpooled_latents, latents), dim = -2) for attn, ff in self.layers: latents = attn(x_with_pos, latents, mask = mask) + latents latents = ff(latents) + latents return latents # attention class Attention(nn.Module): def __init__( self, dim, *, dim_head = 64, heads = 8, context_dim = None, scale = 8 ): super().__init__() self.scale = scale self.heads = heads inner_dim = dim_head * heads self.norm = LayerNorm(dim) self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), LayerNorm(dim) ) def forward(self, x, context = None, mask = None, attn_bias = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) # add text conditioning, if present if exists(context): assert exists(self.to_context) ck, cv = self.to_context(context).chunk(2, dim = -1) k = torch.cat((ck, k), dim = -2) v = torch.cat((cv, v), dim = -2) # qk rmsnorm q, k = map(l2norm, (q, k)) q = q * self.q_scale k = k * self.k_scale # calculate query / key similarities sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale # relative positional encoding (T5 style) if exists(attn_bias): sim = sim + attn_bias # masking max_neg_value = -torch.finfo(sim.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) # attention attn = sim.softmax(dim = -1, dtype = torch.float32) attn = attn.to(sim.dtype) # aggregate values out = einsum('b h i j, b j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) # decoder def Upsample(dim, dim_out = None): dim_out = default(dim_out, dim) return nn.Sequential( nn.Upsample(scale_factor = 2, mode = 'nearest'), nn.Conv2d(dim, dim_out, 3, padding = 1) ) class PixelShuffleUpsample(nn.Module): """ code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf """ def __init__(self, dim, dim_out = None): super().__init__() dim_out = default(dim_out, dim) conv = nn.Conv2d(dim, dim_out * 4, 1) self.net = nn.Sequential( conv, nn.SiLU(), nn.PixelShuffle(2) ) self.init_conv_(conv) def init_conv_(self, conv): o, i, h, w = conv.weight.shape conv_weight = torch.empty(o // 4, i, h, w) nn.init.kaiming_uniform_(conv_weight) conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') conv.weight.data.copy_(conv_weight) nn.init.zeros_(conv.bias.data) def forward(self, x): return self.net(x) def Downsample(dim, dim_out = None): # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample # named SP-conv in the paper, but basically a pixel unshuffle dim_out = default(dim_out, dim) return nn.Sequential( Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2), nn.Conv2d(dim * 4, dim_out, 1) ) class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') return torch.cat((emb.sin(), emb.cos()), dim = -1) class LearnedSinusoidalPosEmb(nn.Module): """ following @crowsonkb 's lead with learned sinusoidal pos emb """ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ def __init__(self, dim): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, x): x = rearrange(x, 'b -> b 1') freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) fouriered = torch.cat((x, fouriered), dim = -1) return fouriered class Block(nn.Module): def __init__( self, dim, dim_out, groups = 8, norm = True ): super().__init__() self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() self.activation = nn.SiLU() self.project = nn.Conv2d(dim, dim_out, 3, padding = 1) def forward(self, x, scale_shift = None): x = self.groupnorm(x) if exists(scale_shift): scale, shift = scale_shift x = x * (scale + 1) + shift x = self.activation(x) return self.project(x) class ResnetBlock(nn.Module): def __init__( self, dim, dim_out, *, cond_dim = None, time_cond_dim = None, groups = 8, linear_attn = False, use_gca = False, squeeze_excite = False, **attn_kwargs ): super().__init__() self.time_mlp = None if exists(time_cond_dim): self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_cond_dim, dim_out * 2) ) self.cross_attn = None if exists(cond_dim): attn_klass = CrossAttention if not linear_attn else LinearCrossAttention self.cross_attn = attn_klass( dim = dim_out, context_dim = cond_dim, **attn_kwargs ) self.block1 = Block(dim, dim_out, groups = groups) self.block2 = Block(dim_out, dim_out, groups = groups) self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else Identity() def forward(self, x, time_emb = None, cond = None): scale_shift = None if exists(self.time_mlp) and exists(time_emb): time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1 1') scale_shift = time_emb.chunk(2, dim = 1) h = self.block1(x) if exists(self.cross_attn): assert exists(cond) h = rearrange(h, 'b c h w -> b h w c') h, ps = pack([h], 'b * c') h = self.cross_attn(h, context = cond) + h h, = unpack(h, ps, 'b * c') h = rearrange(h, 'b h w c -> b c h w') h = self.block2(h, scale_shift = scale_shift) h = h * self.gca(h) return h + self.res_conv(x) class CrossAttention(nn.Module): def __init__( self, dim, *, context_dim = None, dim_head = 64, heads = 8, norm_context = False, scale = 8 ): super().__init__() self.scale = scale self.heads = heads inner_dim = dim_head * heads context_dim = default(context_dim, dim) self.norm = LayerNorm(dim) self.norm_context = LayerNorm(context_dim) if norm_context else Identity() self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), LayerNorm(dim) ) def forward(self, x, context, mask = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) # cosine sim attention q, k = map(l2norm, (q, k)) q = q * self.q_scale k = k * self.k_scale # similarities sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # masking max_neg_value = -torch.finfo(sim.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) attn = sim.softmax(dim = -1, dtype = torch.float32) attn = attn.to(sim.dtype) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class LinearCrossAttention(CrossAttention): def forward(self, x, context, mask = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v)) # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) # masking max_neg_value = -torch.finfo(x.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b n -> b n 1') k = k.masked_fill(~mask, max_neg_value) v = v.masked_fill(~mask, 0.) # linear attention q = q.softmax(dim = -1) k = k.softmax(dim = -2) q = q * self.scale context = einsum('b n d, b n e -> b d e', k, v) out = einsum('b n d, b d e -> b n e', q, context) out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads) return self.to_out(out) class LinearAttention(nn.Module): def __init__( self, dim, dim_head = 32, heads = 8, dropout = 0.05, context_dim = None, **kwargs ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads inner_dim = dim_head * heads self.norm = ChanLayerNorm(dim) self.nonlin = nn.SiLU() self.to_q = nn.Sequential( nn.Dropout(dropout), nn.Conv2d(dim, inner_dim, 1, bias = False), nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) ) self.to_k = nn.Sequential( nn.Dropout(dropout), nn.Conv2d(dim, inner_dim, 1, bias = False), nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) ) self.to_v = nn.Sequential( nn.Dropout(dropout), nn.Conv2d(dim, inner_dim, 1, bias = False), nn.Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) ) self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None self.to_out = nn.Sequential( nn.Conv2d(inner_dim, dim, 1, bias = False), ChanLayerNorm(dim) ) def forward(self, fmap, context = None): h, x, y = self.heads, *fmap.shape[-2:] fmap = self.norm(fmap) q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v)) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v)) if exists(context): assert exists(self.to_context) ck, cv = self.to_context(context).chunk(2, dim = -1) ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv)) k = torch.cat((k, ck), dim = -2) v = torch.cat((v, cv), dim = -2) q = q.softmax(dim = -1) k = k.softmax(dim = -2) q = q * self.scale context = einsum('b n d, b n e -> b d e', k, v) out = einsum('b n d, b d e -> b n e', q, context) out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) out = self.nonlin(out) return self.to_out(out) class GlobalContext(nn.Module): """ basically a superior form of squeeze-excitation that is attention-esque """ def __init__( self, *, dim_in, dim_out ): super().__init__() self.to_k = nn.Conv2d(dim_in, 1, 1) hidden_dim = max(3, dim_out // 2) self.net = nn.Sequential( nn.Conv2d(dim_in, hidden_dim, 1), nn.SiLU(), nn.Conv2d(hidden_dim, dim_out, 1), nn.Sigmoid() ) def forward(self, x): context = self.to_k(x) x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context)) out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x) out = rearrange(out, '... -> ... 1') return self.net(out) def FeedForward(dim, mult = 2): hidden_dim = int(dim * mult) return nn.Sequential( LayerNorm(dim), nn.Linear(dim, hidden_dim, bias = False), nn.GELU(), LayerNorm(hidden_dim), nn.Linear(hidden_dim, dim, bias = False) ) def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width hidden_dim = int(dim * mult) return nn.Sequential( ChanLayerNorm(dim), nn.Conv2d(dim, hidden_dim, 1, bias = False), nn.GELU(), ChanLayerNorm(hidden_dim), nn.Conv2d(hidden_dim, dim, 1, bias = False) ) class TransformerBlock(nn.Module): def __init__( self, dim, *, depth = 1, heads = 8, dim_head = 32, ff_mult = 2, context_dim = None ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), FeedForward(dim = dim, mult = ff_mult) ])) def forward(self, x, context = None): x = rearrange(x, 'b c h w -> b h w c') x, ps = pack([x], 'b * c') for attn, ff in self.layers: x = attn(x, context = context) + x x = ff(x) + x x, = unpack(x, ps, 'b * c') x = rearrange(x, 'b h w c -> b c h w') return x class LinearAttentionTransformerBlock(nn.Module): def __init__( self, dim, *, depth = 1, heads = 8, dim_head = 32, ff_mult = 2, context_dim = None, **kwargs ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), ChanFeedForward(dim = dim, mult = ff_mult) ])) def forward(self, x, context = None): for attn, ff in self.layers: x = attn(x, context = context) + x x = ff(x) + x return x class CrossEmbedLayer(nn.Module): def __init__( self, dim_in, kernel_sizes, dim_out = None, stride = 2 ): super().__init__() assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) dim_out = default(dim_out, dim_in) kernel_sizes = sorted(kernel_sizes) num_scales = len(kernel_sizes) # calculate the dimension at each scale dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] dim_scales = [*dim_scales, dim_out - sum(dim_scales)] self.convs = nn.ModuleList([]) for kernel, dim_scale in zip(kernel_sizes, dim_scales): self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2)) def forward(self, x): fmaps = tuple(map(lambda conv: conv(x), self.convs)) return torch.cat(fmaps, dim = 1) class UpsampleCombiner(nn.Module): def __init__( self, dim, *, enabled = False, dim_ins = tuple(), dim_outs = tuple() ): super().__init__() dim_outs = cast_tuple(dim_outs, len(dim_ins)) assert len(dim_ins) == len(dim_outs) self.enabled = enabled if not self.enabled: self.dim_out = dim return self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)]) self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0) def forward(self, x, fmaps = None): target_size = x.shape[-1] fmaps = default(fmaps, tuple()) if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0: return x fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps] outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)] return torch.cat((x, *outs), dim = 1) class Unet(nn.Module): def __init__( self, *, dim, text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME), num_resnet_blocks = 1, cond_dim = None, num_image_tokens = 4, num_time_tokens = 2, learned_sinu_pos_emb_dim = 16, out_dim = None, dim_mults=(1, 2, 4, 8), cond_images_channels = 0, channels = 3, channels_out = None, attn_dim_head = 64, attn_heads = 8, ff_mult = 2., lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ layer_attns = True, layer_attns_depth = 1, layer_mid_attns_depth = 1, layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) layer_cross_attns = True, use_linear_attn = False, use_linear_cross_attn = False, cond_on_text = True, max_text_len = 256, init_dim = None, resnet_groups = 8, init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed init_cross_embed = True, init_cross_embed_kernel_sizes = (3, 7, 15), cross_embed_downsample = False, cross_embed_downsample_kernel_sizes = (2, 4), attn_pool_text = True, attn_pool_num_latents = 32, dropout = 0., memory_efficient = False, init_conv_to_final_conv_residual = False, use_global_context_attn = True, scale_skip_connection = True, final_resnet_block = True, final_conv_kernel_size = 3, self_cond = False, resize_mode = 'nearest', combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully pixel_shuffle_upsample = True, # may address checkboard artifacts ): super().__init__() # guide researchers assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8' if dim < 128: print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/') # save locals to take care of some hyperparameters for cascading DDPM self._locals = locals() self._locals.pop('self', None) self._locals.pop('__class__', None) # determine dimensions self.channels = channels self.channels_out = default(channels_out, channels) # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis # (2) in self conditioning, one appends the predict x0 (x_start) init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) init_dim = default(init_dim, dim) self.self_cond = self_cond # optional image conditioning self.has_cond_image = cond_images_channels > 0 self.cond_images_channels = cond_images_channels init_channels += cond_images_channels # initial convolution self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # time conditioning cond_dim = default(cond_dim, dim) time_cond_dim = dim * 4 * (2 if lowres_cond else 1) # embedding time for log(snr) noise from continuous version sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 self.to_time_hiddens = nn.Sequential( sinu_pos_emb, nn.Linear(sinu_pos_emb_input_dim, time_cond_dim), nn.SiLU() ) self.to_time_cond = nn.Sequential( nn.Linear(time_cond_dim, time_cond_dim) ) # project to time tokens as well as time hiddens self.to_time_tokens = nn.Sequential( nn.Linear(time_cond_dim, cond_dim * num_time_tokens), Rearrange('b (r d) -> b r d', r = num_time_tokens) ) # low res aug noise conditioning self.lowres_cond = lowres_cond if lowres_cond: self.to_lowres_time_hiddens = nn.Sequential( LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim), nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim), nn.SiLU() ) self.to_lowres_time_cond = nn.Sequential( nn.Linear(time_cond_dim, time_cond_dim) ) self.to_lowres_time_tokens = nn.Sequential( nn.Linear(time_cond_dim, cond_dim * num_time_tokens), Rearrange('b (r d) -> b r d', r = num_time_tokens) ) # normalizations self.norm_cond = nn.LayerNorm(cond_dim) # text encoding conditioning (optional) self.text_to_cond = None if cond_on_text: assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True' self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) # finer control over whether to condition on text encodings self.cond_on_text = cond_on_text # attention pooling self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents) if attn_pool_text else None # for classifier free guidance self.max_text_len = max_text_len self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim)) # for non-attention based text conditioning at all points in the network where time is also conditioned self.to_text_non_attn_cond = None if cond_on_text: self.to_text_non_attn_cond = nn.Sequential( nn.LayerNorm(cond_dim), nn.Linear(cond_dim, time_cond_dim), nn.SiLU(), nn.Linear(time_cond_dim, time_cond_dim) ) # attention related params attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) num_layers = len(in_out) # resnet block klass num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) resnet_groups = cast_tuple(resnet_groups, num_layers) resnet_klass = partial(ResnetBlock, **attn_kwargs) layer_attns = cast_tuple(layer_attns, num_layers) layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) use_linear_attn = cast_tuple(use_linear_attn, num_layers) use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers) assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) # downsample klass downsample_klass = Downsample if cross_embed_downsample: downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) # initial resnet block (for memory efficient unet) self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None # scale for resnet skip connections self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn] reversed_layer_params = list(map(reversed, layer_params)) # downsampling layers skip_connect_dims = [] # keep track of skip connection dimensions for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)): is_last = ind >= (num_resolutions - 1) layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None if layer_attn: transformer_block_klass = TransformerBlock elif layer_use_linear_attn: transformer_block_klass = LinearAttentionTransformerBlock else: transformer_block_klass = Identity current_dim = dim_in # whether to pre-downsample, from memory efficient unet pre_downsample = None if memory_efficient: pre_downsample = downsample_klass(dim_in, dim_out) current_dim = dim_out skip_connect_dims.append(current_dim) # whether to do post-downsample, for non-memory efficient unet post_downsample = None if not memory_efficient: post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv2d(dim_in, dim_out, 3, padding = 1), nn.Conv2d(dim_in, dim_out, 1)) self.downs.append(nn.ModuleList([ pre_downsample, resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), post_downsample ])) # middle layers mid_dim = dims[-1] self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_attn = TransformerBlock(mid_dim, depth = layer_mid_attns_depth, **attn_kwargs) if attend_at_middle else None self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) # upsample klass upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample # upsampling layers upsample_fmap_dims = [] for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)): is_last = ind == (len(in_out) - 1) layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None if layer_attn: transformer_block_klass = TransformerBlock elif layer_use_linear_attn: transformer_block_klass = LinearAttentionTransformerBlock else: transformer_block_klass = Identity skip_connect_dim = skip_connect_dims.pop() upsample_fmap_dims.append(dim_out) self.ups.append(nn.ModuleList([ resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() ])) # whether to combine feature maps from all upsample blocks before final resnet block out self.upsample_combiner = UpsampleCombiner( dim = dim, enabled = combine_upsample_fmaps, dim_ins = upsample_fmap_dims, dim_outs = dim ) # whether to do a final residual from initial conv to the final resnet block out self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0) # final optional resnet block and convolution out self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None final_conv_dim_in = dim if final_resnet_block else final_conv_dim final_conv_dim_in += (channels if lowres_cond else 0) self.final_conv = nn.Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2) zero_init_(self.final_conv) # resize mode self.resize_mode = resize_mode # if the current settings for the unet are not correct # for cascading DDPM, then reinit the unet with the right settings def cast_model_parameters( self, *, lowres_cond, text_embed_dim, channels, channels_out, cond_on_text ): if lowres_cond == self.lowres_cond and \ channels == self.channels and \ cond_on_text == self.cond_on_text and \ text_embed_dim == self._locals['text_embed_dim'] and \ channels_out == self.channels_out: return self updated_kwargs = dict( lowres_cond = lowres_cond, text_embed_dim = text_embed_dim, channels = channels, channels_out = channels_out, cond_on_text = cond_on_text ) return self.__class__(**{**self._locals, **updated_kwargs}) # methods for returning the full unet config as well as its parameter state def to_config_and_state_dict(self): return self._locals, self.state_dict() # class method for rehydrating the unet from its config and state dict @classmethod def from_config_and_state_dict(klass, config, state_dict): unet = klass(**config) unet.load_state_dict(state_dict) return unet # methods for persisting unet to disk def persist_to_file(self, path): path = Path(path) path.parents[0].mkdir(exist_ok = True, parents = True) config, state_dict = self.to_config_and_state_dict() pkg = dict(config = config, state_dict = state_dict) torch.save(pkg, str(path)) # class method for rehydrating the unet from file saved with `persist_to_file` @classmethod def hydrate_from_file(klass, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) assert 'config' in pkg and 'state_dict' in pkg config, state_dict = pkg['config'], pkg['state_dict'] return Unet.from_config_and_state_dict(config, state_dict) # forward with classifier free guidance def forward_with_cond_scale( self, *args, cond_scale = 1., **kwargs ): logits = self.forward(*args, **kwargs) if cond_scale == 1: return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( self, x, time, *, lowres_cond_img = None, lowres_noise_times = None, text_embeds = None, text_mask = None, cond_images = None, self_cond = None, cond_drop_prob = 0. ): batch_size, device = x.shape[0], x.device # condition on self if self.self_cond: self_cond = default(self_cond, lambda: torch.zeros_like(x)) x = torch.cat((x, self_cond), dim = 1) # add low resolution conditioning, if present assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present' if exists(lowres_cond_img): x = torch.cat((x, lowres_cond_img), dim = 1) # condition on input image assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' if exists(cond_images): assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' cond_images = resize_image_to(cond_images, x.shape[-1], mode = self.resize_mode) x = torch.cat((cond_images, x), dim = 1) # initial convolution x = self.init_conv(x) # init conv residual if self.init_conv_to_final_conv_residual: init_conv_residual = x.clone() # time conditioning time_hiddens = self.to_time_hiddens(time) # derive time tokens time_tokens = self.to_time_tokens(time_hiddens) t = self.to_time_cond(time_hiddens) # add lowres time conditioning to time hiddens # and add lowres time tokens along sequence dimension for attention if self.lowres_cond: lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times) lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens) lowres_t = self.to_lowres_time_cond(lowres_time_hiddens) t = t + lowres_t time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2) # text conditioning text_tokens = None if exists(text_embeds) and self.cond_on_text: # conditional dropout text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1') text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1') # calculate text embeds text_tokens = self.text_to_cond(text_embeds) text_tokens = text_tokens[:, :self.max_text_len] if exists(text_mask): text_mask = text_mask[:, :self.max_text_len] text_tokens_len = text_tokens.shape[1] remainder = self.max_text_len - text_tokens_len if remainder > 0: text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) if exists(text_mask): if remainder > 0: text_mask = F.pad(text_mask, (0, remainder), value = False) text_mask = rearrange(text_mask, 'b n -> b n 1') text_keep_mask_embed = text_mask & text_keep_mask_embed null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working text_tokens = torch.where( text_keep_mask_embed, text_tokens, null_text_embed ) if exists(self.attn_pool): text_tokens = self.attn_pool(text_tokens) # extra non-attention conditioning by projecting and then summing text embeddings to time # termed as text hiddens mean_pooled_text_tokens = text_tokens.mean(dim = -2) text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens) null_text_hidden = self.null_text_hidden.to(t.dtype) text_hiddens = torch.where( text_keep_mask_hidden, text_hiddens, null_text_hidden ) t = t + text_hiddens # main conditioning tokens (c) c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2) # normalize conditioning tokens c = self.norm_cond(c) # initial resnet block (for memory efficient unet) if exists(self.init_resnet_block): x = self.init_resnet_block(x, t) # go through the layers of the unet, down and up hiddens = [] for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs: if exists(pre_downsample): x = pre_downsample(x) x = init_block(x, t, c) for resnet_block in resnet_blocks: x = resnet_block(x, t) hiddens.append(x) x = attn_block(x, c) hiddens.append(x) if exists(post_downsample): x = post_downsample(x) x = self.mid_block1(x, t, c) if exists(self.mid_attn): x = self.mid_attn(x) x = self.mid_block2(x, t, c) add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) up_hiddens = [] for init_block, resnet_blocks, attn_block, upsample in self.ups: x = add_skip_connection(x) x = init_block(x, t, c) for resnet_block in resnet_blocks: x = add_skip_connection(x) x = resnet_block(x, t) x = attn_block(x, c) up_hiddens.append(x.contiguous()) x = upsample(x) # whether to combine all feature maps from upsample blocks x = self.upsample_combiner(x, up_hiddens) # final top-most residual if needed if self.init_conv_to_final_conv_residual: x = torch.cat((x, init_conv_residual), dim = 1) if exists(self.final_res_block): x = self.final_res_block(x, t) if exists(lowres_cond_img): x = torch.cat((x, lowres_cond_img), dim = 1) return self.final_conv(x) # null unet class NullUnet(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self.lowres_cond = False self.dummy_parameter = nn.Parameter(torch.tensor([0.])) def cast_model_parameters(self, *args, **kwargs): return self def forward(self, x, *args, **kwargs): return x # predefined unets, with configs lining up with hyperparameters in appendix of paper class BaseUnet64(Unet): def __init__(self, *args, **kwargs): default_kwargs = dict( dim = 512, dim_mults = (1, 2, 3, 4), num_resnet_blocks = 3, layer_attns = (False, True, True, True), layer_cross_attns = (False, True, True, True), attn_heads = 8, ff_mult = 2., memory_efficient = False ) super().__init__(*args, **{**default_kwargs, **kwargs}) class SRUnet256(Unet): def __init__(self, *args, **kwargs): default_kwargs = dict( dim = 128, dim_mults = (1, 2, 4, 8), num_resnet_blocks = (2, 4, 8, 8), layer_attns = (False, False, False, True), layer_cross_attns = (False, False, False, True), attn_heads = 8, ff_mult = 2., memory_efficient = True ) super().__init__(*args, **{**default_kwargs, **kwargs}) class SRUnet1024(Unet): def __init__(self, *args, **kwargs): default_kwargs = dict( dim = 128, dim_mults = (1, 2, 4, 8), num_resnet_blocks = (2, 4, 8, 8), layer_attns = False, layer_cross_attns = (False, False, False, True), attn_heads = 8, ff_mult = 2., memory_efficient = True ) super().__init__(*args, **{**default_kwargs, **kwargs}) # main imagen ddpm class, which is a cascading DDPM from Ho et al. class Imagen(nn.Module): def __init__( self, unets, *, image_sizes, # for cascading ddpm, image size at each stage text_encoder_name = DEFAULT_T5_NAME, text_embed_dim = None, channels = 3, timesteps = 1000, cond_drop_prob = 0.1, loss_type = 'l2', noise_schedules = 'cosine', pred_objectives = 'noise', random_crop_sizes = None, lowres_noise_schedule = 'linear', lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find condition_on_text = True, auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader dynamic_thresholding = True, dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper only_train_unet_number = None, temporal_downsample_factor = 1, resize_cond_video_frames = True, resize_mode = 'nearest', min_snr_loss_weight = True, # https://arxiv.org/abs/2303.09556 min_snr_gamma = 5 ): super().__init__() # loss if loss_type == 'l1': loss_fn = F.l1_loss elif loss_type == 'l2': loss_fn = F.mse_loss elif loss_type == 'huber': loss_fn = F.smooth_l1_loss else: raise NotImplementedError() self.loss_type = loss_type self.loss_fn = loss_fn # conditioning hparams self.condition_on_text = condition_on_text self.unconditional = not condition_on_text # channels self.channels = channels # automatically take care of ensuring that first unet is unconditional # while the rest of the unets are conditioned on the low resolution image produced by previous unet unets = cast_tuple(unets) num_unets = len(unets) # determine noise schedules per unet timesteps = cast_tuple(timesteps, num_unets) # make sure noise schedule defaults to 'cosine', 'cosine', and then 'linear' for rest of super-resoluting unets noise_schedules = cast_tuple(noise_schedules) noise_schedules = pad_tuple_to_length(noise_schedules, 2, 'cosine') noise_schedules = pad_tuple_to_length(noise_schedules, num_unets, 'linear') # construct noise schedulers noise_scheduler_klass = GaussianDiffusionContinuousTimes self.noise_schedulers = nn.ModuleList([]) for timestep, noise_schedule in zip(timesteps, noise_schedules): noise_scheduler = noise_scheduler_klass(noise_schedule = noise_schedule, timesteps = timestep) self.noise_schedulers.append(noise_scheduler) # randomly cropping for upsampler training self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets) assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example' # lowres augmentation noise schedule self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule) # ddpm objectives - predicting noise by default self.pred_objectives = cast_tuple(pred_objectives, num_unets) # get text encoder self.text_encoder_name = text_encoder_name self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name)) self.encode_text = partial(t5_encode_text, name = text_encoder_name) # construct unets self.unets = nn.ModuleList([]) self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment self.only_train_unet_number = only_train_unet_number for ind, one_unet in enumerate(unets): assert isinstance(one_unet, (Unet, Unet3D, NullUnet)) is_first = ind == 0 one_unet = one_unet.cast_model_parameters( lowres_cond = not is_first, cond_on_text = self.condition_on_text, text_embed_dim = self.text_embed_dim if self.condition_on_text else None, channels = self.channels, channels_out = self.channels ) self.unets.append(one_unet) # unet image sizes image_sizes = cast_tuple(image_sizes) self.image_sizes = image_sizes assert num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({len(unets)}) for resolutions {image_sizes}' self.sample_channels = cast_tuple(self.channels, num_unets) # determine whether we are training on images or video is_video = any([isinstance(unet, Unet3D) for unet in self.unets]) self.is_video = is_video self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1')) self.resize_to = resize_video_to if is_video else resize_image_to self.resize_to = partial(self.resize_to, mode = resize_mode) # temporal interpolation temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets) self.temporal_downsample_factor = temporal_downsample_factor self.resize_cond_video_frames = resize_cond_video_frames self.temporal_downsample_divisor = temporal_downsample_factor[0] assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1' assert tuple(sorted(temporal_downsample_factor, reverse = True)) == temporal_downsample_factor, 'temporal downsample factor must be in order of descending' # cascading ddpm related stuff lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' self.lowres_sample_noise_level = lowres_sample_noise_level self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level # classifier free guidance self.cond_drop_prob = cond_drop_prob self.can_classifier_guidance = cond_drop_prob > 0. # normalize and unnormalize image functions self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity self.input_image_range = (0. if auto_normalize_img else -1., 1.) # dynamic thresholding self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets) self.dynamic_thresholding_percentile = dynamic_thresholding_percentile # min snr loss weight min_snr_loss_weight = cast_tuple(min_snr_loss_weight, num_unets) min_snr_gamma = cast_tuple(min_snr_gamma, num_unets) assert len(min_snr_loss_weight) == len(min_snr_gamma) == num_unets self.min_snr_gamma = tuple((gamma if use_min_snr else None) for use_min_snr, gamma in zip(min_snr_loss_weight, min_snr_gamma)) # one temp parameter for keeping track of device self.register_buffer('_temp', torch.tensor([0.]), persistent = False) # default to device of unets passed in self.to(next(self.unets.parameters()).device) def force_unconditional_(self): self.condition_on_text = False self.unconditional = True for unet in self.unets: unet.cond_on_text = False @property def device(self): return self._temp.device def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) index = unet_number - 1 if isinstance(self.unets, nn.ModuleList): unets_list = [unet for unet in self.unets] delattr(self, 'unets') self.unets = unets_list if index != self.unet_being_trained_index: for unet_index, unet in enumerate(self.unets): unet.to(self.device if unet_index == index else 'cpu') self.unet_being_trained_index = index return self.unets[index] def reset_unets_all_one_device(self, device = None): device = default(device, self.device) self.unets = nn.ModuleList([*self.unets]) self.unets.to(device) self.unet_being_trained_index = -1 @contextmanager def one_unet_in_gpu(self, unet_number = None, unet = None): assert exists(unet_number) ^ exists(unet) if exists(unet_number): unet = self.unets[unet_number - 1] cpu = torch.device('cpu') devices = [module_device(unet) for unet in self.unets] self.unets.to(cpu) unet.to(self.device) yield for unet, device in zip(self.unets, devices): unet.to(device) # overriding state dict functions def state_dict(self, *args, **kwargs): self.reset_unets_all_one_device() return super().state_dict(*args, **kwargs) def load_state_dict(self, *args, **kwargs): self.reset_unets_all_one_device() return super().load_state_dict(*args, **kwargs) # gaussian diffusion methods def p_mean_variance( self, unet, x, t, *, noise_scheduler, text_embeds = None, text_mask = None, cond_images = None, cond_video_frames = None, post_cond_video_frames = None, lowres_cond_img = None, self_cond = None, lowres_noise_times = None, cond_scale = 1., model_output = None, t_next = None, pred_objective = 'noise', dynamic_threshold = True ): assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' video_kwargs = dict() if self.is_video: video_kwargs = dict( cond_video_frames = cond_video_frames, post_cond_video_frames = post_cond_video_frames, ) pred = default(model_output, lambda: unet.forward_with_cond_scale( x, noise_scheduler.get_condition(t), text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times), **video_kwargs )) if pred_objective == 'noise': x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) elif pred_objective == 'x_start': x_start = pred elif pred_objective == 'v': x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred) else: raise ValueError(f'unknown objective {pred_objective}') if dynamic_threshold: # following pseudocode in appendix # s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element s = torch.quantile( rearrange(x_start, 'b ... -> b (...)').abs(), self.dynamic_thresholding_percentile, dim = -1 ) s.clamp_(min = 1.) s = right_pad_dims_to(x_start, s) x_start = x_start.clamp(-s, s) / s else: x_start.clamp_(-1., 1.) mean_and_variance = noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next) return mean_and_variance, x_start @torch.no_grad() def p_sample( self, unet, x, t, *, noise_scheduler, t_next = None, text_embeds = None, text_mask = None, cond_images = None, cond_video_frames = None, post_cond_video_frames = None, cond_scale = 1., self_cond = None, lowres_cond_img = None, lowres_noise_times = None, pred_objective = 'noise', dynamic_threshold = True ): b, *_, device = *x.shape, x.device video_kwargs = dict() if self.is_video: video_kwargs = dict( cond_video_frames = cond_video_frames, post_cond_video_frames = post_cond_video_frames, ) (model_mean, _, model_log_variance), x_start = self.p_mean_variance( unet, x = x, t = t, t_next = t_next, noise_scheduler = noise_scheduler, text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_times = lowres_noise_times, pred_objective = pred_objective, dynamic_threshold = dynamic_threshold, **video_kwargs ) noise = torch.randn_like(x) # no noise when t == 0 is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0) nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1))) pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return pred, x_start @torch.no_grad() def p_sample_loop( self, unet, shape, *, noise_scheduler, lowres_cond_img = None, lowres_noise_times = None, text_embeds = None, text_mask = None, cond_images = None, cond_video_frames = None, post_cond_video_frames = None, inpaint_images = None, inpaint_videos = None, inpaint_masks = None, inpaint_resample_times = 5, init_images = None, skip_steps = None, cond_scale = 1, pred_objective = 'noise', dynamic_threshold = True, use_tqdm = True ): device = self.device batch = shape[0] img = torch.randn(shape, device = device) # video is_video = len(shape) == 5 frames = shape[-3] if is_video else None resize_kwargs = dict(target_frames = frames) if exists(frames) else dict() # for initialization with an image or video if exists(init_images): img += init_images # keep track of x0, for self conditioning x_start = None # prepare inpainting inpaint_images = default(inpaint_videos, inpaint_images) has_inpainting = exists(inpaint_images) and exists(inpaint_masks) resample_times = inpaint_resample_times if has_inpainting else 1 if has_inpainting: inpaint_images = self.normalize_img(inpaint_images) inpaint_images = self.resize_to(inpaint_images, shape[-1], **resize_kwargs) inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1], **resize_kwargs).bool() # time timesteps = noise_scheduler.get_sampling_timesteps(batch, device = device) # whether to skip any steps skip_steps = default(skip_steps, 0) timesteps = timesteps[skip_steps:] # video conditioning kwargs video_kwargs = dict() if self.is_video: video_kwargs = dict( cond_video_frames = cond_video_frames, post_cond_video_frames = post_cond_video_frames, ) for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps), disable = not use_tqdm): is_last_timestep = times_next == 0 for r in reversed(range(resample_times)): is_last_resample_step = r == 0 if has_inpainting: noised_inpaint_images, *_ = noise_scheduler.q_sample(inpaint_images, t = times) img = img * ~inpaint_masks + noised_inpaint_images * inpaint_masks self_cond = x_start if unet.self_cond else None img, x_start = self.p_sample( unet, img, times, t_next = times_next, text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_times = lowres_noise_times, noise_scheduler = noise_scheduler, pred_objective = pred_objective, dynamic_threshold = dynamic_threshold, **video_kwargs ) if has_inpainting and not (is_last_resample_step or torch.all(is_last_timestep)): renoised_img = noise_scheduler.q_sample_from_to(img, times_next, times) img = torch.where( self.right_pad_dims_to_datatype(is_last_timestep), img, renoised_img ) img.clamp_(-1., 1.) # final inpainting if has_inpainting: img = img * ~inpaint_masks + inpaint_images * inpaint_masks unnormalize_img = self.unnormalize_img(img) return unnormalize_img @torch.no_grad() @eval_decorator @beartype def sample( self, texts: List[str] = None, text_masks = None, text_embeds = None, video_frames = None, cond_images = None, cond_video_frames = None, post_cond_video_frames = None, inpaint_videos = None, inpaint_images = None, inpaint_masks = None, inpaint_resample_times = 5, init_images = None, skip_steps = None, batch_size = 1, cond_scale = 1., lowres_sample_noise_level = None, start_at_unet_number = 1, start_image_or_video = None, stop_at_unet_number = None, return_all_unet_outputs = False, return_pil_images = False, device = None, use_tqdm = True, use_one_unet_in_gpu = True ): device = default(device, self.device) self.reset_unets_all_one_device(device = device) cond_images = maybe(cast_uint8_images_to_float)(cond_images) if exists(texts) and not exists(text_embeds) and not self.unconditional: assert all([*map(len, texts)]), 'text cannot be empty' with autocast(enabled = False): text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks)) if not self.unconditional: assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training' text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) batch_size = text_embeds.shape[0] # inpainting inpaint_images = default(inpaint_videos, inpaint_images) if exists(inpaint_images): if self.unconditional: if batch_size == 1: # assume researcher wants to broadcast along inpainted images batch_size = inpaint_images.shape[0] assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=)``' assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on' assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified' assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented' assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting' outputs = [] is_cuda = next(self.parameters()).is_cuda device = next(self.parameters()).device lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level) num_unets = len(self.unets) # condition scaling cond_scale = cast_tuple(cond_scale, num_unets) # add frame dimension for video if self.is_video and exists(inpaint_images): video_frames = inpaint_images.shape[2] if inpaint_masks.ndim == 3: inpaint_masks = repeat(inpaint_masks, 'b h w -> b f h w', f = video_frames) assert inpaint_masks.shape[1] == video_frames assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video' all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames) frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict() # for initial image and skipping steps init_images = cast_tuple(init_images, num_unets) init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images] skip_steps = cast_tuple(skip_steps, num_unets) # handle starting at a unet greater than 1, for training only-upscaler training if start_at_unet_number > 1: assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets' assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling' prev_image_size = self.image_sizes[start_at_unet_number - 2] prev_frame_size = all_frame_dims[start_at_unet_number - 2][0] if self.is_video else None img = self.resize_to(start_image_or_video, prev_image_size, **frames_to_resize_kwargs(prev_frame_size)) # go through each unet in cascade for unet_number, unet, channel, image_size, frame_dims, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, all_frame_dims, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm): if unet_number < start_at_unet_number: continue assert not isinstance(unet, NullUnet), 'one cannot sample from null / placeholder unets' context = self.one_unet_in_gpu(unet = unet) if is_cuda and use_one_unet_in_gpu else nullcontext() with context: # video kwargs video_kwargs = dict() if self.is_video: video_kwargs = dict( cond_video_frames = cond_video_frames, post_cond_video_frames = post_cond_video_frames, ) video_kwargs = compact(video_kwargs) if self.is_video and self.resize_cond_video_frames: downsample_scale = self.temporal_downsample_factor[unet_number - 1] temporal_downsample_fn = partial(scale_video_time, downsample_scale = downsample_scale) video_kwargs = maybe_transform_dict_key(video_kwargs, 'cond_video_frames', temporal_downsample_fn) video_kwargs = maybe_transform_dict_key(video_kwargs, 'post_cond_video_frames', temporal_downsample_fn) # low resolution conditioning lowres_cond_img = lowres_noise_times = None shape = (batch_size, channel, *frame_dims, image_size, image_size) resize_kwargs = dict(target_frames = frame_dims[0]) if self.is_video else dict() if unet.lowres_cond: lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device) lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs) lowres_cond_img = self.normalize_img(lowres_cond_img) lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img)) # init images or video if exists(unet_init_images): unet_init_images = self.resize_to(unet_init_images, image_size, **resize_kwargs) # shape of stage shape = (batch_size, self.channels, *frame_dims, image_size, image_size) img = self.p_sample_loop( unet, shape, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, inpaint_resample_times = inpaint_resample_times, init_images = unet_init_images, skip_steps = unet_skip_steps, cond_scale = unet_cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_times = lowres_noise_times, noise_scheduler = noise_scheduler, pred_objective = pred_objective, dynamic_threshold = dynamic_threshold, use_tqdm = use_tqdm, **video_kwargs ) outputs.append(img) if exists(stop_at_unet_number) and stop_at_unet_number == unet_number: break output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs if not return_pil_images: return outputs[output_index] if not return_all_unet_outputs: outputs = outputs[-1:] assert not self.is_video, 'converting sampled video tensor to video file is not supported yet' pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs)) return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png) @beartype def p_losses( self, unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel], x_start, times, *, noise_scheduler, lowres_cond_img = None, lowres_aug_times = None, text_embeds = None, text_mask = None, cond_images = None, noise = None, times_next = None, pred_objective = 'noise', min_snr_gamma = None, random_crop_size = None, **kwargs ): is_video = x_start.ndim == 5 noise = default(noise, lambda: torch.randn_like(x_start)) # normalize to [-1, 1] x_start = self.normalize_img(x_start) lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) # random cropping during training # for upsamplers if exists(random_crop_size): if is_video: frames = x_start.shape[2] x_start, lowres_cond_img, noise = map(lambda t: rearrange(t, 'b c f h w -> (b f) c h w'), (x_start, lowres_cond_img, noise)) aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.) # make sure low res conditioner and image both get augmented the same way # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop x_start = aug(x_start) lowres_cond_img = aug(lowres_cond_img, params = aug._params) noise = aug(noise, params = aug._params) if is_video: x_start, lowres_cond_img, noise = map(lambda t: rearrange(t, '(b f) c h w -> b c f h w', f = frames), (x_start, lowres_cond_img, noise)) # get x_t x_noisy, log_snr, alpha, sigma = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise) # also noise the lowres conditioning image # at sample time, they then fix the noise level of 0.1 - 0.3 lowres_cond_img_noisy = None if exists(lowres_cond_img): lowres_aug_times = default(lowres_aug_times, times) lowres_cond_img_noisy, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img)) # time condition noise_cond = noise_scheduler.get_condition(times) # unet kwargs unet_kwargs = dict( text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times), lowres_cond_img = lowres_cond_img_noisy, cond_drop_prob = self.cond_drop_prob, **kwargs ) # self condition if needed # Because 'unet' can be an instance of DistributedDataParallel coming from the # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to # access the member 'module' of the wrapped unet instance. self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet.self_cond if self_cond and random() < 0.5: with torch.no_grad(): pred = unet.forward( x_noisy, noise_cond, **unet_kwargs ).detach() x_start = noise_scheduler.predict_start_from_noise(x_noisy, t = times, noise = pred) if pred_objective == 'noise' else pred unet_kwargs = {**unet_kwargs, 'self_cond': x_start} # get prediction pred = unet.forward( x_noisy, noise_cond, **unet_kwargs ) # prediction objective if pred_objective == 'noise': target = noise elif pred_objective == 'x_start': target = x_start elif pred_objective == 'v': # derivation detailed in Appendix D of Progressive Distillation paper # https://arxiv.org/abs/2202.00512 # this makes distillation viable as well as solve an issue with color shifting in upresoluting unets, noted in imagen-video target = alpha * noise - sigma * x_start else: raise ValueError(f'unknown objective {pred_objective}') # losses losses = self.loss_fn(pred, target, reduction = 'none') losses = reduce(losses, 'b ... -> b', 'mean') # min snr loss reweighting snr = log_snr.exp() maybe_clipped_snr = snr.clone() if exists(min_snr_gamma): maybe_clipped_snr.clamp_(max = min_snr_gamma) if pred_objective == 'noise': loss_weight = maybe_clipped_snr / snr elif pred_objective == 'x_start': loss_weight = maybe_clipped_snr elif pred_objective == 'v': loss_weight = maybe_clipped_snr / (snr + 1) losses = losses * loss_weight return losses.mean() @beartype def forward( self, images, # rename to images or video unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None, texts: List[str] = None, text_embeds = None, text_masks = None, unet_number = None, cond_images = None, **kwargs ): if self.is_video and images.ndim == 4: images = rearrange(images, 'b c h w -> b c 1 h w') kwargs.update(ignore_time = True) assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}' assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}' images = cast_uint8_images_to_float(images) cond_images = maybe(cast_uint8_images_to_float)(cond_images) assert images.dtype == torch.float or images.dtype == torch.half, f'images tensor needs to be floats but {images.dtype} dtype found instead' unet_index = unet_number - 1 unet = default(unet, lambda: self.get_unet(unet_number)) assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained' noise_scheduler = self.noise_schedulers[unet_index] min_snr_gamma = self.min_snr_gamma[unet_index] pred_objective = self.pred_objectives[unet_index] target_image_size = self.image_sizes[unet_index] random_crop_size = self.random_crop_sizes[unet_index] prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None b, c, *_, h, w, device, is_video = *images.shape, images.device, images.ndim == 5 assert images.shape[1] == self.channels assert h >= target_image_size and w >= target_image_size frames = images.shape[2] if is_video else None all_frame_dims = tuple(safe_get_tuple_index(el, 0) for el in calc_all_frame_dims(self.temporal_downsample_factor, frames)) ignore_time = kwargs.get('ignore_time', False) target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict() times = noise_scheduler.sample_random_times(b, device = device) if exists(texts) and not exists(text_embeds) and not self.unconditional: assert all([*map(len, texts)]), 'text cannot be empty' assert len(texts) == len(images), 'number of text captions does not match up with the number of images given' with autocast(enabled = False): text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks)) if not self.unconditional: text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' # handle video frame conditioning if self.is_video and self.resize_cond_video_frames: downsample_scale = self.temporal_downsample_factor[unet_index] temporal_downsample_fn = partial(scale_video_time, downsample_scale = downsample_scale) kwargs = maybe_transform_dict_key(kwargs, 'cond_video_frames', temporal_downsample_fn) kwargs = maybe_transform_dict_key(kwargs, 'post_cond_video_frames', temporal_downsample_fn) # handle low resolution conditioning lowres_cond_img = lowres_aug_times = None if exists(prev_image_size): lowres_cond_img = self.resize_to(images, prev_image_size, **frames_to_resize_kwargs(prev_frame_size), clamp_range = self.input_image_range) lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, **frames_to_resize_kwargs(target_frame_size), clamp_range = self.input_image_range) if self.per_sample_random_aug_noise_level: lowres_aug_times = self.lowres_noise_schedule.sample_random_times(b, device = device) else: lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device) lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = b) images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size)) return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, min_snr_gamma = min_snr_gamma, random_crop_size = random_crop_size, **kwargs)