File size: 8,065 Bytes
258fd02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# https://github.com/bshall/VectorQuantizedVAE/blob/master/model.py
class VQEmbeddingEMA(nn.Module):
    def __init__(self, nband, num_code, code_dim, decay=0.99, layer=0):
        super(VQEmbeddingEMA, self).__init__()

        self.nband = nband
        self.num_code = num_code
        self.code_dim = code_dim
        self.decay = decay
        self.layer = layer
        self.stale_tolerance = 50
        self.eps = torch.finfo(torch.float32).eps

        if layer == 0:
            embedding = torch.empty(nband, num_code, code_dim).normal_()
            embedding = embedding / (embedding.pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1) # TODO
        else:
            embedding = torch.empty(nband, num_code, code_dim).normal_() / code_dim
            embedding[:,0] = embedding[:,0] * 0 # TODO
        self.register_buffer("embedding", embedding)
        self.register_buffer("ema_weight", self.embedding.clone())
        self.register_buffer("ema_count", torch.zeros(self.nband, self.num_code))
        self.register_buffer("stale_counter", torch.zeros(nband, self.num_code))

    def forward(self, input):
        num_valid_bands = 1
        B, C, N, T = input.shape
        assert N == self.code_dim
        assert C == num_valid_bands

        input_detach = input.detach().permute(0,3,1,2).contiguous().view(B*T, num_valid_bands, self.code_dim)  # B*T, nband, dim
        embedding = self.embedding[:num_valid_bands,:,:].contiguous()
        # distance
        eu_dis = input_detach.pow(2).sum(2).unsqueeze(2) + embedding.pow(2).sum(2).unsqueeze(0)  # B*T, nband, num_code
        eu_dis = eu_dis - 2 * torch.stack([input_detach[:,i].mm(embedding[i].T) for i in range(num_valid_bands)], 1)  # B*T, nband, num_code

        # best codes
        indices = torch.argmin(eu_dis, dim=-1)  # B*T, nband
        quantized = []
        for i in range(num_valid_bands):
            quantized.append(torch.gather(embedding[i], 0, indices[:,i].unsqueeze(-1).expand(-1, self.code_dim)))  # B*T, dim
        quantized = torch.stack(quantized, 1)
        quantized = quantized.view(B, T, C, N).permute(0,2,3,1).contiguous()  # B, C, N, T

        # calculate perplexity
        encodings = F.one_hot(indices, self.num_code).float()  # B*T, nband, num_code
        avg_probs = encodings.mean(0)  # nband, num_code
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self.eps), -1)).mean()

        if self.training:
            # EMA update for codebook
            
            self.ema_count[:num_valid_bands] = self.decay * self.ema_count[:num_valid_bands] + (1 - self.decay) * torch.sum(encodings, dim=0)  # nband, num_code

            update_direction = encodings.permute(1,2,0).bmm(input_detach.permute(1,0,2))  # nband, num_code, dim
            self.ema_weight[:num_valid_bands] = self.decay * self.ema_weight[:num_valid_bands] + (1 - self.decay) * update_direction  # nband, num_code, dim

            # Laplace smoothing on the counters
            # make sure the denominator will never be zero
            n = torch.sum(self.ema_count[:num_valid_bands], dim=-1, keepdim=True)  # nband, 1
            self.ema_count[:num_valid_bands] = (self.ema_count[:num_valid_bands] + self.eps) / (n + self.num_code * self.eps) * n  # nband, num_code

            self.embedding[:num_valid_bands] = self.ema_weight[:num_valid_bands] / self.ema_count[:num_valid_bands].unsqueeze(-1)

            # calculate code usage
            stale_codes = (encodings.sum(0) == 0).float()  # nband, num_code
            self.stale_counter[:num_valid_bands] = self.stale_counter[:num_valid_bands] * stale_codes + stale_codes
            print("Lyaer {}, Ratio of unused vector : {}, {:.1f}, {:.1f}".format(self.layer, encodings.sum(), stale_codes.sum()/torch.numel(stale_codes)*100., (self.stale_counter > self.stale_tolerance //2).sum() /torch.numel(self.stale_counter)*100.))

            # random replace codes that haven't been used for a while
            replace_code = (self.stale_counter[:num_valid_bands] == self.stale_tolerance).float() # nband, num_code
            if replace_code.sum(-1).max() > 0:
                random_input_idx = torch.randperm(input_detach.shape[0])
                random_input = input_detach[random_input_idx].view(input_detach.shape)
                if random_input.shape[0] < self.num_code:
                    random_input = torch.cat([random_input]*(self.num_code // random_input.shape[0] + 1), 0)
                random_input = random_input[:self.num_code,:].contiguous().transpose(0,1)  # nband, num_code, dim

                self.embedding[:num_valid_bands] = self.embedding[:num_valid_bands] * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
                self.ema_weight[:num_valid_bands] = self.ema_weight[:num_valid_bands] * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
                self.ema_count[:num_valid_bands] = self.ema_count[:num_valid_bands] * (1 - replace_code)
                self.stale_counter[:num_valid_bands] = self.stale_counter[:num_valid_bands] * (1 - replace_code)

            # TODO:
            # code constraints
            if self.layer == 0:
                self.embedding[:num_valid_bands] = self.embedding[:num_valid_bands] / (self.embedding[:num_valid_bands].pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1)
            # else:
            #     # make sure there is always a zero code
            #     self.embedding[:,0] = self.embedding[:,0] * 0
            #     self.ema_weight[:,0] = self.ema_weight[:,0] * 0

        return quantized, indices.reshape(B, T, -1), perplexity

class RVQEmbedding(nn.Module):
    def __init__(self, nband, code_dim, decay=0.99, num_codes=[1024, 1024]):
        super(RVQEmbedding, self).__init__()

        self.nband = nband
        self.code_dim = code_dim
        self.decay = decay
        self.eps = torch.finfo(torch.float32).eps
        self.min_max = [10000, -10000]
        self.bins = [256+i*8 for i in range(32)]

        self.VQEmbedding = nn.ModuleList([])
        for i in range(len(num_codes)):
            self.VQEmbedding.append(VQEmbeddingEMA(nband, num_codes[i], code_dim, decay, layer=i))

    def forward(self, input):
        norm_value = torch.norm(input, p=2, dim=-2) # b c t
        if(norm_value.min()<self.min_max[0]):self.min_max[0]=norm_value.min().cpu().item()
        if(norm_value.max()>self.min_max[-1]):self.min_max[-1]=norm_value.max().cpu().item()
        print("Min-max : {}".format(self.min_max))
        norm_value = (((norm_value - 256) / 20).clamp(min=0, max=7).int() * 20 + 256 + 10).float()
        print("Min-max : {}, {}".format(norm_value.min(), norm_value.max()))
        input = torch.nn.functional.normalize(input, p = 2, dim = -2)

        quantized_list = []
        perplexity_list = []
        indices_list = []
        c = []

        residual_input = input
        for i in range(len(self.VQEmbedding)):
            this_quantized, this_indices, this_perplexity = self.VQEmbedding[i](residual_input)
            perplexity_list.append(this_perplexity)
            indices_list.append(this_indices)
            residual_input = residual_input - this_quantized
            if i == 0:
                quantized_list.append(this_quantized)
            else:
                quantized_list.append(quantized_list[-1] + this_quantized)
        
        quantized_list = torch.stack(quantized_list, -1) # b,1,1024,768,1
        perplexity_list = torch.stack(perplexity_list, -1)
        indices_list = torch.stack(indices_list, -1) # B T 1 codebooknum
        latent_loss = 0
        for i in range(quantized_list.shape[-1]):
            latent_loss = latent_loss + F.mse_loss(input, quantized_list.detach()[:,:,:,:,i])
        # TODO: remove unit norm 
        quantized_list = quantized_list / (quantized_list.pow(2).sum(2) + self.eps).sqrt().unsqueeze(2)  # unit norm

        return quantized_list, norm_value, indices_list, latent_loss