File size: 8,731 Bytes
e34aada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import scipy
from scipy import linalg
from torch.nn import functional as F
import torch
from torch import nn
import numpy as np
from modules.audio2motion.transformer_models import FFTBlocks
import modules.audio2motion.utils as utils
from modules.audio2motion.flow_base import Glow, WN, ResidualCouplingBlock
import torch.distributions as dist
from modules.audio2motion.cnn_models import LambdaLayer, LayerNorm

from vector_quantize_pytorch import VectorQuantize


class FVAEEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, latent_channels, kernel_size,
                 n_layers, gin_channels=0, p_dropout=0, strides=[4]):
        super().__init__()
        self.strides = strides
        self.hidden_size = hidden_channels
        self.pre_net = nn.Sequential(*[
            nn.Conv1d(in_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
            if i == 0 else
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
            for i, s in enumerate(strides)
        ])
        self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
        self.out_proj = nn.Conv1d(hidden_channels, latent_channels * 2, 1)
        self.latent_channels = latent_channels

    def forward(self, x, x_mask, g):
        x = self.pre_net(x)
        x_mask = x_mask[:, :, ::np.prod(self.strides)][:, :, :x.shape[-1]]
        x = x * x_mask
        x = self.wn(x, x_mask, g) * x_mask
        x = self.out_proj(x)
        m, logs = torch.split(x, self.latent_channels, dim=1)
        z = (m + torch.randn_like(m) * torch.exp(logs))
        return z, m, logs, x_mask


class FVAEDecoder(nn.Module):
    def __init__(self, latent_channels, hidden_channels, out_channels, kernel_size,
                 n_layers, gin_channels=0, p_dropout=0,
                 strides=[4]):
        super().__init__()
        self.strides = strides
        self.hidden_size = hidden_channels
        self.pre_net = nn.Sequential(*[
            nn.ConvTranspose1d(latent_channels, hidden_channels, kernel_size=s, stride=s)
            if i == 0 else
            nn.ConvTranspose1d(hidden_channels, hidden_channels, kernel_size=s, stride=s)
            for i, s in enumerate(strides)
        ])
        self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
        self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)

    def forward(self, x, x_mask, g):
        x = self.pre_net(x)
        x = x * x_mask
        x = self.wn(x, x_mask, g) * x_mask
        x = self.out_proj(x)
        return x


