import os import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist import gdown from .attention import MultiHeadAttention from ._utils import shift_dim from transformers import PreTrainedModel from typing import Tuple from .configuration_vqvae import VQVAEConfig _VQVAE = { 'bair_stride4x2x2': '1iIAYJ2Qqrx5Q94s5eIXQYJgAydzvT_8L', # trained on 16 frames of 64 x 64 images 'ucf101_stride4x4x4': '1uuB_8WzHP_bbBmfuaIV7PK_Itl3DyHY5', # trained on 16 frames of 128 x 128 images 'kinetics_stride4x4x4': '1DOvOZnFAIQmux6hG7pN_HkyJZy3lXbCB', # trained on 16 frames of 128 x 128 images 'kinetics_stride2x4x4': '1jvtjjtrtE4cy6pl7DK_zWFEPY3RZt2pB' # trained on 16 frames of 128 x 128 images } def download(id, fname, root=None): """ Download the VQVAE weights from Google Drive. Args: id (str): the ID of the file to download fname (str): the name of the file to save root (str): the directory to save the file to """ if root is None: root = os.path.expanduser('~/.cache/sora') os.makedirs(root, exist_ok=True) destination = os.path.join(root, fname) if os.path.exists(destination): return destination gdown.download(id=id, output=destination, quiet=False) return destination class VQVAE(PreTrainedModel): config_class = VQVAEConfig def __init__(self, config): super().__init__(config) self.embedding_dim = config.embedding_dim self.n_codes = config.n_codes self.encoder = Encoder(config.n_hiddens, config.n_res_layers, config.downsample) self.decoder = Decoder(config.n_hiddens, config.n_res_layers, config.downsample) self.pre_vq_conv = SamePadConv3d(config.n_hiddens, config.embedding_dim, 1) self.post_vq_conv = SamePadConv3d(config.embedding_dim, config.n_hiddens, 1) self.codebook = Codebook(config.n_codes, config.embedding_dim) @property def latent_shape(self): input_shape = (self.args.sequence_length, self.args.resolution, self.args.resolution) return tuple([s // d for s, d in zip(input_shape, self.args.downsample)]) def encode(self, x, include_embeddings=False): h = self.pre_vq_conv(self.encoder(x)) vq_output = self.codebook(h) if include_embeddings: return vq_output['encodings'], vq_output['embeddings'] else: return vq_output['encodings'] def decode(self, encodings): h = F.embedding(encodings, self.codebook.embeddings) h = self.post_vq_conv(shift_dim(h, -1, 1)) return self.decoder(h) def forward(self, x): z = self.pre_vq_conv(self.encoder(x)) vq_output = self.codebook(z) x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) recon_loss = F.mse_loss(x_recon, x) / 0.06 return recon_loss, x_recon, vq_output class AxialBlock(nn.Module): def __init__(self, n_hiddens, n_head): super().__init__() kwargs = dict(shape=(0,) * 3, dim_q=n_hiddens, dim_kv=n_hiddens, n_head=n_head, n_layer=1, causal=False, attn_type='axial') self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2), **kwargs) self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3), **kwargs) self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4), **kwargs) def forward(self, x): x = shift_dim(x, 1, -1) x = self.attn_w(x, x, x) + self.attn_h(x, x, x) + self.attn_t(x, x, x) x = shift_dim(x, -1, 1) return x class AttentionResidualBlock(nn.Module): def __init__(self, n_hiddens): super().__init__() self.block = nn.Sequential( nn.BatchNorm3d(n_hiddens), nn.ReLU(), SamePadConv3d(n_hiddens, n_hiddens // 2, 3, bias=False), nn.BatchNorm3d(n_hiddens // 2), nn.ReLU(), SamePadConv3d(n_hiddens // 2, n_hiddens, 1, bias=False), nn.BatchNorm3d(n_hiddens), nn.ReLU(), AxialBlock(n_hiddens, 2) ) def forward(self, x): return x + self.block(x) class Codebook(nn.Module): def __init__(self, n_codes, embedding_dim): super().__init__() self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) self.register_buffer('N', torch.zeros(n_codes)) self.register_buffer('z_avg', self.embeddings.data.clone()) self.n_codes = n_codes self.embedding_dim = embedding_dim self._need_init = True def _tile(self, x): d, ew = x.shape if d < self.n_codes: n_repeats = (self.n_codes + d - 1) // d std = 0.01 / np.sqrt(ew) x = x.repeat(n_repeats, 1) x = x + torch.randn_like(x) * std return x def _init_embeddings(self, z): # z: [b, c, t, h, w] self._need_init = False flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) y = self._tile(flat_inputs) d = y.shape[0] _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] if dist.is_initialized(): dist.broadcast(_k_rand, 0) self.embeddings.data.copy_(_k_rand) self.z_avg.data.copy_(_k_rand) self.N.data.copy_(torch.ones(self.n_codes)) def forward(self, z): # z: [b, c, t, h, w] if self._need_init and self.training: self._init_embeddings(z) flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ - 2 * flat_inputs @ self.embeddings.t() \ + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) encoding_indices = torch.argmin(distances, dim=1) encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) embeddings = F.embedding(encoding_indices, self.embeddings) embeddings = shift_dim(embeddings, -1, 1) commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) # EMA codebook update if self.training: n_total = encode_onehot.sum(dim=0) encode_sum = flat_inputs.t() @ encode_onehot if dist.is_initialized(): dist.all_reduce(n_total) dist.all_reduce(encode_sum) self.N.data.mul_(0.99).add_(n_total, alpha=0.01) self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) n = self.N.sum() weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n encode_normalized = self.z_avg / weights.unsqueeze(1) self.embeddings.data.copy_(encode_normalized) y = self._tile(flat_inputs) _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] if dist.is_initialized(): dist.broadcast(_k_rand, 0) usage = (self.N.view(self.n_codes, 1) >= 1).float() self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) embeddings_st = (embeddings - z).detach() + z avg_probs = torch.mean(encode_onehot, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) return dict(embeddings=embeddings_st, encodings=encoding_indices, commitment_loss=commitment_loss, perplexity=perplexity) def dictionary_lookup(self, encodings): embeddings = F.embedding(encodings, self.embeddings) return embeddings class Encoder(nn.Module): def __init__(self, n_hiddens, n_res_layers, downsample): super().__init__() n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) self.convs = nn.ModuleList() max_ds = n_times_downsample.max() for i in range(max_ds): in_channels = 3 if i == 0 else n_hiddens stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) conv = SamePadConv3d(in_channels, n_hiddens, 4, stride=stride) self.convs.append(conv) n_times_downsample -= 1 self.conv_last = SamePadConv3d(in_channels, n_hiddens, kernel_size=3) self.res_stack = nn.Sequential( *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], nn.BatchNorm3d(n_hiddens), nn.ReLU() ) def forward(self, x): h = x for conv in self.convs: h = F.relu(conv(h)) h = self.conv_last(h) h = self.res_stack(h) return h class Decoder(nn.Module): def __init__(self, n_hiddens, n_res_layers, upsample): super().__init__() self.res_stack = nn.Sequential( *[AttentionResidualBlock(n_hiddens) for _ in range(n_res_layers)], nn.BatchNorm3d(n_hiddens), nn.ReLU() ) n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) max_us = n_times_upsample.max() self.convts = nn.ModuleList() for i in range(max_us): out_channels = 3 if i == max_us - 1 else n_hiddens us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) convt = SamePadConvTranspose3d(n_hiddens, out_channels, 4, stride=us) self.convts.append(convt) n_times_upsample -= 1 def forward(self, x): h = self.res_stack(x) for i, convt in enumerate(self.convts): h = convt(h) if i < len(self.convts) - 1: h = F.relu(h) return h # Does not support dilation class SamePadConv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size,) * 3 if isinstance(stride, int): stride = (stride,) * 3 # assumes that the input shape is divisible by stride total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) pad_input = [] for p in total_pad[::-1]: # reverse since F.pad starts from last dim pad_input.append((p // 2 + p % 2, p // 2)) pad_input = sum(pad_input, tuple()) self.pad_input = pad_input self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias) def forward(self, x): return self.conv(F.pad(x, self.pad_input)) class SamePadConvTranspose3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size,) * 3 if isinstance(stride, int): stride = (stride,) * 3 total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) pad_input = [] for p in total_pad[::-1]: # reverse since F.pad starts from last dim pad_input.append((p // 2 + p % 2, p // 2)) pad_input = sum(pad_input, tuple()) self.pad_input = pad_input self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, bias=bias, padding=tuple([k - 1 for k in kernel_size])) def forward(self, x): return self.convt(F.pad(x, self.pad_input))