Spaces:
Running
on
L40S
Running
on
L40S
File size: 8,065 Bytes
258fd02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# https://github.com/bshall/VectorQuantizedVAE/blob/master/model.py
class VQEmbeddingEMA(nn.Module):
def __init__(self, nband, num_code, code_dim, decay=0.99, layer=0):
super(VQEmbeddingEMA, self).__init__()
self.nband = nband
self.num_code = num_code
self.code_dim = code_dim
self.decay = decay
self.layer = layer
self.stale_tolerance = 50
self.eps = torch.finfo(torch.float32).eps
if layer == 0:
embedding = torch.empty(nband, num_code, code_dim).normal_()
embedding = embedding / (embedding.pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1) # TODO
else:
embedding = torch.empty(nband, num_code, code_dim).normal_() / code_dim
embedding[:,0] = embedding[:,0] * 0 # TODO
self.register_buffer("embedding", embedding)
self.register_buffer("ema_weight", self.embedding.clone())
self.register_buffer("ema_count", torch.zeros(self.nband, self.num_code))
self.register_buffer("stale_counter", torch.zeros(nband, self.num_code))
def forward(self, input):
num_valid_bands = 1
B, C, N, T = input.shape
assert N == self.code_dim
assert C == num_valid_bands
input_detach = input.detach().permute(0,3,1,2).contiguous().view(B*T, num_valid_bands, self.code_dim) # B*T, nband, dim
embedding = self.embedding[:num_valid_bands,:,:].contiguous()
# distance
eu_dis = input_detach.pow(2).sum(2).unsqueeze(2) + embedding.pow(2).sum(2).unsqueeze(0) # B*T, nband, num_code
eu_dis = eu_dis - 2 * torch.stack([input_detach[:,i].mm(embedding[i].T) for i in range(num_valid_bands)], 1) # B*T, nband, num_code
# best codes
indices = torch.argmin(eu_dis, dim=-1) # B*T, nband
quantized = []
for i in range(num_valid_bands):
quantized.append(torch.gather(embedding[i], 0, indices[:,i].unsqueeze(-1).expand(-1, self.code_dim))) # B*T, dim
quantized = torch.stack(quantized, 1)
quantized = quantized.view(B, T, C, N).permute(0,2,3,1).contiguous() # B, C, N, T
# calculate perplexity
encodings = F.one_hot(indices, self.num_code).float() # B*T, nband, num_code
avg_probs = encodings.mean(0) # nband, num_code
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self.eps), -1)).mean()
if self.training:
# EMA update for codebook
self.ema_count[:num_valid_bands] = self.decay * self.ema_count[:num_valid_bands] + (1 - self.decay) * torch.sum(encodings, dim=0) # nband, num_code
update_direction = encodings.permute(1,2,0).bmm(input_detach.permute(1,0,2)) # nband, num_code, dim
self.ema_weight[:num_valid_bands] = self.decay * self.ema_weight[:num_valid_bands] + (1 - self.decay) * update_direction # nband, num_code, dim
# Laplace smoothing on the counters
# make sure the denominator will never be zero
n = torch.sum(self.ema_count[:num_valid_bands], dim=-1, keepdim=True) # nband, 1
self.ema_count[:num_valid_bands] = (self.ema_count[:num_valid_bands] + self.eps) / (n + self.num_code * self.eps) * n # nband, num_code
self.embedding[:num_valid_bands] = self.ema_weight[:num_valid_bands] / self.ema_count[:num_valid_bands].unsqueeze(-1)
# calculate code usage
stale_codes = (encodings.sum(0) == 0).float() # nband, num_code
self.stale_counter[:num_valid_bands] = self.stale_counter[:num_valid_bands] * stale_codes + stale_codes
print("Lyaer {}, Ratio of unused vector : {}, {:.1f}, {:.1f}".format(self.layer, encodings.sum(), stale_codes.sum()/torch.numel(stale_codes)*100., (self.stale_counter > self.stale_tolerance //2).sum() /torch.numel(self.stale_counter)*100.))
# random replace codes that haven't been used for a while
replace_code = (self.stale_counter[:num_valid_bands] == self.stale_tolerance).float() # nband, num_code
if replace_code.sum(-1).max() > 0:
random_input_idx = torch.randperm(input_detach.shape[0])
random_input = input_detach[random_input_idx].view(input_detach.shape)
if random_input.shape[0] < self.num_code:
random_input = torch.cat([random_input]*(self.num_code // random_input.shape[0] + 1), 0)
random_input = random_input[:self.num_code,:].contiguous().transpose(0,1) # nband, num_code, dim
self.embedding[:num_valid_bands] = self.embedding[:num_valid_bands] * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
self.ema_weight[:num_valid_bands] = self.ema_weight[:num_valid_bands] * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
self.ema_count[:num_valid_bands] = self.ema_count[:num_valid_bands] * (1 - replace_code)
self.stale_counter[:num_valid_bands] = self.stale_counter[:num_valid_bands] * (1 - replace_code)
# TODO:
# code constraints
if self.layer == 0:
self.embedding[:num_valid_bands] = self.embedding[:num_valid_bands] / (self.embedding[:num_valid_bands].pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1)
# else:
# # make sure there is always a zero code
# self.embedding[:,0] = self.embedding[:,0] * 0
# self.ema_weight[:,0] = self.ema_weight[:,0] * 0
return quantized, indices.reshape(B, T, -1), perplexity
class RVQEmbedding(nn.Module):
def __init__(self, nband, code_dim, decay=0.99, num_codes=[1024, 1024]):
super(RVQEmbedding, self).__init__()
self.nband = nband
self.code_dim = code_dim
self.decay = decay
self.eps = torch.finfo(torch.float32).eps
self.min_max = [10000, -10000]
self.bins = [256+i*8 for i in range(32)]
self.VQEmbedding = nn.ModuleList([])
for i in range(len(num_codes)):
self.VQEmbedding.append(VQEmbeddingEMA(nband, num_codes[i], code_dim, decay, layer=i))
def forward(self, input):
norm_value = torch.norm(input, p=2, dim=-2) # b c t
if(norm_value.min()<self.min_max[0]):self.min_max[0]=norm_value.min().cpu().item()
if(norm_value.max()>self.min_max[-1]):self.min_max[-1]=norm_value.max().cpu().item()
print("Min-max : {}".format(self.min_max))
norm_value = (((norm_value - 256) / 20).clamp(min=0, max=7).int() * 20 + 256 + 10).float()
print("Min-max : {}, {}".format(norm_value.min(), norm_value.max()))
input = torch.nn.functional.normalize(input, p = 2, dim = -2)
quantized_list = []
perplexity_list = []
indices_list = []
c = []
residual_input = input
for i in range(len(self.VQEmbedding)):
this_quantized, this_indices, this_perplexity = self.VQEmbedding[i](residual_input)
perplexity_list.append(this_perplexity)
indices_list.append(this_indices)
residual_input = residual_input - this_quantized
if i == 0:
quantized_list.append(this_quantized)
else:
quantized_list.append(quantized_list[-1] + this_quantized)
quantized_list = torch.stack(quantized_list, -1) # b,1,1024,768,1
perplexity_list = torch.stack(perplexity_list, -1)
indices_list = torch.stack(indices_list, -1) # B T 1 codebooknum
latent_loss = 0
for i in range(quantized_list.shape[-1]):
latent_loss = latent_loss + F.mse_loss(input, quantized_list.detach()[:,:,:,:,i])
# TODO: remove unit norm
quantized_list = quantized_list / (quantized_list.pow(2).sum(2) + self.eps).sqrt().unsqueeze(2) # unit norm
return quantized_list, norm_value, indices_list, latent_loss
|