class VQVAE(nn.Module):
    def __init__(self,
                 in_out_channels=64, hidden_channels=256, latent_size=16,
                 kernel_size=3, enc_n_layers=5, dec_n_layers=5, gin_channels=80, strides=[4,],
                 sqz_prior=False):
        super().__init__()
        self.in_out_channels = in_out_channels
        self.strides = strides
        self.hidden_size = hidden_channels
        self.latent_size = latent_size
        self.g_pre_net = nn.Sequential(*[
            nn.Conv1d(gin_channels, gin_channels, kernel_size=s * 2, stride=s, padding=s // 2)
            for i, s in enumerate(strides)
        ])
        self.encoder = FVAEEncoder(in_out_channels, hidden_channels, hidden_channels, kernel_size,
                                   enc_n_layers, gin_channels, strides=strides)
        # if use_prior_glow:
        #     self.prior_flow = ResidualCouplingBlock(
        #         latent_size, glow_hidden, glow_kernel_size, 1, glow_n_blocks, 4, gin_channels=gin_channels)
        self.vq = VectorQuantize(dim=hidden_channels, codebook_size=256, codebook_dim=16)

        self.decoder = FVAEDecoder(hidden_channels, hidden_channels, in_out_channels, kernel_size,
                                   dec_n_layers, gin_channels, strides=strides)
        self.prior_dist = dist.Normal(0, 1)
        self.sqz_prior = sqz_prior

    def forward(self, x=None, x_mask=None, g=None, infer=False, **kwargs):
        """

        :param x: [B, T,  C_in_out]
        :param x_mask: [B, T]
        :param g: [B, T, C_g]
        :return:
        """
        x_mask = x_mask[:, None, :] # [B, 1, T]
        g = g.transpose(1,2) # [B, C_g, T]
        g_for_sqz = g

        g_sqz = self.g_pre_net(g_for_sqz)

        if not infer:
            x = x.transpose(1,2) # [B, C, T]
            z_q, m_q, logs_q, x_mask_sqz = self.encoder(x, x_mask, g_sqz)
            if self.sqz_prior:
                z_q = F.interpolate(z_q, scale_factor=1/8)
            z_p, idx, commit_loss = self.vq(z_q.transpose(1,2))
            if self.sqz_prior:
                z_p = F.interpolate(z_p.transpose(1,2),scale_factor=8).transpose(1,2)

            x_recon = self.decoder(z_p.transpose(1,2), x_mask, g)
            return x_recon.transpose(1,2), commit_loss, z_p.transpose(1,2), m_q.transpose(1,2), logs_q.transpose(1,2)
        else:
            bs, t = g_sqz.shape[0], g_sqz.shape[2]
            if self.sqz_prior:
                t = t // 8
            latent_shape = [int(bs * t)]
            latent_idx = torch.randint(0,256,latent_shape).to(self.vq.codebook.device)
            # latent_idx = torch.ones_like(latent_idx, dtype=torch.long)
            # z_p = torch.gather(self.vq.codebook, 0, latent_idx)# self.vq.codebook[latent_idx]
            z_p = self.vq.codebook[latent_idx]
            z_p = z_p.reshape([bs, t, -1])
            z_p = self.vq.project_out(z_p)
            if self.sqz_prior:
                z_p = F.interpolate(z_p.transpose(1,2),scale_factor=8).transpose(1,2)

            x_recon = self.decoder(z_p.transpose(1,2), 1, g)
            return x_recon.transpose(1,2), z_p.transpose(1,2)


class VQVAEModel(nn.Module):
    def __init__(self, in_out_dim=71, sqz_prior=False, enc_no_cond=False):
        super().__init__()
        self.mel_encoder = nn.Sequential(*[
                nn.Conv1d(80, 64, 3, 1, 1, bias=False),
                nn.BatchNorm1d(64),
                nn.GELU(),
                nn.Conv1d(64, 64, 3, 1, 1, bias=False)
            ]) 
        self.in_dim, self.out_dim = in_out_dim, in_out_dim
        self.sqz_prior = sqz_prior
        self.enc_no_cond = enc_no_cond
        self.vae = VQVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5,
            enc_n_layers=8, dec_n_layers=4, gin_channels=64, strides=[4,], sqz_prior=sqz_prior)
        self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
    
    @property
    def device(self):
        return self.vae.parameters().__next__().device

    def forward(self, batch, ret, log_dict=None, train=True):
        infer = not train
        mask = batch['y_mask'].to(self.device)
        mel = batch['mel'].to(self.device)
        mel = self.downsampler(mel)

        mel_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
        if not infer:
            exp = batch['exp'].to(self.device)
            pose = batch['pose'].to(self.device)
            if self.in_dim == 71:
                x = torch.cat([exp, pose], dim=-1) # [B, T, C=64 + 7]
            elif self.in_dim == 64:
                x = exp
            elif self.in_dim == 7:
                x = pose
            if self.enc_no_cond:
                x_recon, loss_commit, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=torch.zeros_like(mel_feat), infer=False)
            else:
                x_recon, loss_commit, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=mel_feat, infer=False)
            loss_commit = loss_commit.reshape([])
            ret['pred'] = x_recon
            ret['mask'] = mask
            ret['loss_commit'] = loss_commit
            return x_recon, loss_commit, m_q, logs_q
        else:
            x_recon, z_p = self.vae(x=None, x_mask=mask, g=mel_feat, infer=True)
            return x_recon

    # def __get_feat(self, exp, pose):
    # diff_exp = exp[:-1, :] - exp[1:, :]
    # exp_std = (np.std(exp, axis = 0) - self.exp_std_mean) / self.exp_std_std
    # diff_exp_std = (np.std(diff_exp, axis = 0) - self.exp_diff_std_mean) / self.exp_diff_std_std

    # diff_pose = pose[:-1, :] - pose[1:, :]
    # diff_pose_std = (np.std(diff_pose, axis = 0) - self.pose_diff_std_mean) / self.pose_diff_std_std

    # return np.concatenate((exp_std, diff_exp_std, diff_pose_std))
    
    def num_params(self, model, print_out=True, model_name="model"):
        parameters = filter(lambda p: p.requires_grad, model.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        if print_out:
            print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
        return parameters