File size: 4,760 Bytes
9123ba9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
""" Adapted from https://github.com/SongweiGe/TATS"""
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import numpy as np
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from vq_gan_3d.utils import shift_dim
import subprocess
import re
command = 'nvidia-smi'
class Codebook(nn.Module):
def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0):
super().__init__()
self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim))
self.register_buffer('N', torch.zeros(n_codes))
self.register_buffer('z_avg', self.embeddings.data.clone())
self.n_codes = n_codes
self.embedding_dim = embedding_dim
self._need_init = True
self.no_random_restart = no_random_restart
self.restart_thres = restart_thres
def _tile(self, x):
d, ew = x.shape
if d < self.n_codes:
n_repeats = (self.n_codes + d - 1) // d
std = 0.01 / np.sqrt(ew)
x = x.repeat(n_repeats, 1)
x = x + torch.randn_like(x) * std
return x
def _init_embeddings(self, z):
# z: [b, c, t, h, w]
self._need_init = False
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
y = self._tile(flat_inputs)
d = y.shape[0]
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
if dist.is_initialized():
dist.broadcast(_k_rand, 0)
self.embeddings.data.copy_(_k_rand)
self.z_avg.data.copy_(_k_rand)
self.N.data.copy_(torch.ones(self.n_codes))
def forward(self, z, gpu_id):
# z: [b, c, t, h, w]
if self._need_init and self.training:
self._init_embeddings(z)
flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c]
distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \
- 2 * flat_inputs @ self.embeddings.t() \
+ (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c]
encoding_indices = torch.argmin(distances, dim=1)
# print(subprocess.check_output(command).decode("utf-8"))
# memory optimization
del distances
# encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs) # [bthw, ncode]
encode_onehot = torch.empty(0, self.n_codes).to(gpu_id)
for ts in torch.tensor_split(encoding_indices, 4):
encode_onehot = torch.cat((encode_onehot, F.one_hot(ts, self.n_codes).type_as(flat_inputs)), 0)
torch.cuda.empty_cache()
encoding_indices = encoding_indices.view(
z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode]
embeddings = F.embedding(
encoding_indices, self.embeddings) # [b, t, h, w, c]
embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w]
commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
# EMA codebook update
if self.training:
n_total = encode_onehot.sum(dim=0)
encode_sum = flat_inputs.t() @ encode_onehot
if dist.is_initialized():
dist.all_reduce(n_total)
dist.all_reduce(encode_sum)
self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
n = self.N.sum()
weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
encode_normalized = self.z_avg / weights.unsqueeze(1)
self.embeddings.data.copy_(encode_normalized)
y = self._tile(flat_inputs)
_k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
if dist.is_initialized():
dist.broadcast(_k_rand, 0)
if not self.no_random_restart:
usage = (self.N.view(self.n_codes, 1)
>= self.restart_thres).float()
self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
del encode_sum
del n_total
embeddings_st = (embeddings - z).detach() + z
#embeddings_st_exp = torch.exp(embeddings_st)
avg_probs = torch.mean(encode_onehot, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs *
torch.log(avg_probs + 1e-10)))
# memory optimization
del encode_onehot
torch.cuda.empty_cache()
return dict(embeddings=embeddings_st, encodings=encoding_indices,
commitment_loss=commitment_loss, perplexity=perplexity)
def dictionary_lookup(self, encodings):
embeddings = F.embedding(encodings, self.embeddings)
return embeddings
|