Pijush2023 commited on
Commit
2288ea6
·
verified ·
1 Parent(s): 3d83510

Delete CHATTS/model/dvae.py

Browse files
Files changed (1) hide show
  1. CHATTS/model/dvae.py +0 -155
CHATTS/model/dvae.py DELETED
@@ -1,155 +0,0 @@
1
- import math
2
- from einops import rearrange
3
- from vector_quantize_pytorch import GroupedResidualFSQ
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
- class ConvNeXtBlock(nn.Module):
10
- def __init__(
11
- self,
12
- dim: int,
13
- intermediate_dim: int,
14
- kernel, dilation,
15
- layer_scale_init_value: float = 1e-6,
16
- ):
17
- # ConvNeXt Block copied from Vocos.
18
- super().__init__()
19
- self.dwconv = nn.Conv1d(dim, dim,
20
- kernel_size=kernel, padding=dilation*(kernel//2),
21
- dilation=dilation, groups=dim
22
- ) # depthwise conv
23
-
24
- self.norm = nn.LayerNorm(dim, eps=1e-6)
25
- self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
26
- self.act = nn.GELU()
27
- self.pwconv2 = nn.Linear(intermediate_dim, dim)
28
- self.gamma = (
29
- nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
30
- if layer_scale_init_value > 0
31
- else None
32
- )
33
-
34
- def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor:
35
- residual = x
36
- x = self.dwconv(x)
37
- x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
38
- x = self.norm(x)
39
- x = self.pwconv1(x)
40
- x = self.act(x)
41
- x = self.pwconv2(x)
42
- if self.gamma is not None:
43
- x = self.gamma * x
44
- x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
45
-
46
- x = residual + x
47
- return x
48
-
49
-
50
-
51
- class GFSQ(nn.Module):
52
-
53
- def __init__(self,
54
- dim, levels, G, R, eps=1e-5, transpose = True
55
- ):
56
- super(GFSQ, self).__init__()
57
- self.quantizer = GroupedResidualFSQ(
58
- dim=dim,
59
- levels=levels,
60
- num_quantizers=R,
61
- groups=G,
62
- )
63
- self.n_ind = math.prod(levels)
64
- self.eps = eps
65
- self.transpose = transpose
66
- self.G = G
67
- self.R = R
68
-
69
- def _embed(self, x):
70
- if self.transpose:
71
- x = x.transpose(1,2)
72
- x = rearrange(
73
- x, "b t (g r) -> g b t r", g = self.G, r = self.R,
74
- )
75
- feat = self.quantizer.get_output_from_indices(x)
76
- return feat.transpose(1,2) if self.transpose else feat
77
-
78
- def forward(self, x,):
79
- if self.transpose:
80
- x = x.transpose(1,2)
81
- feat, ind = self.quantizer(x)
82
- ind = rearrange(
83
- ind, "g b t r ->b t (g r)",
84
- )
85
- embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
86
- e_mean = torch.mean(embed_onehot, dim=[0,1])
87
- e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
88
- perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
89
-
90
- return (
91
- torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
92
- feat.transpose(1,2) if self.transpose else feat,
93
- perplexity,
94
- None,
95
- ind.transpose(1,2) if self.transpose else ind,
96
- )
97
-
98
- class DVAEDecoder(nn.Module):
99
- def __init__(self, idim, odim,
100
- n_layer = 12, bn_dim = 64, hidden = 256,
101
- kernel = 7, dilation = 2, up = False
102
- ):
103
- super().__init__()
104
- self.up = up
105
- self.conv_in = nn.Sequential(
106
- nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(),
107
- nn.Conv1d(bn_dim, hidden, 3, 1, 1)
108
- )
109
- self.decoder_block = nn.ModuleList([
110
- ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
111
- for _ in range(n_layer)])
112
- self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
113
-
114
- def forward(self, input, conditioning=None):
115
- # B, T, C
116
- x = input.transpose(1, 2)
117
- x = self.conv_in(x)
118
- for f in self.decoder_block:
119
- x = f(x, conditioning)
120
-
121
- x = self.conv_out(x)
122
- return x.transpose(1, 2)
123
-
124
-
125
- class DVAE(nn.Module):
126
- def __init__(
127
- self, decoder_config, vq_config, dim=512
128
- ):
129
- super().__init__()
130
- self.register_buffer('coef', torch.randn(1, 100, 1))
131
-
132
- self.decoder = DVAEDecoder(**decoder_config)
133
- self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
134
- if vq_config is not None:
135
- self.vq_layer = GFSQ(**vq_config)
136
- else:
137
- self.vq_layer = None
138
-
139
- def forward(self, inp):
140
-
141
- if self.vq_layer is not None:
142
- vq_feats = self.vq_layer._embed(inp)
143
- else:
144
- vq_feats = inp.detach().clone()
145
-
146
- temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :)
147
- temp = torch.stack(temp, -1)
148
- vq_feats = temp.reshape(*temp.shape[:2], -1)
149
-
150
- vq_feats = vq_feats.transpose(1, 2)
151
- dec_out = self.decoder(input=vq_feats)
152
- dec_out = self.out_conv(dec_out.transpose(1, 2))
153
- mel = dec_out * self.coef
154
-
155
- return mel