import math import copy import operator import functools from typing import List from tqdm.auto import tqdm from functools import partial, wraps from contextlib import contextmanager, nullcontext from collections import namedtuple from pathlib import Path import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat, reduce, pack, unpack from einops.layers.torch import Rearrange, Reduce from einops_exts.torch import EinopsToAndFrom from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME # helper functions def exists(val): return val is not None def identity(t, *args, **kwargs): return t def first(arr, d = None): if len(arr) == 0: return d return arr[0] def divisible_by(numer, denom): return (numer % denom) == 0 def maybe(fn): @wraps(fn) def inner(x): if not exists(x): return x return fn(x) return inner def once(fn): called = False @wraps(fn) def inner(x): nonlocal called if called: return called = True return fn(x) return inner print_once = once(print) def default(val, d): if exists(val): return val return d() if callable(d) else d def cast_tuple(val, length = None): if isinstance(val, list): val = tuple(val) output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) if exists(length): assert len(output) == length return output def cast_uint8_images_to_float(images): if not images.dtype == torch.uint8: return images return images / 255 def module_device(module): return next(module.parameters()).device def zero_init_(m): nn.init.zeros_(m.weight) if exists(m.bias): nn.init.zeros_(m.bias) def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training model.eval() out = fn(model, *args, **kwargs) model.train(was_training) return out return inner def pad_tuple_to_length(t, length, fillvalue = None): remain_length = length - len(t) if remain_length <= 0: return t return (*t, *((fillvalue,) * remain_length)) # helper classes class Identity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, x, *args, **kwargs): return x def Sequential(*modules): return nn.Sequential(*filter(exists, modules)) # tensor helpers def log(t, eps: float = 1e-12): return torch.log(t.clamp(min = eps)) def l2norm(t): return F.normalize(t, dim = -1) def right_pad_dims_to(x, t): padding_dims = x.ndim - t.ndim if padding_dims <= 0: return t return t.view(*t.shape, *((1,) * padding_dims)) def masked_mean(t, *, dim, mask = None): if not exists(mask): return t.mean(dim = dim) denom = mask.sum(dim = dim, keepdim = True) mask = rearrange(mask, 'b n -> b n 1') masked_t = t.masked_fill(~mask, 0.) return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) def resize_video_to( video, target_image_size, target_frames = None, clamp_range = None, mode = 'nearest' ): orig_video_size = video.shape[-1] frames = video.shape[2] target_frames = default(target_frames, frames) target_shape = (target_frames, target_image_size, target_image_size) if tuple(video.shape[-3:]) == target_shape: return video out = F.interpolate(video, target_shape, mode = mode) if exists(clamp_range): out = out.clamp(*clamp_range) return out def scale_video_time( video, downsample_scale = 1, mode = 'nearest' ): if downsample_scale == 1: return video image_size, frames = video.shape[-1], video.shape[-3] assert divisible_by(frames, downsample_scale), f'trying to temporally downsample a conditioning video frames of length {frames} by {downsample_scale}, however it is not neatly divisible' target_frames = frames // downsample_scale resized_video = resize_video_to( video, image_size, target_frames = target_frames, mode = mode ) return resized_video # classifier free guidance functions def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device = device, dtype = torch.bool) elif prob == 0: return torch.zeros(shape, device = device, dtype = torch.bool) else: return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob # norms and residuals class LayerNorm(nn.Module): def __init__(self, dim, stable = False): super().__init__() self.stable = stable self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): if self.stable: x = x / x.amax(dim = -1, keepdim = True).detach() eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim = -1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = -1, keepdim = True) return (x - mean) * (var + eps).rsqrt() * self.g class ChanLayerNorm(nn.Module): def __init__(self, dim, stable = False): super().__init__() self.stable = stable self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) def forward(self, x): if self.stable: x = x / x.amax(dim = 1, keepdim = True).detach() eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim = 1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) * (var + eps).rsqrt() * self.g class Always(): def __init__(self, val): self.val = val def __call__(self, *args, **kwargs): return self.val class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) + x class Parallel(nn.Module): def __init__(self, *fns): super().__init__() self.fns = nn.ModuleList(fns) def forward(self, x): outputs = [fn(x) for fn in self.fns] return sum(outputs) # rearranging class RearrangeTimeCentric(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): x = rearrange(x, 'b c f ... -> b ... f c') x, ps = pack([x], '* f c') x = self.fn(x) x, = unpack(x, ps, '* f c') x = rearrange(x, 'b ... f c -> b c f ...') return x # attention pooling class PerceiverAttention(nn.Module): def __init__( self, *, dim, dim_head = 64, heads = 8, scale = 8 ): super().__init__() self.scale = scale self.heads = heads inner_dim = dim_head * heads self.norm = nn.LayerNorm(dim) self.norm_latents = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), nn.LayerNorm(dim) ) def forward(self, x, latents, mask = None): x = self.norm(x) latents = self.norm_latents(latents) b, h = x.shape[0], self.heads q = self.to_q(latents) # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to kv_input = torch.cat((x, latents), dim = -2) k, v = self.to_kv(kv_input).chunk(2, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # qk rmsnorm q, k = map(l2norm, (q, k)) q = q * self.q_scale k = k * self.k_scale # similarities and masking sim = einsum('... i d, ... j d -> ... i j', q, k) * self.scale if exists(mask): max_neg_value = -torch.finfo(sim.dtype).max mask = F.pad(mask, (0, latents.shape[-2]), value = True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) # attention attn = sim.softmax(dim = -1) out = einsum('... i j, ... j d -> ... i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)', h = h) return self.to_out(out) class PerceiverResampler(nn.Module): def __init__( self, *, dim, depth, dim_head = 64, heads = 8, num_latents = 64, num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence max_seq_len = 512, ff_mult = 4 ): super().__init__() self.pos_emb = nn.Embedding(max_seq_len, dim) self.latents = nn.Parameter(torch.randn(num_latents, dim)) self.to_latents_from_mean_pooled_seq = None if num_latents_mean_pooled > 0: self.to_latents_from_mean_pooled_seq = nn.Sequential( LayerNorm(dim), nn.Linear(dim, dim * num_latents_mean_pooled), Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled) ) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads), FeedForward(dim = dim, mult = ff_mult) ])) def forward(self, x, mask = None): n, device = x.shape[1], x.device pos_emb = self.pos_emb(torch.arange(n, device = device)) x_with_pos = x + pos_emb latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0]) if exists(self.to_latents_from_mean_pooled_seq): meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)) meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) latents = torch.cat((meanpooled_latents, latents), dim = -2) for attn, ff in self.layers: latents = attn(x_with_pos, latents, mask = mask) + latents latents = ff(latents) + latents return latents # main contribution from make-a-video - pseudo conv3d # axial space-time convolutions, but made causal to keep in line with the design decisions of imagen-video paper class Conv3d(nn.Module): def __init__( self, dim, dim_out = None, kernel_size = 3, *, temporal_kernel_size = None, **kwargs ): super().__init__() dim_out = default(dim_out, dim) temporal_kernel_size = default(temporal_kernel_size, kernel_size) self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2) self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size) if kernel_size > 1 else None self.kernel_size = kernel_size if exists(self.temporal_conv): nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity nn.init.zeros_(self.temporal_conv.bias.data) def forward( self, x, ignore_time = False ): b, c, *_, h, w = x.shape is_video = x.ndim == 5 ignore_time &= is_video if is_video: x = rearrange(x, 'b c f h w -> (b f) c h w') x = self.spatial_conv(x) if is_video: x = rearrange(x, '(b f) c h w -> b c f h w', b = b) if ignore_time or not exists(self.temporal_conv): return x x = rearrange(x, 'b c f h w -> (b h w) c f') # causal temporal convolution - time is causal in imagen-video if self.kernel_size > 1: x = F.pad(x, (self.kernel_size - 1, 0)) x = self.temporal_conv(x) x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w) return x # attention class Attention(nn.Module): def __init__( self, dim, *, dim_head = 64, heads = 8, causal = False, context_dim = None, rel_pos_bias = False, rel_pos_bias_mlp_depth = 2, init_zero = False, scale = 8 ): super().__init__() self.scale = scale self.causal = causal self.rel_pos_bias = DynamicPositionBias(dim = dim, heads = heads, depth = rel_pos_bias_mlp_depth) if rel_pos_bias else None self.heads = heads inner_dim = dim_head * heads self.norm = LayerNorm(dim) self.null_attn_bias = nn.Parameter(torch.randn(heads)) self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), LayerNorm(dim) ) if init_zero: nn.init.zeros_(self.to_out[-1].g) def forward( self, x, context = None, mask = None, attn_bias = None ): b, n, device = *x.shape[:2], x.device x = self.norm(x) q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) # add text conditioning, if present if exists(context): assert exists(self.to_context) ck, cv = self.to_context(context).chunk(2, dim = -1) k = torch.cat((ck, k), dim = -2) v = torch.cat((cv, v), dim = -2) # qk rmsnorm q, k = map(l2norm, (q, k)) q = q * self.q_scale k = k * self.k_scale # calculate query / key similarities sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale # relative positional encoding (T5 style) if not exists(attn_bias) and exists(self.rel_pos_bias): attn_bias = self.rel_pos_bias(n, device = device, dtype = q.dtype) if exists(attn_bias): null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n) attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1) sim = sim + attn_bias # masking max_neg_value = -torch.finfo(sim.dtype).max if self.causal: i, j = sim.shape[-2:] causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) sim = sim.masked_fill(causal_mask, max_neg_value) if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) # attention attn = sim.softmax(dim = -1) # aggregate values out = einsum('b h i j, b j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) # pseudo conv2d that uses conv3d but with kernel size of 1 across frames dimension def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs): kernel = cast_tuple(kernel, 2) stride = cast_tuple(stride, 2) padding = cast_tuple(padding, 2) if len(kernel) == 2: kernel = (1, *kernel) if len(stride) == 2: stride = (1, *stride) if len(padding) == 2: padding = (0, *padding) return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs) class Pad(nn.Module): def __init__(self, padding, value = 0.): super().__init__() self.padding = padding self.value = value def forward(self, x): return F.pad(x, self.padding, value = self.value) # decoder def Upsample(dim, dim_out = None): dim_out = default(dim_out, dim) return nn.Sequential( nn.Upsample(scale_factor = 2, mode = 'nearest'), Conv2d(dim, dim_out, 3, padding = 1) ) class PixelShuffleUpsample(nn.Module): def __init__(self, dim, dim_out = None): super().__init__() dim_out = default(dim_out, dim) conv = Conv2d(dim, dim_out * 4, 1) self.net = nn.Sequential( conv, nn.SiLU() ) self.pixel_shuffle = nn.PixelShuffle(2) self.init_conv_(conv) def init_conv_(self, conv): o, i, f, h, w = conv.weight.shape conv_weight = torch.empty(o // 4, i, f, h, w) nn.init.kaiming_uniform_(conv_weight) conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') conv.weight.data.copy_(conv_weight) nn.init.zeros_(conv.bias.data) def forward(self, x): out = self.net(x) frames = x.shape[2] out = rearrange(out, 'b c f h w -> (b f) c h w') out = self.pixel_shuffle(out) return rearrange(out, '(b f) c h w -> b c f h w', f = frames) def Downsample(dim, dim_out = None): dim_out = default(dim_out, dim) return nn.Sequential( Rearrange('b c f (h p1) (w p2) -> b (c p1 p2) f h w', p1 = 2, p2 = 2), Conv2d(dim * 4, dim_out, 1) ) # temporal up and downsamples class TemporalPixelShuffleUpsample(nn.Module): def __init__(self, dim, dim_out = None, stride = 2): super().__init__() self.stride = stride dim_out = default(dim_out, dim) conv = nn.Conv1d(dim, dim_out * stride, 1) self.net = nn.Sequential( conv, nn.SiLU() ) self.pixel_shuffle = Rearrange('b (c r) n -> b c (n r)', r = stride) self.init_conv_(conv) def init_conv_(self, conv): o, i, f = conv.weight.shape conv_weight = torch.empty(o // self.stride, i, f) nn.init.kaiming_uniform_(conv_weight) conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.stride) conv.weight.data.copy_(conv_weight) nn.init.zeros_(conv.bias.data) def forward(self, x): b, c, f, h, w = x.shape x = rearrange(x, 'b c f h w -> (b h w) c f') out = self.net(x) out = self.pixel_shuffle(out) return rearrange(out, '(b h w) c f -> b c f h w', h = h, w = w) def TemporalDownsample(dim, dim_out = None, stride = 2): dim_out = default(dim_out, dim) return nn.Sequential( Rearrange('b c (f p) h w -> b (c p) f h w', p = stride), Conv2d(dim * stride, dim_out, 1) ) # positional embedding class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') return torch.cat((emb.sin(), emb.cos()), dim = -1) class LearnedSinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, x): x = rearrange(x, 'b -> b 1') freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) fouriered = torch.cat((x, fouriered), dim = -1) return fouriered class Block(nn.Module): def __init__( self, dim, dim_out, groups = 8, norm = True ): super().__init__() self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() self.activation = nn.SiLU() self.project = Conv3d(dim, dim_out, 3, padding = 1) def forward( self, x, scale_shift = None, ignore_time = False ): x = self.groupnorm(x) if exists(scale_shift): scale, shift = scale_shift x = x * (scale + 1) + shift x = self.activation(x) return self.project(x, ignore_time = ignore_time) class ResnetBlock(nn.Module): def __init__( self, dim, dim_out, *, cond_dim = None, time_cond_dim = None, groups = 8, linear_attn = False, use_gca = False, squeeze_excite = False, **attn_kwargs ): super().__init__() self.time_mlp = None if exists(time_cond_dim): self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_cond_dim, dim_out * 2) ) self.cross_attn = None if exists(cond_dim): attn_klass = CrossAttention if not linear_attn else LinearCrossAttention self.cross_attn = attn_klass( dim = dim_out, context_dim = cond_dim, **attn_kwargs ) self.block1 = Block(dim, dim_out, groups = groups) self.block2 = Block(dim_out, dim_out, groups = groups) self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1) self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity() def forward( self, x, time_emb = None, cond = None, ignore_time = False ): scale_shift = None if exists(self.time_mlp) and exists(time_emb): time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') scale_shift = time_emb.chunk(2, dim = 1) h = self.block1(x, ignore_time = ignore_time) if exists(self.cross_attn): assert exists(cond) h = rearrange(h, 'b c ... -> b ... c') h, ps = pack([h], 'b * c') h = self.cross_attn(h, context = cond) + h h, = unpack(h, ps, 'b * c') h = rearrange(h, 'b ... c -> b c ...') h = self.block2(h, scale_shift = scale_shift, ignore_time = ignore_time) h = h * self.gca(h) return h + self.res_conv(x) class CrossAttention(nn.Module): def __init__( self, dim, *, context_dim = None, dim_head = 64, heads = 8, norm_context = False, scale = 8 ): super().__init__() self.scale = scale self.heads = heads inner_dim = dim_head * heads context_dim = default(context_dim, dim) self.norm = LayerNorm(dim) self.norm_context = LayerNorm(context_dim) if norm_context else Identity() self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), LayerNorm(dim) ) def forward(self, x, context, mask = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) # qk rmsnorm q, k = map(l2norm, (q, k)) q = q * self.q_scale k = k * self.k_scale # similarities sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # masking max_neg_value = -torch.finfo(sim.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) attn = sim.softmax(dim = -1, dtype = torch.float32) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class LinearCrossAttention(CrossAttention): def forward(self, x, context, mask = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v)) # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) # masking max_neg_value = -torch.finfo(x.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b n -> b n 1') k = k.masked_fill(~mask, max_neg_value) v = v.masked_fill(~mask, 0.) # linear attention q = q.softmax(dim = -1) k = k.softmax(dim = -2) q = q * self.scale context = einsum('b n d, b n e -> b d e', k, v) out = einsum('b n d, b d e -> b n e', q, context) out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads) return self.to_out(out) class LinearAttention(nn.Module): def __init__( self, dim, dim_head = 32, heads = 8, dropout = 0.05, context_dim = None, **kwargs ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads inner_dim = dim_head * heads self.norm = ChanLayerNorm(dim) self.nonlin = nn.SiLU() self.to_q = nn.Sequential( nn.Dropout(dropout), Conv2d(dim, inner_dim, 1, bias = False), Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) ) self.to_k = nn.Sequential( nn.Dropout(dropout), Conv2d(dim, inner_dim, 1, bias = False), Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) ) self.to_v = nn.Sequential( nn.Dropout(dropout), Conv2d(dim, inner_dim, 1, bias = False), Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) ) self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None self.to_out = nn.Sequential( Conv2d(inner_dim, dim, 1, bias = False), ChanLayerNorm(dim) ) def forward(self, fmap, context = None): h, x, y = self.heads, *fmap.shape[-2:] fmap = self.norm(fmap) q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v)) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v)) if exists(context): assert exists(self.to_context) ck, cv = self.to_context(context).chunk(2, dim = -1) ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv)) k = torch.cat((k, ck), dim = -2) v = torch.cat((v, cv), dim = -2) q = q.softmax(dim = -1) k = k.softmax(dim = -2) q = q * self.scale context = einsum('b n d, b n e -> b d e', k, v) out = einsum('b n d, b d e -> b n e', q, context) out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) out = self.nonlin(out) return self.to_out(out) class GlobalContext(nn.Module): """ basically a superior form of squeeze-excitation that is attention-esque """ def __init__( self, *, dim_in, dim_out ): super().__init__() self.to_k = Conv2d(dim_in, 1, 1) hidden_dim = max(3, dim_out // 2) self.net = nn.Sequential( Conv2d(dim_in, hidden_dim, 1), nn.SiLU(), Conv2d(hidden_dim, dim_out, 1), nn.Sigmoid() ) def forward(self, x): context = self.to_k(x) x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context)) out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x) out = rearrange(out, '... -> ... 1 1') return self.net(out) def FeedForward(dim, mult = 2): hidden_dim = int(dim * mult) return nn.Sequential( LayerNorm(dim), nn.Linear(dim, hidden_dim, bias = False), nn.GELU(), LayerNorm(hidden_dim), nn.Linear(hidden_dim, dim, bias = False) ) class TimeTokenShift(nn.Module): def forward(self, x): if x.ndim != 5: return x x, x_shift = x.chunk(2, dim = 1) x_shift = F.pad(x_shift, (0, 0, 0, 0, 1, -1), value = 0.) return torch.cat((x, x_shift), dim = 1) def ChanFeedForward(dim, mult = 2, time_token_shift = True): # in paper, it seems for self attention layers they did feedforwards with twice channel width hidden_dim = int(dim * mult) return Sequential( ChanLayerNorm(dim), Conv2d(dim, hidden_dim, 1, bias = False), nn.GELU(), TimeTokenShift() if time_token_shift else None, ChanLayerNorm(hidden_dim), Conv2d(hidden_dim, dim, 1, bias = False) ) class TransformerBlock(nn.Module): def __init__( self, dim, *, depth = 1, heads = 8, dim_head = 32, ff_mult = 2, ff_time_token_shift = True, context_dim = None ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift) ])) def forward(self, x, context = None): for attn, ff in self.layers: x = rearrange(x, 'b c ... -> b ... c') x, ps = pack([x], 'b * c') x = attn(x, context = context) + x x, = unpack(x, ps, 'b * c') x = rearrange(x, 'b ... c -> b c ...') x = ff(x) + x return x class LinearAttentionTransformerBlock(nn.Module): def __init__( self, dim, *, depth = 1, heads = 8, dim_head = 32, ff_mult = 2, ff_time_token_shift = True, context_dim = None, **kwargs ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift) ])) def forward(self, x, context = None): for attn, ff in self.layers: x = attn(x, context = context) + x x = ff(x) + x return x class CrossEmbedLayer(nn.Module): def __init__( self, dim_in, kernel_sizes, dim_out = None, stride = 2 ): super().__init__() assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) dim_out = default(dim_out, dim_in) kernel_sizes = sorted(kernel_sizes) num_scales = len(kernel_sizes) # calculate the dimension at each scale dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] dim_scales = [*dim_scales, dim_out - sum(dim_scales)] self.convs = nn.ModuleList([]) for kernel, dim_scale in zip(kernel_sizes, dim_scales): self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2)) def forward(self, x): fmaps = tuple(map(lambda conv: conv(x), self.convs)) return torch.cat(fmaps, dim = 1) class UpsampleCombiner(nn.Module): def __init__( self, dim, *, enabled = False, dim_ins = tuple(), dim_outs = tuple() ): super().__init__() dim_outs = cast_tuple(dim_outs, len(dim_ins)) assert len(dim_ins) == len(dim_outs) self.enabled = enabled if not self.enabled: self.dim_out = dim return self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)]) self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0) def forward(self, x, fmaps = None): target_size = x.shape[-1] fmaps = default(fmaps, tuple()) if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0: return x fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps] outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)] return torch.cat((x, *outs), dim = 1) class DynamicPositionBias(nn.Module): def __init__( self, dim, *, heads, depth ): super().__init__() self.mlp = nn.ModuleList([]) self.mlp.append(nn.Sequential( nn.Linear(1, dim), LayerNorm(dim), nn.SiLU() )) for _ in range(max(depth - 1, 0)): self.mlp.append(nn.Sequential( nn.Linear(dim, dim), LayerNorm(dim), nn.SiLU() )) self.mlp.append(nn.Linear(dim, heads)) def forward(self, n, device, dtype): i = torch.arange(n, device = device) j = torch.arange(n, device = device) indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j') indices += (n - 1) pos = torch.arange(-n + 1, n, device = device, dtype = dtype) pos = rearrange(pos, '... -> ... 1') for layer in self.mlp: pos = layer(pos) bias = pos[indices] bias = rearrange(bias, 'i j h -> h i j') return bias class Unet3D(nn.Module): def __init__( self, *, dim, text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME), num_resnet_blocks = 1, cond_dim = None, num_image_tokens = 4, num_time_tokens = 2, learned_sinu_pos_emb_dim = 16, out_dim = None, dim_mults = (1, 2, 4, 8), temporal_strides = 1, cond_images_channels = 0, channels = 3, channels_out = None, attn_dim_head = 64, attn_heads = 8, ff_mult = 2., ff_time_token_shift = True, # this would do a token shift along time axis, at the hidden layer within feedforwards - from successful use in RWKV (Peng et al), and other token shift video transformer works lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ layer_attns = False, layer_attns_depth = 1, layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) time_rel_pos_bias_depth = 2, time_causal_attn = True, layer_cross_attns = True, use_linear_attn = False, use_linear_cross_attn = False, cond_on_text = True, max_text_len = 256, init_dim = None, resnet_groups = 8, init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed init_cross_embed = True, init_cross_embed_kernel_sizes = (3, 7, 15), cross_embed_downsample = False, cross_embed_downsample_kernel_sizes = (2, 4), attn_pool_text = True, attn_pool_num_latents = 32, dropout = 0., memory_efficient = False, init_conv_to_final_conv_residual = False, use_global_context_attn = True, scale_skip_connection = True, final_resnet_block = True, final_conv_kernel_size = 3, self_cond = False, combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully pixel_shuffle_upsample = True, # may address checkboard artifacts resize_mode = 'nearest' ): super().__init__() # guide researchers assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8' if dim < 128: print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/') # save locals to take care of some hyperparameters for cascading DDPM self._locals = locals() self._locals.pop('self', None) self._locals.pop('__class__', None) self.self_cond = self_cond # determine dimensions self.channels = channels self.channels_out = default(channels_out, channels) # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis # (2) in self conditioning, one appends the predict x0 (x_start) init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) init_dim = default(init_dim, dim) # optional image conditioning self.has_cond_image = cond_images_channels > 0 self.cond_images_channels = cond_images_channels init_channels += cond_images_channels # initial convolution self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # time conditioning cond_dim = default(cond_dim, dim) time_cond_dim = dim * 4 * (2 if lowres_cond else 1) # embedding time for log(snr) noise from continuous version sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 self.to_time_hiddens = nn.Sequential( sinu_pos_emb, nn.Linear(sinu_pos_emb_input_dim, time_cond_dim), nn.SiLU() ) self.to_time_cond = nn.Sequential( nn.Linear(time_cond_dim, time_cond_dim) ) # project to time tokens as well as time hiddens self.to_time_tokens = nn.Sequential( nn.Linear(time_cond_dim, cond_dim * num_time_tokens), Rearrange('b (r d) -> b r d', r = num_time_tokens) ) # low res aug noise conditioning self.lowres_cond = lowres_cond if lowres_cond: self.to_lowres_time_hiddens = nn.Sequential( LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim), nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim), nn.SiLU() ) self.to_lowres_time_cond = nn.Sequential( nn.Linear(time_cond_dim, time_cond_dim) ) self.to_lowres_time_tokens = nn.Sequential( nn.Linear(time_cond_dim, cond_dim * num_time_tokens), Rearrange('b (r d) -> b r d', r = num_time_tokens) ) # normalizations self.norm_cond = nn.LayerNorm(cond_dim) # text encoding conditioning (optional) self.text_to_cond = None if cond_on_text: assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True' self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) # finer control over whether to condition on text encodings self.cond_on_text = cond_on_text # attention pooling self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents) if attn_pool_text else None # for classifier free guidance self.max_text_len = max_text_len self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim)) # for non-attention based text conditioning at all points in the network where time is also conditioned self.to_text_non_attn_cond = None if cond_on_text: self.to_text_non_attn_cond = nn.Sequential( nn.LayerNorm(cond_dim), nn.Linear(cond_dim, time_cond_dim), nn.SiLU(), nn.Linear(time_cond_dim, time_cond_dim) ) # attention related params attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) num_layers = len(in_out) # temporal attention - attention across video frames temporal_peg_padding = (0, 0, 0, 0, 2, 0) if time_causal_attn else (0, 0, 0, 0, 1, 1) temporal_peg = lambda dim: Residual(nn.Sequential(Pad(temporal_peg_padding), nn.Conv3d(dim, dim, (3, 1, 1), groups = dim))) temporal_attn = lambda dim: RearrangeTimeCentric(Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn, 'init_zero': True, 'rel_pos_bias': True}))) # resnet block klass num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) resnet_groups = cast_tuple(resnet_groups, num_layers) resnet_klass = partial(ResnetBlock, **attn_kwargs) layer_attns = cast_tuple(layer_attns, num_layers) layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) # temporal downsample config temporal_strides = cast_tuple(temporal_strides, num_layers) self.total_temporal_divisor = functools.reduce(operator.mul, temporal_strides, 1) # downsample klass downsample_klass = Downsample if cross_embed_downsample: downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) # initial resnet block (for memory efficient unet) self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None self.init_temporal_peg = temporal_peg(init_dim) self.init_temporal_attn = temporal_attn(init_dim) # scale for resnet skip connections self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, temporal_strides] reversed_layer_params = list(map(reversed, layer_params)) # downsampling layers skip_connect_dims = [] # keep track of skip connection dimensions for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(in_out, *layer_params)): is_last = ind >= (num_resolutions - 1) layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity) current_dim = dim_in # whether to pre-downsample, from memory efficient unet pre_downsample = None if memory_efficient: pre_downsample = downsample_klass(dim_in, dim_out) current_dim = dim_out skip_connect_dims.append(current_dim) # whether to do post-downsample, for non-memory efficient unet post_downsample = None if not memory_efficient: post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(Conv2d(dim_in, dim_out, 3, padding = 1), Conv2d(dim_in, dim_out, 1)) self.downs.append(nn.ModuleList([ pre_downsample, resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs), temporal_peg(current_dim), temporal_attn(current_dim), TemporalDownsample(current_dim, stride = temporal_stride) if temporal_stride > 1 else None, post_downsample ])) # middle layers mid_dim = dims[-1] self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_attn = EinopsToAndFrom('b c f h w', 'b (f h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None self.mid_temporal_peg = temporal_peg(mid_dim) self.mid_temporal_attn = temporal_attn(mid_dim) self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) # upsample klass upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample # upsampling layers upsample_fmap_dims = [] for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(reversed(in_out), *reversed_layer_params)): is_last = ind == (len(in_out) - 1) layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity) skip_connect_dim = skip_connect_dims.pop() upsample_fmap_dims.append(dim_out) self.ups.append(nn.ModuleList([ resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs), temporal_peg(dim_out), temporal_attn(dim_out), TemporalPixelShuffleUpsample(dim_out, stride = temporal_stride) if temporal_stride > 1 else None, upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() ])) # whether to combine feature maps from all upsample blocks before final resnet block out self.upsample_combiner = UpsampleCombiner( dim = dim, enabled = combine_upsample_fmaps, dim_ins = upsample_fmap_dims, dim_outs = dim ) # whether to do a final residual from initial conv to the final resnet block out self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0) # final optional resnet block and convolution out self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None final_conv_dim_in = dim if final_resnet_block else final_conv_dim final_conv_dim_in += (channels if lowres_cond else 0) self.final_conv = Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2) zero_init_(self.final_conv) # resize mode self.resize_mode = resize_mode # if the current settings for the unet are not correct # for cascading DDPM, then reinit the unet with the right settings def cast_model_parameters( self, *, lowres_cond, text_embed_dim, channels, channels_out, cond_on_text ): if lowres_cond == self.lowres_cond and \ channels == self.channels and \ cond_on_text == self.cond_on_text and \ text_embed_dim == self._locals['text_embed_dim'] and \ channels_out == self.channels_out: return self updated_kwargs = dict( lowres_cond = lowres_cond, text_embed_dim = text_embed_dim, channels = channels, channels_out = channels_out, cond_on_text = cond_on_text ) return self.__class__(**{**self._locals, **updated_kwargs}) # methods for returning the full unet config as well as its parameter state def to_config_and_state_dict(self): return self._locals, self.state_dict() # class method for rehydrating the unet from its config and state dict @classmethod def from_config_and_state_dict(klass, config, state_dict): unet = klass(**config) unet.load_state_dict(state_dict) return unet # methods for persisting unet to disk def persist_to_file(self, path): path = Path(path) path.parents[0].mkdir(exist_ok = True, parents = True) config, state_dict = self.to_config_and_state_dict() pkg = dict(config = config, state_dict = state_dict) torch.save(pkg, str(path)) # class method for rehydrating the unet from file saved with `persist_to_file` @classmethod def hydrate_from_file(klass, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) assert 'config' in pkg and 'state_dict' in pkg config, state_dict = pkg['config'], pkg['state_dict'] return Unet.from_config_and_state_dict(config, state_dict) # forward with classifier free guidance def forward_with_cond_scale( self, *args, cond_scale = 1., **kwargs ): logits = self.forward(*args, **kwargs) if cond_scale == 1: return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( self, x, time, *, lowres_cond_img = None, lowres_noise_times = None, text_embeds = None, text_mask = None, cond_images = None, cond_video_frames = None, post_cond_video_frames = None, self_cond = None, cond_drop_prob = 0., ignore_time = False ): assert x.ndim == 5, 'input to 3d unet must have 5 dimensions (batch, channels, time, height, width)' batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype assert ignore_time or divisible_by(frames, self.total_temporal_divisor), f'number of input frames {frames} must be divisible by {self.total_temporal_divisor}' # add self conditioning if needed if self.self_cond: self_cond = default(self_cond, lambda: torch.zeros_like(x)) x = torch.cat((x, self_cond), dim = 1) # add low resolution conditioning, if present assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present' if exists(lowres_cond_img): x = torch.cat((x, lowres_cond_img), dim = 1) if exists(cond_video_frames): lowres_cond_img = torch.cat((cond_video_frames, lowres_cond_img), dim = 2) cond_video_frames = torch.cat((cond_video_frames, cond_video_frames), dim = 1) if exists(post_cond_video_frames): lowres_cond_img = torch.cat((lowres_cond_img, post_cond_video_frames), dim = 2) post_cond_video_frames = torch.cat((post_cond_video_frames, post_cond_video_frames), dim = 1) # conditioning on video frames as a prompt num_preceding_frames = 0 if exists(cond_video_frames): cond_video_frames_len = cond_video_frames.shape[2] assert divisible_by(cond_video_frames_len, self.total_temporal_divisor) cond_video_frames = resize_video_to(cond_video_frames, x.shape[-1]) x = torch.cat((cond_video_frames, x), dim = 2) num_preceding_frames = cond_video_frames_len # conditioning on video frames as a prompt num_succeeding_frames = 0 if exists(post_cond_video_frames): cond_video_frames_len = post_cond_video_frames.shape[2] assert divisible_by(cond_video_frames_len, self.total_temporal_divisor) post_cond_video_frames = resize_video_to(post_cond_video_frames, x.shape[-1]) x = torch.cat((post_cond_video_frames, x), dim = 2) num_succeeding_frames = cond_video_frames_len # condition on input image assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' if exists(cond_images): assert cond_images.ndim == 4, 'conditioning images must have 4 dimensions only, if you want to condition on frames of video, use `cond_video_frames` instead' assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' cond_images = repeat(cond_images, 'b c h w -> b c f h w', f = x.shape[2]) cond_images = resize_video_to(cond_images, x.shape[-1], mode = self.resize_mode) x = torch.cat((cond_images, x), dim = 1) # ignoring time in pseudo 3d resnet blocks conv_kwargs = dict( ignore_time = ignore_time ) # initial convolution x = self.init_conv(x) if not ignore_time: x = self.init_temporal_peg(x) x = self.init_temporal_attn(x) # init conv residual if self.init_conv_to_final_conv_residual: init_conv_residual = x.clone() # time conditioning time_hiddens = self.to_time_hiddens(time) # derive time tokens time_tokens = self.to_time_tokens(time_hiddens) t = self.to_time_cond(time_hiddens) # add lowres time conditioning to time hiddens # and add lowres time tokens along sequence dimension for attention if self.lowres_cond: lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times) lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens) lowres_t = self.to_lowres_time_cond(lowres_time_hiddens) t = t + lowres_t time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2) # text conditioning text_tokens = None if exists(text_embeds) and self.cond_on_text: # conditional dropout text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1') text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1') # calculate text embeds text_tokens = self.text_to_cond(text_embeds) text_tokens = text_tokens[:, :self.max_text_len] if exists(text_mask): text_mask = text_mask[:, :self.max_text_len] text_tokens_len = text_tokens.shape[1] remainder = self.max_text_len - text_tokens_len if remainder > 0: text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) if exists(text_mask): if remainder > 0: text_mask = F.pad(text_mask, (0, remainder), value = False) text_mask = rearrange(text_mask, 'b n -> b n 1') text_keep_mask_embed = text_mask & text_keep_mask_embed null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working text_tokens = torch.where( text_keep_mask_embed, text_tokens, null_text_embed ) if exists(self.attn_pool): text_tokens = self.attn_pool(text_tokens) # extra non-attention conditioning by projecting and then summing text embeddings to time # termed as text hiddens mean_pooled_text_tokens = text_tokens.mean(dim = -2) text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens) null_text_hidden = self.null_text_hidden.to(t.dtype) text_hiddens = torch.where( text_keep_mask_hidden, text_hiddens, null_text_hidden ) t = t + text_hiddens # main conditioning tokens (c) c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2) # normalize conditioning tokens c = self.norm_cond(c) # initial resnet block (for memory efficient unet) if exists(self.init_resnet_block): x = self.init_resnet_block(x, t, **conv_kwargs) # go through the layers of the unet, down and up hiddens = [] for pre_downsample, init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, temporal_downsample, post_downsample in self.downs: if exists(pre_downsample): x = pre_downsample(x) x = init_block(x, t, c, **conv_kwargs) for resnet_block in resnet_blocks: x = resnet_block(x, t, **conv_kwargs) hiddens.append(x) x = attn_block(x, c) if not ignore_time: x = temporal_peg(x) x = temporal_attn(x) hiddens.append(x) if exists(temporal_downsample) and not ignore_time: x = temporal_downsample(x) if exists(post_downsample): x = post_downsample(x) x = self.mid_block1(x, t, c, **conv_kwargs) if exists(self.mid_attn): x = self.mid_attn(x) if not ignore_time: x = self.mid_temporal_peg(x) x = self.mid_temporal_attn(x) x = self.mid_block2(x, t, c, **conv_kwargs) add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) up_hiddens = [] for init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, temporal_upsample, upsample in self.ups: if exists(temporal_upsample) and not ignore_time: x = temporal_upsample(x) x = add_skip_connection(x) x = init_block(x, t, c, **conv_kwargs) for resnet_block in resnet_blocks: x = add_skip_connection(x) x = resnet_block(x, t, **conv_kwargs) x = attn_block(x, c) if not ignore_time: x = temporal_peg(x) x = temporal_attn(x) up_hiddens.append(x.contiguous()) x = upsample(x) # whether to combine all feature maps from upsample blocks x = self.upsample_combiner(x, up_hiddens) # final top-most residual if needed if self.init_conv_to_final_conv_residual: x = torch.cat((x, init_conv_residual), dim = 1) if exists(self.final_res_block): x = self.final_res_block(x, t, **conv_kwargs) if exists(lowres_cond_img): x = torch.cat((x, lowres_cond_img), dim = 1) out = self.final_conv(x) if num_preceding_frames > 0: out = out[:, :, num_preceding_frames:] if num_succeeding_frames > 0: out = out[:, :, :-num_succeeding_frames] return out