vshirasuna's picture
Move code to 3dgrid_vqgan folder
a4c759f
""" Adapted from https://github.com/SongweiGe/TATS"""
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import numpy as np
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from vq_gan_3d.utils import shift_dim
import subprocess
import re
command = 'nvidia-smi'
class Codebook(nn.Module):
def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0):
super().__init__()
self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim))
self.register_buffer('N', torch.zeros(n_codes))
self.register_buffer('z_avg', self.embeddings.data.clone())
self.n_codes = n_codes
self.embedding_dim = embedding_dim
self._need_init = True
self.no_random_restart = no_random_restart
self.restart_thres = restart_thres
def _tile(self, x):
d, ew = x.shape
if d < self.n_codes:
n_repeats = (self.n_codes + d - 1) // d
std = 0.01 / np.sqrt(ew)
x = x.repeat(n_repeats, 1)
x = x + torch.randn_like(x) * std
return x
def _init_embeddings(self, z):
# z: [b, c, t, h, w]
self._need_init = False
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
y = self._tile(flat_inputs)
d = y.shape[0]
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
if dist.is_initialized():
dist.broadcast(_k_rand, 0)
self.embeddings.data.copy_(_k_rand)
self.z_avg.data.copy_(_k_rand)
self.N.data.copy_(torch.ones(self.n_codes))
def forward(self, z, gpu_id):
# z: [b, c, t, h, w]
if self._need_init and self.training:
self._init_embeddings(z)
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c]
distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \
- 2 * flat_inputs @ self.embeddings.t() \
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c]
encoding_indices = torch.argmin(distances, dim=1)
# print(subprocess.check_output(command).decode("utf-8"))
# memory optimization
del distances
# encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) # [bthw, ncode]
encode_onehot = torch.empty(0, self.n_codes).to(gpu_id)
for ts in torch.tensor_split(encoding_indices, 4):
encode_onehot = torch.cat((encode_onehot, F.one_hot(ts, self.n_codes).type_as(flat_inputs)), 0)
torch.cuda.empty_cache()
encoding_indices = encoding_indices.view(
z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode]
embeddings = F.embedding(
encoding_indices, self.embeddings) # [b, t, h, w, c]
embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w]
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
# EMA codebook update
if self.training:
n_total = encode_onehot.sum(dim=0)
encode_sum = flat_inputs.t() @ encode_onehot
if dist.is_initialized():
dist.all_reduce(n_total)
dist.all_reduce(encode_sum)
self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
n = self.N.sum()
weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
encode_normalized = self.z_avg / weights.unsqueeze(1)
self.embeddings.data.copy_(encode_normalized)
y = self._tile(flat_inputs)
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
if dist.is_initialized():
dist.broadcast(_k_rand, 0)
if not self.no_random_restart:
usage = (self.N.view(self.n_codes, 1)
>= self.restart_thres).float()
self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
del encode_sum
del n_total
embeddings_st = (embeddings - z).detach() + z
#embeddings_st_exp = torch.exp(embeddings_st)
avg_probs = torch.mean(encode_onehot, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs *
torch.log(avg_probs + 1e-10)))
# memory optimization
del encode_onehot
torch.cuda.empty_cache()
return dict(embeddings=embeddings_st, encodings=encoding_indices,
commitment_loss=commitment_loss, perplexity=perplexity)
def dictionary_lookup(self, encodings):
embeddings = F.embedding(encodings, self.embeddings)
return embeddings