Spaces:
Runtime error
Runtime error
Delete CHATTS/model/dvae.py
Browse files- CHATTS/model/dvae.py +0 -155
CHATTS/model/dvae.py
DELETED
@@ -1,155 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from einops import rearrange
|
3 |
-
from vector_quantize_pytorch import GroupedResidualFSQ
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
-
|
9 |
-
class ConvNeXtBlock(nn.Module):
|
10 |
-
def __init__(
|
11 |
-
self,
|
12 |
-
dim: int,
|
13 |
-
intermediate_dim: int,
|
14 |
-
kernel, dilation,
|
15 |
-
layer_scale_init_value: float = 1e-6,
|
16 |
-
):
|
17 |
-
# ConvNeXt Block copied from Vocos.
|
18 |
-
super().__init__()
|
19 |
-
self.dwconv = nn.Conv1d(dim, dim,
|
20 |
-
kernel_size=kernel, padding=dilation*(kernel//2),
|
21 |
-
dilation=dilation, groups=dim
|
22 |
-
) # depthwise conv
|
23 |
-
|
24 |
-
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
25 |
-
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
26 |
-
self.act = nn.GELU()
|
27 |
-
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
28 |
-
self.gamma = (
|
29 |
-
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
30 |
-
if layer_scale_init_value > 0
|
31 |
-
else None
|
32 |
-
)
|
33 |
-
|
34 |
-
def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor:
|
35 |
-
residual = x
|
36 |
-
x = self.dwconv(x)
|
37 |
-
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
38 |
-
x = self.norm(x)
|
39 |
-
x = self.pwconv1(x)
|
40 |
-
x = self.act(x)
|
41 |
-
x = self.pwconv2(x)
|
42 |
-
if self.gamma is not None:
|
43 |
-
x = self.gamma * x
|
44 |
-
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
45 |
-
|
46 |
-
x = residual + x
|
47 |
-
return x
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
class GFSQ(nn.Module):
|
52 |
-
|
53 |
-
def __init__(self,
|
54 |
-
dim, levels, G, R, eps=1e-5, transpose = True
|
55 |
-
):
|
56 |
-
super(GFSQ, self).__init__()
|
57 |
-
self.quantizer = GroupedResidualFSQ(
|
58 |
-
dim=dim,
|
59 |
-
levels=levels,
|
60 |
-
num_quantizers=R,
|
61 |
-
groups=G,
|
62 |
-
)
|
63 |
-
self.n_ind = math.prod(levels)
|
64 |
-
self.eps = eps
|
65 |
-
self.transpose = transpose
|
66 |
-
self.G = G
|
67 |
-
self.R = R
|
68 |
-
|
69 |
-
def _embed(self, x):
|
70 |
-
if self.transpose:
|
71 |
-
x = x.transpose(1,2)
|
72 |
-
x = rearrange(
|
73 |
-
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
|
74 |
-
)
|
75 |
-
feat = self.quantizer.get_output_from_indices(x)
|
76 |
-
return feat.transpose(1,2) if self.transpose else feat
|
77 |
-
|
78 |
-
def forward(self, x,):
|
79 |
-
if self.transpose:
|
80 |
-
x = x.transpose(1,2)
|
81 |
-
feat, ind = self.quantizer(x)
|
82 |
-
ind = rearrange(
|
83 |
-
ind, "g b t r ->b t (g r)",
|
84 |
-
)
|
85 |
-
embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
|
86 |
-
e_mean = torch.mean(embed_onehot, dim=[0,1])
|
87 |
-
e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
|
88 |
-
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
|
89 |
-
|
90 |
-
return (
|
91 |
-
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
|
92 |
-
feat.transpose(1,2) if self.transpose else feat,
|
93 |
-
perplexity,
|
94 |
-
None,
|
95 |
-
ind.transpose(1,2) if self.transpose else ind,
|
96 |
-
)
|
97 |
-
|
98 |
-
class DVAEDecoder(nn.Module):
|
99 |
-
def __init__(self, idim, odim,
|
100 |
-
n_layer = 12, bn_dim = 64, hidden = 256,
|
101 |
-
kernel = 7, dilation = 2, up = False
|
102 |
-
):
|
103 |
-
super().__init__()
|
104 |
-
self.up = up
|
105 |
-
self.conv_in = nn.Sequential(
|
106 |
-
nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(),
|
107 |
-
nn.Conv1d(bn_dim, hidden, 3, 1, 1)
|
108 |
-
)
|
109 |
-
self.decoder_block = nn.ModuleList([
|
110 |
-
ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
|
111 |
-
for _ in range(n_layer)])
|
112 |
-
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
|
113 |
-
|
114 |
-
def forward(self, input, conditioning=None):
|
115 |
-
# B, T, C
|
116 |
-
x = input.transpose(1, 2)
|
117 |
-
x = self.conv_in(x)
|
118 |
-
for f in self.decoder_block:
|
119 |
-
x = f(x, conditioning)
|
120 |
-
|
121 |
-
x = self.conv_out(x)
|
122 |
-
return x.transpose(1, 2)
|
123 |
-
|
124 |
-
|
125 |
-
class DVAE(nn.Module):
|
126 |
-
def __init__(
|
127 |
-
self, decoder_config, vq_config, dim=512
|
128 |
-
):
|
129 |
-
super().__init__()
|
130 |
-
self.register_buffer('coef', torch.randn(1, 100, 1))
|
131 |
-
|
132 |
-
self.decoder = DVAEDecoder(**decoder_config)
|
133 |
-
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
|
134 |
-
if vq_config is not None:
|
135 |
-
self.vq_layer = GFSQ(**vq_config)
|
136 |
-
else:
|
137 |
-
self.vq_layer = None
|
138 |
-
|
139 |
-
def forward(self, inp):
|
140 |
-
|
141 |
-
if self.vq_layer is not None:
|
142 |
-
vq_feats = self.vq_layer._embed(inp)
|
143 |
-
else:
|
144 |
-
vq_feats = inp.detach().clone()
|
145 |
-
|
146 |
-
temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :)
|
147 |
-
temp = torch.stack(temp, -1)
|
148 |
-
vq_feats = temp.reshape(*temp.shape[:2], -1)
|
149 |
-
|
150 |
-
vq_feats = vq_feats.transpose(1, 2)
|
151 |
-
dec_out = self.decoder(input=vq_feats)
|
152 |
-
dec_out = self.out_conv(dec_out.transpose(1, 2))
|
153 |
-
mel = dec_out * self.coef
|
154 |
-
|
155 |
-
return mel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|