vqvae / modeling_vqvae.py
frankleeeee's picture
Upload VQVAE
8ff4a33 verified
raw
history blame
11.8 kB
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import gdown
from .attention import MultiHeadAttention
from ._utils import shift_dim
from transformers import PreTrainedModel
from typing import Tuple
from .configuration_vqvae import VQVAEConfig
_VQVAE = {
'bair_stride4x2x2': '1iIAYJ2Qqrx5Q94s5eIXQYJgAydzvT_8L', # trained on 16 frames of 64 x 64 images
'ucf101_stride4x4x4': '1uuB_8WzHP_bbBmfuaIV7PK_Itl3DyHY5', # trained on 16 frames of 128 x 128 images
'kinetics_stride4x4x4': '1DOvOZnFAIQmux6hG7pN_HkyJZy3lXbCB', # trained on 16 frames of 128 x 128 images
'kinetics_stride2x4x4': '1jvtjjtrtE4cy6pl7DK_zWFEPY3RZt2pB' # trained on 16 frames of 128 x 128 images
}
def download(id, fname, root=None):
"""
Download the VQVAE weights from Google Drive.
Args:
id (str): the ID of the file to download
fname (str): the name of the file to save
root (str): the directory to save the file to
"""
if root is None:
root = os.path.expanduser('~/.cache/sora')
os.makedirs(root, exist_ok=True)
destination = os.path.join(root, fname)
if os.path.exists(destination):
return destination
gdown.download(id=id, output=destination, quiet=False)
return destination
class VQVAE(PreTrainedModel):
config_class = VQVAEConfig
def __init__(self, config):
super().__init__(config)
self.embedding_dim = config.embedding_dim
self.n_codes = config.n_codes
self.encoder = Encoder(config.n_hiddens, config.n_res_layers, config.downsample)
self.decoder = Decoder(config.n_hiddens, config.n_res_layers, config.downsample)
self.pre_vq_conv = SamePadConv3d(config.n_hiddens, config.embedding_dim, 1)
self.post_vq_conv = SamePadConv3d(config.embedding_dim, config.n_hiddens, 1)
self.codebook = Codebook(config.n_codes, config.embedding_dim)
@property
def latent_shape(self):
input_shape = (self.args.sequence_length, self.args.resolution,
self.args.resolution)
return tuple([s // d for s, d in zip(input_shape,
self.args.downsample)])
def encode(self, x, include_embeddings=False):
h = self.pre_vq_conv(self.encoder(x))
vq_output = self.codebook(h)
if include_embeddings:
return vq_output['encodings'], vq_output['embeddings']
else:
return vq_output['encodings']
def decode(self, encodings):
h = F.embedding(encodings, self.codebook.embeddings)
h = self.post_vq_conv(shift_dim(h, -1, 1))
return self.decoder(h)
def forward(self, x):
z = self.pre_vq_conv(self.encoder(x))
vq_output = self.codebook(z)
x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings']))
recon_loss = F.mse_loss(x_recon, x) / 0.06
return recon_loss, x_recon, vq_output
class AxialBlock(nn.Module):
def __init__(self, n_hiddens, n_head):
super().__init__()
kwargs = dict(shape=(0,) * 3, dim_q=n_hiddens,
dim_kv=n_hiddens, n_head=n_head,
n_layer=1, causal=False, attn_type='axial')
self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2),
**kwargs)
self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3),
**kwargs)
self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4),
**kwargs)
def forward(self, x):
x = shift_dim(x, 1, -1)
x = self.attn_w(x, x, x) + self.attn_h(x, x, x) + self.attn_t(x, x, x)
x = shift_dim(x, -1, 1)
return x
class AttentionResidualBlock(nn.Module):
def __init__(self, n_hiddens):
super().__init__()
self.block = nn.Sequential(
nn.BatchNorm3d(n_hiddens),
nn.ReLU(),
SamePadConv3d(n_hiddens, n_hiddens // 2, 3, bias=False),
nn.BatchNorm3d(n_hiddens // 2),
nn.ReLU(),
SamePadConv3d(n_hiddens // 2, n_hiddens, 1, bias=False),
nn.BatchNorm3d(n_hiddens),
nn.ReLU(),
AxialBlock(n_hiddens, 2)
)
def forward(self, x):
return x + self.block(x)
class Codebook(nn.Module):
def __init__(self, n_codes, embedding_dim):
super().__init__()
self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim))
self.register_buffer('N', torch.zeros(n_codes))
self.register_buffer('z_avg', self.embeddings.data.clone())
self.n_codes = n_codes
self.embedding_dim = embedding_dim
self._need_init = True
def _tile(self, x):
d, ew = x.shape
if d < self.n_codes:
n_repeats = (self.n_codes + d - 1) // d
std = 0.01 / np.sqrt(ew)
x = x.repeat(n_repeats, 1)
x = x + torch.randn_like(x) * std
return x
def _init_embeddings(self, z):
# z: [b, c, t, h, w]
self._need_init = False
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
y = self._tile(flat_inputs)
d = y.shape[0]
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
if dist.is_initialized():
dist.broadcast(_k_rand, 0)
self.embeddings.data.copy_(_k_rand)
self.z_avg.data.copy_(_k_rand)
self.N.data.copy_(torch.ones(self.n_codes))
def forward(self, z):
# z: [b, c, t, h, w]
if self._need_init and self.training:
self._init_embeddings(z)
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \
- 2 * flat_inputs @ self.embeddings.t() \
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
encoding_indices = torch.argmin(distances, dim=1)
encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)
encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
embeddings = F.embedding(encoding_indices, self.embeddings)
embeddings = shift_dim(embeddings, -1, 1)
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
# EMA codebook update
if self.training:
n_total = encode_onehot.sum(dim=0)
encode_sum = flat_inputs.t() @ encode_onehot
if dist.is_initialized():
dist.all_reduce(n_total)
dist.all_reduce(encode_sum)
self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
n = self.N.sum()
weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
encode_normalized = self.z_avg / weights.unsqueeze(1)
self.embeddings.data.copy_(encode_normalized)
y = self._tile(flat_inputs)
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
if dist.is_initialized():
dist.broadcast(_k_rand, 0)
usage = (self.N.view(self.n_codes, 1) >= 1).float()
self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
embeddings_st = (embeddings - z).detach() + z
avg_probs = torch.mean(encode_onehot, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
return dict(embeddings=embeddings_st, encodings=encoding_indices,
commitment_loss=commitment_loss, perplexity=perplexity)
def dictionary_lookup(self, encodings):
embeddings = F.embedding(encodings, self.embeddings)
return embeddings
class Encoder(nn.Module):
def __init__(self, n_hiddens, n_res_layers, downsample):
super().__init__()
n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
self.convs = nn.ModuleList()
max_ds = n_times_downsample.max()
for i in range(max_ds):
in_channels = 3 if i == 0 else n_hiddens
stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
conv = SamePadConv3d(in_channels, n_hiddens, 4, stride=stride)
self.convs.append(conv)
n_times_downsample -= 1
self.conv_last = SamePadConv3d(in_channels, n_hiddens, kernel_size=3)
self.res_stack = nn.Sequential(
*[AttentionResidualBlock(n_hiddens)
for _ in range(n_res_layers)],
nn.BatchNorm3d(n_hiddens),
nn.ReLU()
)
def forward(self, x):
h = x
for conv in self.convs:
h = F.relu(conv(h))
h = self.conv_last(h)
h = self.res_stack(h)
return h
class Decoder(nn.Module):
def __init__(self, n_hiddens, n_res_layers, upsample):
super().__init__()
self.res_stack = nn.Sequential(
*[AttentionResidualBlock(n_hiddens)
for _ in range(n_res_layers)],
nn.BatchNorm3d(n_hiddens),
nn.ReLU()
)
n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
max_us = n_times_upsample.max()
self.convts = nn.ModuleList()
for i in range(max_us):
out_channels = 3 if i == max_us - 1 else n_hiddens
us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
convt = SamePadConvTranspose3d(n_hiddens, out_channels, 4,
stride=us)
self.convts.append(convt)
n_times_upsample -= 1
def forward(self, x):
h = self.res_stack(x)
for i, convt in enumerate(self.convts):
h = convt(h)
if i < len(self.convts) - 1:
h = F.relu(h)
return h
# Does not support dilation
class SamePadConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
# assumes that the input shape is divisible by stride
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
pad_input = []
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
pad_input.append((p // 2 + p % 2, p // 2))
pad_input = sum(pad_input, tuple())
self.pad_input = pad_input
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
stride=stride, padding=0, bias=bias)
def forward(self, x):
return self.conv(F.pad(x, self.pad_input))
class SamePadConvTranspose3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
pad_input = []
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
pad_input.append((p // 2 + p % 2, p // 2))
pad_input = sum(pad_input, tuple())
self.pad_input = pad_input
self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
stride=stride, bias=bias,
padding=tuple([k - 1 for k in kernel_size]))
def forward(self, x):
return self.convt(F.pad(x, self.pad_input))