File size: 9,088 Bytes
e34aada |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import einsum
from einops import rearrange
import torch.distributed as dist
from utils.commons.hparams import hparams
class ClusteringVectorQuantiser(nn.Module):
"""
Improved version over vector quantiser, with the dynamic initialisation
for these unoptimised "dead" points.
num_embed: number of codebook entry
embed_dim: dimensionality of codebook entry
beta: weight for the commitment loss
distance: distance for looking up the closest code
anchor: anchor sampled methods
first_batch: if true, the offline version of our model
contras_loss: if true, use the contras_loss to further improve the performance
"""
def __init__(self, num_embed=1024, embed_dim=512, beta=0.25, distance='l2',
anchor='closest', first_batch=False, contras_loss=True):
super().__init__()
self.num_embed = num_embed
self.embed_dim = embed_dim
self.beta = beta
self.distance = distance
self.anchor = anchor
self.first_batch = first_batch
self.contras_loss = contras_loss
self.decay = 0.99
self.init = False
self.pool = FeaturePool(self.num_embed, self.embed_dim)
self.embedding = nn.Embedding(self.num_embed, self.embed_dim)
self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed)
self.register_buffer("embed_prob", torch.zeros(self.num_embed))
def forward(self, z, mask=None, temp=None, rescale_logits=False, return_logits=False):
if mask is not None:
assert mask.shape[:2] == z.shape[:2], (mask.shape, z.shape)
assert mask.shape[-1] == 1, (mask.shape,)
z = z * mask
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
assert rescale_logits == False, "Only for interface compatible with Gumbel"
assert return_logits == False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
# z = rearrange(z, 'b c h w -> b h w c').contiguous()
assert z.shape[-1] == self.embed_dim
z_flattened = z.view(-1, self.embed_dim)
# clculate the distance
if self.distance == 'l2':
# l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = - torch.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - \
torch.sum(self.embedding.weight ** 2, dim=1) + \
2 * torch.einsum('bd, dn-> bn', z_flattened.detach(), rearrange(self.embedding.weight, 'n d-> d n'))
elif self.distance == 'cos':
# cosine distances from z to embeddings e_j
normed_z_flattened = F.normalize(z_flattened, dim=1).detach()
normed_codebook = F.normalize(self.embedding.weight, dim=1)
d = torch.einsum('bd,dn->bn', normed_z_flattened, rearrange(normed_codebook, 'n d -> d n'))
# encoding
sort_distance, indices = d.sort(dim=1)
# look up the closest point for the indices
encoding_indices = indices[:,-1]
encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device)
encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
# quantise and unflatten
z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = self.beta * (z_q.detach() - z) ** 2 + (z_q - z.detach()) ** 2
if mask is not None:
loss = (loss * mask).sum() / mask.sum() / self.embed_dim
else:
loss = loss.mean()
# loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
# z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
# count
# import pdb
# pdb.set_trace()
avg_probs = torch.mean(encodings, dim=0)
# perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# min_encodings = encodings
# online clustered reinitialisation for unoptimized points
if self.training:
# calculate the average usage of code entries
self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay)
# running average updates
if self.anchor in ['closest', 'random', 'probrandom'] and (not self.init):
# closest sampling
if self.anchor == 'closest':
sort_distance, indices = d.sort(dim=0)
random_feat = z_flattened.detach()[indices[-1,:]]
# feature pool based random sampling
elif self.anchor == 'random':
random_feat = self.pool.query(z_flattened.detach())
# probabilitical based random sampling
elif self.anchor == 'probrandom':
norm_distance = F.softmax(d.t(), dim=1)
prob = torch.multinomial(norm_distance, num_samples=1).view(-1)
random_feat = z_flattened.detach()[prob]
# decay parameter based on the average usage
decay = torch.exp(-(self.embed_prob*self.num_embed*10)/(1-self.decay)-1e-3).unsqueeze(1).repeat(1, self.embed_dim)
if hparams.get('reduce_cvq_embed') and dist.is_initialized():
# 确保在所有GPU上同步embedding的权重
dist.all_reduce(random_feat.data, op=dist.ReduceOp.SUM)
random_feat.data /= dist.get_world_size()
self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
if self.first_batch:
self.init = True
# contrastive loss
if self.contras_loss:
sort_distance, indices = d.sort(dim=0)
dis_pos = sort_distance[-max(1, int(sort_distance.size(0)/self.num_embed)):,:].mean(dim=0, keepdim=True)
dis_neg = sort_distance[:int(sort_distance.size(0)*1/2),:]
dis = torch.cat([dis_pos, dis_neg], dim=0).t() / 0.07
contra_loss = F.cross_entropy(dis, torch.zeros((dis.size(0),), dtype=torch.long, device=dis.device))
loss += contra_loss
encoding_indices = encoding_indices.reshape(z.shape[:-1])
return z_q, loss, encoding_indices
def get_codebook_entry(self, encoding_indices):
# # get quantized latent vectors
# print(encoding_indices.shape)
# encoding_indices = encoding_indices.view(-1)
# encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=encoding_indices.device)
# print(encodings.shape)
# encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
# print(encodings.shape)
# # quantise and unflatten
# z_q = torch.matmul(encodings, self.embedding.weight).view(encoding_indices.shape[0], -1)
z_q = self.embedding(encoding_indices)
return z_q
class FeaturePool():
"""
This class implements a feature buffer that stores previously encoded features
This buffer enables us to initialize the codebook using a history of generated features
rather than the ones produced by the latest encoders
"""
def __init__(self, pool_size, dim=64):
"""
Initialize the FeaturePool class
Parameters:
pool_size(int) -- the size of featue buffer
"""
self.pool_size = pool_size
if self.pool_size > 0:
self.nums_features = 0
self.features = (torch.rand((pool_size, dim)) * 2 - 1)/ pool_size
def query(self, features):
"""
return features from the pool
"""
self.features = self.features.to(features.device)
if self.nums_features < self.pool_size:
if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook
random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
self.features = features[random_feat_id]
self.nums_features = self.pool_size
else:
# if the mini-batch is not large nuough, just store it for the next update
num = self.nums_features + features.size(0)
self.features[self.nums_features:num] = features
self.nums_features = num
else:
if features.size(0) > int(self.pool_size):
random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
self.features = features[random_feat_id]
else:
random_id = torch.randperm(self.pool_size)
self.features[random_id[:features.size(0)]] = features
return self.features |