|
""" |
|
MIT License |
|
|
|
Copyright (c) 2021 Wilson Yan |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
of this software and associated documentation files (the "Software"), to deal |
|
in the Software without restriction, including without limitation the rights |
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
copies of the Software, and to permit persons to whom the Software is |
|
furnished to do so, subject to the following conditions: |
|
|
|
The above copyright notice and this permission notice shall be included in all |
|
copies or substantial portions of the Software. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
SOFTWARE. |
|
|
|
|
|
This file is copied from https://github.com/wilson1yan/VideoGPT/blob/master/videogpt/vqvae.py |
|
We adapted it to Hugging Face AutoModel for easier model loading. |
|
""" |
|
|
|
|
|
import os |
|
import math |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.distributed as dist |
|
|
|
from .attention import MultiHeadAttention |
|
from ._utils import shift_dim |
|
from transformers import PreTrainedModel |
|
from .configuration_vqvae import VQVAEConfig |
|
|
|
|
|
class VQVAE(PreTrainedModel): |
|
config_class = VQVAEConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.embedding_dim = config.embedding_dim |
|
self.n_codes = config.n_codes |
|
|
|
self.encoder = Encoder(config.n_hiddens, config.n_res_layers, config.downsample) |
|
self.decoder = Decoder(config.n_hiddens, config.n_res_layers, config.downsample) |
|
|
|
self.pre_vq_conv = SamePadConv3d(config.n_hiddens, config.embedding_dim, 1) |
|
self.post_vq_conv = SamePadConv3d(config.embedding_dim, config.n_hiddens, 1) |
|
|
|
self.codebook = Codebook(config.n_codes, config.embedding_dim) |
|
|
|
@property |
|
def latent_shape(self): |
|
input_shape = (self.args.sequence_length, self.args.resolution, |
|
self.args.resolution) |
|
return tuple([s // d for s, d in zip(input_shape, |
|
self.args.downsample)]) |
|
|
|
def encode(self, x, include_embeddings=False): |
|
h = self.pre_vq_conv(self.encoder(x)) |
|
vq_output = self.codebook(h) |
|
if include_embeddings: |
|
return vq_output['encodings'], vq_output['embeddings'] |
|
else: |
|
return vq_output['encodings'] |
|
|
|
def decode(self, encodings): |
|
h = F.embedding(encodings, self.codebook.embeddings) |
|
h = self.post_vq_conv(shift_dim(h, -1, 1)) |
|
return self.decoder(h) |
|
|
|
def decode_from_embeddings(self, embeddings): |
|
|
|
encodings = self.codebook.search_indices(embeddings) |
|
return self.decode(encodings) |
|
|
|
def forward(self, x): |
|
z = self.pre_vq_conv(self.encoder(x)) |
|
vq_output = self.codebook(z) |
|
x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) |
|
recon_loss = F.mse_loss(x_recon, x) / 0.06 |
|
|
|
return recon_loss, x_recon, vq_output |
|
|
|
|
|
class AxialBlock(nn.Module): |
|
def __init__(self, n_hiddens, n_head): |
|
super().__init__() |
|
kwargs = dict(shape=(0,) * 3, dim_q=n_hiddens, |
|
dim_kv=n_hiddens, n_head=n_head, |
|
n_layer=1, causal=False, attn_type='axial') |
|
self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2), |
|
**kwargs) |
|
self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3), |
|
**kwargs) |
|
self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4), |
|
**kwargs) |
|
|
|
def forward(self, x): |
|
x = shift_dim(x, 1, -1) |
|
x = self.attn_w(x, x, x) + self.attn_h(x, x, x) + self.attn_t(x, x, x) |
|
x = shift_dim(x, -1, 1) |
|
return x |
|
|
|
|
|
class AttentionResidualBlock(nn.Module): |
|
def __init__(self, n_hiddens): |
|
super().__init__() |
|
self.block = nn.Sequential( |
|
nn.BatchNorm3d(n_hiddens), |
|
nn.ReLU(), |
|
SamePadConv3d(n_hiddens, n_hiddens // 2, 3, bias=False), |
|
nn.BatchNorm3d(n_hiddens // 2), |
|
nn.ReLU(), |
|
SamePadConv3d(n_hiddens // 2, n_hiddens, 1, bias=False), |
|
nn.BatchNorm3d(n_hiddens), |
|
nn.ReLU(), |
|
AxialBlock(n_hiddens, 2) |
|
) |
|
|
|
def forward(self, x): |
|
return x + self.block(x) |
|
|
|
class Codebook(nn.Module): |
|
def __init__(self, n_codes, embedding_dim): |
|
super().__init__() |
|
self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) |
|
self.register_buffer('N', torch.zeros(n_codes)) |
|
self.register_buffer('z_avg', self.embeddings.data.clone()) |
|
|
|
self.n_codes = n_codes |
|
self.embedding_dim = embedding_dim |
|
self._need_init = True |
|
|
|
def _tile(self, x): |
|
d, ew = x.shape |
|
if d < self.n_codes: |
|
n_repeats = (self.n_codes + d - 1) // d |
|
std = 0.01 / np.sqrt(ew) |
|
x = x.repeat(n_repeats, 1) |
|
x = x + torch.randn_like(x) * std |
|
return x |
|
|
|
def _init_embeddings(self, z): |
|
|
|
self._need_init = False |
|
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) |
|
y = self._tile(flat_inputs) |
|
|
|
d = y.shape[0] |
|
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] |
|
if dist.is_initialized(): |
|
dist.broadcast(_k_rand, 0) |
|
self.embeddings.data.copy_(_k_rand) |
|
self.z_avg.data.copy_(_k_rand) |
|
self.N.data.copy_(torch.ones(self.n_codes)) |
|
|
|
def search_indices(self, z): |
|
|
|
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) |
|
distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ |
|
- 2 * flat_inputs @ self.embeddings.t() \ |
|
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) |
|
|
|
encoding_indices = torch.argmin(distances, dim=1) |
|
encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) |
|
return encoding_indices |
|
|
|
|
|
def forward(self, z): |
|
|
|
if self._need_init and self.training: |
|
self._init_embeddings(z) |
|
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) |
|
distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ |
|
- 2 * flat_inputs @ self.embeddings.t() \ |
|
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) |
|
|
|
encoding_indices = torch.argmin(distances, dim=1) |
|
encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) |
|
encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:]) |
|
|
|
embeddings = F.embedding(encoding_indices, self.embeddings) |
|
embeddings = shift_dim(embeddings, -1, 1) |
|
|
|
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) |
|
|
|
|
|
if self.training: |
|
n_total = encode_onehot.sum(dim=0) |
|
encode_sum = flat_inputs.t() @ encode_onehot |
|
if dist.is_initialized(): |
|
dist.all_reduce(n_total) |
|
dist.all_reduce(encode_sum) |
|
|
|
self.N.data.mul_(0.99).add_(n_total, alpha=0.01) |
|
self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) |
|
|
|
n = self.N.sum() |
|
weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n |
|
encode_normalized = self.z_avg / weights.unsqueeze(1) |
|
self.embeddings.data.copy_(encode_normalized) |
|
|
|
y = self._tile(flat_inputs) |
|
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] |
|
if dist.is_initialized(): |
|
dist.broadcast(_k_rand, 0) |
|
|
|
usage = (self.N.view(self.n_codes, 1) >= 1).float() |
|
self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) |
|
|
|
embeddings_st = (embeddings - z).detach() + z |
|
|
|
avg_probs = torch.mean(encode_onehot, dim=0) |
|
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) |
|
|
|
return dict(embeddings=embeddings_st, encodings=encoding_indices, |
|
commitment_loss=commitment_loss, perplexity=perplexity) |
|
|
|
def dictionary_lookup(self, encodings): |
|
embeddings = F.embedding(encodings, self.embeddings) |
|
return embeddings |
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, n_hiddens, n_res_layers, downsample): |
|
super().__init__() |
|
n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) |
|
self.convs = nn.ModuleList() |
|
max_ds = n_times_downsample.max() |
|
for i in range(max_ds): |
|
in_channels = 3 if i == 0 else n_hiddens |
|
stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) |
|
conv = SamePadConv3d(in_channels, n_hiddens, 4, stride=stride) |
|
self.convs.append(conv) |
|
n_times_downsample -= 1 |
|
self.conv_last = SamePadConv3d(in_channels, n_hiddens, kernel_size=3) |
|
|
|
self.res_stack = nn.Sequential( |
|
*[AttentionResidualBlock(n_hiddens) |
|
for _ in range(n_res_layers)], |
|
nn.BatchNorm3d(n_hiddens), |
|
nn.ReLU() |
|
) |
|
|
|
def forward(self, x): |
|
h = x |
|
for conv in self.convs: |
|
h = F.relu(conv(h)) |
|
h = self.conv_last(h) |
|
h = self.res_stack(h) |
|
return h |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, n_hiddens, n_res_layers, upsample): |
|
super().__init__() |
|
self.res_stack = nn.Sequential( |
|
*[AttentionResidualBlock(n_hiddens) |
|
for _ in range(n_res_layers)], |
|
nn.BatchNorm3d(n_hiddens), |
|
nn.ReLU() |
|
) |
|
|
|
n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) |
|
max_us = n_times_upsample.max() |
|
self.convts = nn.ModuleList() |
|
for i in range(max_us): |
|
out_channels = 3 if i == max_us - 1 else n_hiddens |
|
us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) |
|
convt = SamePadConvTranspose3d(n_hiddens, out_channels, 4, |
|
stride=us) |
|
self.convts.append(convt) |
|
n_times_upsample -= 1 |
|
|
|
def forward(self, x): |
|
h = self.res_stack(x) |
|
for i, convt in enumerate(self.convts): |
|
h = convt(h) |
|
if i < len(self.convts) - 1: |
|
h = F.relu(h) |
|
return h |
|
|
|
|
|
|
|
class SamePadConv3d(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True): |
|
super().__init__() |
|
if isinstance(kernel_size, int): |
|
kernel_size = (kernel_size,) * 3 |
|
if isinstance(stride, int): |
|
stride = (stride,) * 3 |
|
|
|
|
|
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) |
|
pad_input = [] |
|
for p in total_pad[::-1]: |
|
pad_input.append((p // 2 + p % 2, p // 2)) |
|
pad_input = sum(pad_input, tuple()) |
|
self.pad_input = pad_input |
|
|
|
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, |
|
stride=stride, padding=0, bias=bias) |
|
|
|
def forward(self, x): |
|
return self.conv(F.pad(x, self.pad_input)) |
|
|
|
|
|
class SamePadConvTranspose3d(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True): |
|
super().__init__() |
|
if isinstance(kernel_size, int): |
|
kernel_size = (kernel_size,) * 3 |
|
if isinstance(stride, int): |
|
stride = (stride,) * 3 |
|
|
|
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) |
|
pad_input = [] |
|
for p in total_pad[::-1]: |
|
pad_input.append((p // 2 + p % 2, p // 2)) |
|
pad_input = sum(pad_input, tuple()) |
|
self.pad_input = pad_input |
|
|
|
self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, |
|
stride=stride, bias=bias, |
|
padding=tuple([k - 1 for k in kernel_size])) |
|
|
|
def forward(self, x): |
|
return self.convt(F.pad(x, self.pad_input)) |
|
|